Optimizing sorting a 3D xarray.DataArray by another 2D xarray.DataArray

Hello!

I would like to sort my 3D (depth, lat, time) fields of temperature, salinity, and density by a 2D (lat, time) array of dynamic height anomaly.

I’ve been able to sort in a for-loop, but I imagine there is a better (i.e. faster) way to do this sorting. I have tried xarray.DataArray.sortby but it does not like the fact that the dynamic height anomaly array is not 1D, so I would have to make another loop for each time step anyway.

Can anyone help me optimize this code please? My working for-loop code is shown below and I have attached a screenshot.

# make empty arrays for the sorted data

dynamic_height_anomaly_sorted = np.empty((gridded_gamma.shape[1],gridded_gamma.shape[2])); dynamic_height_anomaly_sorted[:] = np.nan
gridded_gamma_sorted = gridded_gamma * np.nan; gridded_gamma_sorted

# dynamic_height_anomaly: 2D (latitude X time) gives the dynamic height anomaly that has been summed across some depth range at each latitude for each time.
# gridded_gamma: 3D (depth X latitude X time) gridded field of density that is a depth-latitude cross-section varying with time.

> for t in range(60): # iterate over time
>     increasing_D_idxs = dynamic_height_anomaly[:,t].values.argsort() # I want the fields of temperature, salinity, and density to be sorted such that dynamic height anomaly increased with latitude
>     dynamic_height_anomaly_sorted[:,t] = (dynamic_height_anomaly[increasing_D_idxs,t].values)
>     gridded_gamma_sorted[:,:,t] = (gridded_gamma[:,increasing_D_idxs,t].values) # <-  this line is  slow 
> dynamic_height_anomaly_sorted.shape # 2D: rows = latitudes, cols = times; i.e. for each time step this array tells you how to sort so that dynamic height is increasing. 
> gridded_gamma_sorted # gridded_gamma_sorted: 3D (depth X latitude X time) gridded field of density that is a depth-latitude cross-section varying with time AND has been sorted so that dynamic height anomaly increases with latitude at each time step

Hey Kathy,

You can include nicely formatted code snippets

which makes reading them a lot easier. Another option of longer code examples or whole notebooks is to upload it to https://gist.github.com and refer back to it in your post.

Any time you’re using np.empty, np.nan and .values is warning signs that this could be done in a more xarray way, which is also usually faster.

Cheers

Aidan

1 Like

Hi @kathy.gunn. I’m a bit confused about what you’re after here. It looks like to want to sort your temperature/salinity/density fields at each time step such that the dynamic height anomaly is sorted along the latitude dimension. This would potentially give a different order of latitude coordinates at each time step (i.e. latitude becomes a function of time), but you are not worrying about that in your solution. Is that right? Apologies in advance if I’m being dense.

Hi @dougiesquire its a good question. And you are right. I don’t mind if the latitude coordinates are different after the sorting (e.g. a temperature profile gets shifted by x kilometres). I want to ensure that the dynamic height anomaly increases with latitude.

In other words, the original coordinates are depth-latitude, and the final coordinates are depth-dynamic height anomaly. In both cases, the x-coord increases as you move northwards.

@aidan, is this how you would make an empty array in a more ‘xarray’ way:

isopycnal_layer_depths = xr.DataArray(np.zeros( (gridded_salinity.shape) ) , dims=['isopycnal_surface','latitude','time'],coords={'isopycnal_surface': density_bins,'latitude': gridded_salinity.latitude, 'time': gridded_salinity.time})

isopycnal_layer_depths[:]=np.nan

I’d look into xr.zeros_like (or the slightly more general xr.full_like)

1 Like

@angus-g answered how to do this directly in an xarray way, but my rather vague point was that often this is an anti-pattern that tends to occur when traditional programming patterns are used with xarray where it is usually easier and better to use in-built functions, e.g. reduce, resample or groupby.

These higher level methods will return an xarray object with correct coordinates. In your case, as @dougiesquire noticed, your latitude coordinate is going to be incorrect, so it may be that you’re better off getting rid of it altogether, and replacing it with something more meaningful.

If you really want the coordinate to be dynamic height anomaly, you could add one, and then just sort on that.

1 Like

@kathy.gunn, I think this is one approach to your problem:

def sort_by_array(arr, arr_sort, dim):
    """
    Sort array(s) by the values of another array along a dimension
    
    Parameters
    ----------
    arr : xarray DataArray or Dataset
        The field(s) to be sorted
    arr_sort : xarray DataArray
        The field containing the values to sort by
    dim : str
        Dimension in arr_sort to sort along
    """

    SORT_DIM = "i"

    sort_axis = arr_sort.get_axis_num(dim)
    sort_inds = arr_sort.argsort(axis=sort_axis)
    # Replace dim with dummy dim and drop coordinates since
    # they're no longer correct
    sort_inds = sort_inds.rename({dim: SORT_DIM}).drop(SORT_DIM)
    
    return arr.isel({dim: sort_inds})

In your case: arr would be your dataset containing temperature, salinity and density; arr_sort would be your dataarray containing dynamic height anomalies; and dim would be "lat". You’ll see that the function adds a dimension i (along which dynamic height increases) and the lat coordinate depends on both i and time. This is one way to keep track of the correct latitude information after the sort.

1 Like

Dougie, I am getting this error: ‘NotImplementedError: ‘argsort’ is not yet a valid method on dask arrays’. Is there a simple way to fix that?

That’s annoying. That’s because your dynamic_height_anomaly DataArray wraps a dask array, and dask arrays don’t yet have an argsort method. To fix, you can compute your dynamic_height_anomalys prior to passing them to the function (i.e. bring the underlying array into memory as a numpy array, which does have an argsort method):

dynamic_height_anomaly = dynamic_height_anomaly.compute()
2 Likes

Fab - works like a charm and exactly matches what I was doing with my for loop above. Thank you once again!

2 Likes