def distributed_strategy(args): model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model) path = os.getenv('CENSAI_PATH') + "/results/" dataset = [] for file in sorted(glob.glob(path + args.h5_pattern)): try: dataset.append(h5py.File(file, "r")) except: continue B = dataset[0]["source"].shape[0] data_len = len(dataset) * B // N_WORKERS ps_observation = PowerSpectrum(bins=args.observation_coherence_bins, pixels=128) ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=128) ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=128) phys = PhysicalModel( pixels=128, kappa_pixels=128, src_pixels=128, image_fov=7.69, kappa_fov=7.69, src_fov=3., method="fft", ) with open(os.path.join(model, "unet_hparams.json")) as f: unet_params = json.load(f) unet_params["kernel_l2_amp"] = args.l2_amp unet = Model(**unet_params) ckpt = tf.train.Checkpoint(net=unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1) checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial() with open(os.path.join(model, "rim_hparams.json")) as f: rim_params = json.load(f) rim = RIM(phys, unet, **rim_params) kvae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.kappa_vae) with open(os.path.join(kvae_path, "model_hparams.json"), "r") as f: kappa_vae_hparams = json.load(f) kappa_vae = VAE(**kappa_vae_hparams) ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae) checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, kvae_path, 1) checkpoint_manager1.checkpoint.restore( checkpoint_manager1.latest_checkpoint).expect_partial() svae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.source_vae) with open(os.path.join(svae_path, "model_hparams.json"), "r") as f: source_vae_hparams = json.load(f) source_vae = VAE(**source_vae_hparams) ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae) checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, svae_path, 1) checkpoint_manager2.checkpoint.restore( checkpoint_manager2.latest_checkpoint).expect_partial() wk = lambda k: tf.sqrt(k) / tf.reduce_sum( tf.sqrt(k), axis=(1, 2, 3), keepdims=True) # Freeze L5 # encoding layers # rim.unet.layers[0].trainable = False # L1 # rim.unet.layers[1].trainable = False # rim.unet.layers[2].trainable = False # rim.unet.layers[3].trainable = False # rim.unet.layers[4].trainable = False # L5 # GRU # rim.unet.layers[5].trainable = False # rim.unet.layers[6].trainable = False # rim.unet.layers[7].trainable = False # rim.unet.layers[8].trainable = False # rim.unet.layers[9].trainable = False # rim.unet.layers[15].trainable = False # bottleneck GRU # output layer # rim.unet.layers[-2].trainable = False # input layer # rim.unet.layers[-1].trainable = False # decoding layers # rim.unet.layers[10].trainable = False # L5 # rim.unet.layers[11].trainable = False # rim.unet.layers[12].trainable = False # rim.unet.layers[13].trainable = False # rim.unet.layers[14].trainable = False # L1 with h5py.File( os.path.join( os.getenv("CENSAI_PATH"), "results", args.experiment_name + "_" + args.model + "_" + args.dataset + f"_{THIS_WORKER:03d}.h5"), 'w') as hf: hf.create_dataset(name="observation", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="psf", shape=[data_len, 20, 20, 1], dtype=np.float32) hf.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32) hf.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32) hf.create_dataset( name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred_reoptimized", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred_reoptimized", shape=[data_len, phys.src_pixels, phys.src_pixels, 1]) hf.create_dataset(name="kappa_pred", shape=[ data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1 ], dtype=np.float32) hf.create_dataset( name="kappa_pred_reoptimized", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized", shape=[data_len], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="source_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum2", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum_reoptimized", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum2", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum_reoptimized", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum_reoptimized", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_frequencies", shape=[args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_frequencies", shape=[args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_frequencies", shape=[args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="source_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="observation_fov", shape=[1], dtype=np.float32) for batch, j in enumerate( range((THIS_WORKER - 1) * data_len, THIS_WORKER * data_len)): b = j // B k = j % B observation = dataset[b]["observation"][k][None, ...] source = dataset[b]["source"][k][None, ...] kappa = dataset[b]["kappa"][k][None, ...] noise_rms = np.array([dataset[b]["noise_rms"][k]]) psf = dataset[b]["psf"][k][None, ...] fwhm = dataset[b]["psf_fwhm"][k] checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial( ) # reset model weights # Compute predictions for kappa and source source_pred, kappa_pred, chi_squared = rim.predict( observation, noise_rms, psf) observation_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) # reset the seed for reproducible sampling in the VAE for EWC tf.random.set_seed(args.seed) np.random.seed(args.seed) # Initialize regularization term ewc = EWC(observation=observation, noise_rms=noise_rms, psf=psf, phys=phys, rim=rim, source_vae=source_vae, kappa_vae=kappa_vae, n_samples=args.sample_size, sigma_source=args.source_vae_ball_size, sigma_kappa=args.kappa_vae_ball_size) # Re-optimize weights of the model STEPS = args.re_optimize_steps learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=args.learning_rate, decay_rate=args.decay_rate, decay_steps=args.decay_steps, staircase=args.staircase) optim = tf.keras.optimizers.SGD( learning_rate=learning_rate_schedule) chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) source_mse = tf.TensorArray(DTYPE, size=STEPS) kappa_mse = tf.TensorArray(DTYPE, size=STEPS) best = chi_squared[-1, 0] source_best = source_pred[-1] kappa_best = kappa_pred[-1] source_mse_best = tf.reduce_mean( (source_best - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_sum( wk(kappa) * (kappa_best - rim.kappa_inverse_link(kappa))**2) for current_step in tqdm(range(STEPS)): with tf.GradientTape() as tape: tape.watch(unet.trainable_variables) s, k, chi_sq = rim.call(observation, noise_rms, psf, outer_tape=tape) cost = tf.reduce_mean(chi_sq) # mean over time steps cost += tf.reduce_sum(rim.unet.losses) # L2 regularisation cost += args.lam_ewc * ewc.penalty( rim) # Elastic Weights Consolidation log_likelihood = chi_sq[-1] chi_squared_series = chi_squared_series.write( index=current_step, value=log_likelihood) source_o = s[-1] kappa_o = k[-1] source_mse = source_mse.write( index=current_step, value=tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2)) kappa_mse = kappa_mse.write( index=current_step, value=tf.reduce_sum( wk(kappa) * (kappa_o - rim.kappa_inverse_link(kappa))**2)) if 2 * chi_sq[-1, 0] < 1.0 and args.early_stopping: source_best = rim.source_link(source_o) kappa_best = rim.kappa_link(kappa_o) best = chi_sq[-1, 0] source_mse_best = tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_sum( wk(kappa) * (kappa_o - rim.kappa_inverse_link(kappa))**2) break if chi_sq[-1, 0] < best: source_best = rim.source_link(source_o) kappa_best = rim.kappa_link(kappa_o) best = chi_sq[-1, 0] source_mse_best = tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_sum( wk(kappa) * (kappa_o - rim.kappa_inverse_link(kappa))**2) grads = tape.gradient(cost, unet.trainable_variables) optim.apply_gradients(zip(grads, unet.trainable_variables)) source_o = source_best kappa_o = kappa_best y_pred = phys.forward(source_o, kappa_o, psf) chi_sq_series = tf.transpose(chi_squared_series.stack(), perm=[1, 0]) source_mse = source_mse.stack()[None, ...] kappa_mse = kappa_mse.stack()[None, ...] # Compute Power spectrum of converged predictions _ps_observation = ps_observation.cross_correlation_coefficient( observation[..., 0], observation_pred[..., 0]) _ps_observation2 = ps_observation.cross_correlation_coefficient( observation[..., 0], y_pred[..., 0]) _ps_kappa = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0]) _ps_kappa2 = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_o[..., 0])) _ps_source = ps_source.cross_correlation_coefficient( source[..., 0], source_pred[-1][..., 0]) _ps_source2 = ps_source.cross_correlation_coefficient( source[..., 0], source_o[..., 0]) # save results hf["observation"][batch] = observation.astype(np.float32) hf["psf"][batch] = psf.astype(np.float32) hf["psf_fwhm"][batch] = fwhm hf["noise_rms"][batch] = noise_rms.astype(np.float32) hf["source"][batch] = source.astype(np.float32) hf["kappa"][batch] = kappa.astype(np.float32) hf["observation_pred"][batch] = observation_pred.numpy().astype( np.float32) hf["observation_pred_reoptimized"][batch] = y_pred.numpy().astype( np.float32) hf["source_pred"][batch] = tf.transpose( source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["source_pred_reoptimized"][batch] = source_o.numpy().astype( np.float32) hf["kappa_pred"][batch] = tf.transpose( kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["kappa_pred_reoptimized"][batch] = kappa_o.numpy().astype( np.float32) hf["chi_squared"][batch] = 2 * tf.transpose( chi_squared).numpy().astype(np.float32) hf["chi_squared_reoptimized"][batch] = 2 * best.numpy().astype( np.float32) hf["chi_squared_reoptimized_series"][ batch] = 2 * chi_sq_series.numpy().astype(np.float32) hf["source_optim_mse"][batch] = source_mse_best.numpy().astype( np.float32) hf["source_optim_mse_series"][batch] = source_mse.numpy().astype( np.float32) hf["kappa_optim_mse"][batch] = kappa_mse_best.numpy().astype( np.float32) hf["kappa_optim_mse_series"][batch] = kappa_mse.numpy().astype( np.float32) hf["observation_coherence_spectrum"][batch] = _ps_observation hf["observation_coherence_spectrum_reoptimized"][ batch] = _ps_observation2 hf["source_coherence_spectrum"][batch] = _ps_source hf["source_coherence_spectrum_reoptimized"][batch] = _ps_source2 hf["kappa_coherence_spectrum"][batch] = _ps_kappa hf["kappa_coherence_spectrum_reoptimized"][batch] = _ps_kappa2 if batch == 0: _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels // 2], bins=ps_observation.bins) f = (f[:-1] + f[1:]) / 2 hf["observation_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins) f = (f[:-1] + f[1:]) / 2 hf["source_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins) f = (f[:-1] + f[1:]) / 2 hf["kappa_frequencies"][:] = f hf["kappa_fov"][0] = phys.kappa_fov hf["source_fov"][0] = phys.src_fov
def __init__( self, physical_model: PhysicalModel, unet: Model, steps: int, adam=True, rmsprop=False, # overwrites ADAM with RMSProp kappalog=True, kappa_normalize=False, source_link="relu", beta_1=0.9, beta_2=0.99, epsilon=1e-8, flux_lagrange_multiplier: float = 0.): self.physical_model = physical_model self.pixels = physical_model.kappa_pixels self.unet = unet self.steps = steps self.adam = adam self.kappalog = kappalog self._source_link_func = source_link self.kappa_normalize = kappa_normalize self.beta_1 = beta_1 self.beta_2 = beta_2 self.epsilon = epsilon self.flux_lagrange_multiplier = flux_lagrange_multiplier if self.kappalog: if self.kappa_normalize: self.kappa_inverse_link = tf.keras.layers.Lambda( lambda x: logkappa_normalization(log_10(x), forward=True)) self.kappa_link = tf.keras.layers.Lambda( lambda x: 10**(logkappa_normalization(x, forward=False))) else: self.kappa_inverse_link = tf.keras.layers.Lambda( lambda x: log_10(x)) self.kappa_link = tf.keras.layers.Lambda(lambda x: 10**x) else: self.kappa_link = tf.identity self.kappa_inverse_link = tf.identity if self._source_link_func == "exp": self.source_inverse_link = tf.keras.layers.Lambda( lambda x: tf.math.log(x + 1e-6)) self.source_link = tf.keras.layers.Lambda(lambda x: tf.math.exp(x)) elif self._source_link_func == "identity": self.source_inverse_link = tf.identity self.source_link = tf.identity elif self._source_link_func == "relu": self.source_inverse_link = tf.identity self.source_link = tf.nn.relu elif self._source_link_func == "sigmoid": self.source_inverse_link = logit self.source_link = tf.nn.sigmoid elif self._source_link_func == "leaky_relu": self.source_inverse_link = tf.identity self.source_link = tf.nn.leaky_relu elif self._source_link_func == "lrelu4p": self.source_inverse_link = tf.identity self.source_link = lrelu4p else: raise NotImplementedError( f"{source_link} not in ['exp', 'identity', 'relu', 'leaky_relu', 'lrelu4p', 'sigmoid']" ) if rmsprop: self.grad_update = self.rmsprop_grad_update elif adam: self.grad_update = self.adam_grad_update else: self.grad_update = lambda x, y, t: (x, y)
def distributed_strategy(args): psf_pixels = 20 pixels = 128 model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model) ps_observation = PowerSpectrum(bins=args.observation_coherence_bins, pixels=pixels) ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=pixels) ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=pixels) phys = PhysicalModel( pixels=pixels, kappa_pixels=pixels, src_pixels=pixels, image_fov=7.68, kappa_fov=7.68, src_fov=3., method="fft", ) with open(os.path.join(model, "unet_hparams.json")) as f: unet_params = json.load(f) unet_params["kernel_l2_amp"] = args.l2_amp unet = Model(**unet_params) ckpt = tf.train.Checkpoint(net=unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1) checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial() with open(os.path.join(model, "rim_hparams.json")) as f: rim_params = json.load(f) rim_params["source_link"] = "relu" rim = RIM(phys, unet, **rim_params) kvae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.kappa_vae) with open(os.path.join(kvae_path, "model_hparams.json"), "r") as f: kappa_vae_hparams = json.load(f) kappa_vae = VAE(**kappa_vae_hparams) ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae) checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, kvae_path, 1) checkpoint_manager1.checkpoint.restore( checkpoint_manager1.latest_checkpoint).expect_partial() svae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.source_vae) with open(os.path.join(svae_path, "model_hparams.json"), "r") as f: source_vae_hparams = json.load(f) source_vae = VAE(**source_vae_hparams) ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae) checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, svae_path, 1) checkpoint_manager2.checkpoint.restore( checkpoint_manager2.latest_checkpoint).expect_partial() model_name = os.path.split(model)[-1] wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum( tf.sqrt(k), axis=(1, 2, 3), keepdims=True)) with h5py.File( os.path.join( os.getenv("CENSAI_PATH"), "results", args.experiment_name + "_" + model_name + f"_{THIS_WORKER:02d}.h5"), 'w') as hf: data_len = args.size // N_WORKERS hf.create_dataset(name="observation", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="psf", shape=[data_len, psf_pixels, psf_pixels, 1], dtype=np.float32) hf.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32) hf.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32) hf.create_dataset( name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred_reoptimized", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred_reoptimized", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset(name="kappa_pred", shape=[ data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1 ], dtype=np.float32) hf.create_dataset( name="kappa_pred_reoptimized", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized", shape=[data_len, rim.steps], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized_series", shape=[data_len, rim.steps, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="sampled_chi_squared_reoptimized_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="source_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="sampled_source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="sampled_kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="latent_kappa_gt_distance_init", shape=[data_len, kappa_vae.latent_size], dtype=np.float32) hf.create_dataset(name="latent_source_gt_distance_init", shape=[data_len, source_vae.latent_size], dtype=np.float32) hf.create_dataset(name="latent_kappa_gt_distance_end", shape=[data_len, kappa_vae.latent_size], dtype=np.float32) hf.create_dataset(name="latent_source_gt_distance_end", shape=[data_len, source_vae.latent_size], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum_reoptimized", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum_reoptimized", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum_reoptimized", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_frequencies", shape=[args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_frequencies", shape=[args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_frequencies", shape=[args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="source_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="observation_fov", shape=[1], dtype=np.float32) for i in range(data_len): checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial( ) # reset model weights # Produce an observation kappa = 10**kappa_vae.sample(1) source = tf.nn.relu(source_vae.sample(1)) source /= tf.reduce_max(source, axis=(1, 2, 3), keepdims=True) noise_rms = 10**tf.random.uniform(shape=[1], minval=-2.5, maxval=-1) fwhm = tf.random.uniform(shape=[1], minval=0.06, maxval=0.3) psf = phys.psf_models(fwhm, cutout_size=psf_pixels) observation = phys.noisy_forward(source, kappa, noise_rms, psf) # RIM predictions for kappa and source source_pred, kappa_pred, chi_squared = rim.predict( observation, noise_rms, psf) observation_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) source_o = source_pred[-1] kappa_o = kappa_pred[-1] # Latent code of model predictions z_source, _ = source_vae.encoder(source_o) z_kappa, _ = kappa_vae.encoder(log_10(kappa_o)) # Ground truth latent code for oracle metrics z_source_gt, _ = source_vae.encoder(source) z_kappa_gt, _ = kappa_vae.encoder(log_10(kappa)) # Re-optimize weights of the model STEPS = args.re_optimize_steps learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=args.learning_rate, decay_rate=args.decay_rate, decay_steps=args.decay_steps, staircase=args.staircase) optim = tf.keras.optimizers.RMSprop( learning_rate=learning_rate_schedule) chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) source_mse = tf.TensorArray(DTYPE, size=STEPS) kappa_mse = tf.TensorArray(DTYPE, size=STEPS) sampled_chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) sampled_source_mse = tf.TensorArray(DTYPE, size=STEPS) sampled_kappa_mse = tf.TensorArray(DTYPE, size=STEPS) best = chi_squared source_best = source_pred[-1] kappa_best = kappa_pred[-1] source_mse_best = tf.reduce_mean((source_best - source)**2) kappa_mse_best = tf.reduce_mean((kappa_best - log_10(kappa))**2) # ===================== Optimization ============================== for current_step in tqdm(range(STEPS)): # ===================== VAE SAMPLING ============================== # L1 distance with ground truth in latent space -- this is changed by an user defined value when using real data # z_source_std = tf.abs(z_source - z_source_gt) # z_kappa_std = tf.abs(z_kappa - z_kappa_gt) z_source_std = args.source_vae_ball_size z_kappa_std = args.kappa_vae_ball_size # Sample latent code, then decode and forward z_s = tf.random.normal( shape=[args.sample_size, source_vae.latent_size], mean=z_source, stddev=z_source_std) z_k = tf.random.normal( shape=[args.sample_size, kappa_vae.latent_size], mean=z_kappa, stddev=z_kappa_std) sampled_source = tf.nn.relu(source_vae.decode(z_s)) sampled_source /= tf.reduce_max(sampled_source, axis=(1, 2, 3), keepdims=True) sampled_kappa = kappa_vae.decode(z_k) # output in log_10 space sampled_observation = phys.noisy_forward( sampled_source, 10**sampled_kappa, noise_rms, tf.tile(psf, [args.sample_size, 1, 1, 1])) with tf.GradientTape() as tape: tape.watch(unet.trainable_variables) s, k, chi_sq = rim.call( sampled_observation, noise_rms, tf.tile(psf, [args.sample_size, 1, 1, 1]), outer_tape=tape) _kappa_mse = tf.reduce_sum(wk(10**sampled_kappa) * (k - sampled_kappa)**2, axis=(2, 3, 4)) cost = tf.reduce_mean(_kappa_mse) cost += tf.reduce_mean((s - sampled_source)**2) cost += tf.reduce_sum(rim.unet.losses) # weight decay grads = tape.gradient(cost, unet.trainable_variables) optim.apply_gradients(zip(grads, unet.trainable_variables)) # Record performance on sampled dataset sampled_chi_squared_series = sampled_chi_squared_series.write( index=current_step, value=tf.squeeze(tf.reduce_mean(chi_sq[-1]))) sampled_source_mse = sampled_source_mse.write( index=current_step, value=tf.reduce_mean((s[-1] - sampled_source)**2)) sampled_kappa_mse = sampled_kappa_mse.write( index=current_step, value=tf.reduce_mean((k[-1] - sampled_kappa)**2)) # Record model prediction on data s, k, chi_sq = rim.call(observation, noise_rms, psf) chi_squared_series = chi_squared_series.write( index=current_step, value=tf.squeeze(chi_sq)) source_o = s[-1] kappa_o = k[-1] # oracle metrics, remove when using real data source_mse = source_mse.write(index=current_step, value=tf.reduce_mean( (source_o - source)**2)) kappa_mse = kappa_mse.write(index=current_step, value=tf.reduce_mean( (kappa_o - log_10(kappa))**2)) if abs(chi_sq[-1, 0] - 1) < abs(best[-1, 0] - 1): source_best = tf.nn.relu(source_o) kappa_best = 10**kappa_o best = chi_sq source_mse_best = tf.reduce_mean((source_best - source)**2) kappa_mse_best = tf.reduce_mean( (kappa_best - log_10(kappa))**2) source_o = source_best kappa_o = kappa_best y_pred = phys.forward(source_o, kappa_o, psf) chi_sq_series = tf.transpose(chi_squared_series.stack()) source_mse = source_mse.stack() kappa_mse = kappa_mse.stack() sampled_chi_squared_series = sampled_chi_squared_series.stack() sampled_source_mse = sampled_source_mse.stack() sampled_kappa_mse = sampled_kappa_mse.stack() # Latent code of optimized model predictions z_source_opt, _ = source_vae.encoder(tf.nn.relu(source_o)) z_kappa_opt, _ = kappa_vae.encoder(log_10(kappa_o)) # Compute Power spectrum of converged predictions _ps_observation = ps_observation.cross_correlation_coefficient( observation[..., 0], observation_pred[..., 0]) _ps_observation2 = ps_observation.cross_correlation_coefficient( observation[..., 0], y_pred[..., 0]) _ps_kappa = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0]) _ps_kappa2 = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_o[..., 0])) _ps_source = ps_source.cross_correlation_coefficient( source[..., 0], source_pred[-1][..., 0]) _ps_source2 = ps_source.cross_correlation_coefficient( source[..., 0], source_o[..., 0]) # save results hf["observation"][i] = observation.numpy().astype(np.float32) hf["psf"][i] = psf.numpy().astype(np.float32) hf["psf_fwhm"][i] = fwhm.numpy().astype(np.float32) hf["noise_rms"][i] = noise_rms.numpy().astype(np.float32) hf["source"][i] = source.numpy().astype(np.float32) hf["kappa"][i] = kappa.numpy().astype(np.float32) hf["observation_pred"][i] = observation_pred.numpy().astype( np.float32) hf["observation_pred_reoptimized"][i] = y_pred.numpy().astype( np.float32) hf["source_pred"][i] = tf.transpose( source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["source_pred_reoptimized"][i] = source_o.numpy().astype( np.float32) hf["kappa_pred"][i] = tf.transpose( kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["kappa_pred_reoptimized"][i] = kappa_o.numpy().astype( np.float32) hf["chi_squared"][i] = tf.squeeze(chi_squared).numpy().astype( np.float32) hf["chi_squared_reoptimized"][i] = tf.squeeze(best).numpy().astype( np.float32) hf["chi_squared_reoptimized_series"][i] = chi_sq_series.numpy( ).astype(np.float32) hf["sampled_chi_squared_reoptimized_series"][ i] = 2 * sampled_chi_squared_series.numpy().astype(np.float32) hf["source_optim_mse"][i] = source_mse_best.numpy().astype( np.float32) hf["source_optim_mse_series"][i] = source_mse.numpy().astype( np.float32) hf["sampled_source_optim_mse_series"][ i] = sampled_source_mse.numpy().astype(np.float32) hf["kappa_optim_mse"][i] = kappa_mse_best.numpy().astype( np.float32) hf["kappa_optim_mse_series"][i] = kappa_mse.numpy().astype( np.float32) hf["sampled_kappa_optim_mse_series"][i] = sampled_kappa_mse.numpy( ).astype(np.float32) hf["latent_source_gt_distance_init"][i] = tf.abs( z_source - z_source_gt).numpy().squeeze().astype(np.float32) hf["latent_kappa_gt_distance_init"][i] = tf.abs( z_kappa - z_kappa_gt).numpy().squeeze().astype(np.float32) hf["latent_source_gt_distance_end"][i] = tf.abs( z_source_opt - z_source_gt).numpy().squeeze().astype( np.float32) hf["latent_kappa_gt_distance_end"][i] = tf.abs( z_kappa_opt - z_kappa_gt).numpy().squeeze().astype(np.float32) hf["observation_coherence_spectrum"][i] = _ps_observation hf["observation_coherence_spectrum_reoptimized"][ i] = _ps_observation2 hf["source_coherence_spectrum"][i] = _ps_source hf["source_coherence_spectrum_reoptimized"][i] = _ps_source2 hf["kappa_coherence_spectrum"][i] = _ps_kappa hf["kappa_coherence_spectrum_reoptimized"][i] = _ps_kappa2 if i == 0: _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels // 2], bins=ps_observation.bins) f = (f[:-1] + f[1:]) / 2 hf["observation_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins) f = (f[:-1] + f[1:]) / 2 hf["source_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins) f = (f[:-1] + f[1:]) / 2 hf["kappa_frequencies"][:] = f hf["kappa_fov"][0] = phys.kappa_fov hf["source_fov"][0] = phys.src_fov
def distributed_strategy(args): tf.random.set_seed(args.seed) np.random.seed(args.seed) model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model) files = glob.glob( os.path.join(os.getenv('CENSAI_PATH'), "data", args.train_dataset, "*.tfrecords")) files = tf.data.Dataset.from_tensor_slices(files) train_dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type).shuffle(len(files)), block_length=1, num_parallel_calls=tf.data.AUTOTUNE) # Read off global parameters from first example in dataset for physical_params in train_dataset.map(decode_physical_model_info): break train_dataset = train_dataset.map(decode_results).shuffle( buffer_size=args.buffer_size) files = glob.glob( os.path.join(os.getenv('CENSAI_PATH'), "data", args.val_dataset, "*.tfrecords")) files = tf.data.Dataset.from_tensor_slices(files) val_dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type).shuffle(len(files)), block_length=1, num_parallel_calls=tf.data.AUTOTUNE) val_dataset = val_dataset.map(decode_results).shuffle( buffer_size=args.buffer_size) files = glob.glob( os.path.join(os.getenv('CENSAI_PATH'), "data", args.test_dataset, "*.tfrecords")) files = tf.data.Dataset.from_tensor_slices(files) test_dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type).shuffle(len(files)), block_length=1, num_parallel_calls=tf.data.AUTOTUNE) test_dataset = test_dataset.map(decode_results).shuffle( buffer_size=args.buffer_size) ps_lens = PowerSpectrum(bins=args.lens_coherence_bins, pixels=physical_params["pixels"].numpy()) ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=physical_params["src pixels"].numpy()) ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=physical_params["kappa pixels"].numpy()) phys = PhysicalModel( pixels=physical_params["pixels"].numpy(), kappa_pixels=physical_params["kappa pixels"].numpy(), src_pixels=physical_params["src pixels"].numpy(), image_fov=physical_params["image fov"].numpy(), kappa_fov=physical_params["kappa fov"].numpy(), src_fov=physical_params["source fov"].numpy(), method="fft", ) phys_sie = AnalyticalPhysicalModel( pixels=physical_params["pixels"].numpy(), image_fov=physical_params["image fov"].numpy(), src_fov=physical_params["source fov"].numpy()) with open(os.path.join(model, "unet_hparams.json")) as f: unet_params = json.load(f) unet_params["kernel_l2_amp"] = args.l2_amp unet = Model(**unet_params) ckpt = tf.train.Checkpoint(net=unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1) checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial() with open(os.path.join(model, "rim_hparams.json")) as f: rim_params = json.load(f) rim = RIM(phys, unet, **rim_params) dataset_names = [args.train_dataset, args.val_dataset, args.test_dataset] dataset_shapes = [args.train_size, args.val_size, args.test_size] model_name = os.path.split(model)[-1] # from censai.utils import nulltape # def call_with_mask(self, lensed_image, noise_rms, psf, mask, outer_tape=nulltape): # """ # Used in training. Return linked kappa and source maps. # """ # batch_size = lensed_image.shape[0] # source, kappa, source_grad, kappa_grad, states = self.initial_states(batch_size) # initiate all tensors to 0 # source, kappa, states = self.time_step(lensed_image, source, kappa, source_grad, kappa_grad, # states) # Use lens to make an initial guess with Unet # source_series = tf.TensorArray(DTYPE, size=self.steps) # kappa_series = tf.TensorArray(DTYPE, size=self.steps) # chi_squared_series = tf.TensorArray(DTYPE, size=self.steps) # # record initial guess # source_series = source_series.write(index=0, value=source) # kappa_series = kappa_series.write(index=0, value=kappa) # # Main optimization loop # for current_step in tf.range(self.steps - 1): # with outer_tape.stop_recording(): # with tf.GradientTape() as g: # g.watch(source) # g.watch(kappa) # y_pred = self.physical_model.forward(self.source_link(source), self.kappa_link(kappa), psf) # flux_term = tf.square( # tf.reduce_sum(y_pred, axis=(1, 2, 3)) - tf.reduce_sum(lensed_image, axis=(1, 2, 3))) # log_likelihood = 0.5 * tf.reduce_sum( # tf.square(y_pred - mask * lensed_image) / noise_rms[:, None, None, None] ** 2, axis=(1, 2, 3)) # cost = tf.reduce_mean(log_likelihood + self.flux_lagrange_multiplier * flux_term) # source_grad, kappa_grad = g.gradient(cost, [source, kappa]) # source_grad, kappa_grad = self.grad_update(source_grad, kappa_grad, current_step) # source, kappa, states = self.time_step(lensed_image, source, kappa, source_grad, kappa_grad, states) # source_series = source_series.write(index=current_step + 1, value=source) # kappa_series = kappa_series.write(index=current_step + 1, value=kappa) # chi_squared_series = chi_squared_series.write(index=current_step, # value=log_likelihood / self.pixels ** 2) # renormalize chi squared here # # last step score # log_likelihood = self.physical_model.log_likelihood(y_true=lensed_image, source=self.source_link(source), # kappa=self.kappa_link(kappa), psf=psf, noise_rms=noise_rms) # chi_squared_series = chi_squared_series.write(index=self.steps - 1, value=log_likelihood) # return source_series.stack(), kappa_series.stack(), chi_squared_series.stack() with h5py.File( os.path.join( os.getenv("CENSAI_PATH"), "results", args.experiment_name + "_" + model_name + f"_{THIS_WORKER:02d}.h5"), 'w') as hf: for i, dataset in enumerate([train_dataset, val_dataset, test_dataset]): g = hf.create_group(f'{dataset_names[i]}') data_len = dataset_shapes[i] // N_WORKERS g.create_dataset(name="lens", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="psf", shape=[ data_len, physical_params['psf pixels'], physical_params['psf pixels'], 1 ], dtype=np.float32) g.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32) g.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32) g.create_dataset( name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) g.create_dataset( name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) g.create_dataset(name="lens_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="lens_pred_reoptimized", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="source_pred", shape=[ data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1 ], dtype=np.float32) g.create_dataset( name="source_pred_reoptimized", shape=[data_len, phys.src_pixels, phys.src_pixels, 1]) g.create_dataset(name="kappa_pred", shape=[ data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1 ], dtype=np.float32) g.create_dataset( name="kappa_pred_reoptimized", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) g.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32) g.create_dataset(name="chi_squared_reoptimized", shape=[data_len], dtype=np.float32) g.create_dataset(name="chi_squared_reoptimized_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) g.create_dataset(name="source_optim_mse", shape=[data_len], dtype=np.float32) g.create_dataset(name="source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) g.create_dataset(name="kappa_optim_mse", shape=[data_len], dtype=np.float32) g.create_dataset(name="kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum2", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum_reoptimized", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum2", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum_reoptimized", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_coherence_spectrum_reoptimized", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_frequencies", shape=[args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_frequencies", shape=[args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_frequencies", shape=[args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32) g.create_dataset(name="source_fov", shape=[1], dtype=np.float32) g.create_dataset(name="lens_fov", shape=[1], dtype=np.float32) dataset = dataset.skip(data_len * (THIS_WORKER - 1)).take(data_len) for batch, (lens, source, kappa, noise_rms, psf, fwhm) in enumerate( dataset.batch(1).prefetch( tf.data.experimental.AUTOTUNE)): checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial( ) # reset model weights # Compute predictions for kappa and source source_pred, kappa_pred, chi_squared = rim.predict( lens, noise_rms, psf) lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) # Re-optimize weights of the model STEPS = args.re_optimize_steps learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=args.learning_rate, decay_rate=args.decay_rate, decay_steps=args.decay_steps, staircase=args.staircase) optim = tf.keras.optimizers.RMSprop( learning_rate=learning_rate_schedule) chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) source_mse = tf.TensorArray(DTYPE, size=STEPS) kappa_mse = tf.TensorArray(DTYPE, size=STEPS) best = chi_squared[-1, 0] # best = abs(2*chi_squared[-1, 0] - 1) # best_chisq = 2*chi_squared[-1, 0] source_best = source_pred[-1] kappa_best = kappa_pred[-1] # source_mean = source_pred[-1] # kappa_mean = rim.kappa_link(kappa_pred[-1]) # source_std = tf.zeros_like(source_mean) # kappa_std = tf.zeros_like(kappa_mean) # counter = 0 for current_step in tqdm(range(STEPS)): with tf.GradientTape() as tape: tape.watch(unet.trainable_variables) # s, k, chi_sq = call_with_mask(rim, lens, noise_rms, psf, mask, tape) s, k, chi_sq = rim.call(lens, noise_rms, psf, outer_tape=tape) cost = tf.reduce_mean(chi_sq) # mean over time steps cost += tf.reduce_sum(rim.unet.losses) log_likelihood = chi_sq[-1] chi_squared_series = chi_squared_series.write( index=current_step, value=log_likelihood) source_o = s[-1] kappa_o = k[-1] source_mse = source_mse.write( index=current_step, value=tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2)) kappa_mse = kappa_mse.write( index=current_step, value=tf.reduce_mean( (kappa_o - rim.kappa_inverse_link(kappa))**2)) if chi_sq[-1, 0] < args.converged_chisq: source_best = rim.source_link(source_o) kappa_best = rim.kappa_link(kappa_o) best = chi_sq[-1, 0] break if chi_sq[-1, 0] < best: source_best = rim.source_link(source_o) kappa_best = rim.kappa_link(kappa_o) best = chi_sq[-1, 0] source_mse_best = tf.reduce_mean( (source_best - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_mean( (kappa_best - rim.kappa_inverse_link(kappa))**2) # if counter > 0: # # Welford's online algorithm # # source # delta = source_o - source_mean # source_mean = (counter * source_mean + (counter + 1) * source_o)/(counter + 1) # delta2 = source_o - source_mean # source_std += delta * delta2 # # kappa # delta = rim.kappa_link(kappa_o) - kappa_mean # kappa_mean = (counter * kappa_mean + (counter + 1) * rim.kappa_link(kappa_o)) / (counter + 1) # delta2 = rim.kappa_link(kappa_o) - kappa_mean # kappa_std += delta * delta2 # if best_chisq < args.converged_chisq: # counter += 1 # if counter == args.window: # break # if 2*chi_sq[-1, 0] < best_chisq: # best_chisq = 2*chi_sq[-1, 0] # if abs(2*chi_sq[-1, 0] - 1) < best: # source_best = rim.source_link(source_o) # kappa_best = rim.kappa_link(kappa_o) # best = abs(2 * chi_squared[-1, 0] - 1) # source_mse_best = tf.reduce_mean((source_best - rim.source_inverse_link(source)) ** 2) # kappa_mse_best = tf.reduce_mean((kappa_best - rim.kappa_inverse_link(kappa)) ** 2) grads = tape.gradient(cost, unet.trainable_variables) optim.apply_gradients(zip(grads, unet.trainable_variables)) source_o = source_best kappa_o = kappa_best y_pred = phys.forward(source_o, kappa_o, psf) chi_sq_series = tf.transpose(chi_squared_series.stack(), perm=[1, 0]) source_mse = source_mse.stack()[None, ...] kappa_mse = kappa_mse.stack()[None, ...] # kappa_std /= float(args.window) # source_std /= float(args.window) # Compute Power spectrum of converged predictions _ps_lens = ps_lens.cross_correlation_coefficient( lens[..., 0], lens_pred[..., 0]) _ps_lens3 = ps_lens.cross_correlation_coefficient( lens[..., 0], y_pred[..., 0]) _ps_kappa = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0]) _ps_kappa2 = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_o[..., 0])) _ps_source = ps_source.cross_correlation_coefficient( source[..., 0], source_pred[-1][..., 0]) _ps_source3 = ps_source.cross_correlation_coefficient( source[..., 0], source_o[..., 0]) # save results g["lens"][batch] = lens.numpy().astype(np.float32) g["psf"][batch] = psf.numpy().astype(np.float32) g["psf_fwhm"][batch] = fwhm.numpy().astype(np.float32) g["noise_rms"][batch] = noise_rms.numpy().astype(np.float32) g["source"][batch] = source.numpy().astype(np.float32) g["kappa"][batch] = kappa.numpy().astype(np.float32) g["lens_pred"][batch] = lens_pred.numpy().astype(np.float32) g["lens_pred_reoptimized"][batch] = y_pred.numpy().astype( np.float32) g["source_pred"][batch] = tf.transpose( source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) g["source_pred_reoptimized"][batch] = source_o.numpy().astype( np.float32) g["kappa_pred"][batch] = tf.transpose( kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) g["kappa_pred_reoptimized"][batch] = kappa_o.numpy().astype( np.float32) g["chi_squared"][batch] = tf.transpose( chi_squared).numpy().astype(np.float32) g["chi_squared_reoptimized"][batch] = best.numpy().astype( np.float32) g["chi_squared_reoptimized_series"][ batch] = chi_sq_series.numpy().astype(np.float32) g["source_optim_mse"][batch] = source_mse_best.numpy().astype( np.float32) g["source_optim_mse_series"][batch] = source_mse.numpy( ).astype(np.float32) g["kappa_optim_mse"][batch] = kappa_mse_best.numpy().astype( np.float32) g["kappa_optim_mse_series"][batch] = kappa_mse.numpy().astype( np.float32) g["lens_coherence_spectrum"][batch] = _ps_lens g["lens_coherence_spectrum_reoptimized"][batch] = _ps_lens3 g["source_coherence_spectrum"][batch] = _ps_source g["source_coherence_spectrum_reoptimized"][batch] = _ps_source3 g["lens_coherence_spectrum"][batch] = _ps_lens g["lens_coherence_spectrum"][batch] = _ps_lens g["kappa_coherence_spectrum"][batch] = _ps_kappa g["kappa_coherence_spectrum_reoptimized"][batch] = _ps_kappa2 if batch == 0: _, f = np.histogram(np.fft.fftfreq( phys.pixels)[:phys.pixels // 2], bins=ps_lens.bins) f = (f[:-1] + f[1:]) / 2 g["lens_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins) f = (f[:-1] + f[1:]) / 2 g["source_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins) f = (f[:-1] + f[1:]) / 2 g["kappa_frequencies"][:] = f g["kappa_fov"][0] = phys.kappa_fov g["source_fov"][0] = phys.src_fov # Create SIE test g = hf.create_group(f'SIE_test') data_len = args.sie_size // N_WORKERS sie_dataset = test_dataset.skip(data_len * (THIS_WORKER - 1)).take(data_len) g.create_dataset(name="lens", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="psf", shape=[ data_len, physical_params['psf pixels'], physical_params['psf pixels'], 1 ], dtype=np.float32) g.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32) g.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32) g.create_dataset(name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) g.create_dataset( name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) g.create_dataset(name="lens_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="lens_pred2", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset( name="source_pred", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) g.create_dataset(name="kappa_pred", shape=[ data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1 ], dtype=np.float32) g.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum2", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum2", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_frequencies", shape=[args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_frequencies", shape=[args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_frequencies", shape=[args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="einstein_radius", shape=[data_len], dtype=np.float32) g.create_dataset(name="position", shape=[data_len, 2], dtype=np.float32) g.create_dataset(name="orientation", shape=[data_len], dtype=np.float32) g.create_dataset(name="ellipticity", shape=[data_len], dtype=np.float32) g.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32) g.create_dataset(name="source_fov", shape=[1], dtype=np.float32) g.create_dataset(name="lens_fov", shape=[1], dtype=np.float32) for batch, (_, source, _, noise_rms, psf, fwhm) in enumerate( sie_dataset.take(data_len).batch(args.batch_size).prefetch( tf.data.experimental.AUTOTUNE)): batch_size = source.shape[0] # Create some SIE kappa maps _r = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0, maxval=args.max_shift) _theta = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=-np.pi, maxval=np.pi) x0 = _r * tf.math.cos(_theta) y0 = _r * tf.math.sin(_theta) ellipticity = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0., maxval=args.max_ellipticity) phi = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=-np.pi, maxval=np.pi) einstein_radius = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=args.min_theta_e, maxval=args.max_theta_e) kappa = phys_sie.kappa_field(x0=x0, y0=y0, e=ellipticity, phi=phi, r_ein=einstein_radius) lens = phys.noisy_forward(source, kappa, noise_rms=noise_rms, psf=psf) # Compute predictions for kappa and source source_pred, kappa_pred, chi_squared = rim.predict( lens, noise_rms, psf) lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) # Compute Power spectrum of converged predictions _ps_lens = ps_lens.cross_correlation_coefficient( lens[..., 0], lens_pred[..., 0]) _ps_kappa = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0]) _ps_source = ps_source.cross_correlation_coefficient( source[..., 0], source_pred[-1][..., 0]) # save results i_begin = batch * args.batch_size i_end = i_begin + batch_size g["lens"][i_begin:i_end] = lens.numpy().astype(np.float32) g["psf"][i_begin:i_end] = psf.numpy().astype(np.float32) g["psf_fwhm"][i_begin:i_end] = fwhm.numpy().astype(np.float32) g["noise_rms"][i_begin:i_end] = noise_rms.numpy().astype( np.float32) g["source"][i_begin:i_end] = source.numpy().astype(np.float32) g["kappa"][i_begin:i_end] = kappa.numpy().astype(np.float32) g["lens_pred"][i_begin:i_end] = lens_pred.numpy().astype( np.float32) g["source_pred"][i_begin:i_end] = tf.transpose( source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) g["kappa_pred"][i_begin:i_end] = tf.transpose( kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) g["chi_squared"][i_begin:i_end] = 2 * tf.transpose( chi_squared).numpy().astype(np.float32) g["lens_coherence_spectrum"][i_begin:i_end] = _ps_lens.numpy( ).astype(np.float32) g["source_coherence_spectrum"][i_begin:i_end] = _ps_source.numpy( ).astype(np.float32) g["kappa_coherence_spectrum"][i_begin:i_end] = _ps_kappa.numpy( ).astype(np.float32) g["einstein_radius"][ i_begin:i_end] = einstein_radius[:, 0, 0, 0].numpy().astype(np.float32) g["position"][i_begin:i_end] = tf.stack( [x0[:, 0, 0, 0], y0[:, 0, 0, 0]], axis=1).numpy().astype(np.float32) g["ellipticity"][i_begin:i_end] = ellipticity[:, 0, 0, 0].numpy().astype( np.float32) g["orientation"][i_begin:i_end] = phi[:, 0, 0, 0].numpy().astype(np.float32) if batch == 0: _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels // 2], bins=ps_lens.bins) f = (f[:-1] + f[1:]) / 2 g["lens_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins) f = (f[:-1] + f[1:]) / 2 g["source_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins) f = (f[:-1] + f[1:]) / 2 g["kappa_frequencies"][:] = f g["kappa_fov"][0] = phys.kappa_fov g["source_fov"][0] = phys.src_fov
def __init__( self, pixels, filter_scaling=1, layers=4, block_conv_layers=2, kernel_size=3, filters=32, strides=2, bottleneck_filters=None, resampling_kernel_size=None, input_kernel_size=11, upsampling_interpolation=False, # use strided transposed convolution if false kernel_regularizer_amp=0., bias_regularizer_amp=0., # if bias is used activation="linear", initializer="random_uniform", use_bias=False, kappalog=True, normalize=False, trainable=True, input_layer=False, name="ray_tracer", ): super(RayTracer, self).__init__(name=name) self.trainable = trainable self.kappalog = kappalog self.kappa_normalize = normalize common_params = { "padding": "same", "kernel_initializer": initializer, "data_format": "channels_last", "use_bias": use_bias, "kernel_regularizer": tf.keras.regularizers.L2(l2=kernel_regularizer_amp) } if use_bias: common_params.update({ "bias_regularizer": tf.keras.regularizers.L2(l2=bias_regularizer_amp) }) resampling_kernel_size = resampling_kernel_size if resampling_kernel_size is not None else kernel_size bottleneck_filters = bottleneck_filters if bottleneck_filters is not None else int( filter_scaling**(layers) * filters) activation = get_activation(activation) # compute size of bottleneck here bottleneck_size = pixels // strides**(layers) self.encoding_layers = [] self.decoding_layers = [] for i in range(layers): self.encoding_layers.append( UnetEncodingLayer( kernel_size=kernel_size, downsampling_kernel_size=resampling_kernel_size, filters=int(filter_scaling**(i) * filters), strides=strides, conv_layers=block_conv_layers, activation=activation, **common_params)) self.decoding_layers.append( UnetDecodingLayer( kernel_size=kernel_size, upsampling_kernel_size=resampling_kernel_size, filters=int(filter_scaling**(i) * filters), conv_layers=block_conv_layers, strides=strides, activation=activation, bilinear=upsampling_interpolation, **common_params)) # reverse decoding layers order self.decoding_layers = self.decoding_layers[::-1] self.bottleneck_layer1 = tf.keras.layers.Conv2D( filters=bottleneck_filters, kernel_size=2 * bottleneck_size, # we perform a convolution over the full image at this point, activation="linear", **common_params) self.bottleneck_layer2 = tf.keras.layers.Conv2D( filters=bottleneck_filters, kernel_size=2 * bottleneck_size, activation="linear", **common_params) self.output_layer = tf.keras.layers.Conv2D(filters=2, kernel_size=(1, 1), activation="linear", **common_params) if input_layer is True: self.input_layer = tf.keras.layers.Conv2D( filters=filters, kernel_size=input_kernel_size, activation="linear", **common_params) else: self.input_layer = tf.identity if self.kappalog: if self.kappa_normalize: self.kappa_link = tf.keras.layers.Lambda( lambda x: log_10(logkappa_normalization(x, forward=True))) else: self.kappa_link = tf.keras.layers.Lambda(lambda x: log_10(x)) else: self.kappa_link = tf.identity
def __init__(self, observation, noise_rms, psf, phys: PhysicalModel, rim: RIM, source_vae: VAE, kappa_vae: VAE, n_samples=100, sigma_source=0.5, sigma_kappa=0.5): """ Make a copy of initial parameters \varphi^{(0)} and compute the Fisher diagonal F_{ii} """ wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum( tf.sqrt(k), axis=(1, 2, 3), keepdims=True)) # Baseline prediction from observation source_pred, kappa_pred, chi_squared = rim.predict( observation, noise_rms, psf) # Latent code of model predictions z_source, _ = source_vae.encoder(source_pred[-1]) z_kappa, _ = kappa_vae.encoder(log_10(kappa_pred[-1])) # Deepcopy of the initial parameters self.initial_params = [ deepcopy(w) for w in rim.unet.trainable_variables ] self.fisher_diagonal = [tf.zeros_like(w) for w in self.initial_params] for n in range(n_samples): # Sample latent code around the prediction mean z_s = tf.random.normal(shape=[1, source_vae.latent_size], mean=z_source, stddev=sigma_source) z_k = tf.random.normal(shape=[1, kappa_vae.latent_size], mean=z_kappa, stddev=sigma_kappa) # Decode sampled_source = tf.nn.relu(source_vae.decode(z_s)) sampled_source /= tf.reduce_max(sampled_source, axis=(1, 2, 3), keepdims=True) sampled_kappa = kappa_vae.decode(z_k) # output in log_10 space # Simulate observation sampled_observation = phys.noisy_forward(sampled_source, 10**sampled_kappa, noise_rms, psf) # Compute the gradient of the MSE with tf.GradientTape() as tape: tape.watch(rim.unet.trainable_variables) s, k, chi_squared = rim.call(sampled_observation, noise_rms, psf) # Remove the temperature from the loss when computing the Fisher: sum instead of mean, and weighted sum is renormalized by number of pixels _kappa_mse = phys.kappa_pixels**2 * tf.reduce_sum( wk(10**sampled_kappa) * (k - sampled_kappa)**2, axis=(2, 3, 4)) cost = tf.reduce_sum(_kappa_mse) cost += tf.reduce_sum((s - sampled_source)**2) grad = tape.gradient(cost, rim.unet.trainable_variables) # Square the derivative relative to initial parameters and add to total self.fisher_diagonal = [ F + g**2 / n_samples for F, g in zip(self.fisher_diagonal, grad) ]