The apply_over_axes() method allows you to apply a function repeatedly over multiple axes.
Example
import numpy as np
# create a 3D array
arr = np.array([
    [[1, 2, 3],
     [4, 5, 6]],
    
    [[7, 8, 9],
     [10, 11, 12]]
])
# define a function to compute the column-wise sum
def col_sum(x, axis=0):
    # compute the sum along the specified axis
    return np.sum(x, axis=axis)
# apply col_sum over the first and third axes
result = np.apply_over_axes(col_sum, arr, axes=(0, 2))
print(result)
'''
Output:
[[[ 8]
  [10]
  [12]]
 [[14]
  [16]
  [18]]]
'''
apply_over_axes() Syntax
The syntax of apply_over_axes() is:
numpy.apply_over_axes(func, array, axis)
apply_over_axes() Arguments
The apply_over_axes() method takes the following arguments:
- func- the function to apply
- axis- the axis along which the functions are applied
- array- the input array
Note: The func should take two arguments, an input array and axis.
apply_over_axes() Return Value
The apply_over_axes() method returns the resultant array with functions applied.
Example 1: Apply a Function Along Multiple Axes
import numpy as np
# create a 3D array
arr = np.arange(8).reshape(2, 2, 2)
print('Original Array:\n', arr)
# sum the array on axes (0 and 1)
# adds the elements with same value at axis = 2
result = np.apply_over_axes(np.sum, arr, axes=(0, 1))
print('Sum along axes (0, 1):\n',result)
# sum the array on  axes (0 and 2)
# adds the elements with same value at axis = 1
result = np.apply_over_axes(np.sum, arr, axes=(0, 2))
print('Sum along axes (0, 2):\n',result)
Output
Original Array: [[[0 1] [2 3]] [[4 5] [6 7]]] Sum along axes (0, 1): [[[12 16]]] Sum along axes (0, 2): [[[10] [18]]]
Example 2: Apply a lambda Function in an Array
We can return an array of values from the function.
import numpy as np
# create a 2D array
arr = np.array([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])
# apply the lambda function to compute the sum of an array along a specific axis
# compute the sum along the rows (axis=1) of the 2D array
result = np.apply_over_axes(lambda arr, axis: np.sum(arr, axis=axis), arr, axes=(1))
print(result)
Output
[[ 6] [15] [24]]
