How can I improve the performance of the following function by reducing the number of ‘for loops’ in it but keeping the memory usage minimum? Currently, this function takes about 15 minutes for a 365-day, 2.5-deg global dataset input. I believe parallel processing can significantly reduce the present computation time.
Any help or suggestions on enhancing its performance would be greatly appreciated.

The function reads a 4D xarray (ocean temperature) with dimensions [time x depth x lat x lon] and returns a 3D array [time x lat x lon] of the thermocline.

The function was originally written by @Dietmar_Dommenget in MATLAB, which also takes a similar time. It is now translated into Python as the following:

def thermocl(var):
dz = 0.1
ii = np.where(var.lev > 500)[0][0]
xzh = np.arange(var.lev[0],var.lev[ii],dz)
zz = var.lev[0:ii]
tcline = np.full([len(var.time),len(var.latitude),len(var.longitude)],np.nan,dtype=‘float’)
for il, jl in itertools.product(range(len(var.longitude)), range(len(var.latitude))):
if not np.isnan(var[:,:,jl,il]).all():
for it in range(len(var.time)):
temp = np.array(var[it,0:len(zz),jl,il])
xth = np.interp(xzh,zz,temp)
dtdzh = np.zeros_like(xth)
dtdzh[1:len(xzh)] = np.diff(xth)/dz
idx = dtdzh.argmin() # min -ve values (max gradient) achieved
tcline[it,jl,il] = xzh[idx]
else:
tcline[:,jl,il] = np.nan
return tcline

Hi Abhik,
I don’t think general code optimisation is in scope for us, but others might be able to make some suggestions.
You could look into using multiprocessing or similar, using xarray functions might be able to speed things up if used (I’m too new to xarray to suggest anything more precise) from my experience in numpy the best way to improve things was to get numpy vectorisation to do the work of the loop by using index expressions instead of loops before resorting to multiprocessing or other parallel methods.

A hint: when posting code you can use code blocks to get it to display correctly formatted (see the </> icon on the toolbar or enter three back ticks ` at the start and end of the code block)
eg: (hopefully I’ve got the indentation correct)

def thermocl(var):
dz = 0.1
ii = np.where(var.lev > 500)[0][0]
xzh = np.arange(var.lev[0],var.lev[ii],dz)
zz = var.lev[0:ii]
tcline = np.full([len(var.time),len(var.latitude),len(var.longitude)],np.nan,dtype='float')
for il, jl in itertools.product(range(len(var.longitude)), range(len(var.latitude))):
if not np.isnan(var[:,:,jl,il]).all():
for it in range(len(var.time)):
temp = np.array(var[it,0:len(zz),jl,il])
xth = np.interp(xzh,zz,temp)
dtdzh = np.zeros_like(xth)
dtdzh[1:len(xzh)] = np.diff(xth)/dz
idx = dtdzh.argmin() # min -ve values (max gradient) achieved
tcline[it,jl,il] = xzh[idx]
else:
tcline[:,jl,il] = np.nan
return tcline

Hi @OwKal, Thanks for your reply and for helping me to post the code in a suitable format.
The vectorization consumes huge memory and my Jupyter kernel runs out of memory even for a ‘large’ job. A moderate path may be useful for optimizing memory use and speeding up the task. Any suggestions?

Here’s a starting idea. Note that rather than using loops it’s using operations that work on the whole array at once. I would guess that the interpolation is going to be the most expensive operation on a large dataset, have a think about if you can reduce or eliminate it.

Note I’ve not tested this gives equivalent results to your original function, this is just a starting point.

def thermocl_xarray(ds):
"""
Calculate the thermocline using whole-array operations
Args:
ds: xarray.DataArray with dimensions [time, lev, latitude, longitude]
Returns:
xarray.DataArray with dimensions [time, latitude, longitide] with the thermocline depth at that point
"""
ii = np.where(ds.lev > 500)[0][0]
dz = 0.1
# Refined vertical levels
Z = np.arange(ds.lev[0], ds.lev[ii], dz)
# Interpolate data to refined levels
ds_interp = ds[:,:ii,:,:].interp(lev=Z)
# Calculate the derivative
dxdz = ds_interp.diff('lev') / dz
# Find the level where the derivative is minimized
iso = dxdz.argmin('lev')
# Return the depths at the minimum derivative
Z[iso]

Thanks, your function appears to be faster. However, I get an " All-NaN slice encountered" error from the .argmin statement in the second last line.
My input xarray (ocean temperature) contains nan values over the land points. I tried the following but didn’t get any success.

Finally, managed to solve this as the following. Perhaps, restricting the computation only to the valid grid points can make it even faster (now it takes about a minute for a 12-month 2.5-deg dataset). Many thanks, @Scott.

def thermocl_xarray(ds):
"""
Calculate the thermocline using whole-array operations
Args:
ds: xarray.DataArray with dimensions [time, lev, latitude, longitude]
Returns:
xarray.DataArray with dimensions [time, latitude, longitide] with the thermocline depth at that point
"""
ii = np.where(ds.lev > 500)[0][0]
dz = 0.1
# Create a mask, then replace nan in the input dataset by 0
ocean = ds[:,0,:,:]
ds = ds.fillna(0)
# Refined vertical levels
Z = np.arange(ds.lev[0], ds.lev[ii], dz)
# Interpolate data to refined levels
ds_interp = ds[:,:ii,:,:].interp(lev=Z)
# Calculate the derivative
dxdz = ds_interp.diff('lev') / dz
# Find the level where the derivative is minimized
iso = dxdz.argmin('lev')
# Return the depths at the minimum derivative
return np.ma.masked_where(ocean.isnull(),Z[iso])

[/quote]

2 Likes

clairecarouge
(Claire Carouge, ACCESS-NRI Land Modelling Team Lead)
8

@abhik Do you mind if I mark your last post as the solution? It’s for people who may have a similar issue and are looking at it in the future so they see quickly if there is a solution.

Yes, you can mark the last post as the solution. The function works well but it can be improved further.
Thanks for the input from the ACCESS Hive community, especially Scott.

The above function works perfectly for datasets with 2.5 deg horizontal resolution. But the PBS job crashed for relatively higher resolution (1 deg) datasets due to exceeding memory allocation, even with memory request = 192 GB, ncpus = 8. Is there any solution to reduce the memory load here? Thanks.

Aidan
(Aidan Heerdegen, ACCESS-NRI Release Team Lead)
11

One solution is to use dask, which can reduce memory usage by cutting the job into a number of smaller tasks that are run in parallel.

There is an existing topic about improving the performance of xarray and dask that might be helpful.

For 100 yrs monthly data (temp[1200,:,:,:]), I am planning to parallelize each year’s computation using dask. But something goes wrong in the following lines (please note I have never used dask) and the job gets killed.
Could you please advise where I am getting wrong? Thanks.

import dask
delayed_results = []
it = 0
for iyr in range(100):
thcline = dask.delayed(thermocl_xarray)(temp[it:it+12,:,:,:])
it += 12
delayed_results.append(thcline)
results = dask.compute(*delayed_results)
tcline = xr.DataArray(results,coords=[temp.time,temp.lat,temp.lon],
dims=['time','lat','lon'],
attrs = {
'_FillValue': -9.e+33,
'long_name' : 'thermocline depth',
'units' : 'm'
})

Aidan
(Aidan Heerdegen, ACCESS-NRI Release Team Lead)
13

It doesn’t look like you are using dask in parallel because you’re not using a dask.distributed client.

Take a look at an example COSIMA Recipe to see what is required to properly parallelise xarray computation with dask, e.g.