Creating an xarray mask based on multiple conditions in xr.where


I want to take an empty 4D array and fill it with 1s and 0s based on certain conditions being met in another array. The way I have coded it is:

da_fuzz = xr.zeros_like(S).expand_dims({'tree_depth':2**tree_depth}).assign_coords({'tree_depth':np.arange(0,2**tree_depth)})

## da_fuzz is an empty array with shape (16, 540, 300, 360)

for i in range(540):
    for j in range(16):
        da_fuzz[j,i] = xr.where((S.isel(time=i)>partitions[i,j,0])&\
                                1, 0)

## where S and T are sea surface salinity and temperature (shape = [540x300x360]) and partitions is a set of criteria of shape [540x16x4]

This loop takes 5.5 minutes on gadi, but surely there’s a more pythonic way to do this that’s faster and avoids loops? I am hoping to scale up to a mask da_fuzz of size [128x540x300x360], so would ideally want to optimise this loop. Thanks in advance for any help!

Hi @taimoorsohail. If you make partitions an xarray object (maybe it already is), you should be able to do this in a single line. This should be much faster.

# I don't know appropriate dimension names or coordinate
# values for the 1st and 2nd axes, so I've made some up
partitions_da = xr.DataArray(
        "time": S["time"],
        "dim_j": range(16),
        "dim_thresh": range(4)

da_fuzz = xr.where(
    (S > partitions_da.isel(dim_thresh=0)) &
    (S <= partitions_da.isel(dim_thresh=1)) &
    (T > partitions_da.isel(dim_thresh=2)) &
    (T <= partitions_da.isel(dim_thresh=3)),

For readability, you could use four xarray DataArrays for the partitions rather than having the dim_thresh dimension, e.g.

da_fuzz = xr.where(
    (S > S_lower) & (S <= S_upper) & 
    (T > T_lower) & (T <= T_upper),

As a general rule, try and avoid pre-defining empty xarray objects and assigning into them. This is not a good pattern for xarray and the vast majority of the time there’s an easier/better approach. If you really do need this type of pattern, it’s better to work on unlabelled arrays (e.g. numpy arrays) and then pack everything into an xarray DataArray/set at the end. Or, better still, use xarray.apply_ufunc to make your code that works on unlabelled arrays compatible with xarray.

I agree. Could also do what Taimoor has done, but have a labelled coordinate that also provides more information, e,g.

da_fuzz = xr.where(
    (S > partitions_da.sel(dim_thresh='S_lower')) &
    (S <= partitions_da.sel(dim_thresh='S_upper')) &
    (T > partitions_da.sel(dim_thresh='T_lower')) &
    (T <= partitions_da.sel(dim_thresh='T_upper')),

or have a partitions dataset with S and T variables and label the coordinates so it looks like:

da_fuzz = xr.where(
    (S > partitions_ds.S.sel(dim_thresh='lower')) &
    (S <= partitions_ds.S.sel(dim_thresh='upper')) &
    (T > partitions_ds.T.sel(dim_thresh='lower')) &
    (T <= partitions_ds.T.sel(dim_thresh='upper')),

Just thought it was worth pointing out the various ways the xarray data model and coordinates can be used to make more readable code.

Thank you both! Dougie that fix makes sense and really speeds things up! Thank you!

1 Like