Hi there,
This question is very xarray related. However, it does not fit the github forum, since I suspect the problem lies with me (I am not smart enough!) rather than with xarray. I also have a similar post out at stackoverflow, but no response. After 10+ failed attempts to solve the issue myself, I decided it was time to ask for some help. Maybe someone has got a similar experience…
I am developing a relatively large model using xarray and therefore want to make use of chunks. Most of my operations run a lot faster when chunked but there is one that keeps running (a lot) slower than unchunked.
Take this dataarray as an example (the real one is a lot bigger)
import xarray as xr
import numpy as np
da = xr.DataArray(
np.random.randint(low=-5, high=5, size=(10, 100, 100)),
coords=[range(10), range(100), range(100)],
dims=["time", "x", "y"],
).chunk(chunks={"x":10, "y":10}) # comment this to opt out of chunking
my_max = 10
def sum_storage(arr):
arr_new = xr.zeros_like(arr)
for idx in range(1,len(arr)):
arr_new[dict(time=idx)] = arr_new[dict(time=idx-1)] + da[dict(time=idx)]
arr_storage[dict(time=idx)] = arr_storage[dict(time=idx)].where(arr_storage[dict(time=idx)] <= my_max, my_max)
return arr_new
%time arr_storage = sum_storage(da)
I ran this unchunked and chunked. CPU times:
Unchunked: 0 ns. Chunked: 6.72 s
I have tried .rolling and .np_apply_along_axis following other suggestions (e.g., Applying custom functions to rolling · pydata/xarray · Discussion #6247 · GitHub) and applying the ufunc on dask array (Handling dask arrays). I also called a function using xr.concat e.g.:
xr.concat([my_function(da[dict(time=i)]) for i in my_list], dim=da.time)
but all solutions I looked at are iterating on only one dataarray. Whereas my example has two dataarrays that need to be iterated over at the same time, including the fact that for one of the arrays the idx and idx-1 indices need to be accessed, where the idx component needs to be derived from the idx-1 component first.
Any suggestions?