Possible bug(?) in xarray.groupby('time.month') operations

I’ve come across a problem using some of xarray’s .groupby('time.month') operations that I don’t understand. This might not be the forum for it (should I post on the xarray github?), but I thought I’d post here in case anyone else has noticed this.

I’m trying to calculate a climatology of the standard deviation of temperature and salinity as a function of depth and time (month). I’m using the conda 24.04 kernel. When I calculate the standard deviation using Dask, I get a noisy field and some missing values for salinity (the temperature calculation is fine!), despite there being no missing values in the array that I’m calculating the standard deviation of. If I first load the data then the standard deviation looks fine (see below). Perhaps the problem is related to the flox warnings that show up? I don’t really understand what these warnings are telling me.

I find it pretty concerning that the .groupby('time'month').std('time') function gives different results depending on whether the data is loaded or not.

Example
Loading data:

Calculating and plotting standard deviation using Dask:


Calculating and plotting standard deviation after loading into memory:

If I use an older kernel (e.g. 23.01) the same problem occurs except that the missing cells are not consistent (different cells show up as missing).

Does anyone know why this occurs and how to avoid it? I won’t always be able to load data to run the calculation. Thanks in advance for any advice.

2 Likes

hi @hrsdawson, thanks for reporting!

Could you please post the code snippets as text, eg enclosed within triple backticks ```? That way it’ll be much easier for somebody to reproduce the error and investigate instead of having to read of code from the screenshot.

2 Likes

A couple of things maybe worth highlighting:

  • @hrsdawson found doing the same calculation with dask and without dask gives different answers! And whilst the dask version gives warnings, it produces a result which looks wrong at first glance.
  • Doing the same calculation with the 3d temperature field (which is very similar … its calculated at all the same grid points) appears to work fine. Both with dask and without.

Do you get the same inconsistency if you use the raw functions dask.array.nanstd and numpy.nanstd?

If you plot as a timeseries are there differences between the cells that are blank and the cells with values?

If flox is involved ( or even if it’s not ) then Deepak is your guy. He seems to be improving flox regularly and I’m sure would welcome an issue over on GitHub - xarray-contrib/flox: Fast & furious GroupBy operations for dask.array.

There appears to be lots of activity and closed PR’s so might check version and dependency issues first?

1 Like

Yes, sure thing. Here’s the code snippets.

# define latitude and depth range
lat_range = slice(-75, -65)
depth_range = slice(0, 600) 

# create list of files to open
files = sorted(glob('/scratch/e14/hd4873/archive/dk158/history/ocn/' + 'ocean_month.nc-*'))

# define a preprocessing function to read in the ocean variables of interest
# note: this is not a weighted average in space, but at this stage I don't mind. 
def preprocess(ds):
    return ds.get(['pot_temp', 'salt']).sel(st_ocean=depth_range).sel(yt_ocean=lat_range).mean(['yt_ocean', 'xt_ocean'], skipna=True) 

ds_piC = xr.open_mfdataset(files, preprocess=preprocess, parallel=True)
# plot monthly std climatology without loading the data
ds_piC.salt.groupby('time.month').std('time', skipna=True).T.plot(vmin=0, vmax=0.04);
plt.gca().invert_yaxis();

Which results in:

# now load the data and then plot again
ds_piC = ds_piC.load()
ds_piC.salt.groupby('time.month').std('time', skipna=True).T.plot(vmin=0, vmax=0.04);
plt.gca().invert_yaxis();

Which gives:

1 Like

We tried using numpy yesterday and it looked fine. I’ve just tried using the dask.array.nanstd function and it also looks fine. Note, I couldn’t figure out how to do this (without the .groupby() function) without chucking it into a loop. Happy to be informed if there’s a better way :slight_smile:

# select one year of the dataset and copy to new array
std_array = ds_piC.sel(time=slice('0950','0950')).copy(deep=True)

# loop through months
for month in range(1,13,1):
    
    # calculate standard deviation of that particular month
    monthly_std_values = dask.array.nanstd(ds_piC.where(ds_piC.time.dt.month == month, drop=True).salt, axis=0)

    # save to array
    std_array['salt'][month-1,:] = monthly_std_values

# plot
std_array.salt.T.plot(vmin=0, vmax=0.04);
plt.gca().invert_yaxis();

Which gives a normal looking result:

Here’s a timeseries of two cells in the array. Red = bottom cell which ends up with nans. Blue = cell that doesn’t end up with nans. The values look fine to me.

ds_piC.salt.isel(st_ocean=-1).plot(c='r');
ds_piC.salt.isel(st_ocean=-4).plot(c='b');

But here’s a timeseries (climatology) of the array after I’ve calculated the standard deviation, showing the missing values.

ds_piC.salt.groupby('time.month').std('time', skipna=True).isel(st_ocean=-1).plot(c='r');
ds_piC.salt.groupby('time.month').std('time', skipna=True).isel(st_ocean=-4).plot(c='b');

Is this what you meant?

That looks ok - must be something in xarray then, could be something like flox as Thomas mentioned, or it could be xarray is calling dask.array.std instead of dask.array.nanstd (they behave differently when nans are in the data)

1 Like

Just a heads-up @hrsdawson, this has been tagged as outofscope for ACCESS-NRI support because it is not a product or release we are responsible for. That doesn’t stop anyone from the team, or indeed anyone else, suggesting possible remedies, and thanks for pointing out there may be an issue affecting others.

If it does look like a bug it would be good to craft a minimal complete reproducible example (this can also be useful for your own debugging). See here for some tips:

Good luck!

Thanks @Aidan for the heads up. Anton mentioned that was probably the case.

I’ll have a go at making a minimal reproducible example tomorrow. If I’m successful, I’ll probably raise an issue on the repo that @Thomas-Moore linked, unless there’s more updates here.

1 Like

If you can make a MCRE I’ll try to give a whirl. I couldn’t test your code as I don’t have access to e14.

If this is genuinely something wrong with xarray.groupby and / or flox IMO the devs will really appreciate your MCRE and any issue you raise.

2 Likes

I think it’s a problem with the precision of the calculation, and it’s showing up in salinity and not temperature because there the variance is much smaller than the mean state. @hrsdawson does the problem go away if you subtract 1036 or some other reference value?

This is my best attempt to replicate, without access to the data on e14. Does it look like your problem, Hannah? I think the NaNs are coming from trying to square root negative variances (which shouldn’t happen, but might if there’s very small numbers and precision problems, and is implied by the flox RuntimeWarning)

import numpy as np
import xarray as xr

# In an imaginary universe with two months, and very small variability compared to the magnitude of numbers
test_data = xr.DataArray(np.arange(8)/10000+1000,dims=('time',)
                        ).assign_coords({'month':xr.DataArray(np.arange(8)%2,dims=('time',))})

# using loaded data
test_data.groupby('month').std('time')
# Output [0.00022361, 0.00022361]

# using lazy computation/dask
test_data.chunk({'time':1}).groupby('month').std('time').load()
# Output [0.000224  , 0.00022374]

#[Output edited from original post to be internally consistent with test_data. I accidentally had the output and test_data copied from notebook cells of different iterations, and while the overall point was correct the numbers weren't internally consistent]

If this doesn’t fix the problem, I’m happy to have a poke at your data (if you can put a sample on x77 or /scratch/public?) or be corrected on an aspect I’ve missed. If it does address the issue, I’d be interested in spending some more time burrowing into why this is happening and raising it with the devs

2 Likes

Thanks for the offers @Thomas-Moore and @jemmajeffree. Adding an offset does not solve the problem. I think your example could be related, but btw I get different output values to you. Is that expected?

array([0.00022361, 0.00022361])
# then
array([0.000224  , 0.00022374])

I have been thinking along the same lines as you re the variance, but that doesn’t explain (at least to me) why the calculation would work when loaded into memory. Do calculations involving dask use different precision compared to loading data to memory?

I have not been able to reproduce the errors in my attempts to create a minimal example. :frowning: I have tried creating small arrays with the same distribution as the original salinity fields, tried reading in these saved arrays using dask etc but the error doesn’t show up for me.

@jemmajeffree if you do feel inclined, here is a link to a gist that provides a short example (though not minimal and contingent on Gadi access) using the model data. I have copied two files needed to reproduce the error to the /scratch/public/hd4873/ directory.

Looks to me like an offset of -34 fixes the problem? Which I chose because the mean salinity is 34.19, so I’m trying to make the variances larger as a percentage of the mean.

Re your different answers, I think I just changed one cell and forgot to rerun the latter two. My answers now match yours, and I’ve updated my example to be internally consistent. Oops.

2 Likes

It does fix it! Does that mean the solution is to add appropriate offsets when calculating the standard deviation on an array that has small variance? Or should that behaviour not happen?..

That behaviour definitely should not happen. I’d guess a generic fix is to alter your function to be:

(ds.salt.groupby('time.month')-ds.salt.groupby('time.month').mean('time')).groupby('time.month').std('time', skipna=True).T.plot(vmin=0, vmax=0.03);

and basically make sure you always subtract the mean before taking a standard deviation.
and if I were to posit wilder guesses, maybe loaded/basic xarray is doing this already under the hood and flox is not? There’d be very few situations in which it actually makes a difference. But I haven’t actually looked at the code and don’t have the evidence to back this up, so I’m just speculating at the moment

2 Likes

Good to know, thanks Jemma!!!

Are we sure this is right? I would expect the std deviation to vary much more smoothly that this with depth ?

Update for those following along at home:

@anton I can’t speak for the physical validity, but it’s consistent with calculating std without dask, and Hannah says it’s bound to be a bit noisy with only two months of data.

I’ve got an expanded working example that produces NaNs:

import numpy as np
import xarray as xr

l =12000
np.random.seed(1)
test_data = xr.DataArray(np.random.uniform(0,1,l)/100+1000000,dims=('time',)
                        ).assign_coords({'month':xr.DataArray(np.arange(l)%12,dims=('time',))})

# xarray groupby and std
test_data.groupby('month').std('time')
# array([0.00283648, 0.00281895, 0.00287791, 0.00287652, 0.00287337,
#        0.00287037, 0.00289802, 0.00289441, 0.00285839, 0.00296478,
#        0.00284787, 0.00292089])

# using lazy computation/dask
test_data.chunk({'time':100}).groupby('month').std('time').load()
# array([0.01118034, 0.01118034, 0.01118034, 0.01581139, 0.        ,
#        0.01581139, 0.01118034, 0.01118034, 0.01118034,        nan,
#               nan, 0.        ])

As best I can tell, the issue only occurs with the use of groupby, and not std over the same values selected with isel. I’m planning on spinning up a separate environment to test whether it’s flox specifically or more generally groupby.

I’m planning on raising the bug with the xarray/flox devs once I’ve dug a little deeper into the causes

Edit:

Plus a summary of earlier points:

My best guess as to what’s happening is that calling std after groupby on a dask array is lower precision than a numpy array, which is producing noisy variances, and occasionally negative variances that produce NaNs when square rooted

The current workaround is to manually subtract the mean before calculating standard deviation, so that lower precision is needed in the standard deviation:

(ds.salt.groupby('time.month')-ds.salt.groupby('time.month').mean('time')).groupby('time.month').std('time', skipna=True)
5 Likes