Guide and example on how to use nested grids with DataTrees#

This is an example on how to use PyDDA’s ability to handle nested grids using xarray DataTrees. In this example, we load radars with two pre-generated Cf/Radial grid. The fine grids are higher resolution grids that are contained within the coarser grid.

The DataTree structure that PyDDA follows is:

::

root
  |---nest_0/radar_1
  |---nest_0/radar_2
  |---nest_0/radar_n
  |---nest_1/radar_1
  |---nest_1/radar_2
  |---nest_1/radar_m

Each member of this tree is a DataTree itself. PyDDA will know if the DataTree contains data from a radar when the name of the node begins with radar_. The root node of each grid level, in this example, root and inner_nest will contain the keyword arguments that are inputs to :code:pydda.retrieval.get_dd_wind_field as attributes for the tree. PyDDA will use the attributes at each level as the arguments for the retrieval, allowing the user to vary the coefficients by grid level.

Using :code:pydda.retrieval.get_dd_wind_field_nested will allow PyDDA to perform the retrieval on the 0th grid first. It will then perform on the subsequent grid levels, using the previous nest as both the horizontal boundary conditions and initialization for the retrieval in the next nest. Finally, PyDDA will update the winds in the first grid by nearest- neighbor interpolation of the latter grid into the overlapping portion between the inner and outer grid level.

PyDDA will then return the retrieved wind fields as the “u”, “v”, and “w” DataArrays inside each of the root nodes for each level, in this case root and inner_nest.

## Do imports
import pydda
import matplotlib.pyplot as plt
import warnings
from xarray import DataTree

warnings.filterwarnings("ignore")

"""
We will load pregenerated grids for this case.
"""
test_coarse0 = pydda.io.read_grid(pydda.tests.get_sample_file("test_coarse0.nc"))
test_coarse1 = pydda.io.read_grid(pydda.tests.get_sample_file("test_coarse1.nc"))
test_fine0 = pydda.io.read_grid(pydda.tests.get_sample_file("test_fine0.nc"))
test_fine1 = pydda.io.read_grid(pydda.tests.get_sample_file("test_fine1.nc"))

"""
Initalize with a zero wind field. We have HRRR data already generated for this case inside
the example data files to provide a model constraint.
"""
test_coarse0 = pydda.initialization.make_constant_wind_field(
    test_coarse0, (0.0, 0.0, 0.0)
)

"""
Specify the retrieval parameters at each level
"""
kwargs_dict = dict(
    Cm=256.0,
    Co=1e-2,
    Cx=50.0,
    Cy=50.0,
    Cz=50.0,
    Cmod=1e-5,
    model_fields=["hrrr"],
    refl_field="DBZ",
    wind_tol=0.5,
    max_iterations=150,
    engine="scipy",
)

"""
Enforce equal times for each grid. This is required for the DataTree structure since time is an
inherited dimension.
"""
test_coarse1["time"] = test_coarse0["time"]
test_fine0["time"] = test_coarse0["time"]
test_fine1["time"] = test_coarse1["time"]
"""

Provide the overlying grid structure as specified above.
"""
tree_dict = {
    "/nest_0/radar_ktlx": test_coarse0,
    "/nest_0/radar_kict": test_coarse1,
    "/nest_1/radar_ktlx": test_fine0,
    "/nest_1/radar_kict": test_fine1,
}

tree = DataTree.from_dict(tree_dict)
tree["/nest_0/"].attrs = kwargs_dict
tree["/nest_1/"].attrs = kwargs_dict

"""
Perform the retrieval
"""

grid_tree = pydda.retrieval.get_dd_wind_field_nested(tree)

"""
Plot the coarse grid output and finer grid output
"""

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
pydda.vis.plot_horiz_xsection_quiver(
    grid_tree["nest_0"],
    ax=ax[0],
    level=5,
    cmap="ChaseSpectral",
    vmin=-10,
    vmax=80,
    quiverkey_len=10.0,
    background_field="DBZ",
    bg_grid_no=1,
    w_vel_contours=[1, 2, 5, 10],
    quiver_spacing_x_km=50.0,
    quiver_spacing_y_km=50.0,
    quiverkey_loc="bottom_right",
)
pydda.vis.plot_horiz_xsection_quiver(
    grid_tree["nest_1"],
    ax=ax[1],
    level=5,
    cmap="ChaseSpectral",
    vmin=-10,
    vmax=80,
    quiverkey_len=10.0,
    background_field="DBZ",
    bg_grid_no=1,
    w_vel_contours=[1, 2, 5, 10],
    quiver_spacing_x_km=50.0,
    quiver_spacing_y_km=50.0,
    quiverkey_loc="bottom_right",
)

