A reference for making dask work (faster)

Many of us use xarray and dask to analyse model output. When I was first introduced to dask a few years ago, I was given a snippet of code to set up a client, and I had absolutely no idea what it was doing. The general wisdom was that there weren’t any hard and fast rules about how to make dask work, and you just developed a gut feeling for what was a good idea over time.

I think it would be great to build a collection of experiences and advice about how best to use dask, both to share with incoming dask-users and to help each other run more efficient computations. Some starting contributions from Claire Yung and I are below. Please add your wisdom :slight_smile:

For dask newbies:

  • From the dask dashboard (third tab on the left in jupyter lab), open “Task Stream”, “Progress” and “Worker Memory”. These can tell you how efficiently time is being utilised, what is taking the majority of your computation, and if the reason something is slow/crashing is memory related. I think that even having them open in the background lets you develop a feel for what a “normal” comptation should look like, and therefore when to look for problems.

  • For really big and high memory computation, the dask task graph can get very large if computations are all lazily applied. When loading the computation, the scheduler can appear to hang and not do anything. It can be useful to force dask to execute operations and save the files as intermediate netcdf files using .to_netcdf(): see API Calculation with Xarray + Dask — CLEX CMS Blog for an example

  • Chunking is nominally important (see Choosing good chunk sizes in Dask), but I’ve found that the time it takes to rechunk is usually much greater than the time you save.

  • When using dask in .py scripts submitted by the PBS scheduler on gadi, you need to include a main loop - see Batch Scripts — Parallel Programming in Climate and Weather for details and an example. You should also specify a dask dump working space. Ideally set it to be the local jobfs disk on the compute nodes, which an example of, using climtas.nci.GadiClient() , is in the above link (you could also use the /scratch disk on gadi but that tends to be much slower). You will also need to add a PBS resource request for jobfs memory e.g. #PBS -l jobfs=200gb - see the same link for an example.

  • Sometimes it’s faster not to use dask. Often if you can fit something into memory you should. Experiment.

For those of us trying to get things to work a little better:

  • Recently I was opening a large number of ACCESS-OM2-01 files to access data at one latitude for an entire model run. Dale Roberts suggested switching from a client with 4 workers and 7 threads each to a client with 28 single-threaded workers, which sped up the process by a factor of around four.
7 Likes

My 2c worth on chunking:

There are typically two types of chunking to consider when you’re using netCDF (or zarr) data with xarray+dask: file chunking and dask chunking:

file chunking

  • netCDF data are “chunked” within a netCDF file if the file is compressed. See the NCAR guide on chunking for a detailed description.
  • Zarr formatted data is also generally chunked
  • file chunks are the minimum size you can read from the file at any time. Reading less than a single chunk can be wasteful if the unused data from that chunk will be read at a later time in the same calculation
  • chunking can be very beneficial and massively speed-up some data access patterns compared to contiguous storage. Chunk dimensions will generally be a trade off to avoid some access patterns from being particularly slow, at the cost of none of them being incredibly fast
  • some chunking you have no control over: if your dataset is split into separate files along a coordinate, typically time, the maximum file chunksize along that coordinate will be limited to the size of the dimension of that coordinate in each file. e.g. monthly data stored with 3 months per file will have a maximum chunk size in time of 3.

dask chunking

  • dask chunking is specified in the xarray.open_dataset function using the chunks argument.
  • dask chunking effectively specifies how much data is processed in each dask task
  • the size of the task graph for a computation is inversely proportional to the dask chunk size. The smaller the dask chunks the larger the task graph, and vice-versa
  • there is usually no point having a dask chunk size smaller than the file chunk size
  • if your dask chunk size is too small your task graph can become too large. This can create issues with excessive memory usage, causing the calculation to error. When starting a calculation the dask dashboards can take a while before they begin to show the calculation proceeding. Often this is the time taken to generate the task graph. If this time is excessive, or there is an out of memory error before the calculations begin, it is likely your task graph is too large. In this case you should increase the dask chunk size and try again.

Also see my post from just last week here, indicating the following resources:

A few other bits and pieces:

A key point here is that if chunks is not specified, no chunking will be done (if open_mfdataset is used then the chunk size will be the file size).

You can see the native chunking of variables in a netcdf file using ncdump -hs <filename>.

I’ve found it useful and efficient to save some intermediate results that are expensive to calculate to .zarr stores using a for loop over sections of the datasets (.isel) and appending to the file as described here.

2 Likes

And just adding a personal gripe: I find that the way that dask/xarray/jupyter/ARE fails when you do something wrong while experimenting is uninformative and frustrating. The system will often print 100s or 1000s of lines of error messages (that don’t really tell you what’s going on), sometimes continuing long after you’ve shutdown and restarted the kernel. Not being able to suppress error messages using the standard warnings.simplefilter('ignore',*). Sometimes restarting the kernel is not enough and you have to restart the whole ARE session.

If anyone has hints about how these kinds of things can be addressed that would be super useful!

1 Like

Thanks to Andrew Hicks at the Bureau for these hints: " I’ve found using distributed instead of the default dask scheduler helps to keep memory under control, as it will spill to disk if it uses more than is available:

from dask.distributed import Client, LocalCluster 
cluster = LocalCluster(n_workers=2, memory_limit"3GiB") 
client = Client(cluster)

This will make sure it never uses more than 6 gigs of memory during computation.

Any dask calls after that will use use the cluster. You also get better metrics with the dashboard by running:

cluster.dashboard_link

More info on distributed and LocalCluster params here: Python API — Dask documentation

I have also found plotting tries to access the data multiple times, as though it was a normal numpy array, which triggers re-reading and re-calculating multiple times, so I would call ds1_LTMD_mon.sm_pct.load() before you try to plot it."

1 Like

Furthermore, if you’re running really compute intensive stuff you can launch dynamic dask clusters on multiple Gadi nodes (and have them only active for the compute intensive part of the computation) as documented by NCI here: https://opus.nci.org.au/display/DAE/Setting+up+a+Dask+cluster+on+the+NCI+ARE

1 Like

CLEX blog on the use of preprocess for making open_mfdataset work faster/better: More efficent use of xarray.open_mfdataset() — CLEX CMS Blog

1 Like

This looks like a great new resource for dask:

https://projectpythia.org/dask-cookbook/README.html

1 Like

Linking in this very useful comment from @dougiesquire on chunking during and after open_mfdataset: Dask remove time chunks for Fourier transforms - #8 by dougiesquire

1 Like

Hi all.

Did this a while ago but forgot to link it here: Introducing the dask-optimiser module — CLEX CMS Blog. It does some performance tuning stuff (worker settings, process binding, etc.), but most importantly, it correctly sets the link to the dask dashboard on ARE.

1 Like

[EDIT: The below code works, but installing flox (conda install -c conda-forge flox) achieves the same end, probably in a more generic manner and definitely with less code. This package is already installed in analysis3-unstable, so nobody else may have noticed the issue I was encountering anyway]

Subtracting climatologies from a dask array with data = data.groupby('time.month')-data.groupby('time.month').mean() changes the chunking to 1 in the time dimension, which can vastly slow down the performance of this supposedly simple procedure. The attached code wraps the whole operation up in xr.map_blocks, and subtracts a climatology with a single chunk in the time dimension. In some of my test cases, I got a speed-up by a factor of 100.

def strip_climatology(ds, 
                      clim = None, 
                      time_dim = 'time',
                      seasonal_dim = 'month',
                     ):
    '''
    Removes the climatology/seasonal variation of a dataset. This function should 
    return the same as ds.groupby('time.month')-ds.groupby('time.month').mean(),
    but it does so without rechunking in the time dimension, and in some instances
    can be 100x faster

    Parameters
    ----------
    ds : xr.DataArray 
        dataarray to remove climatology from
        (probably works with a xr.Dataset, but I haven't tested this functionality rigourously
    clim : None or xr.DataArray
        climatology of dataset, if already calculated
    time_dim : str
        dimension over which to calculate the climatology
    seasonal_dim: str
        variable over which to calculate climatologies
        Together, time_dim and seasonal_dim (ie 'time.month') form what goes into the groupby
    
    Returns
    -------
    ds_anomaly: xr.DataArray
        ds, with the climatology removed
        should be chunked in the same way as ds, and not expand the dask graph too much
    
    With MANY thanks to @rabernat from github for bringing this solution to my attention; 
    This code is adapted from 
    https://nbviewer.org/gist/rabernat/30e7b747f0e3583b5b776e4093266114
    
    '''
    
    def calculate_anomaly(ds,clim,time_dim,seasonal_dim):
        gb = ds.groupby(time_dim+'.'+seasonal_dim)
        if clim is None:
            clim = gb.mean()
        return gb - clim

    return xr.map_blocks(calculate_anomaly,
                         ds.chunk({'time':-1}),
                         kwargs={'clim':clim,'time_dim':time_dim,'seasonal_dim':seasonal_dim},
                         template=ds.assign_coords({seasonal_dim:ds[time_dim+'.'+seasonal_dim]})
                        )
2 Likes