def _compute_with_sigma(sigma): def _compute(dtec, dtec_uncert): return log_normal_with_outliers(dtec, 0., sigma**2 * K, dtec_uncert) # M return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1)
def precompute_log_prob_components_without_wind(kernel, X, dtec, dtec_uncert, bottom_array, width_array, lengthscale_array, sigma_array, chunksize=2): """ Precompute the log_prob for each parameter. Args: kernel: X: dtec: dtec_uncert: *arrays: Returns: """ arrays = jnp.meshgrid(bottom_array, width_array, lengthscale_array, indexing='ij') arrays = [a.ravel() for a in arrays] def compute_log_prob_components(bottom, width, lengthscale): # N, N K = kernel(X, X, bottom, width, lengthscale, 1., wind_velocity=None) def _compute_with_sigma(sigma): def _compute(dtec, dtec_uncert): return log_normal_with_outliers(dtec, 0., sigma**2 * K, dtec_uncert) # M return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1) # Ns,M return chunked_pmap(_compute_with_sigma, sigma_array, chunksize=1) Nb = bottom_array.shape[0] Nw = width_array.shape[0] Nl = lengthscale_array.shape[0] Ns = sigma_array.shape[0] # Nb*Nw*Nl,Ns,M log_prob = chunked_pmap(compute_log_prob_components, *arrays, chunksize=chunksize) # M, Nb,Nw,Nl,Ns log_prob = log_prob.reshape((Nb * Nw * Nl * Ns, dtec.shape[0])).transpose( (1, 0)).reshape((dtec.shape[0], Nb, Nw, Nl, Ns)) return log_prob
def smooth(y, weights): y = axes_move(y, ['dat', 'b'], ['da', 'tb'], size_dict=size_dict) weights = axes_move(weights, ['dat', 'b'], ['da', 'tb'], size_dict=size_dict) y = chunked_pmap( lambda y, weights: poly_smooth(times, y, deg=5, weights=weights), y, weights) y = axes_move(y, ['da', 'tb'], ['dat', 'b'], size_dict=size_dict) return y
def compute_log_prob_components(lengthscale): # N, N K = kernel(X, X, lengthscale, 1.) def _compute_with_sigma(sigma): def _compute(dtec, dtec_uncert): #each [Nd] return log_normal_with_outliers(dtec, 0., sigma**2 * K, dtec_uncert) return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1) #M # Ns,M return chunked_pmap(_compute_with_sigma, sigma_array, chunksize=1)
def compute_log_prob_components(bottom, width, lengthscale): # N, N K = kernel(X, X, bottom, width, lengthscale, 1., wind_velocity=None) def _compute_with_sigma(sigma): def _compute(dtec, dtec_uncert): return log_normal_with_outliers(dtec, 0., sigma**2 * K, dtec_uncert) # M return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1) # Ns,M return chunked_pmap(_compute_with_sigma, sigma_array, chunksize=1)
def get_data(solution_file): logger.info("Getting DDS4 data.") with DataPack(solution_file, readonly=True) as h: select = dict(pol=slice(0, 1, 1), ant=slice(50, 51), dir=slice(0, None, 1), time=slice(0, 100, 1)) h.select(**select) phase, axes = h.phase phase = phase[0, ...] gain_outliers, _ = h.weights_phase gain_outliers = gain_outliers[0, ...] == 1 amp, axes = h.amplitude amp = amp[0, ...] _, freqs = h.get_freqs(axes['freq']) freqs = freqs.to(au.Hz).value _, times = h.get_times(axes['time']) times = jnp.asarray(times.mjd, dtype=jnp.float64) * 86400. times = times - times[0] logger.info("Shape: {}".format(phase.shape)) (Nd, Na, Nf, Nt) = phase.shape @jit def smooth(amp, outliers): ''' Smooth amplitudes Args: amp: [Nt, Nf] outliers: [Nt, Nf] ''' weights = jnp.where(outliers, 0., 1.) log_amp = jnp.log(amp) log_amp = vmap(lambda log_amp, weights: poly_smooth( times, log_amp, deg=3, weights=weights))(log_amp.T, weights.T).T log_amp = vmap(lambda log_amp, weights: poly_smooth( freqs, log_amp, deg=3, weights=weights))(log_amp, weights) amp = jnp.exp(log_amp) return amp logger.info("Smoothing amplitudes") amp = chunked_pmap(smooth, amp.reshape((Nd * Na, Nf, Nt)).transpose((0, 2, 1)), gain_outliers.reshape((Nd * Na, Nf, Nt)).transpose( (0, 2, 1))) # Nd*Na,Nt,Nf amp = amp.transpose((0, 2, 1)).reshape((Nd, Na, Nf, Nt)) return gain_outliers, phase, amp, times, freqs
def detect_tec_outliers(times, tec_mean, tec_std): """ Detect outliers in dphase (in batch) Args: tec: [Nd, Na, Nt] tec uncert times: [Nt] Returns: outliers [Nd, Na, Nt] """ times, tec_mean, tec_std = jnp.asarray(times), jnp.asarray( tec_mean), jnp.asarray(tec_std) Nd, Na, Nt = tec_mean.shape tec_mean = tec_mean.reshape((Nd * Na, Nt)) tec_std = tec_std.reshape((Nd * Na, Nt)) res = chunked_pmap(lambda tec_mean, tec_std: single_detect_tec_outliers( times, tec_mean, tec_std), tec_mean, tec_std, chunksize=None) res = tree_map(lambda x: x.reshape((Nd, Na, Nt)), res) return res
def detect_dphase_outliers(dphase): """ Detect outliers in dphase (in batch) Args: dphase: [Nd, Na, Nf, Nt] tec uncert times: [Nt] Returns: outliers [Nd, Na, Nf, Nt] """ Nd, Na, Nf, Nt = dphase.shape dphase = dphase.reshape((Nd * Na * Nf, Nt)) outliers = jnp.abs(dphase) > 1. outliers = outliers | (jnp.abs(dphase) > 5. * jnp.sqrt(jnp.mean(dphase**2))) print(outliers.sum()) outliers = chunked_pmap(lambda dphase, outliers: single_detect_outliers( dphase, window=15, init_outliers=outliers), dphase, outliers, chunksize=None) outliers = outliers.reshape((Nd, Na, Nf, Nt)) print(outliers.sum()) return outliers
def solve_with_vanilla_kernel(key, dtec, dtec_uncert, X, Xstar, fed_kernel, time_block_size, chunksize): """ Precompute look-up tables for all blocks. Args: key: PRNG key dtec: [Nd, Na, Nt] TECU dtec_uncert: [Nd, Na, Nt] TECU X: [Nd,2] coordinates in deg Xstar: [Nd_screen, 2] screen coordinates fed_kernel: StationaryKernel time_block_size: int chunksize: int number of parallel devices to use. """ field_of_view = 4. #deg min_separation_arcmin = 4. #drcmin min_separation_deg = min_separation_arcmin / 60. lengthscale_array = jnp.linspace(min_separation_deg, field_of_view, 120) sigma_array = jnp.linspace(0., 150., 150) kernel = fed_kernel lookup_func = build_lookup_index(lengthscale_array, sigma_array) dtec_uncert = jnp.maximum(dtec_uncert, 1e-6) Nd, Na, Nt = dtec.shape remainder = Nt % time_block_size extra = time_block_size - remainder dtec = jnp.concatenate([dtec, dtec[:, :, Nt - extra:]], axis=-1) dtec_uncert = jnp.concatenate( [dtec_uncert, dtec_uncert[:, :, Nt - extra:]], axis=-1) Nt = dtec.shape[-1] size_dict = dict(a=Na, d=Nd, b=time_block_size) dtec = axes_move(dtec, ['d', 'a', 'tb'], ['atb', 'd'], size_dict=size_dict) dtec_uncert = axes_move(dtec_uncert, ['d', 'a', 'tb'], ['atb', 'd'], size_dict=size_dict) def compute_log_prob_components(lengthscale): # N, N K = kernel(X, X, lengthscale, 1.) def _compute_with_sigma(sigma): def _compute(dtec, dtec_uncert): #each [Nd] return log_normal_with_outliers(dtec, 0., sigma**2 * K, dtec_uncert) return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1) #M # Ns,M return chunked_pmap(_compute_with_sigma, sigma_array, chunksize=1) # Nl,Ns,M log_prob = chunked_pmap(compute_log_prob_components, lengthscale_array, chunksize=chunksize) # Na * (Nt//time_block_size),block_size,Nl,Ns log_prob = axes_move(log_prob, ['l', 's', 'atb'], ['at', 'b', 'l', 's'], size_dict=size_dict) # Na * (Nt//time_block_size),Nl,Ns log_prob = jnp.sum(log_prob, axis=1) #independent datasets summed up. def run_block(key, dtec, dtec_uncert, log_prob): key1, key2 = random.split(key, 2) def log_likelihood(lengthscale, sigma, **kwargs): # K = kernel(X, X, lengthscale, sigma) # def _compute(dtec, dtec_uncert): # #each [Nd] # return log_normal_with_outliers(dtec, 0., K, jnp.maximum(1e-6, dtec_uncert)) # return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1).sum() return lookup_func(log_prob, lengthscale, sigma) lengthscale = UniformPrior('lengthscale', jnp.min(lengthscale_array), jnp.max(lengthscale_array)) sigma = UniformPrior('sigma', sigma_array.min(), sigma_array.max()) prior_chain = PriorChain(lengthscale, sigma) ns = NestedSampler(loglikelihood=log_likelihood, prior_chain=prior_chain, sampler_kwargs=dict(num_slices=prior_chain.U_ndims * 1), num_live_points=prior_chain.U_ndims * 50) ns = jit(ns) results = ns(key1, termination_evidence_frac=0.1) def marg_func(lengthscale, sigma, **kwargs): def screen(dtec, dtec_uncert, **kw): K = kernel(X, X, lengthscale, sigma) Kstar = kernel(X, Xstar, lengthscale, sigma) L = jnp.linalg.cholesky( K / (dtec_uncert[:, None] * dtec_uncert[None, :]) + jnp.eye(dtec.shape[0])) # L = jnp.where(jnp.isnan(L), jnp.eye(L.shape[0])/sigma, L) dx = solve_triangular(L, dtec / dtec_uncert, lower=True) JT = solve_triangular(L, Kstar / dtec_uncert[:, None], lower=True) #var_ik = JT_ji JT_jk mean = JT.T @ dx var = jnp.sum(JT * JT, axis=0) return mean, var return vmap(screen)(dtec, dtec_uncert), lengthscale, jnp.log( sigma ) #[time_block_size, Nd_screen], [time_block_size, Nd_screen] #[time_block_size, Nd_screen], [time_block_size, Nd_screen], [time_block_size] (mean, var), mean_lengthscale, mean_logsigma = marginalise_static( key2, results.samples, results.log_p, 500, marg_func) uncert = jnp.sqrt(var) mean_sigma = jnp.exp(mean_logsigma) mean_lengthscale = jnp.ones(time_block_size) * mean_lengthscale mean_sigma = jnp.ones(time_block_size) * mean_sigma ESS = results.ESS * jnp.ones(time_block_size) logZ = results.logZ * jnp.ones(time_block_size) likelihood_evals = results.num_likelihood_evaluations * jnp.ones( time_block_size) return mean, uncert, mean_lengthscale, mean_sigma, ESS, logZ, likelihood_evals T = Na * (Nt // time_block_size) keys = random.split(key, T) # [T, time_block_size, Nd_screen], [T, time_block_size, Nd_screen], [T, time_block_size], [T, time_block_size] dtec = axes_move(dtec, ['atb', 'd'], ['at', 'b', 'd'], size_dict=size_dict) dtec_uncert = axes_move(dtec_uncert, ['atb', 'd'], ['at', 'b', 'd'], size_dict=size_dict) mean, uncert, mean_lengthscale, mean_sigma, ESS, logZ, likelihood_evals = chunked_pmap( run_block, keys, dtec, dtec_uncert, log_prob, chunksize=chunksize) mean = axes_move(mean, ['at', 'b', 'n'], ['n', 'a', 'tb'], size_dict=size_dict) uncert = axes_move(uncert, ['at', 'b', 'n'], ['n', 'a', 'tb'], size_dict=size_dict) mean_lengthscale = axes_move(mean_lengthscale, ['at', 'b'], ['a', 'tb'], size_dict=size_dict) mean_sigma = axes_move(mean_sigma, ['at', 'b'], ['a', 'tb'], size_dict=size_dict) ESS = axes_move(ESS, ['at', 'b'], ['a', 'tb'], size_dict=size_dict) logZ = axes_move(logZ, ['at', 'b'], ['a', 'tb'], size_dict=size_dict) likelihood_evals = axes_move(likelihood_evals, ['at', 'b'], ['a', 'tb'], size_dict=size_dict) return mean[..., Nt - extra:], uncert[..., Nt - extra:], mean_lengthscale[ ..., Nt - extra:], mean_sigma[..., Nt - extra:], ESS[ ..., Nt - extra:], logZ, likelihood_evals[..., Nt - extra:]
def solve_with_tomographic_kernel(dtec, dtec_uncert, X, x0, fed_kernel, time_block_size): """ Precompute look-up tables for all blocks. Assumes that each antenna is independent and doesn't take into account time. Args: dtec: [Nd, Na, Nt] dtec_uncert: [Nd, Na, Nt] X: [Nd,6] x0: [3] fed_kernel: StationaryKernel time_block_size: int """ scale = jnp.std(dtec) / 35. dtec /= scale dtec_uncert /= scale bottom_array = jnp.linspace(200., 400., 5) width_array = jnp.linspace(50., 50., 1) lengthscale_array = jnp.linspace(0.1, 7.5, 7) sigma_array = jnp.linspace(0.1, 2., 11) kernel = TomographicKernel(x0, x0, fed_kernel, S_marg=25, compute_tec=False) lookup_func = build_lookup_index(bottom_array, width_array, lengthscale_array, sigma_array) Nd, Na, Nt = dtec.shape remainder = Nt % time_block_size dtec = jnp.concatenate([dtec, dtec[:, :, -remainder:]], axis=-1) dtec_uncert = jnp.concatenate( [dtec_uncert, dtec_uncert[:, :, -remainder:]], axis=-1) Nt = dtec.shape[-1] dtec = dtec.transpose((2, 1, 0)).reshape((Nt * Na, Nd)) dtec_uncert = dtec_uncert.transpose((2, 1, 0)).reshape((Nt * Na, Nd)) # Nt*Na, ... log_prob = precompute_log_prob_components_without_wind(kernel, X, dtec, dtec_uncert, bottom_array, width_array, lengthscale_array, sigma_array, chunksize=4) log_prob = jnp.reshape(log_prob, (Nt // remainder, remainder, Na) + log_prob.shape[1:]) def run_block(block_idx): def log_likelihood(bottom, width, lengthscale, sigma, **kwargs): return jnp.sum( vmap(lambda log_prob: lookup_func(log_prob, bottom, width, lengthscale, sigma))( log_prob[block_idx])) bottom = UniformPrior('bottom', bottom_array.min(), bottom_array.max()) width = DeltaPrior('width', 50., tracked=False) lengthscale = UniformPrior('lengthscale', jnp.min(lengthscale_array), jnp.max(lengthscale_array)) sigma = UniformPrior('sigma', sigma_array.min(), sigma_array.max()) prior_chain = PriorChain(lengthscale, sigma, bottom, width) ns = NestedSampler(loglikelihood=log_likelihood, prior_chain=prior_chain, sampler_name='slice', sampler_kwargs=dict(num_slices=prior_chain.U_ndims * 5), num_live_points=prior_chain.U_ndims * 50) ns = jit(ns) results = ns(random.PRNGKey(42), termination_frac=0.001) return results # results.efficiency.block_until_ready() # t0 = default_timer() # results = ns(random.PRNGKey(42), termination_frac=0.001) # summary(results) # print(default_timer() - t0) # def screen(bottom, lengthscale, east_wind_speed, north_wind_speed, sigma, **kw): # wind_velocity = jnp.asarray([east_wind_speed, north_wind_speed, 0.]) # K = kernel(X, X, bottom, 50., lengthscale, sigma, wind_velocity=wind_velocity) # Kstar = kernel(X, Xstar, bottom, 50., lengthscale, sigma) # L = jnp.linalg.cholesky(K + jnp.diag(jnp.maximum(1e-6, dtec_uncert) ** 2)) # dx = solve_triangular(L, dtec, lower=True) # return solve_triangular(L, Kstar, lower=True).T @ dx # summary(results) # plot_diagnostics(results) # plot_cornerplot(results) # screen_mean = marginalise_static(random.PRNGKey(4325325), results.samples, results.log_p, int(results.ESS), screen) # print(screen_mean) # plot_vornoi_map(Xstar[:, 3:5], screen_mean) # plt.show() # plot_vornoi_map(X[:, 3:5], dtec) # plt.show() # return screen_mean results = chunked_pmap(run_block, jnp.arange(Nt // time_block_size))
def solve_and_smooth(gain_outliers, phase_obs, times, freqs): logger.info("Performing solve for tec and const from phases.") Nd, Na, Nf, Nt = phase_obs.shape logger.info("Number of nan: {}".format(jnp.sum(jnp.isnan(phase_obs)))) logger.info("Number of inf: {}".format(jnp.sum(jnp.isinf(phase_obs)))) # blocksize chosen to maximise Fisher information, which is 2 for tec+const, and 3 for tec+const+clock blocksize = 2 remainder = Nt % blocksize if remainder != 0: if remainder < Nt: raise ValueError( f"Block size {blocksize} too big for number of timesteps {Nt}." ) (gain_outliers, phase_obs) = tree_map( lambda x: jnp.concatenate( [x, jnp.repeat(x[..., -1:], remainder, axis=-1)], axis=-1), (gain_outliers, phase_obs)) Nt = Nt + remainder times = jnp.concatenate([ times, times[-1] + jnp.arange(1, 1 + remainder) * jnp.mean(jnp.diff(times)) ]) size_dict = dict(d=Nd, a=Na, f=Nf, b=blocksize) # [Nd*Na*(Nt//blocksize), blocksize, Nf] gain_outliers = axes_move(gain_outliers, ['d', 'a', 'f', 'tb'], ['dat', 'b', 'f'], size_dict=size_dict) phase_obs = axes_move(phase_obs, ['d', 'a', 'f', 'tb'], ['dat', 'b', 'f'], size_dict=size_dict) T = Nd * Na * (Nt // blocksize) # Nd * Na * (Nt // blocksize) keys = random.split(random.PRNGKey(int(1000 * default_timer())), T) # [Nd*Na*(Nt//blocksize), blocksize], [# Nd*Na*(Nt//blocksize), blocksize] tec_mean, tec_std, const_mean, const_std, uncert_mean = chunked_pmap( lambda *args: unconstrained_solve(freqs, *args), keys, phase_obs, gain_outliers) # Nd*Na*(Nt//blocksize), blocksize const_weights = 1. / const_std**2 def smooth(y, weights): y = axes_move(y, ['dat', 'b'], ['da', 'tb'], size_dict=size_dict) weights = axes_move(weights, ['dat', 'b'], ['da', 'tb'], size_dict=size_dict) y = chunked_pmap( lambda y, weights: poly_smooth(times, y, deg=5, weights=weights), y, weights) y = axes_move(y, ['da', 'tb'], ['dat', 'b'], size_dict=size_dict) return y logger.info("Smoothing and outlier rejection of const (a weak prior).") # Nd,Na,Nt/blocksize, blocksize const_real_mean = smooth(jnp.cos(const_mean), const_weights) # Nd*Na*(Nt//blocksize), blocksize const_imag_mean = smooth(jnp.sin(const_mean), const_weights) # Nd*Na*(Nt//blocksize), blocksize const_mean_smoothed = jnp.arctan2( const_imag_mean, const_real_mean) # Nd*Na*(Nt//blocksize), blocksize # empirically determined uncertainty point where sigma(tec - tec_true) > 6 mTECU which_reprocess = jnp.any(uncert_mean > 0., axis=1) # Nd*Na*(Nt//blocksize) replace_map = jnp.where(which_reprocess) logger.info("Performing refined tec-only solve, with fixed const.") keys = random.split(random.PRNGKey(int(1000 * default_timer())), jnp.sum(which_reprocess)) # [Nd*Na*(Nt//blocksize), blocksize] (tec_mean_constrained, tec_std_constrained, const_mean_constrained, const_std_constrained) = \ chunked_pmap(lambda *args: constrained_solve(freqs, *args), keys, phase_obs[which_reprocess], gain_outliers[which_reprocess], const_mean_smoothed[which_reprocess], const_std[which_reprocess] ) tec_mean = tec_mean.at[replace_map].set(tec_mean_constrained) tec_std = tec_std.at[replace_map].set(tec_std_constrained) const_std = const_std.at[replace_map].set(const_std_constrained) const_mean = const_mean.at[replace_map].set(const_mean_constrained) (tec_mean, tec_std, const_mean, const_std) = tree_map(lambda x: x.reshape((Nd, Na, Nt)), (tec_mean, tec_std, const_mean, const_std)) # Nd, Na, Nt logger.info("Performing outlier detection on tec values.") tec_est, tec_outliers = detect_tec_outliers(times, tec_mean, tec_std) tec_std = jnp.where(tec_outliers, jnp.inf, tec_std) # remove remainder at the end if remainder != 0: (tec_mean, tec_std, const_mean, const_std) = tree_map(lambda x: x[..., :Nt - remainder], (tec_mean, tec_std, const_mean, const_std)) # compute phase mean with outlier-suppressed tec. phase_mean = tec_mean[..., None, :] * ( TEC_CONV / freqs[:, None]) + const_mean[..., None, :] phase_uncert = jnp.sqrt((tec_std[..., None, :] * (TEC_CONV / freqs[:, None]))**2 + (const_std[..., None, :])**2) return phase_mean, phase_uncert, tec_mean, tec_std, tec_outliers, const_mean, const_std
def run(self, output_h5parm, ncpu, avg_direction_spacing, field_of_view_diameter, duration, time_resolution, start_time, array_name, phase_tracking): Nd = get_num_directions( avg_direction_spacing, field_of_view_diameter, ) Nf = 2 # 8000 Nt = int(duration / time_resolution) + 1 min_freq = 700. max_freq = 2000. dp = create_empty_datapack( Nd, Nf, Nt, pols=None, field_of_view_diameter=field_of_view_diameter, start_time=start_time, time_resolution=time_resolution, min_freq=min_freq, max_freq=max_freq, array_file=ARRAYS[array_name], phase_tracking=(phase_tracking.ra.deg, phase_tracking.dec.deg), save_name=output_h5parm, clobber=True) with dp: dp.current_solset = 'sol000' dp.select(pol=slice(0, 1, 1)) axes = dp.axes_tec patch_names, directions = dp.get_directions(axes['dir']) antenna_labels, antennas = dp.get_antennas(axes['ant']) timestamps, times = dp.get_times(axes['time']) ref_ant = antennas[0] ref_time = times[0] Na = len(antennas) Nd = len(directions) Nt = len(times) logger.info(f"Number of directions: {Nd}") logger.info(f"Number of antennas: {Na}") logger.info(f"Number of times: {Nt}") logger.info(f"Reference Ant: {ref_ant}") logger.info(f"Reference Time: {ref_time.isot}") # Plot Antenna Layout in East North Up frame ref_frame = ENU(obstime=ref_time, location=ref_ant.earth_location) _antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=ref_time).transform_to(ref_frame) # plt.scatter(_antennas.east, _antennas.north, marker='+') # plt.xlabel(f"East (m)") # plt.ylabel(f"North (m)") # plt.show() x0 = ac.ITRS( *antennas[0].cartesian.xyz, obstime=ref_time).transform_to(ref_frame).cartesian.xyz.to( au.km).value earth_centre_x = ac.ITRS( x=0 * au.m, y=0 * au.m, z=0. * au.m, obstime=ref_time).transform_to(ref_frame).cartesian.xyz.to( au.km).value self._kernel = TomographicKernel(x0, earth_centre_x, M32(), S_marg=20, compute_tec=False) k = directions.transform_to(ref_frame).cartesian.xyz.value.T t = times.mjd * 86400. t -= t[0] X1 = GeodesicTuple(x=[], k=[], t=[], ref_x=[]) logger.info("Computing coordinates in frame ...") for i, time in tqdm(enumerate(times)): x = ac.ITRS(*antennas.cartesian.xyz, obstime=time).transform_to(ref_frame).cartesian.xyz.to( au.km).value.T ref_ant_x = ac.ITRS( *ref_ant.cartesian.xyz, obstime=time).transform_to(ref_frame).cartesian.xyz.to( au.km).value X = make_coord_array(x, k, t[i:i + 1, None], ref_ant_x[None, :], flat=True) X1.x.append(X[:, 0:3]) X1.k.append(X[:, 3:6]) X1.t.append(X[:, 6:7]) X1.ref_x.append(X[:, 7:8]) X1 = X1._replace( x=jnp.concatenate(X1.x, axis=0), k=jnp.concatenate(X1.k, axis=0), t=jnp.concatenate(X1.t, axis=0), ref_x=jnp.concatenate(X1.ref_x, axis=0), ) logger.info(f"Total number of coordinates: {X1.x.shape[0]}") def compute_covariance_row(X1: GeodesicTuple, X2: GeodesicTuple): K = self._kernel(X1, X2, self._bottom, self._width, self._fed_sigma, self._fed_kernel_params, wind_velocity=self._wind_vector) # 1, N return K[0, :] covariance_row = lambda X: compute_covariance_row( tree_map(lambda x: x.reshape((1, -1)), X), X1) mean = jit(lambda X1: self._kernel.mean_function(X1, self._bottom, self._width, self._fed_mu, wind_velocity=self. _wind_vector))(X1) cov = chunked_pmap(covariance_row, X1, batch_size=X1.x.shape[0], chunksize=ncpu) plt.imshow(cov) plt.show() Z = random.normal(random.PRNGKey(42), (cov.shape[0], 1), dtype=cov.dtype) t0 = default_timer() jitter = 1e-6 logger.info(f"Computing Cholesky with jitter: {jitter}") L = jnp.linalg.cholesky(cov + jitter * jnp.eye(cov.shape[0])) if np.any(np.isnan(L)): logger.info("Numerically instable. Using SVD.") L = msqrt(cov) logger.info(f"Cholesky took {default_timer() - t0} seconds.") dtec = (L @ Z + mean[:, None])[:, 0].reshape((Na, Nd, Nt)).transpose( (1, 0, 2)) logger.info(f"Saving result to {output_h5parm}") with dp: dp.current_solset = 'sol000' dp.select(pol=slice(0, 1, 1)) dp.tec = np.asarray(dtec[None])