plt.show()
## You are using the Python ARM Radar Toolkit (Py-ART), an open source
## library for working with weather radar data. Py-ART is partly
## supported by the U.S. Department of Energy as part of the Atmospheric
## Radiation Measurement (ARM) Climate Research Facility, an Office of
## Science user facility.
##
## If you use this software to prepare a publication, please cite:
##
##     JJ Helmus and SM Collis, JORS 2016, doi: 10.5334/jors.119
Welcome to PyDDA 2.4.1
If you are using PyDDA in your publications, please cite:
Jackson et al. (2020) Journal of Open Research Science
Detecting Jax...
Jax/JaxOpt are not installed on your system, unable to use Jax engine.
Detecting TensorFlow...
Unable to load both TensorFlow and tensorflow-probability. TensorFlow engine disabled.
No module named 'tensorflow'
False
Calculating weights for radars 0 and 1
Calculating weights for radars 1 and 0
Calculating weights for models...
Starting solver 
rmsVR = 16.213679825765688
Total points: 412470
The max of w_init is 0.0
Total number of model points: 2342081
Nfeval | Jvel    | Jmass   | Jsmooth |   Jbg   | Jvort   | Jmodel  | Jpoint  | Max w  
      0|4115.2916|   0.0000|   0.0000|   0.0000|   0.0000|11243.5276|   0.0000|   0.0000
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[1], line 68
     64 """
     65 Perform the retrieval
     66 """
     67 
