def test_plot_data(): dp = DataPack( '/home/albert/data/gains_screen/data/L342938_DDS5_full_merged.h5', readonly=True) with dp: select = dict(pol=slice(0, 1, 1), ant=[0, 50], time=slice(0, 100, 1)) dp.current_solset = 'sol000' dp.select(**select) tec_mean, axes = dp.tec tec_mean = tec_mean[0, ...] const_mean, axes = dp.const const_mean = const_mean[0, ...] tec_std, axes = dp.weights_tec tec_std = tec_std[0, ...] patch_names, directions = dp.get_directions(axes['dir']) antenna_labels, antennas = dp.get_antennas(axes['ant']) timestamps, times = dp.get_times(axes['time']) antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0]) ref_ant = antennas[0] for i, time in enumerate(times): frame = ENU(obstime=time, location=ref_ant.earth_location) antennas = antennas.transform_to(frame) directions = directions.transform_to(frame) x = antennas.cartesian.xyz.to(au.km).value.T k = directions.cartesian.xyz.value.T dtec = tec_mean[:, 1, i] ax = plot_vornoi_map(k[:, 0:2], dtec) ax.set_title(time) plt.show()
def itrs_to_enu_6D(X, ref_location=None): """ Convert the given coordinates from ITRS to ENU :param X: float array [b0,...,bB,6] The coordinates are ordered [time, ra, dec, itrs.x, itrs.y, itrs.z] :param ref_location: float array [3] Point about which to rotate frame. :return: float array [b0,...,bB, 7] The transforms coordinates. """ time = np.unique(X[..., 0]) if time.size > 1: raise ValueError("Times should be the same.") shape = X.shape[:-1] X = X.reshape((-1, 6)) if ref_location is None: ref_location = X[0, 3:] obstime = at.Time(time / 86400., format='mjd') location = ac.ITRS(x=ref_location[0] * dist_type, y=ref_location[1] * dist_type, z=ref_location[2] * dist_type) enu = ENU(location=location, obstime=obstime) antennas = ac.ITRS(x=X[:, 3] * dist_type, y=X[:, 4] * dist_type, z=X[:, 5] * dist_type, obstime=obstime) antennas = antennas.transform_to(enu).cartesian.xyz.to(dist_type).value.T directions = ac.ICRS(ra=X[:, 1] * angle_type, dec=X[:, 2] * angle_type) directions = directions.transform_to(enu).cartesian.xyz.value.T return np.concatenate([X[:, 0:1], directions, antennas], axis=1).reshape(shape + (7, ))
def visualisation(h5parm, ant=None, time=None): with DataPack(h5parm, readonly=True) as dp: dp.current_solset = 'sol000' dp.select(ant=ant, time=time) dtec, axes = dp.tec dtec = dtec[0] patch_names, directions = dp.get_directions(axes['dir']) antenna_labels, antennas = dp.get_antennas(axes['ant']) timestamps, times = dp.get_times(axes['time']) frame = ENU(obstime=time, location=antennas[0].earth_location) directions = directions.transform_to(frame) t = times.mjd * 86400. t -= t[0] dt = np.diff(t).mean() x = antennas.cartesian.xyz.to(au.km).value.T[1:, :] # x[1,:] = x[0,:] # x[1,0] += 0.3 k = directions.cartesian.xyz.value.T logger.info(f"Directions: {directions}") logger.info(f"Antennas: {x} {antenna_labels}") logger.info(f"Times: {t}") Na = x.shape[0] logger.info(f"Number of antenna to plot: {Na}") Nd = k.shape[0] Nt = t.shape[0] fig, axs = plt.subplots(Na, Nt, sharex=True, sharey=True, figsize=(2 * Nt, 2 * Na), squeeze=False) for a in range(Na): for i in range(Nt): ax = axs[a][i] ax = plot_vornoi_map(k[:, 0:2], dtec[:, a, i], ax=ax, colorbar=False) if a == (Na - 1): ax.set_xlabel(r"$k_{\rm east}$") if i == 0: ax.set_ylabel(r"$k_{\rm north}$") if a == 0: ax.set_title(f"Time: {int(t[i])} sec") plt.show()
def transform(X, ref_location=ref_location): """ Convert the given coordinates from ITRS to ENU :param X: float array [Nd, Na,6] The coordinates are ordered [time, ra, dec, itrs.x, itrs.y, itrs.z] :return: float array [Nd, Na, 7(10(13))] The transforms coordinates. """ time = np.unique(X[..., 0]) if time.size > 1: raise ValueError("Times should be the same.") shape = X.shape[:-1] X = X.reshape((-1, 6)) if ref_location is None: ref_location = X[0, 3:6] obstime = at.Time(time / 86400., format='mjd') location = ac.ITRS(x=ref_location[0] * dist_type, y=ref_location[1] * dist_type, z=ref_location[2] * dist_type) ref_ant = ac.ITRS(x=ref_antenna[0] * dist_type, y=ref_antenna[1] * dist_type, z=ref_antenna[2] * dist_type, obstime=obstime) ref_dir = ac.ICRS(ra=ref_direction[0] * angle_type, dec=ref_direction[1] * angle_type) enu = ENU(location=location, obstime=obstime) ref_ant = ref_ant.transform_to(enu).cartesian.xyz.to(dist_type).value.T ref_dir = ref_dir.transform_to(enu).cartesian.xyz.value.T antennas = ac.ITRS(x=X[:, 3] * dist_type, y=X[:, 4] * dist_type, z=X[:, 5] * dist_type, obstime=obstime) antennas = antennas.transform_to(enu).cartesian.xyz.to( dist_type).value.T directions = ac.ICRS(ra=X[:, 1] * angle_type, dec=X[:, 2] * angle_type) directions = directions.transform_to(enu).cartesian.xyz.value.T result = np.concatenate([X[:, 0:1], directions, antennas], axis=1) # if ref_antenna is not None: result = np.concatenate( [result, np.tile(ref_ant, (result.shape[0], 1))], axis=-1) if ref_direction is not None: result = np.concatenate( [result, np.tile(ref_dir, (result.shape[0], 1))], axis=-1) result = result.reshape(shape + result.shape[-1:]) return result
def test_tomographic_kernel(): dp = make_example_datapack(500, 24, 1, clobber=True) with dp: select = dict(pol=slice(0, 1, 1), ant=slice(0, None, 1)) dp.current_solset = 'sol000' dp.select(**select) tec_mean, axes = dp.tec tec_mean = tec_mean[0, ...] patch_names, directions = dp.get_directions(axes['dir']) antenna_labels, antennas = dp.get_antennas(axes['ant']) timestamps, times = dp.get_times(axes['time']) antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0]) ref_ant = antennas[0] frame = ENU(obstime=times[0], location=ref_ant.earth_location) antennas = antennas.transform_to(frame) ref_ant = antennas[0] directions = directions.transform_to(frame) x = antennas.cartesian.xyz.to(au.km).value.T k = directions.cartesian.xyz.value.T X = make_coord_array(x[50:51, :], k) x0 = ref_ant.cartesian.xyz.to(au.km).value print(k.shape) kernel = TomographicKernel(x0, x0, RBF(), S_marg=25) K = jit(lambda X: kernel( X, X, bottom=200., width=50., fed_kernel_params=dict(l=7., sigma=1.)))( jnp.asarray(X)) # K /= jnp.outer(jnp.sqrt(jnp.diag(K)), jnp.sqrt(jnp.diag(K))) plt.imshow(K) plt.colorbar() plt.show() L = jnp.linalg.cholesky(K + 1e-6 * jnp.eye(K.shape[0])) print(L) dtec = L @ random.normal(random.PRNGKey(24532), shape=(K.shape[0], )) print(jnp.std(dtec)) ax = plot_vornoi_map(k[:, 0:2], dtec) ax.set_xlabel(r"$k_{\rm east}$") ax.set_ylabel(r"$k_{\rm north}$") ax.set_xlim(-0.1, 0.1) ax.set_ylim(-0.1, 0.1) plt.show()
def test_enu(): lofar_array = np.random.normal(size=[10, 3]) antennas = lofar_array[1] obstime = at.Time("2018-01-01T00:00:00.000", format='isot') location = ac.ITRS(x=antennas[0, 0] * dist_type, y=antennas[0, 1] * dist_type, z=antennas[0, 2] * dist_type) enu = ENU(obstime=obstime, location=location.earth_location) altaz = ac.AltAz(obstime=obstime, location=location.earth_location) lofar_antennas = ac.ITRS(x=antennas[:, 0] * dist_type, y=antennas[:, 1] * dist_type, z=antennas[:, 2] * dist_type, obstime=obstime) assert np.all( np.linalg.norm(lofar_antennas.transform_to(enu).cartesian.xyz.to( dist_type).value, axis=0) < 100.) assert np.all( np.isclose( lofar_antennas.transform_to(enu).cartesian.xyz.to(dist_type).value, lofar_antennas.transform_to(enu).transform_to(altaz).transform_to( enu).cartesian.xyz.to(dist_type).value)) assert np.all( np.isclose( lofar_antennas.transform_to(altaz).cartesian.xyz.to( dist_type).value, lofar_antennas.transform_to(altaz).transform_to(enu).transform_to( altaz).cartesian.xyz.to(dist_type).value)) north_enu = ac.SkyCoord(east=0., north=1., up=0., frame=enu) north_altaz = ac.SkyCoord(az=0 * au.deg, alt=0 * au.deg, distance=1., frame=altaz) assert np.all( np.isclose( north_enu.transform_to(altaz).cartesian.xyz.value, north_altaz.cartesian.xyz.value)) assert np.all( np.isclose(north_enu.cartesian.xyz.value, north_altaz.transform_to(enu).cartesian.xyz.value)) east_enu = ac.SkyCoord(east=1., north=0., up=0., frame=enu) east_altaz = ac.SkyCoord(az=90 * au.deg, alt=0 * au.deg, distance=1., frame=altaz) assert np.all( np.isclose( east_enu.transform_to(altaz).cartesian.xyz.value, east_altaz.cartesian.xyz.value)) assert np.all( np.isclose(east_enu.cartesian.xyz.value, east_altaz.transform_to(enu).cartesian.xyz.value)) up_enu = ac.SkyCoord(east=0., north=0., up=1., frame=enu) up_altaz = ac.SkyCoord(az=0 * au.deg, alt=90 * au.deg, distance=1., frame=altaz) assert np.all( np.isclose( up_enu.transform_to(altaz).cartesian.xyz.value, up_altaz.cartesian.xyz.value)) assert np.all( np.isclose(up_enu.cartesian.xyz.value, up_altaz.transform_to(enu).cartesian.xyz.value)) ### # dimensionful north_enu = ac.SkyCoord(east=0. * dist_type, north=1. * dist_type, up=0. * dist_type, frame=enu) north_altaz = ac.SkyCoord(az=0 * au.deg, alt=0 * au.deg, distance=1. * dist_type, frame=altaz) assert np.all( np.isclose( north_enu.transform_to(altaz).cartesian.xyz.to(dist_type).value, north_altaz.cartesian.xyz.to(dist_type).value)) assert np.all( np.isclose( north_enu.cartesian.xyz.to(dist_type).value, north_altaz.transform_to(enu).cartesian.xyz.to(dist_type).value)) east_enu = ac.SkyCoord(east=1. * dist_type, north=0. * dist_type, up=0. * dist_type, frame=enu) east_altaz = ac.SkyCoord(az=90 * au.deg, alt=0 * au.deg, distance=1. * dist_type, frame=altaz) assert np.all( np.isclose( east_enu.transform_to(altaz).cartesian.xyz.to(dist_type).value, east_altaz.cartesian.xyz.to(dist_type).value)) assert np.all( np.isclose( east_enu.cartesian.xyz.to(dist_type).value, east_altaz.transform_to(enu).cartesian.xyz.to(dist_type).value)) up_enu = ac.SkyCoord(east=0. * dist_type, north=0. * dist_type, up=1. * dist_type, frame=enu) up_altaz = ac.SkyCoord(az=0 * au.deg, alt=90 * au.deg, distance=1. * dist_type, frame=altaz) assert np.all( np.isclose( up_enu.transform_to(altaz).cartesian.xyz.to(dist_type).value, up_altaz.cartesian.xyz.to(dist_type).value)) assert np.all( np.isclose( up_enu.cartesian.xyz.to(dist_type).value, up_altaz.transform_to(enu).cartesian.xyz.to(dist_type).value))
def test_compare_with_forward_model(): dp = make_example_datapack(5, 24, 1, clobber=True) with dp: select = dict(pol=slice(0, 1, 1), ant=slice(0, None, 1)) dp.current_solset = 'sol000' dp.select(**select) tec_mean, axes = dp.tec tec_mean = tec_mean[0, ...] patch_names, directions = dp.get_directions(axes['dir']) antenna_labels, antennas = dp.get_antennas(axes['ant']) timestamps, times = dp.get_times(axes['time']) antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0]) ref_ant = antennas[0] earth_centre = ac.ITRS(x=0 * au.m, y=0 * au.m, z=0. * au.m, obstime=times[0]) frame = ENU(obstime=times[0], location=ref_ant.earth_location) antennas = antennas.transform_to(frame) earth_centre = earth_centre.transform_to(frame) ref_ant = antennas[0] directions = directions.transform_to(frame) x = antennas.cartesian.xyz.to(au.km).value.T[20:21, :] k = directions.cartesian.xyz.value.T X = make_coord_array(x, k) x0 = ref_ant.cartesian.xyz.to(au.km).value earth_centre = earth_centre.cartesian.xyz.to(au.km).value bottom = 200. width = 50. l = 10. sigma = 1. fed_kernel_params = dict(l=l, sigma=sigma) S_marg = 1000 fed_kernel = M12() def get_points_on_rays(X): x = X[0:3] k = X[3:6] smin = (bottom - (x[2] - x0[2])) / k[2] smax = (bottom + width - (x[2] - x0[2])) / k[2] ds = (smax - smin) _x = x + k * smin _k = k * ds t = jnp.linspace(0., 1., S_marg + 1) return _x + _k * t[:, None], ds / S_marg points, ds = vmap(get_points_on_rays)(X) points = points.reshape((-1, 3)) plt.scatter(points[:, 1], points[:, 2], marker='.') plt.show() K = fed_kernel(points, points, l, sigma) plt.imshow(K) plt.show() L = jnp.linalg.cholesky(K + 1e-6 * jnp.eye(K.shape[0])) Z = L @ random.normal(random.PRNGKey(1245214), (L.shape[0], 3000)) Z = Z.reshape((directions.shape[0], -1, Z.shape[1])) Y = jnp.sum(Z * ds[:, None, None], axis=1) K = jnp.mean(Y[:, None, :] * Y[None, :, :], axis=2) # print("Directly Computed TEC Covariance",K) plt.imshow(K) plt.colorbar() plt.title("Directly Computed TEC Covariance") plt.show() # kernel = TomographicKernel(x0, x0,fed_kernel, S_marg=200, compute_tec=False) # K = kernel(X, X, bottom, width, fed_kernel_params) # plt.imshow(K) # plt.colorbar() # plt.title("Analytic Weighted TEC Covariance") # plt.show() # # print("Analytic Weighted TEC Covariance",K) # print(x0, earth_centre, fed_kernel) kernel = TomographicKernel(x0, earth_centre, fed_kernel, S_marg=200, compute_tec=True) # print(X) X1 = GeodesicTuple(x=X[:, 0:3], k=X[:, 3:6], t=jnp.zeros_like(X[:, :1]), ref_x=x0) print(X1) K = kernel(X1, X1, bottom, width, fed_kernel_params) plt.imshow(K) plt.colorbar() plt.title("Analytic TEC Covariance") plt.show() print("Analytic TEC Covariance", K)
def visualise_grid(self): for d in range(0, 5): with DataPack(self._input_datapack, readonly=True) as dp: dp.current_solset = 'sol000' dp.select(pol=slice(0, 1, 1), dir=d, time=0) tec_grid, axes = dp.tec tec_grid = tec_grid[0] patch_names, directions_grid = dp.get_directions(axes['dir']) antenna_labels, antennas_grid = dp.get_antennas(axes['ant']) ref_ant = antennas_grid[0] timestamps, times_grid = dp.get_times(axes['time']) frame = ENU(location=ref_ant.earth_location, obstime=times_grid[0]) antennas_grid = ac.ITRS( *antennas_grid.cartesian.xyz, obstime=times_grid[0]).transform_to(frame) ant_pos = antennas_grid.cartesian.xyz.to(au.km).value.T plt.scatter(ant_pos[:, 0], ant_pos[:, 1], c=tec_grid[0, :, 0], cmap=plt.cm.PuOr) plt.xlabel('East [km]') plt.ylabel("North [km]") plt.title(f"Direction {repr(directions_grid)}") plt.show() ant_scatter_args = (ant_pos[:, 0], ant_pos[:, 1], tec_grid[0, :, 0]) for a in [0, 50, 150, 200, 250]: with DataPack(self._input_datapack, readonly=True) as dp: dp.current_solset = 'sol000' dp.select(pol=slice(0, 1, 1), ant=a, time=0) tec_grid, axes = dp.tec tec_grid = tec_grid[0] patch_names, directions_grid = dp.get_directions(axes['dir']) antenna_labels, antennas_grid = dp.get_antennas(axes['ant']) timestamps, times_grid = dp.get_times(axes['time']) frame = ENU(location=ref_ant.earth_location, obstime=times_grid[0]) antennas_grid = ac.ITRS( *antennas_grid.cartesian.xyz, obstime=times_grid[0]).transform_to(frame) _ant_pos = antennas_grid.cartesian.xyz.to(au.km).value.T[0] fig, axs = plt.subplots(2, 1, figsize=(4, 8)) axs[0].scatter(*ant_scatter_args[0:2], c=ant_scatter_args[2], cmap=plt.cm.PuOr, alpha=0.5) axs[0].scatter(*_ant_pos[0:2], marker='x', c='red') axs[0].set_xlabel('East [km]') axs[0].set_ylabel('North [km]') pos = 180 / np.pi * np.stack( [wrap(directions_grid.ra.rad), wrap(directions_grid.dec.rad)], axis=-1) plot_vornoi_map(pos, tec_grid[:, 0, 0], fov_circle=True, ax=axs[1]) axs[1].set_xlabel('RA(2000) [ded]') axs[1].set_ylabel('DEC(2000) [ded]') plt.show()
def train_neural_network(datapack: DataPack, batch_size, learning_rate, num_batches): with datapack: select = dict(pol=slice(0, 1, 1), ant=None, time=slice(0, 1, 1)) datapack.current_solset = 'sol000' datapack.select(**select) axes = datapack.axes_tec patch_names, directions = datapack.get_directions(axes['dir']) antenna_labels, antennas = datapack.get_antennas(axes['ant']) timestamps, times = datapack.get_times(axes['time']) antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0]) ref_ant = antennas[0] frame = ENU(obstime=times[0], location=ref_ant.earth_location) antennas = antennas.transform_to(frame) ref_ant = antennas[0] directions = directions.transform_to(frame) x = antennas.cartesian.xyz.to(au.km).value.T k = directions.cartesian.xyz.value.T t = times.mjd t -= t[len(t) // 2] t *= 86400. n_screen = 250 kstar = random.uniform(random.PRNGKey(29428942), (n_screen, 3), minval=jnp.min(k, axis=0), maxval=jnp.max(k, axis=0)) kstar /= jnp.linalg.norm(kstar, axis=-1, keepdims=True) X = jnp.asarray( make_coord_array(x, jnp.concatenate([k, kstar], axis=0), t[:, None])) x0 = jnp.asarray(antennas.cartesian.xyz.to(au.km).value.T[0, :]) ref_ant = x0 kernel = TomographicKernel(x0, ref_ant, RBF(), S_marg=100) neural_kernel = NeuralTomographicKernel(x0, ref_ant) def loss(params, key): keys = random.split(key, 5) indices = random.permutation(keys[0], jnp.arange(X.shape[0]))[:batch_size] X_batch = X[indices, :] wind_velocity = random.uniform(keys[1], shape=(3, ), minval=jnp.asarray([-200., -200., 0.]), maxval=jnp.asarray([200., 200., 0. ])) / 1000. bottom = random.uniform(keys[2], minval=50., maxval=500.) width = random.uniform(keys[3], minval=40., maxval=300.) l = random.uniform(keys[4], minval=1., maxval=30.) sigma = 1. K = kernel(X_batch, X_batch, bottom, width, l, sigma, wind_velocity=wind_velocity) neural_kernel.set_params(params) neural_K = neural_kernel(X_batch, X_batch, bottom, width, l, sigma, wind_velocity=wind_velocity) return jnp.mean((K - neural_K)**2) / width**2 init_params = neural_kernel.init_params(random.PRNGKey(42)) def train_one_batch(params, key): l, g = value_and_grad(lambda params: loss(params, key))(params) params = tree_multimap(lambda p, g: p - learning_rate * g, params, g) return params, l final_params, losses = jit(lambda key: scan( train_one_batch, init_params, random.split(key, num_batches)))( random.PRNGKey(42)) plt.plot(losses) plt.yscale('log') plt.show()
def run(self, output_h5parm, ncpu, avg_direction_spacing, field_of_view_diameter, duration, time_resolution, start_time, array_name, phase_tracking): Nd = get_num_directions( avg_direction_spacing, field_of_view_diameter, ) Nf = 2 # 8000 Nt = int(duration / time_resolution) + 1 min_freq = 700. max_freq = 2000. dp = create_empty_datapack( Nd, Nf, Nt, pols=None, field_of_view_diameter=field_of_view_diameter, start_time=start_time, time_resolution=time_resolution, min_freq=min_freq, max_freq=max_freq, array_file=ARRAYS[array_name], phase_tracking=(phase_tracking.ra.deg, phase_tracking.dec.deg), save_name=output_h5parm, clobber=True) with dp: dp.current_solset = 'sol000' dp.select(pol=slice(0, 1, 1)) axes = dp.axes_tec patch_names, directions = dp.get_directions(axes['dir']) antenna_labels, antennas = dp.get_antennas(axes['ant']) timestamps, times = dp.get_times(axes['time']) ref_ant = antennas[0] ref_time = times[0] Na = len(antennas) Nd = len(directions) Nt = len(times) logger.info(f"Number of directions: {Nd}") logger.info(f"Number of antennas: {Na}") logger.info(f"Number of times: {Nt}") logger.info(f"Reference Ant: {ref_ant}") logger.info(f"Reference Time: {ref_time.isot}") # Plot Antenna Layout in East North Up frame ref_frame = ENU(obstime=ref_time, location=ref_ant.earth_location) _antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=ref_time).transform_to(ref_frame) # plt.scatter(_antennas.east, _antennas.north, marker='+') # plt.xlabel(f"East (m)") # plt.ylabel(f"North (m)") # plt.show() x0 = ac.ITRS( *antennas[0].cartesian.xyz, obstime=ref_time).transform_to(ref_frame).cartesian.xyz.to( au.km).value earth_centre_x = ac.ITRS( x=0 * au.m, y=0 * au.m, z=0. * au.m, obstime=ref_time).transform_to(ref_frame).cartesian.xyz.to( au.km).value self._kernel = TomographicKernel(x0, earth_centre_x, M32(), S_marg=20, compute_tec=False) k = directions.transform_to(ref_frame).cartesian.xyz.value.T t = times.mjd * 86400. t -= t[0] X1 = GeodesicTuple(x=[], k=[], t=[], ref_x=[]) logger.info("Computing coordinates in frame ...") for i, time in tqdm(enumerate(times)): x = ac.ITRS(*antennas.cartesian.xyz, obstime=time).transform_to(ref_frame).cartesian.xyz.to( au.km).value.T ref_ant_x = ac.ITRS( *ref_ant.cartesian.xyz, obstime=time).transform_to(ref_frame).cartesian.xyz.to( au.km).value X = make_coord_array(x, k, t[i:i + 1, None], ref_ant_x[None, :], flat=True) X1.x.append(X[:, 0:3]) X1.k.append(X[:, 3:6]) X1.t.append(X[:, 6:7]) X1.ref_x.append(X[:, 7:8]) X1 = X1._replace( x=jnp.concatenate(X1.x, axis=0), k=jnp.concatenate(X1.k, axis=0), t=jnp.concatenate(X1.t, axis=0), ref_x=jnp.concatenate(X1.ref_x, axis=0), ) logger.info(f"Total number of coordinates: {X1.x.shape[0]}") def compute_covariance_row(X1: GeodesicTuple, X2: GeodesicTuple): K = self._kernel(X1, X2, self._bottom, self._width, self._fed_sigma, self._fed_kernel_params, wind_velocity=self._wind_vector) # 1, N return K[0, :] covariance_row = lambda X: compute_covariance_row( tree_map(lambda x: x.reshape((1, -1)), X), X1) mean = jit(lambda X1: self._kernel.mean_function(X1, self._bottom, self._width, self._fed_mu, wind_velocity=self. _wind_vector))(X1) cov = chunked_pmap(covariance_row, X1, batch_size=X1.x.shape[0], chunksize=ncpu) plt.imshow(cov) plt.show() Z = random.normal(random.PRNGKey(42), (cov.shape[0], 1), dtype=cov.dtype) t0 = default_timer() jitter = 1e-6 logger.info(f"Computing Cholesky with jitter: {jitter}") L = jnp.linalg.cholesky(cov + jitter * jnp.eye(cov.shape[0])) if np.any(np.isnan(L)): logger.info("Numerically instable. Using SVD.") L = msqrt(cov) logger.info(f"Cholesky took {default_timer() - t0} seconds.") dtec = (L @ Z + mean[:, None])[:, 0].reshape((Na, Nd, Nt)).transpose( (1, 0, 2)) logger.info(f"Saving result to {output_h5parm}") with dp: dp.current_solset = 'sol000' dp.select(pol=slice(0, 1, 1)) dp.tec = np.asarray(dtec[None])