Beyond Flat Indices: Extracting True Positions of Maximum Values in Multidimensional Arrays with NumPy
However, if you're dealing with multidimensional arrays and want to find the indices within the original shape, you need to unpack the flat index back into its corresponding non-flat indices.
Here's how to achieve this:
For instance, consider a 3D array arr
with shape (2, 2, 3). If you want to find the argmax along the last axis (axis=2), you'll get the maximum indices within each row of the 2D slices. By unraveling the flat indices, you can obtain the corresponding row and column positions where those maximum values reside in the original 3D array.
Here's an example to illustrate this:
import numpy as np
# Create a 3D array
arr = np.array([[[1, 3, 2], [4, 5, 7]], [[8, 9, 10], [11, 12, 14]]])
# Find the argmax along the last axis
ravel_idx = np.argmax(arr, axis=-1)
# Unravel the index to get the original shape
unraveled_idx = np.unravel_index(ravel_idx, arr.shape[:-1])
print("Original array:")
print(arr)
print("\nArgmax along the last axis (flattened):")
print(ravel_idx)
print("\nArgmax along the last axis (unraveled):")
print(unraveled_idx)
This code will print the original array, the flat indices of the maximum values along the last axis, and finally, the unraveled indices which represent the row and column positions of those maximum values in the original 3D array.
import numpy as np
# Create a 2D array
arr = np.array([[10, 11, 12], [13, 14, 15], [5, 8, 2]])
# Find the argmax along each row (axis=0)
row_argmax = np.argmax(arr, axis=0) # Flat indices of max values in each row
# Unravel the indices to get row positions
row_positions = row_argmax
# Find the argmax along each column (axis=1)
col_argmax = np.argmax(arr, axis=1) # Flat indices of max values in each column
# Unravel the indices to get column positions
col_positions = col_argmax
print("Original array:")
print(arr)
print("\nArgmax along each row (flattened):")
print(row_argmax)
print("\nRow positions (unraveled):")
print(row_positions)
print("\nArgmax along each column (flattened):")
print(col_argmax)
print("\nColumn positions (unraveled):")
print(col_positions)
This code defines a 2D array arr
. It then performs argmax
along two different axes:
Looping with conditional statements (less efficient for large arrays):
This method iterates through the array and keeps track of the maximum value and its corresponding indices for each row/column (depending on the desired axis).
import numpy as np
def argmax_unraveled(arr, axis):
"""
Finds argmax along a specific axis and returns unraveled indices.
Args:
arr: NumPy array.
axis: Axis along which to find the maximum.
Returns:
A tuple of arrays containing the unraveled indices.
"""
if axis == 0:
unraveled_indices = []
for row in arr:
max_value = row[0]
max_idx = 0
for i, val in enumerate(row):
if val > max_value:
max_value = val
max_idx = i
unraveled_indices.append([max_idx])
return np.array(unraveled_indices)
elif axis == 1:
unraveled_indices = []
for col in arr.T:
max_value = col[0]
max_idx = 0
for i, val in enumerate(col):
if val > max_value:
max_value = val
max_idx = i
unraveled_indices.append([max_idx])
return np.array(unraveled_indices).T
else:
raise ValueError("Invalid axis value")
# Example usage
arr = np.array([[10, 11, 12], [13, 14, 15], [5, 8, 2]])
row_positions = argmax_unraveled(arr.copy(), axis=0)
col_positions = argmax_unraveled(arr.copy(), axis=1)
print("Row positions:", row_positions)
print("Column positions:", col_positions)
Masking and advanced indexing (more efficient):
This method utilizes boolean masks to identify the maximum elements and then uses advanced indexing to extract their corresponding indices.
import numpy as np
def argmax_unraveled_mask(arr, axis):
"""
Finds argmax along a specific axis and returns unraveled indices using masks.
Args:
arr: NumPy array.
axis: Axis along which to find the maximum.
Returns:
A tuple of arrays containing the unraveled indices.
"""
if axis == 0:
# Find maximum along each row
max_values = np.max(arr, axis=axis)
# Create mask for maximum elements in each row
row_mask = arr == max_values[:, np.newaxis]
# Extract indices using boolean indexing
return np.argmax(row_mask, axis=axis)
elif axis == 1:
# Similar logic for columns, using np.newaxis for broadcasting
max_values = np.max(arr, axis=axis)
col_mask = arr == max_values[np.newaxis, :]
return np.argmax(col_mask, axis=axis)
else:
raise ValueError("Invalid axis value")
# Example usage
arr = np.array([[10, 11, 12], [13, 14, 15], [5, 8, 2]])
row_positions = argmax_unraveled_mask(arr.copy(), axis=0)
col_positions = argmax_unraveled_mask(arr.copy(), axis=1)
print("Row positions:", row_positions)
print("Column positions:", col_positions)
These methods offer alternative approaches, but keep in mind that the np.argmax
with np.unravel_index
approach is generally more efficient, especially for larger arrays.
python multidimensional-array numpy