Finding a way to iterate using the input of two xarray dataarrays when chunked

Hi there,
This question is very xarray related. However, it does not fit the github forum, since I suspect the problem lies with me (I am not smart enough!) rather than with xarray. I also have a similar post out at stackoverflow, but no response. After 10+ failed attempts to solve the issue myself, I decided it was time to ask for some help. Maybe someone has got a similar experience…

I am developing a relatively large model using xarray and therefore want to make use of chunks. Most of my operations run a lot faster when chunked but there is one that keeps running (a lot) slower than unchunked.

Take this dataarray as an example (the real one is a lot bigger)

import xarray as xr
import numpy as np

da = xr.DataArray(
    np.random.randint(low=-5, high=5, size=(10, 100, 100)),
    coords=[range(10), range(100), range(100)],
    dims=["time", "x", "y"],
).chunk(chunks={"x":10, "y":10}) # comment this to opt out of chunking

my_max = 10
def sum_storage(arr):
    arr_new = xr.zeros_like(arr)
    for idx in range(1,len(arr)):
        arr_new[dict(time=idx)] = arr_new[dict(time=idx-1)] + da[dict(time=idx)]
        arr_storage[dict(time=idx)] = arr_storage[dict(time=idx)].where(arr_storage[dict(time=idx)] <= my_max, my_max)
    return arr_new

%time arr_storage = sum_storage(da)

I ran this unchunked and chunked. CPU times:
Unchunked: 0 ns. Chunked: 6.72 s

I have tried .rolling and .np_apply_along_axis following other suggestions (e.g., Applying custom functions to rolling · pydata/xarray · Discussion #6247 · GitHub) and applying the ufunc on dask array (Handling dask arrays). I also called a function using xr.concat e.g.:

xr.concat([my_function(da[dict(time=i)]) for i in my_list], dim=da.time)

but all solutions I looked at are iterating on only one dataarray. Whereas my example has two dataarrays that need to be iterated over at the same time, including the fact that for one of the arrays the idx and idx-1 indices need to be accessed, where the idx component needs to be derived from the idx-1 component first.

Any suggestions?

Hi Rogier

Does xarray’s cumsum do what you need ? It looks like you are trying to do a cumulative sum over time ?

From your example, it’s expected that the unchunked performance would be better than splitting into chunks. The chunks here are too small to be worth the overhead of splitting the operation up. I guess you are running this in a scaled up way, and then because your operation is over the time dimension (e.g. each value of arr_new depends on the previous time and does not depend on adjacent values in x/y) the chunks should be as large as feasible in the time direction.

In the example, it doesn’t actually fill arr_storage with values yet - it is still a dask delayed object. So the times measured are the time to create the dask graph - kind of like its plan of work it will do if values of arr_storage are requested, e.g. through a .load(). So the time that can be more important is the time of .load().

The function sum_storage appears to operate on the whole dataset
and xarray doesn’t know how operate on it in a “by-chunk” way. I didn’t work through the details but I think xarray’s apply_ufunc() or map_blocks() could be used in this case for the operation to be broken up.

1 Like

Hi Rogier,

I would agree with Anton’s suggestion to use cumsum. I think the following code does the same thing as your sum_storage. It’s been tested on your example array, though it’s probably worth double-checking with a subset of your actual array.

def sum_storage_xr(arr):
    # Your sum_storage code discards the first value of da; 
    # not clear if this is deliberate or not but I've imitated this behaviour here using where
    return arr.where(arr.time>0,0).cumsum('time')

Many simple steps in xarray or numpy code will vectorise and use all available cores without needing to be chunked, so if your dataarray fits into memory it’ll be faster without chunks or dask.

The code I’ve listed is slightly slower than what you’ve written, but it works on a chunked dask array without loading if you do need to use chunks.

2 Likes

Thank Anton and Jemma,
Those are excellent responses. However, I do not think I can use cumsum. It is because I did not explain well enough in my initial example. So entirely my fault.

After the summing of storage a bunch of conditions need to be met, such as the value of arr_storage not being able to be larger than my_max (see below). Hence, the cumsum solution is too simple.

da = xr.DataArray(
    np.random.randint(low=-5, high=5, size=(10, 100, 100)),
    coords=[range(10), range(100), range(100)],
    dims=["time", "x", "y"],
).chunk(chunks={"x":10, "y":10})

my_max = 10
def sum_storage(arr):
    arr_storage = xr.zeros_like(arr)
    for idx in range(1,len(arr)):
        # print(f'idx: {idx}')
        arr_storage[dict(time=idx)] = arr_storage[dict(time=idx-1)] + da[dict(time=idx)]
        arr_storage[dict(time=idx)] = arr_storage[dict(time=idx)].where(arr_storage[dict(time=idx)] <= my_max, my_max)
    return arr_storage

%time arr_storage = sum_storage(da)

I am now thinking of a solution of manually chunking the parts when they outgrow memory. But that does not sound ideal, i.e. I don’t want to write a lot of manual code when a possible chunked solution is also available…

Thanks for your time, hope you can look at it again?

Hi @rogierw,

from your last message, it appears the condition to be met is that all the values of your final object (arr_storage) need to be lower than a constant value (my_max).

In your last script, you are performing the step to guarantee this condition within the for loop, through:

arr_storage[dict(time=idx)] = arr_storage[dict(time=idx)].where(arr_storage[dict(time=idx)] <= my_max, my_max)

However, since your my_max is a constant, this step can be performed after the for loop (which is basically a cumsum, as previously suggested in other replies to the post), without changing the result.

This means you can first use the cumsum function and then apply the condition:

arr_storage = arr.cumsum('time')
arr_storage = arr_storage.where(arr_storage <= my_max, my_max)

Hi Davide,
Thanks for your help. Unfortunately, your solution works for a simple example where all inputs of da are positive, but not if the numbers are negative.

Again, it was me with my too simple example that did not explain well enough. I have now changed the definition of the array da to random integers in between -5 and 5 in all above explanations of code.

np.random.randint(low=-5, high=5, size=(1000))

Explaining the problem with a too simple example becomes a headache! I sure hope that I am not wasting people’s time…

Kind regards,
Rogier

Hi @rogierw,

I am sorry but I might not be completely understanding your final objective.

Would you be able to explain, conceptually (without code), what you are trying to achieve?
What is your expected input data structure like?
What should your output data structure be like?
Please, also try to be as general as possible based on the input/output data you are expecting.

Thank you

Cheers
Davide

Hi Davide,
Thanks for being so patient and kind to help out.

I am building a rainfall recharge model as a simple water bucket model. For each timestep, I am filling, or draining, a soil layer with water. If the soil layer is at field capacity (at my_max), any value over my_max will drain to deeper layers, i.e., recharge to groundwater.

My example shown below tried to keep it simple. It focuses on that layer of soil storage (i.e. arr_storage). I figured, that if I have the solution for that, I can easily add in the recharge condition later.

Focusing on that soil storage, this means that the value of water will not be able to exceed my_max. So, if after a time step water would be exceeding my_max, it would be taken out. If that is the case, at the next time step there can be either: water taken out of the soil layer (by evaporation) or water added to the soil layer (by rainfall).

Hope that gives a bit more clarity?

Kind regards,
Rogier

Are you just needing numpy.clip?