---> 68 grid_tree = pydda.retrieval.get_dd_wind_field_nested(tree)
     69 
     70 """
     71 Plot the coarse grid output and finer grid output

File ~/work/PyDDA/PyDDA/pydda/retrieval/nesting.py:56, in get_dd_wind_field_nested(grid_tree, **kwargs)
     54 elif len(grid_list) > 0:
     55     my_kwargs = tree_attrs
---> 56     output_grids, output_parameters = get_dd_wind_field(grid_list, **my_kwargs)
     57     output_parameters = output_parameters.__dict__
     58     if in_parent is True:

File ~/work/PyDDA/PyDDA/pydda/retrieval/wind_retrieve.py:1502, in get_dd_wind_field(Grids, u_init, v_init, w_init, engine, **kwargs)
   1495     w_init = new_grids[0]["w"].values.squeeze()
   1497 if (
   1498     engine.lower() == "scipy"
   1499     or engine.lower() == "jax"
   1500     or engine.lower() == "auglag"
   1501 ):
-> 1502     return _get_dd_wind_field_scipy(
   1503         new_grids, u_init, v_init, w_init, engine, **kwargs
   1504     )
   1505 elif engine.lower() == "tensorflow":
   1506     return _get_dd_wind_field_tensorflow(
   1507         new_grids, u_init, v_init, w_init, **kwargs
   1508     )

File ~/work/PyDDA/PyDDA/pydda/retrieval/wind_retrieve.py:614, in _get_dd_wind_field_scipy(Grids, u_init, v_init, w_init, engine, points, vel_name, refl_field, u_back, v_back, z_back, frz, Co, Cm, Cx, Cy, Cz, Cb, Cv, Cmod, Cpoint, cvtol, gtol, Jveltol, Ut, Vt, low_pass_filter, mask_outside_opt, weights_obs, weights_model, weights_bg, max_iterations, mask_w_outside_opt, filter_window, filter_order, min_bca, max_bca, upper_bc, model_fields, output_cost_functions, roi, wind_tol, tolerance, const_boundary_cond, max_wind_mag, parallel)
    612 parameters.print_out = False
    613 if engine.lower() == "scipy":
--> 614     winds = fmin_l_bfgs_b(
    615         J_function,
    616         winds,
    617         args=(parameters,),
    618         maxiter=max_iterations,
    619         pgtol=tolerance,
    620         bounds=bounds,
    621         fprime=grad_J,
    622         callback=_vert_velocity_callback,
    623     )
    624 else:
    626     def loss_and_gradient(x):

File /usr/share/miniconda/envs/pydda-docs/lib/python3.14/site-packages/scipy/optimize/_lbfgsb_py.py:281, in fmin_l_bfgs_b(func, x0, fprime, args, approx_grad, bounds, m, factr, pgtol, epsilon, iprint, maxfun, maxiter, disp, callback, maxls)
    269 callback = _wrap_callback(callback)
    270 opts = {'disp': disp,
    271         'iprint': iprint,
    272         'maxcor': m,
   (...)    278         'callback': callback,
    279         'maxls': maxls}
--> 281 res = _minimize_lbfgsb(fun, x0, args=args, jac=jac, bounds=bounds,
    282                        **opts)
    283 d = {'grad': res['jac'],
    284      'task': res['message'],
    285      'funcalls': res['nfev'],
    286      'nit': res['nit'],
    287      'warnflag': res['status']}
    288 f = res['fun']

File /usr/share/miniconda/envs/pydda-docs/lib/python3.14/site-packages/scipy/optimize/_lbfgsb_py.py:469, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, workers, **unknown_options)
    461 _lbfgsb.setulb(m, x, low_bnd, upper_bnd, nbd, f, g, factr, pgtol, wa,
    462                iwa, task, lsave, isave, dsave, maxls, ln_task)
    464 if task[0] == 3:
    465     # The minimization routine wants f and g at the current x.
    466     # Note that interruptions due to maxfun are postponed
    467     # until the completion of the current minimization iteration.
    468     # Overwrite f and g:
--> 469     f, g = func_and_grad(x)
    470 elif task[0] == 1:
    471     # new iteration
    472     n_iterations += 1

File /usr/share/miniconda/envs/pydda-docs/lib/python3.14/site-packages/scipy/optimize/_differentiable_functions.py:412, in ScalarFunction.fun_and_grad(self, x)
    410 if not np.array_equal(x, self.x):
    411     self._update_x(x)
--> 412 self._update_fun()
    413 self._update_grad()
    414 return self.f, self.g

File /usr/share/miniconda/envs/pydda-docs/lib/python3.14/site-packages/scipy/optimize/_differentiable_functions.py:362, in ScalarFunction._update_fun(self)
    360 def _update_fun(self):
    361     if not self.f_updated:
--> 362         fx = self._wrapped_fun(self.x)
    363         self._nfev += 1
    364         if fx < self._lowest_f:

File /usr/share/miniconda/envs/pydda-docs/lib/python3.14/site-packages/scipy/_lib/_util.py:603, in _ScalarFunctionWrapper.__call__(self, x)
    600 def __call__(self, x):
    601     # Send a copy because the user may overwrite it.
    602     # The user of this class might want `x` to remain unchanged.
--> 603     fx = self.f(np.copy(x), *self.args)
    604     self.nfev += 1
    606     # Make sure the function returns a true scalar

File ~/work/PyDDA/PyDDA/pydda/cost_functions/cost_functions.py:207, in J_function(winds, parameters)
    204     Jmass = 0
    206 if parameters.Cx > 0 or parameters.Cy > 0 or parameters.Cz > 0:
--> 207     Jsmooth = _cost_functions_numpy.calculate_smoothness_cost(
    208         winds[0],
    209         winds[1],
    210         winds[2],
    211         parameters.dx,
    212         parameters.dy,
    213         parameters.dz,
    214         Cx=parameters.Cx,
    215         Cy=parameters.Cy,
    216         Cz=parameters.Cz,
    217     )
    218 else:
    219     Jsmooth = 0

File ~/work/PyDDA/PyDDA/pydda/cost_functions/_cost_functions_numpy.py:216, in calculate_smoothness_cost(u, v, w, dx, dy, dz, Cx, Cy, Cz)
    214 dvdy = np.gradient(v, dy, axis=1)
    215 dvdz = np.gradient(v, dz, axis=0)
--> 216 dwdx = np.gradient(w, dx, axis=2)
    217 dwdy = np.gradient(w, dy, axis=1)
    218 dwdz = np.gradient(w, dz, axis=0)

File /usr/share/miniconda/envs/pydda-docs/lib/python3.14/site-packages/numpy/lib/_function_base_impl.py:1338, in gradient(f, axis, edge_order, *varargs)
   1336 dx_0 = ax_dx if uniform_spacing else ax_dx[0]
   1337 # 1D equivalent -- out[0] = (f[1] - f[0]) / (x[1] - x[0])
-> 1338 out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_0
   1340 slice1[axis] = -1
   1341 slice2[axis] = -1

KeyboardInterrupt: