Guide and example on how to use nested grids with DataTrees#
This is an example on how to use PyDDA’s ability to handle nested grids using xarray DataTrees. In this example, we load radars with two pre-generated Cf/Radial grid. The fine grids are higher resolution grids that are contained within the coarser grid.
The DataTree structure that PyDDA follows is:
::
root
|---nest_0/radar_1
|---nest_0/radar_2
|---nest_0/radar_n
|---nest_1/radar_1
|---nest_1/radar_2
|---nest_1/radar_m
Each member of this tree is a DataTree itself. PyDDA will know if the
DataTree contains data from a radar when the name of the node begins
with radar_. The root node of each grid level, in this example,
root and inner_nest will contain the keyword arguments that are
inputs to :code:pydda.retrieval.get_dd_wind_field as attributes for the
tree. PyDDA will use the attributes at each level as the arguments for the
retrieval, allowing the user to vary the coefficients by grid level.
Using :code:pydda.retrieval.get_dd_wind_field_nested will allow PyDDA
to perform the retrieval on the 0th grid first. It will then
perform on the subsequent grid levels, using the previous nest as both the
horizontal boundary conditions and initialization for the retrieval in the next
nest. Finally, PyDDA will update the winds in the first grid by nearest-
neighbor interpolation of the latter grid into the overlapping portion between
the inner and outer grid level.
PyDDA will then return the retrieved wind fields as the “u”, “v”, and “w” DataArrays inside each of the root nodes for each level, in this case root and inner_nest.
## Do imports
import pydda
import matplotlib.pyplot as plt
import warnings
from xarray import DataTree
warnings.filterwarnings("ignore")
"""
We will load pregenerated grids for this case.
"""
test_coarse0 = pydda.io.read_grid(pydda.tests.get_sample_file("test_coarse0.nc"))
test_coarse1 = pydda.io.read_grid(pydda.tests.get_sample_file("test_coarse1.nc"))
test_fine0 = pydda.io.read_grid(pydda.tests.get_sample_file("test_fine0.nc"))
test_fine1 = pydda.io.read_grid(pydda.tests.get_sample_file("test_fine1.nc"))
"""
Initalize with a zero wind field. We have HRRR data already generated for this case inside
the example data files to provide a model constraint.
"""
test_coarse0 = pydda.initialization.make_constant_wind_field(
test_coarse0, (0.0, 0.0, 0.0)
)
"""
Specify the retrieval parameters at each level
"""
kwargs_dict = dict(
Cm=256.0,
Co=1e-2,
Cx=50.0,
Cy=50.0,
Cz=50.0,
Cmod=1e-5,
model_fields=["hrrr"],
refl_field="DBZ",
wind_tol=0.5,
max_iterations=150,
engine="scipy",
)
"""
Enforce equal times for each grid. This is required for the DataTree structure since time is an
inherited dimension.
"""
test_coarse1["time"] = test_coarse0["time"]
test_fine0["time"] = test_coarse0["time"]
test_fine1["time"] = test_coarse1["time"]
"""
Provide the overlying grid structure as specified above.
"""
tree_dict = {
"/nest_0/radar_ktlx": test_coarse0,
"/nest_0/radar_kict": test_coarse1,
"/nest_1/radar_ktlx": test_fine0,
"/nest_1/radar_kict": test_fine1,
}
tree = DataTree.from_dict(tree_dict)
tree["/nest_0/"].attrs = kwargs_dict
tree["/nest_1/"].attrs = kwargs_dict
"""
Perform the retrieval
"""
grid_tree = pydda.retrieval.get_dd_wind_field_nested(tree)
"""
Plot the coarse grid output and finer grid output
"""
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
pydda.vis.plot_horiz_xsection_quiver(
grid_tree["nest_0"],
ax=ax[0],
level=5,
cmap="ChaseSpectral",
vmin=-10,
vmax=80,
quiverkey_len=10.0,
background_field="DBZ",
bg_grid_no=1,
w_vel_contours=[1, 2, 5, 10],
quiver_spacing_x_km=50.0,
quiver_spacing_y_km=50.0,
quiverkey_loc="bottom_right",
)
pydda.vis.plot_horiz_xsection_quiver(
grid_tree["nest_1"],
ax=ax[1],
level=5,
cmap="ChaseSpectral",
vmin=-10,
vmax=80,
quiverkey_len=10.0,
background_field="DBZ",
bg_grid_no=1,
w_vel_contours=[1, 2, 5, 10],
quiver_spacing_x_km=50.0,
quiver_spacing_y_km=50.0,
quiverkey_loc="bottom_right",
)
plt.show()
## You are using the Python ARM Radar Toolkit (Py-ART), an open source
## library for working with weather radar data. Py-ART is partly
## supported by the U.S. Department of Energy as part of the Atmospheric
## Radiation Measurement (ARM) Climate Research Facility, an Office of
## Science user facility.
##
## If you use this software to prepare a publication, please cite:
##
## JJ Helmus and SM Collis, JORS 2016, doi: 10.5334/jors.119
Failed to import TF-Keras. Please note that TF-Keras is not installed by default when you install TensorFlow Probability. This is so that JAX-only users do not have to install TensorFlow or TF-Keras. To use TensorFlow Probability with TensorFlow, please install the tf-keras or tf-keras-nightly package.
This can be be done through installing the tensorflow-probability[tf] extra.
Welcome to PyDDA 2.2
If you are using PyDDA in your publications, please cite:
Jackson et al. (2020) Journal of Open Research Science
Detecting Jax...
Jax/JaxOpt are not installed on your system, unable to use Jax engine.
Detecting TensorFlow...
TensorFlow detected. Checking for tensorflow-probability...
Failed to import TF-Keras. Please note that TF-Keras is not installed by default when you install TensorFlow Probability. This is so that JAX-only users do not have to install TensorFlow or TF-Keras. To use TensorFlow Probability with TensorFlow, please install the tf-keras or tf-keras-nightly package.
This can be be done through installing the tensorflow-probability[tf] extra.
Unable to load both TensorFlow and tensorflow-probability. TensorFlow engine disabled.
No module named 'tf_keras'
False
Calculating weights for radars 0 and 1
Calculating weights for radars 1 and 0
Calculating weights for models...
Starting solver
rmsVR = 16.213679825765688
Total points: 412470
The max of w_init is 0.0
Total number of model points: 2342081
Nfeval | Jvel | Jmass | Jsmooth | Jbg | Jvort | Jmodel | Jpoint | Max w
0|4115.2916| 0.0000| 0.0000| 0.0000| 0.0000|11243.5276| 0.0000| 0.0000
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[1], line 68
62 tree["/nest_1/"].attrs = kwargs_dict
64 """
65 Perform the retrieval
66 """
---> 68 grid_tree = pydda.retrieval.get_dd_wind_field_nested(tree)
70 """
71 Plot the coarse grid output and finer grid output
72 """
74 fig, ax = plt.subplots(1, 2, figsize=(10, 5))
File ~/work/PyDDA/PyDDA/pydda/retrieval/nesting.py:56, in get_dd_wind_field_nested(grid_tree, **kwargs)
54 elif len(grid_list) > 0:
55 my_kwargs = tree_attrs
---> 56 output_grids, output_parameters = get_dd_wind_field(grid_list, **my_kwargs)
57 output_parameters = output_parameters.__dict__
58 if in_parent is True:
File ~/work/PyDDA/PyDDA/pydda/retrieval/wind_retrieve.py:1472, in get_dd_wind_field(Grids, u_init, v_init, w_init, engine, **kwargs)
1465 w_init = new_grids[0]["w"].values.squeeze()
1467 if (
1468 engine.lower() == "scipy"
1469 or engine.lower() == "jax"
1470 or engine.lower() == "auglag"
1471 ):
-> 1472 return _get_dd_wind_field_scipy(
1473 new_grids, u_init, v_init, w_init, engine, **kwargs
1474 )
1475 elif engine.lower() == "tensorflow":
1476 return _get_dd_wind_field_tensorflow(
1477 new_grids, u_init, v_init, w_init, **kwargs
1478 )
File ~/work/PyDDA/PyDDA/pydda/retrieval/wind_retrieve.py:602, in _get_dd_wind_field_scipy(Grids, u_init, v_init, w_init, engine, points, vel_name, refl_field, u_back, v_back, z_back, frz, Co, Cm, Cx, Cy, Cz, Cb, Cv, Cmod, Cpoint, cvtol, gtol, Jveltol, Ut, Vt, low_pass_filter, mask_outside_opt, weights_obs, weights_model, weights_bg, max_iterations, mask_w_outside_opt, filter_window, filter_order, min_bca, max_bca, upper_bc, model_fields, output_cost_functions, roi, wind_tol, tolerance, const_boundary_cond, max_wind_mag)
600 parameters.print_out = False
601 if engine.lower() == "scipy":
--> 602 winds = fmin_l_bfgs_b(
603 J_function,
604 winds,
605 args=(parameters,),
606 maxiter=max_iterations,
607 pgtol=tolerance,
608 bounds=bounds,
609 fprime=grad_J,
610 callback=_vert_velocity_callback,
611 )
612 else:
614 def loss_and_gradient(x):
File /usr/share/miniconda/envs/pydda-docs/lib/python3.12/site-packages/scipy/optimize/_lbfgsb_py.py:281, in fmin_l_bfgs_b(func, x0, fprime, args, approx_grad, bounds, m, factr, pgtol, epsilon, iprint, maxfun, maxiter, disp, callback, maxls)
269 callback = _wrap_callback(callback)
270 opts = {'disp': disp,
271 'iprint': iprint,
272 'maxcor': m,
(...) 278 'callback': callback,
279 'maxls': maxls}
--> 281 res = _minimize_lbfgsb(fun, x0, args=args, jac=jac, bounds=bounds,
282 **opts)
283 d = {'grad': res['jac'],
284 'task': res['message'],
285 'funcalls': res['nfev'],
286 'nit': res['nit'],
287 'warnflag': res['status']}
288 f = res['fun']
File /usr/share/miniconda/envs/pydda-docs/lib/python3.12/site-packages/scipy/optimize/_lbfgsb_py.py:469, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, workers, **unknown_options)
461 _lbfgsb.setulb(m, x, low_bnd, upper_bnd, nbd, f, g, factr, pgtol, wa,
462 iwa, task, lsave, isave, dsave, maxls, ln_task)
464 if task[0] == 3:
465 # The minimization routine wants f and g at the current x.
466 # Note that interruptions due to maxfun are postponed
467 # until the completion of the current minimization iteration.
468 # Overwrite f and g:
--> 469 f, g = func_and_grad(x)
470 elif task[0] == 1:
471 # new iteration
472 n_iterations += 1
File /usr/share/miniconda/envs/pydda-docs/lib/python3.12/site-packages/scipy/optimize/_differentiable_functions.py:404, in ScalarFunction.fun_and_grad(self, x)
402 self._update_x(x)
403 self._update_fun()
--> 404 self._update_grad()
405 return self.f, self.g
File /usr/share/miniconda/envs/pydda-docs/lib/python3.12/site-packages/scipy/optimize/_differentiable_functions.py:366, in ScalarFunction._update_grad(self)
364 if self._orig_grad in FD_METHODS:
365 self._update_fun()
--> 366 self.g = self._wrapped_grad(self.x, f0=self.f)
367 self.g_updated = True
File /usr/share/miniconda/envs/pydda-docs/lib/python3.12/site-packages/scipy/optimize/_differentiable_functions.py:39, in _ScalarGradWrapper.__call__(self, x, f0, **kwds)
35 def __call__(self, x, f0=None, **kwds):
36 # Send a copy because the user may overwrite it.
37 # The user of this class might want `x` to remain unchanged.
38 if callable(self.grad):
---> 39 g = np.atleast_1d(self.grad(np.copy(x), *self.args))
40 elif self.grad in FD_METHODS:
41 g, dct = approx_derivative(
42 self.fun,
43 x,
44 f0=f0,
45 **self.finite_diff_options,
46 )
File ~/work/PyDDA/PyDDA/pydda/cost_functions/cost_functions.py:541, in grad_J(winds, parameters)
528 grad += _cost_functions_numpy.calculate_mass_continuity_gradient(
529 winds[0],
530 winds[1],
(...) 537 upper_bc=parameters.upper_bc,
538 )
540 if parameters.Cx > 0 or parameters.Cy > 0 or parameters.Cz > 0:
--> 541 grad += _cost_functions_numpy.calculate_smoothness_gradient(
542 winds[0],
543 winds[1],
544 winds[2],
545 parameters.dx,
546 parameters.dy,
547 parameters.dz,
548 Cx=parameters.Cx,
549 Cy=parameters.Cy,
550 Cz=parameters.Cz,
551 upper_bc=parameters.upper_bc,
552 )
554 if parameters.Cb > 0:
555 grad += _cost_functions_numpy.calculate_background_gradient(
556 winds[0],
557 winds[1],
(...) 562 parameters.Cb,
563 )
File ~/work/PyDDA/PyDDA/pydda/cost_functions/_cost_functions_numpy.py:245, in calculate_smoothness_gradient(u, v, w, dx, dy, dz, Cx, Cy, Cz, upper_bc)
243 grad_v = np.zeros(w.shape)
244 grad_w = np.zeros(w.shape)
--> 245 scipy.ndimage.laplace(u, du, mode="wrap")
246 scipy.ndimage.laplace(v, dv, mode="wrap")
247 scipy.ndimage.laplace(w, dw, mode="wrap")
File /usr/share/miniconda/envs/pydda-docs/lib/python3.12/site-packages/scipy/ndimage/_filters.py:1070, in laplace(input, output, mode, cval, axes)
1068 def derivative2(input, axis, output, mode, cval):
1069 return correlate1d(input, [1, -2, 1], axis, output, mode, cval, 0)
-> 1070 return generic_laplace(input, derivative2, output, mode, cval, axes=axes)
File /usr/share/miniconda/envs/pydda-docs/lib/python3.12/site-packages/scipy/ndimage/_filters.py:1027, in generic_laplace(input, derivative2, output, mode, cval, extra_arguments, extra_keywords, axes)
1024 derivative2(input, axes[0], output, modes[0], cval,
1025 *extra_arguments, **extra_keywords)
1026 for ii in range(1, len(axes)):
-> 1027 tmp = derivative2(input, axes[ii], output.dtype, modes[ii], cval,
1028 *extra_arguments, **extra_keywords)
1029 output += tmp
1030 else:
File /usr/share/miniconda/envs/pydda-docs/lib/python3.12/site-packages/scipy/ndimage/_filters.py:1069, in laplace.<locals>.derivative2(input, axis, output, mode, cval)
1068 def derivative2(input, axis, output, mode, cval):
-> 1069 return correlate1d(input, [1, -2, 1], axis, output, mode, cval, 0)
File /usr/share/miniconda/envs/pydda-docs/lib/python3.12/site-packages/scipy/ndimage/_filters.py:610, in correlate1d(input, weights, axis, output, mode, cval, origin)
606 raise ValueError('Invalid origin; origin must satisfy '
607 '-(len(weights) // 2) <= origin <= '
608 '(len(weights)-1) // 2')
609 mode = _ni_support._extend_mode_to_code(mode)
--> 610 _nd_image.correlate1d(input, weights, axis, output, mode, cval,
611 origin)
612 return output
KeyboardInterrupt: