Many of us use xarray and dask to analyse model output. When I was first introduced to dask a few years ago, I was given a snippet of code to set up a client, and I had absolutely no idea what it was doing. The general wisdom was that there weren’t any hard and fast rules about how to make dask work, and you just developed a gut feeling for what was a good idea over time.
I think it would be great to build a collection of experiences and advice about how best to use dask, both to share with incoming dask-users and to help each other run more efficient computations. Some starting contributions from Claire Yung and I are below. Please add your wisdom
For dask newbies:
From the dask dashboard (third tab on the left in jupyter lab), open “Task Stream”, “Progress” and “Worker Memory”. These can tell you how efficiently time is being utilised, what is taking the majority of your computation, and if the reason something is slow/crashing is memory related. I think that even having them open in the background lets you develop a feel for what a “normal” comptation should look like, and therefore when to look for problems.
For really big and high memory computation, the dask task graph can get very large if computations are all lazily applied. When loading the computation, the scheduler can appear to hang and not do anything. It can be useful to force dask to execute operations and save the files as intermediate netcdf files using .to_netcdf(): see API Calculation with Xarray + Dask — CLEX CMS Blog for an example
Chunking is nominally important (see Choosing good chunk sizes in Dask), but I’ve found that the time it takes to rechunk is usually much greater than the time you save.
When using dask in .py scripts submitted by the PBS scheduler on gadi, you need to include a main loop - see Batch Scripts — Parallel Programming in Climate and Weather for details and an example. You should also specify a dask dump working space. Ideally set it to be the local jobfs disk on the compute nodes, which an example of, using climtas.nci.GadiClient() , is in the above link (you could also use the /scratch disk on gadi but that tends to be much slower). You will also need to add a PBS resource request for jobfs memory e.g. #PBS -l jobfs=200gb - see the same link for an example.
Sometimes it’s faster not to use dask. Often if you can fit something into memory you should. Experiment.
For those of us trying to get things to work a little better:
Recently I was opening a large number of ACCESS-OM2-01 files to access data at one latitude for an entire model run. Dale Roberts suggested switching from a client with 4 workers and 7 threads each to a client with 28 single-threaded workers, which sped up the process by a factor of around four.
file chunks are the minimum size you can read from the file at any time. Reading less than a single chunk can be wasteful if the unused data from that chunk will be read at a later time in the same calculation
chunking can be very beneficial and massively speed-up some data access patterns compared to contiguous storage. Chunk dimensions will generally be a trade off to avoid some access patterns from being particularly slow, at the cost of none of them being incredibly fast
some chunking you have no control over: if your dataset is split into separate files along a coordinate, typically time, the maximum file chunksize along that coordinate will be limited to the size of the dimension of that coordinate in each file. e.g. monthly data stored with 3 months per file will have a maximum chunk size in time of 3.
dask chunking effectively specifies how much data is processed in each dask task
the size of the task graph for a computation is inversely proportional to the dask chunk size. The smaller the dask chunks the larger the task graph, and vice-versa
there is usually no point having a dask chunk size smaller than the file chunk size
if your dask chunk size is too small your task graph can become too large. This can create issues with excessive memory usage, causing the calculation to error. When starting a calculation the dask dashboards can take a while before they begin to show the calculation proceeding. Often this is the time taken to generate the task graph. If this time is excessive, or there is an out of memory error before the calculations begin, it is likely your task graph is too large. In this case you should increase the dask chunk size and try again.
Also see my post from just last week here, indicating the following resources:
Parallel computing with Dask - in particular see Optimization tips at the bottom. This contains some suggestions on how to get dask to play nicely with operations such as groupby (for subtracting climatologies).
A key point here is that if chunks is not specified, no chunking will be done (if open_mfdataset is used then the chunk size will be the file size).
You can see the native chunking of variables in a netcdf file using ncdump -hs <filename>.
I’ve found it useful and efficient to save some intermediate results that are expensive to calculate to .zarr stores using a for loop over sections of the datasets (.isel) and appending to the file as described here.
And just adding a personal gripe: I find that the way that dask/xarray/jupyter/ARE fails when you do something wrong while experimenting is uninformative and frustrating. The system will often print 100s or 1000s of lines of error messages (that don’t really tell you what’s going on), sometimes continuing long after you’ve shutdown and restarted the kernel. Not being able to suppress error messages using the standard warnings.simplefilter('ignore',*). Sometimes restarting the kernel is not enough and you have to restart the whole ARE session.
If anyone has hints about how these kinds of things can be addressed that would be super useful!
Thanks to Andrew Hicks at the Bureau for these hints: " I’ve found using distributed instead of the default dask scheduler helps to keep memory under control, as it will spill to disk if it uses more than is available:
I have also found plotting tries to access the data multiple times, as though it was a normal numpy array, which triggers re-reading and re-calculating multiple times, so I would call ds1_LTMD_mon.sm_pct.load() before you try to plot it."