Improving performance of a Python-based function

Hi,

How can I improve the performance of the following function by reducing the number of ‘for loops’ in it but keeping the memory usage minimum? Currently, this function takes about 15 minutes for a 365-day, 2.5-deg global dataset input. I believe parallel processing can significantly reduce the present computation time.
Any help or suggestions on enhancing its performance would be greatly appreciated.

The function reads a 4D xarray (ocean temperature) with dimensions [time x depth x lat x lon] and returns a 3D array [time x lat x lon] of the thermocline.

The function was originally written by @Dietmar_Dommenget in MATLAB, which also takes a similar time. It is now translated into Python as the following:

def thermocl(var):
dz = 0.1
ii = np.where(var.lev > 500)[0][0]
xzh = np.arange(var.lev[0],var.lev[ii],dz)
zz = var.lev[0:ii]
tcline = np.full([len(var.time),len(var.latitude),len(var.longitude)],np.nan,dtype=‘float’)
for il, jl in itertools.product(range(len(var.longitude)), range(len(var.latitude))):
if not np.isnan(var[:,:,jl,il]).all():
for it in range(len(var.time)):
temp = np.array(var[it,0:len(zz),jl,il])
xth = np.interp(xzh,zz,temp)
dtdzh = np.zeros_like(xth)
dtdzh[1:len(xzh)] = np.diff(xth)/dz
idx = dtdzh.argmin() # min -ve values (max gradient) achieved
tcline[it,jl,il] = xzh[idx]
else:
tcline[:,jl,il] = np.nan
return tcline

Hi Abhik,
I don’t think general code optimisation is in scope for us, but others might be able to make some suggestions.
You could look into using multiprocessing or similar, using xarray functions might be able to speed things up if used (I’m too new to xarray to suggest anything more precise) from my experience in numpy the best way to improve things was to get numpy vectorisation to do the work of the loop by using index expressions instead of loops before resorting to multiprocessing or other parallel methods.

