Optimizing sorting a 3D xarray.DataArray by another 2D xarray.DataArray

@kathy.gunn, I think this is one approach to your problem:

def sort_by_array(arr, arr_sort, dim):
    """
    Sort array(s) by the values of another array along a dimension
    
    Parameters
    ----------
    arr : xarray DataArray or Dataset
        The field(s) to be sorted
    arr_sort : xarray DataArray
        The field containing the values to sort by
    dim : str
        Dimension in arr_sort to sort along
    """

    SORT_DIM = "i"

    sort_axis = arr_sort.get_axis_num(dim)
    sort_inds = arr_sort.argsort(axis=sort_axis)
    # Replace dim with dummy dim and drop coordinates since
    # they're no longer correct
    sort_inds = sort_inds.rename({dim: SORT_DIM}).drop(SORT_DIM)
    
    return arr.isel({dim: sort_inds})

In your case: arr would be your dataset containing temperature, salinity and density; arr_sort would be your dataarray containing dynamic height anomalies; and dim would be "lat". You’ll see that the function adds a dimension i (along which dynamic height increases) and the lat coordinate depends on both i and time. This is one way to keep track of the correct latitude information after the sort.

1 Like