A hint: when posting code you can use code blocks to get it to display correctly formatted (see the </> icon on the toolbar or enter three back ticks ` at the start and end of the code block)
eg: (hopefully I’ve got the indentation correct)

def thermocl(var):
    dz = 0.1
    ii = np.where(var.lev > 500)[0][0]
    xzh = np.arange(var.lev[0],var.lev[ii],dz)
    zz = var.lev[0:ii]
    tcline = np.full([len(var.time),len(var.latitude),len(var.longitude)],np.nan,dtype='float')
    for il, jl in itertools.product(range(len(var.longitude)), range(len(var.latitude))):
        if not np.isnan(var[:,:,jl,il]).all():
            for it in range(len(var.time)):
                temp = np.array(var[it,0:len(zz),jl,il])
                xth = np.interp(xzh,zz,temp)
                dtdzh = np.zeros_like(xth)
                dtdzh[1:len(xzh)] = np.diff(xth)/dz
                idx = dtdzh.argmin() # min -ve values (max gradient) achieved
                tcline[it,jl,il] = xzh[idx]
        else:
            tcline[:,jl,il] = np.nan
    return tcline

Hi @OwKal, Thanks for your reply and for helping me to post the code in a suitable format.
The vectorization consumes huge memory and my Jupyter kernel runs out of memory even for a ‘large’ job. A moderate path may be useful for optimizing memory use and speeding up the task. Any suggestions?

Here’s a starting idea. Note that rather than using loops it’s using operations that work on the whole array at once. I would guess that the interpolation is going to be the most expensive operation on a large dataset, have a think about if you can reduce or eliminate it.

Note I’ve not tested this gives equivalent results to your original function, this is just a starting point.

def thermocl_xarray(ds):
    """
    Calculate the thermocline using whole-array operations

    Args:
        ds: xarray.DataArray with dimensions [time, lev, latitude, longitude]

    Returns:
        xarray.DataArray with dimensions [time, latitude, longitide] with the thermocline depth at that point
    """
    
    ii = np.where(ds.lev > 500)[0][0]
    dz = 0.1
    
    # Refined vertical levels
    Z = np.arange(ds.lev[0], ds.lev[ii], dz)
    
    # Interpolate data to refined levels
    ds_interp = ds[:,:ii,:,:].interp(lev=Z)
    
    # Calculate the derivative
    dxdz = ds_interp.diff('lev') / dz
    
    # Find the level where the derivative is minimized
    iso = dxdz.argmin('lev')
    
    # Return the depths at the minimum derivative
    Z[iso]
4 Likes

Hi @Scott

Thanks, your function appears to be faster. However, I get an " All-NaN slice encountered" error from the .argmin statement in the second last line.
My input xarray (ocean temperature) contains nan values over the land points. I tried the following but didn’t get any success.

ds_interp = ds[:,:ii,:,:].interp(lev=Z).where(ds != np.nan)

Some masking to those nan points is required, so the computation will only be applied to the ! nan points. Any suggestion?

A quick and dirty fix is to replace the nans with zeros, which I think is

da = da.fillna(0)

You can then re-apply the mask once the computation is done

Finally, managed to solve this as the following. Perhaps, restricting the computation only to the valid grid points can make it even faster (now it takes about a minute for a 12-month 2.5-deg dataset). Many thanks, @Scott.

    def thermocl_xarray(ds):
    """
    Calculate the thermocline using whole-array operations

    Args:
        ds: xarray.DataArray with dimensions [time, lev, latitude, longitude]

    Returns:
        xarray.DataArray with dimensions [time, latitude, longitide] with the thermocline depth at that point
    """
    
    ii = np.where(ds.lev > 500)[0][0]
    dz = 0.1

    # Create a mask, then replace nan in the input dataset by 0
    ocean = ds[:,0,:,:]
    ds    = ds.fillna(0)
    
    # Refined vertical levels
    Z = np.arange(ds.lev[0], ds.lev[ii], dz)
    
    # Interpolate data to refined levels
    ds_interp = ds[:,:ii,:,:].interp(lev=Z)
    
    # Calculate the derivative
    dxdz = ds_interp.diff('lev') / dz
    
    # Find the level where the derivative is minimized
    iso = dxdz.argmin('lev')
    
    # Return the depths at the minimum derivative
    return np.ma.masked_where(ocean.isnull(),Z[iso])

[/quote]

3 Likes

@abhik Do you mind if I mark your last post as the solution? It’s for people who may have a similar issue and are looking at it in the future so they see quickly if there is a solution.

Hi Claire,

Yes, you can mark the last post as the solution. The function works well but it can be improved further.
Thanks for the input from the ACCESS Hive community, especially Scott.

1 Like

Hi @Scott,

The above function works perfectly for datasets with 2.5 deg horizontal resolution. But the PBS job crashed for relatively higher resolution (1 deg) datasets due to exceeding memory allocation, even with memory request = 192 GB, ncpus = 8. Is there any solution to reduce the memory load here? Thanks.

One solution is to use dask, which can reduce memory usage by cutting the job into a number of smaller tasks that are run in parallel.

There is an existing topic about improving the performance of xarray and dask that might be helpful.

Hi @Aiden

For 100 yrs monthly data (temp[1200,:,:,:]), I am planning to parallelize each year’s computation using dask. But something goes wrong in the following lines (please note I have never used dask) and the job gets killed.
Could you please advise where I am getting wrong? Thanks.

import dask

delayed_results = []
it           = 0
for iyr in range(100):
        thcline   = dask.delayed(thermocl_xarray)(temp[it:it+12,:,:,:]) 
        it  += 12
        delayed_results.append(thcline)
results     = dask.compute(*delayed_results)

tcline       = xr.DataArray(results,coords=[temp.time,temp.lat,temp.lon],
                                dims=['time','lat','lon'],
                                attrs  = {
                                '_FillValue': -9.e+33,
                                'long_name' : 'thermocline depth',
                                'units'     : 'm'
                                })

It doesn’t look like you are using dask in parallel because you’re not using a dask.distributed client.

Take a look at an example COSIMA Recipe to see what is required to properly parallelise xarray computation with dask, e.g.

https://cosima-recipes.readthedocs.io/en/latest/DocumentedExamples/True_Zonal_Mean.html

If that looks too complex then I would suggest enquiring with NCI when they might provide this training again:

https://opus.nci.org.au/display/Help/Parallel+Python

If you are a member of CLEX then another option would be to attend one of their CodeBreak sessions and get some assistance there.

1 Like