def distributed_strategy(args): model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model) path = os.getenv('CENSAI_PATH') + "/results/" dataset = [] for file in sorted(glob.glob(path + args.h5_pattern)): try: dataset.append(h5py.File(file, "r")) except: continue B = dataset[0]["source"].shape[0] data_len = len(dataset) * B // N_WORKERS ps_observation = PowerSpectrum(bins=args.observation_coherence_bins, pixels=128) ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=128) ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=128) phys = PhysicalModel( pixels=128, kappa_pixels=128, src_pixels=128, image_fov=7.69, kappa_fov=7.69, src_fov=3., method="fft", ) with open(os.path.join(model, "unet_hparams.json")) as f: unet_params = json.load(f) unet_params["kernel_l2_amp"] = args.l2_amp unet = Model(**unet_params) ckpt = tf.train.Checkpoint(net=unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1) checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial() with open(os.path.join(model, "rim_hparams.json")) as f: rim_params = json.load(f) rim = RIM(phys, unet, **rim_params) kvae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.kappa_vae) with open(os.path.join(kvae_path, "model_hparams.json"), "r") as f: kappa_vae_hparams = json.load(f) kappa_vae = VAE(**kappa_vae_hparams) ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae) checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, kvae_path, 1) checkpoint_manager1.checkpoint.restore( checkpoint_manager1.latest_checkpoint).expect_partial() svae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.source_vae) with open(os.path.join(svae_path, "model_hparams.json"), "r") as f: source_vae_hparams = json.load(f) source_vae = VAE(**source_vae_hparams) ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae) checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, svae_path, 1) checkpoint_manager2.checkpoint.restore( checkpoint_manager2.latest_checkpoint).expect_partial() wk = lambda k: tf.sqrt(k) / tf.reduce_sum( tf.sqrt(k), axis=(1, 2, 3), keepdims=True) # Freeze L5 # encoding layers # rim.unet.layers[0].trainable = False # L1 # rim.unet.layers[1].trainable = False # rim.unet.layers[2].trainable = False # rim.unet.layers[3].trainable = False # rim.unet.layers[4].trainable = False # L5 # GRU # rim.unet.layers[5].trainable = False # rim.unet.layers[6].trainable = False # rim.unet.layers[7].trainable = False # rim.unet.layers[8].trainable = False # rim.unet.layers[9].trainable = False # rim.unet.layers[15].trainable = False # bottleneck GRU # output layer # rim.unet.layers[-2].trainable = False # input layer # rim.unet.layers[-1].trainable = False # decoding layers # rim.unet.layers[10].trainable = False # L5 # rim.unet.layers[11].trainable = False # rim.unet.layers[12].trainable = False # rim.unet.layers[13].trainable = False # rim.unet.layers[14].trainable = False # L1 with h5py.File( os.path.join( os.getenv("CENSAI_PATH"), "results", args.experiment_name + "_" + args.model + "_" + args.dataset + f"_{THIS_WORKER:03d}.h5"), 'w') as hf: hf.create_dataset(name="observation", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="psf", shape=[data_len, 20, 20, 1], dtype=np.float32) hf.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32) hf.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32) hf.create_dataset( name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred_reoptimized", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred_reoptimized", shape=[data_len, phys.src_pixels, phys.src_pixels, 1]) hf.create_dataset(name="kappa_pred", shape=[ data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1 ], dtype=np.float32) hf.create_dataset( name="kappa_pred_reoptimized", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized", shape=[data_len], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="source_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum2", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum_reoptimized", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum2", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum_reoptimized", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum_reoptimized", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_frequencies", shape=[args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_frequencies", shape=[args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_frequencies", shape=[args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="source_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="observation_fov", shape=[1], dtype=np.float32) for batch, j in enumerate( range((THIS_WORKER - 1) * data_len, THIS_WORKER * data_len)): b = j // B k = j % B observation = dataset[b]["observation"][k][None, ...] source = dataset[b]["source"][k][None, ...] kappa = dataset[b]["kappa"][k][None, ...] noise_rms = np.array([dataset[b]["noise_rms"][k]]) psf = dataset[b]["psf"][k][None, ...] fwhm = dataset[b]["psf_fwhm"][k] checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial( ) # reset model weights # Compute predictions for kappa and source source_pred, kappa_pred, chi_squared = rim.predict( observation, noise_rms, psf) observation_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) # reset the seed for reproducible sampling in the VAE for EWC tf.random.set_seed(args.seed) np.random.seed(args.seed) # Initialize regularization term ewc = EWC(observation=observation, noise_rms=noise_rms, psf=psf, phys=phys, rim=rim, source_vae=source_vae, kappa_vae=kappa_vae, n_samples=args.sample_size, sigma_source=args.source_vae_ball_size, sigma_kappa=args.kappa_vae_ball_size) # Re-optimize weights of the model STEPS = args.re_optimize_steps learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=args.learning_rate, decay_rate=args.decay_rate, decay_steps=args.decay_steps, staircase=args.staircase) optim = tf.keras.optimizers.SGD( learning_rate=learning_rate_schedule) chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) source_mse = tf.TensorArray(DTYPE, size=STEPS) kappa_mse = tf.TensorArray(DTYPE, size=STEPS) best = chi_squared[-1, 0] source_best = source_pred[-1] kappa_best = kappa_pred[-1] source_mse_best = tf.reduce_mean( (source_best - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_sum( wk(kappa) * (kappa_best - rim.kappa_inverse_link(kappa))**2) for current_step in tqdm(range(STEPS)): with tf.GradientTape() as tape: tape.watch(unet.trainable_variables) s, k, chi_sq = rim.call(observation, noise_rms, psf, outer_tape=tape) cost = tf.reduce_mean(chi_sq) # mean over time steps cost += tf.reduce_sum(rim.unet.losses) # L2 regularisation cost += args.lam_ewc * ewc.penalty( rim) # Elastic Weights Consolidation log_likelihood = chi_sq[-1] chi_squared_series = chi_squared_series.write( index=current_step, value=log_likelihood) source_o = s[-1] kappa_o = k[-1] source_mse = source_mse.write( index=current_step, value=tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2)) kappa_mse = kappa_mse.write( index=current_step, value=tf.reduce_sum( wk(kappa) * (kappa_o - rim.kappa_inverse_link(kappa))**2)) if 2 * chi_sq[-1, 0] < 1.0 and args.early_stopping: source_best = rim.source_link(source_o) kappa_best = rim.kappa_link(kappa_o) best = chi_sq[-1, 0] source_mse_best = tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_sum( wk(kappa) * (kappa_o - rim.kappa_inverse_link(kappa))**2) break if chi_sq[-1, 0] < best: source_best = rim.source_link(source_o) kappa_best = rim.kappa_link(kappa_o) best = chi_sq[-1, 0] source_mse_best = tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_sum( wk(kappa) * (kappa_o - rim.kappa_inverse_link(kappa))**2) grads = tape.gradient(cost, unet.trainable_variables) optim.apply_gradients(zip(grads, unet.trainable_variables)) source_o = source_best kappa_o = kappa_best y_pred = phys.forward(source_o, kappa_o, psf) chi_sq_series = tf.transpose(chi_squared_series.stack(), perm=[1, 0]) source_mse = source_mse.stack()[None, ...] kappa_mse = kappa_mse.stack()[None, ...] # Compute Power spectrum of converged predictions _ps_observation = ps_observation.cross_correlation_coefficient( observation[..., 0], observation_pred[..., 0]) _ps_observation2 = ps_observation.cross_correlation_coefficient( observation[..., 0], y_pred[..., 0]) _ps_kappa = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0]) _ps_kappa2 = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_o[..., 0])) _ps_source = ps_source.cross_correlation_coefficient( source[..., 0], source_pred[-1][..., 0]) _ps_source2 = ps_source.cross_correlation_coefficient( source[..., 0], source_o[..., 0]) # save results hf["observation"][batch] = observation.astype(np.float32) hf["psf"][batch] = psf.astype(np.float32) hf["psf_fwhm"][batch] = fwhm hf["noise_rms"][batch] = noise_rms.astype(np.float32) hf["source"][batch] = source.astype(np.float32) hf["kappa"][batch] = kappa.astype(np.float32) hf["observation_pred"][batch] = observation_pred.numpy().astype( np.float32) hf["observation_pred_reoptimized"][batch] = y_pred.numpy().astype( np.float32) hf["source_pred"][batch] = tf.transpose( source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["source_pred_reoptimized"][batch] = source_o.numpy().astype( np.float32) hf["kappa_pred"][batch] = tf.transpose( kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["kappa_pred_reoptimized"][batch] = kappa_o.numpy().astype( np.float32) hf["chi_squared"][batch] = 2 * tf.transpose( chi_squared).numpy().astype(np.float32) hf["chi_squared_reoptimized"][batch] = 2 * best.numpy().astype( np.float32) hf["chi_squared_reoptimized_series"][ batch] = 2 * chi_sq_series.numpy().astype(np.float32) hf["source_optim_mse"][batch] = source_mse_best.numpy().astype( np.float32) hf["source_optim_mse_series"][batch] = source_mse.numpy().astype( np.float32) hf["kappa_optim_mse"][batch] = kappa_mse_best.numpy().astype( np.float32) hf["kappa_optim_mse_series"][batch] = kappa_mse.numpy().astype( np.float32) hf["observation_coherence_spectrum"][batch] = _ps_observation hf["observation_coherence_spectrum_reoptimized"][ batch] = _ps_observation2 hf["source_coherence_spectrum"][batch] = _ps_source hf["source_coherence_spectrum_reoptimized"][batch] = _ps_source2 hf["kappa_coherence_spectrum"][batch] = _ps_kappa hf["kappa_coherence_spectrum_reoptimized"][batch] = _ps_kappa2 if batch == 0: _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels // 2], bins=ps_observation.bins) f = (f[:-1] + f[1:]) / 2 hf["observation_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins) f = (f[:-1] + f[1:]) / 2 hf["source_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins) f = (f[:-1] + f[1:]) / 2 hf["kappa_frequencies"][:] = f hf["kappa_fov"][0] = phys.kappa_fov hf["source_fov"][0] = phys.src_fov
def main(args): if args.seed is not None: tf.random.set_seed(args.seed) np.random.seed(args.seed) if args.json_override is not None: if isinstance(args.json_override, list): files = args.json_override else: files = [args.json_override,] for file in files: with open(file, "r") as f: json_override = json.load(f) args_dict = vars(args) args_dict.update(json_override) files = [] for dataset in args.datasets: files.extend(glob.glob(os.path.join(dataset, "*.tfrecords"))) np.random.shuffle(files) if args.val_datasets is not None: """ In this conditional, we assume total items might be a subset of the dataset size. Thus we want to reshuffle at each epoch to get a different realisation of the dataset. In case total_items == true dataset size, this means we only change ordering of items each epochs. Also, validation is not a split of the training data, but a saved dataset on disk. """ files = tf.data.Dataset.from_tensor_slices(files).shuffle(len(files), reshuffle_each_iteration=True) dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE) # Read off global parameters from first example in dataset for physical_params in dataset.map(decode_physical_model_info): break # preprocessing dataset = dataset.map(decode_train) if args.cache_file is not None: dataset = dataset.cache(args.cache_file) train_dataset = dataset.shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True).take(args.total_items).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE) val_files = [] for dataset in args.val_datasets: val_files.extend(glob.glob(os.path.join(dataset, "*.tfrecords"))) val_files = tf.data.Dataset.from_tensor_slices(val_files).shuffle(len(files), reshuffle_each_iteration=True) val_dataset = val_files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE) val_dataset = val_dataset.map(decode_train).shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True).\ take(math.ceil((1 - args.train_split) * args.total_items)).\ batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE) else: """ Here, we split the dataset, so we assume total_items is the true dataset size. Any extra items will be discarded. This is to make sure validation set is never seen by the model, so shuffling occurs after the split. """ files = tf.data.Dataset.from_tensor_slices(files) dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE) # Read off global parameters from first example in dataset for physical_params in dataset.map(decode_physical_model_info): break # preprocessing dataset = dataset.map(decode_train) if args.cache_file is not None: dataset = dataset.cache(args.cache_file) train_dataset = dataset.take(math.floor(args.train_split * args.total_items)).shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE) val_dataset = dataset.skip(math.floor(args.train_split * args.total_items)).take(math.ceil((1 - args.train_split) * args.total_items)).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE) train_dataset = STRATEGY.experimental_distribute_dataset(train_dataset) val_dataset = STRATEGY.experimental_distribute_dataset(val_dataset) with STRATEGY.scope(): # Replicate ops accross gpus if args.raytracer is not None: with open(os.path.join(args.raytracer, "ray_tracer_hparams.json"), "r") as f: raytracer_hparams = json.load(f) raytracer = RayTracer(**raytracer_hparams) # load last checkpoint in the checkpoint directory checkpoint = tf.train.Checkpoint(net=raytracer) manager = tf.train.CheckpointManager(checkpoint, directory=args.raytracer, max_to_keep=3) checkpoint.restore(manager.latest_checkpoint).expect_partial() else: raytracer = None phys = PhysicalModel( pixels=physical_params["pixels"].numpy(), kappa_pixels=physical_params["kappa pixels"].numpy(), src_pixels=physical_params["src pixels"].numpy(), image_fov=physical_params["image fov"].numpy(), kappa_fov=physical_params["kappa fov"].numpy(), src_fov=physical_params["source fov"].numpy(), method=args.forward_method, raytracer=raytracer, ) unet = Model( filters=args.filters, filter_scaling=args.filter_scaling, kernel_size=args.kernel_size, layers=args.layers, block_conv_layers=args.block_conv_layers, strides=args.strides, bottleneck_kernel_size=args.bottleneck_kernel_size, resampling_kernel_size=args.resampling_kernel_size, input_kernel_size=args.input_kernel_size, gru_kernel_size=args.gru_kernel_size, upsampling_interpolation=args.upsampling_interpolation, kernel_l2_amp=args.kernel_l2_amp, bias_l2_amp=args.bias_l2_amp, kernel_l1_amp=args.kernel_l1_amp, bias_l1_amp=args.bias_l1_amp, activation=args.activation, initializer=args.initializer, batch_norm=args.batch_norm, dropout_rate=args.dropout_rate, filter_cap=args.filter_cap ) rim = RIM( physical_model=phys, unet=unet, steps=args.steps, adam=args.adam, kappalog=args.kappalog, source_link=args.source_link, kappa_normalize=args.kappa_normalize, flux_lagrange_multiplier=args.flux_lagrange_multiplier ) learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=args.initial_learning_rate, decay_rate=args.decay_rate, decay_steps=args.decay_steps, staircase=args.staircase ) optim = tf.keras.optimizers.deserialize( { "class_name": args.optimizer, 'config': {"learning_rate": learning_rate_schedule} } ) # weights for time steps in the loss function if args.time_weights == "uniform": wt = tf.ones(shape=(args.steps), dtype=DTYPE) / args.steps elif args.time_weights == "linear": wt = 2 * (tf.range(args.steps, dtype=DTYPE) + 1) / args.steps / (args.steps + 1) elif args.time_weights == "quadratic": wt = 6 * (tf.range(args.steps, dtype=DTYPE) + 1)**2 / args.steps / (args.steps + 1) / (2 * args.steps + 1) else: raise ValueError("time_weights must be in ['uniform', 'linear', 'quadratic']") wt = wt[..., tf.newaxis] # [steps, batch] if args.kappa_residual_weights == "uniform": wk = tf.keras.layers.Lambda(lambda k: tf.ones_like(k, dtype=DTYPE) / tf.cast(tf.math.reduce_prod(k.shape[1:]), DTYPE)) elif args.kappa_residual_weights == "linear": wk = tf.keras.layers.Lambda(lambda k: k / tf.reduce_sum(k, axis=(1, 2, 3), keepdims=True)) elif args.kappa_residual_weights == "sqrt": wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum(tf.sqrt(k), axis=(1, 2, 3), keepdims=True)) elif args.kappa_residual_weights == "quadratic": wk = tf.keras.layers.Lambda(lambda k: tf.square(k) / tf.reduce_sum(tf.square(k), axis=(1, 2, 3), keepdims=True)) else: raise ValueError("kappa_residual_weights must be in ['uniform', 'linear', 'quadratic', 'sqrt']") if args.source_residual_weights == "uniform": ws = tf.keras.layers.Lambda(lambda s: tf.ones_like(s, dtype=DTYPE) / tf.cast(tf.math.reduce_prod(s.shape[1:]), DTYPE)) elif args.source_residual_weights == "linear": ws = tf.keras.layers.Lambda(lambda s: s / tf.reduce_sum(s, axis=(1, 2, 3), keepdims=True)) elif args.source_residual_weights == "quadratic": ws = tf.keras.layers.Lambda(lambda s: tf.square(s) / tf.reduce_sum(tf.square(s), axis=(1, 2, 3), keepdims=True)) elif args.source_residual_weights == "sqrt": ws = tf.keras.layers.Lambda(lambda s: tf.sqrt(s) / tf.reduce_sum(tf.sqrt(s), axis=(1, 2, 3), keepdims=True)) else: raise ValueError("kappa_residual_weights must be in ['uniform', 'linear', 'quadratic', 'sqrt']") # ==== Take care of where to write logs and stuff ================================================================= if args.model_id.lower() != "none": if args.logname is not None: logname = args.model_id + "_" + args.logname model_id = args.model_id else: logname = args.model_id + "_" + datetime.now().strftime("%y%m%d%H%M%S") model_id = args.model_id elif args.logname is not None: logname = args.logname model_id = logname else: logname = args.logname_prefixe + "_" + datetime.now().strftime("%y%m%d%H%M%S") model_id = logname if args.logdir.lower() != "none": logdir = os.path.join(args.logdir, logname) if not os.path.isdir(logdir): os.mkdir(logdir) writer = tf.summary.create_file_writer(logdir) else: writer = nullwriter() # ===== Make sure directory and checkpoint manager are created to save model =================================== if args.model_dir.lower() != "none": checkpoints_dir = os.path.join(args.model_dir, logname) old_checkpoints_dir = os.path.join(args.model_dir, model_id) # in case they differ we load model from a different directory if not os.path.isdir(checkpoints_dir): os.mkdir(checkpoints_dir) with open(os.path.join(checkpoints_dir, "script_params.json"), "w") as f: json.dump(vars(args), f, indent=4) with open(os.path.join(checkpoints_dir, "unet_hparams.json"), "w") as f: hparams_dict = {key: vars(args)[key] for key in UNET_MODEL_HPARAMS} json.dump(hparams_dict, f, indent=4) with open(os.path.join(checkpoints_dir, "rim_hparams.json"), "w") as f: hparams_dict = {key: vars(args)[key] for key in RIM_HPARAMS} json.dump(hparams_dict, f, indent=4) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optim, net=rim.unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, old_checkpoints_dir, max_to_keep=args.max_to_keep) save_checkpoint = True # ======= Load model if model_id is provided =============================================================== if args.model_id.lower() != "none": checkpoint_manager.checkpoint.restore(checkpoint_manager.latest_checkpoint) if old_checkpoints_dir != checkpoints_dir: # save progress in another directory. if args.reset_optimizer_states: optim = tf.keras.optimizers.deserialize( { "class_name": args.optimizer, 'config': {"learning_rate": learning_rate_schedule} } ) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optim, net=rim.unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, checkpoints_dir, max_to_keep=args.max_to_keep) else: save_checkpoint = False # ================================================================================================================= def train_step(X, source, kappa, noise_rms, psf): with tf.GradientTape() as tape: tape.watch(rim.unet.trainable_variables) if args.unroll_time_steps: source_series, kappa_series, chi_squared = rim.call_function(X, noise_rms, psf) else: source_series, kappa_series, chi_squared = rim.call(X, noise_rms, psf, outer_tape=tape) # mean over image residuals (in model prediction space) source_cost1 = tf.reduce_sum(ws(source) * tf.square(source_series - rim.source_inverse_link(source)), axis=(2, 3, 4)) kappa_cost1 = tf.reduce_sum(wk(kappa) * tf.square(kappa_series - rim.kappa_inverse_link(kappa)), axis=(2, 3, 4)) # weighted mean over time steps source_cost = tf.reduce_sum(wt * source_cost1, axis=0) kappa_cost = tf.reduce_sum(wt * kappa_cost1, axis=0) # final cost is mean over global batch size cost = tf.reduce_sum(kappa_cost + source_cost) / args.batch_size gradient = tape.gradient(cost, rim.unet.trainable_variables) gradient = [tf.clip_by_norm(grad, 5.) for grad in gradient] optim.apply_gradients(zip(gradient, rim.unet.trainable_variables)) # Update metrics with "converged" score chi_squared = tf.reduce_sum(chi_squared[-1]) / args.batch_size source_cost = tf.reduce_sum(source_cost1[-1]) / args.batch_size kappa_cost = tf.reduce_sum(kappa_cost1[-1]) / args.batch_size return cost, chi_squared, source_cost, kappa_cost @tf.function def distributed_train_step(X, source, kappa, noise_rms, psf): per_replica_losses, per_replica_chi_squared, per_replica_source_cost, per_replica_kappa_cost = STRATEGY.run(train_step, args=(X, source, kappa, noise_rms, psf)) # Replica losses are aggregated by summing them global_loss = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) global_chi_squared = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_chi_squared, axis=None) global_source_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_source_cost, axis=None) global_kappa_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_kappa_cost, axis=None) return global_loss, global_chi_squared, global_source_cost, global_kappa_cost def test_step(X, source, kappa, noise_rms, psf): source_series, kappa_series, chi_squared = rim.call(X, noise_rms, psf) # mean over image residuals (in model prediction space) source_cost1 = tf.reduce_sum(ws(source) * tf.square(source_series - rim.source_inverse_link(source)), axis=(2, 3, 4)) kappa_cost1 = tf.reduce_sum(wk(kappa) * tf.square(kappa_series - rim.kappa_inverse_link(kappa)), axis=(2, 3, 4)) # weighted mean over time steps source_cost = tf.reduce_sum(wt * source_cost1, axis=0) kappa_cost = tf.reduce_sum(wt * kappa_cost1, axis=0) # final cost is mean over global batch size cost = tf.reduce_sum(kappa_cost + source_cost) / args.batch_size # Update metrics with "converged" score chi_squared = tf.reduce_sum(chi_squared[-1]) / args.batch_size source_cost = tf.reduce_sum(source_cost1[-1]) / args.batch_size kappa_cost = tf.reduce_sum(kappa_cost1[-1]) / args.batch_size return cost, chi_squared, source_cost, kappa_cost @tf.function def distributed_test_step(X, source, kappa, noise_rms, psf): per_replica_losses, per_replica_chi_squared, per_replica_source_cost, per_replica_kappa_cost = STRATEGY.run(test_step, args=(X, source, kappa, noise_rms, psf)) # Replica losses are aggregated by summing them global_loss = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) global_chi_squared = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_chi_squared, axis=None) global_source_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_source_cost, axis=None) global_kappa_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_kappa_cost, axis=None) return global_loss, global_chi_squared, global_source_cost, global_kappa_cost # ====== Training loop ============================================================================================ epoch_loss = tf.metrics.Mean() time_per_step = tf.metrics.Mean() val_loss = tf.metrics.Mean() epoch_chi_squared = tf.metrics.Mean() epoch_source_loss = tf.metrics.Mean() epoch_kappa_loss = tf.metrics.Mean() val_chi_squared = tf.metrics.Mean() val_source_loss = tf.metrics.Mean() val_kappa_loss = tf.metrics.Mean() history = { # recorded at the end of an epoch only "train_cost": [], "train_chi_squared": [], "train_source_cost": [], "train_kappa_cost": [], "val_cost": [], "val_chi_squared": [], "val_source_cost": [], "val_kappa_cost": [], "learning_rate": [], "time_per_step": [], "step": [], "wall_time": [] } best_loss = np.inf patience = args.patience step = 0 global_start = time.time() estimated_time_for_epoch = 0 out_of_time = False lastest_checkpoint = 1 for epoch in range(args.epochs): if (time.time() - global_start) > args.max_time*3600 - estimated_time_for_epoch: break epoch_start = time.time() epoch_loss.reset_states() epoch_chi_squared.reset_states() epoch_source_loss.reset_states() epoch_kappa_loss.reset_states() time_per_step.reset_states() with writer.as_default(): for batch, (X, source, kappa, noise_rms, psf) in enumerate(train_dataset): start = time.time() cost, chi_squared, source_cost, kappa_cost = distributed_train_step(X, source, kappa, noise_rms, psf) # ========== Summary and logs ================================================================================== _time = time.time() - start time_per_step.update_state([_time]) epoch_loss.update_state([cost]) epoch_chi_squared.update_state([chi_squared]) epoch_source_loss.update_state([source_cost]) epoch_kappa_loss.update_state([kappa_cost]) step += 1 # last batch we make a summary of residuals if args.n_residuals > 0: source_pred, kappa_pred, chi_squared = rim.predict(X, noise_rms, psf) lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) for res_idx in range(min(args.n_residuals, args.batch_size)): try: tf.summary.image(f"Residuals {res_idx}", plot_to_image( residual_plot( X[res_idx], source[res_idx], kappa[res_idx], lens_pred[res_idx], source_pred[-1][res_idx], kappa_pred[-1][res_idx], chi_squared[-1][res_idx] )), step=step) except ValueError: continue # ========== Validation set =================== val_loss.reset_states() val_chi_squared.reset_states() val_source_loss.reset_states() val_kappa_loss.reset_states() for X, source, kappa, noise_rms, psf in val_dataset: cost, chi_squared, source_cost, kappa_cost = distributed_test_step(X, source, kappa, noise_rms, psf) val_loss.update_state([cost]) val_chi_squared.update_state([chi_squared]) val_source_loss.update_state([source_cost]) val_kappa_loss.update_state([kappa_cost]) if args.n_residuals > 0 and math.ceil((1 - args.train_split) * args.total_items) > 0: # validation set not empty set not empty source_pred, kappa_pred, chi_squared = rim.predict(X, noise_rms, psf) lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) for res_idx in range(min(args.n_residuals, args.batch_size, math.ceil((1 - args.train_split) * args.total_items))): try: tf.summary.image(f"Val Residuals {res_idx}", plot_to_image( residual_plot( X[res_idx], # rescale intensity like it is done in the likelihood source[res_idx], kappa[res_idx], lens_pred[res_idx], source_pred[-1][res_idx], kappa_pred[-1][res_idx], chi_squared[-1][res_idx] )), step=step) except ValueError: continue val_cost = val_loss.result().numpy() train_cost = epoch_loss.result().numpy() val_chi_sq = val_chi_squared.result().numpy() train_chi_sq = epoch_chi_squared.result().numpy() val_kappa_cost = val_kappa_loss.result().numpy() train_kappa_cost = epoch_kappa_loss.result().numpy() val_source_cost = val_source_loss.result().numpy() train_source_cost = epoch_source_loss.result().numpy() tf.summary.scalar("Time per step", time_per_step.result(), step=step) tf.summary.scalar("Chi Squared", train_chi_sq, step=step) tf.summary.scalar("Kappa cost", train_kappa_cost, step=step) tf.summary.scalar("Val Kappa cost", val_kappa_cost, step=step) tf.summary.scalar("Source cost", train_source_cost, step=step) tf.summary.scalar("Val Source cost", val_source_cost, step=step) tf.summary.scalar("MSE", train_cost, step=step) tf.summary.scalar("Val MSE", val_cost, step=step) tf.summary.scalar("Learning Rate", optim.lr(step), step=step) tf.summary.scalar("Val Chi Squared", val_chi_sq, step=step) print(f"epoch {epoch} | train loss {train_cost:.3e} | val loss {val_cost:.3e} " f"| lr {optim.lr(step).numpy():.2e} | time per step {time_per_step.result().numpy():.2e} s" f"| kappa cost {train_kappa_cost:.2e} | source cost {train_source_cost:.2e} | chi sq {train_chi_sq:.2e}") history["train_cost"].append(train_cost) history["val_cost"].append(val_cost) history["learning_rate"].append(optim.lr(step).numpy()) history["train_chi_squared"].append(train_chi_sq) history["val_chi_squared"].append(val_chi_sq) history["time_per_step"].append(time_per_step.result().numpy()) history["train_kappa_cost"].append(train_kappa_cost) history["train_source_cost"].append(train_source_cost) history["val_kappa_cost"].append(val_kappa_cost) history["val_source_cost"].append(val_source_cost) history["step"].append(step) history["wall_time"].append(time.time() - global_start) cost = train_cost if args.track_train else val_cost if np.isnan(cost): print("Training broke the Universe") break if cost < (1 - args.tolerance) * best_loss: best_loss = cost patience = args.patience else: patience -= 1 if (time.time() - global_start) > args.max_time * 3600: out_of_time = True if save_checkpoint: checkpoint_manager.checkpoint.step.assign_add(1) # a bit of a hack if epoch % args.checkpoints == 0 or patience == 0 or epoch == args.epochs - 1 or out_of_time: with open(os.path.join(checkpoints_dir, "score_sheet.txt"), mode="a") as f: np.savetxt(f, np.array([[lastest_checkpoint, cost]])) lastest_checkpoint += 1 checkpoint_manager.save() print("Saved checkpoint for step {}: {}".format(int(checkpoint_manager.checkpoint.step), checkpoint_manager.latest_checkpoint)) if patience == 0: print("Reached patience") break if out_of_time: break if epoch > 0: # First epoch is always very slow and not a good estimate of an epoch time. estimated_time_for_epoch = time.time() - epoch_start if optim.lr(step).numpy() < 1e-8: print("Reached learning rate limit") break print(f"Finished training after {(time.time() - global_start)/3600:.3f} hours.") return history, best_loss
def distributed_strategy(args): tf.random.set_seed(args.seed) np.random.seed(args.seed) model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model) files = glob.glob( os.path.join(os.getenv('CENSAI_PATH'), "data", args.train_dataset, "*.tfrecords")) files = tf.data.Dataset.from_tensor_slices(files) train_dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type).shuffle(len(files)), block_length=1, num_parallel_calls=tf.data.AUTOTUNE) # Read off global parameters from first example in dataset for physical_params in train_dataset.map(decode_physical_model_info): break train_dataset = train_dataset.map(decode_results).shuffle( buffer_size=args.buffer_size) files = glob.glob( os.path.join(os.getenv('CENSAI_PATH'), "data", args.val_dataset, "*.tfrecords")) files = tf.data.Dataset.from_tensor_slices(files) val_dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type).shuffle(len(files)), block_length=1, num_parallel_calls=tf.data.AUTOTUNE) val_dataset = val_dataset.map(decode_results).shuffle( buffer_size=args.buffer_size) files = glob.glob( os.path.join(os.getenv('CENSAI_PATH'), "data", args.test_dataset, "*.tfrecords")) files = tf.data.Dataset.from_tensor_slices(files) test_dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type).shuffle(len(files)), block_length=1, num_parallel_calls=tf.data.AUTOTUNE) test_dataset = test_dataset.map(decode_results).shuffle( buffer_size=args.buffer_size) ps_lens = PowerSpectrum(bins=args.lens_coherence_bins, pixels=physical_params["pixels"].numpy()) ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=physical_params["src pixels"].numpy()) ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=physical_params["kappa pixels"].numpy()) phys = PhysicalModel( pixels=physical_params["pixels"].numpy(), kappa_pixels=physical_params["kappa pixels"].numpy(), src_pixels=physical_params["src pixels"].numpy(), image_fov=physical_params["image fov"].numpy(), kappa_fov=physical_params["kappa fov"].numpy(), src_fov=physical_params["source fov"].numpy(), method="fft", ) phys_sie = AnalyticalPhysicalModel( pixels=physical_params["pixels"].numpy(), image_fov=physical_params["image fov"].numpy(), src_fov=physical_params["source fov"].numpy()) with open(os.path.join(model, "unet_hparams.json")) as f: unet_params = json.load(f) unet_params["kernel_l2_amp"] = args.l2_amp unet = Model(**unet_params) ckpt = tf.train.Checkpoint(net=unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1) checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial() with open(os.path.join(model, "rim_hparams.json")) as f: rim_params = json.load(f) rim = RIM(phys, unet, **rim_params) dataset_names = [args.train_dataset, args.val_dataset, args.test_dataset] dataset_shapes = [args.train_size, args.val_size, args.test_size] model_name = os.path.split(model)[-1] # from censai.utils import nulltape # def call_with_mask(self, lensed_image, noise_rms, psf, mask, outer_tape=nulltape): # """ # Used in training. Return linked kappa and source maps. # """ # batch_size = lensed_image.shape[0] # source, kappa, source_grad, kappa_grad, states = self.initial_states(batch_size) # initiate all tensors to 0 # source, kappa, states = self.time_step(lensed_image, source, kappa, source_grad, kappa_grad, # states) # Use lens to make an initial guess with Unet # source_series = tf.TensorArray(DTYPE, size=self.steps) # kappa_series = tf.TensorArray(DTYPE, size=self.steps) # chi_squared_series = tf.TensorArray(DTYPE, size=self.steps) # # record initial guess # source_series = source_series.write(index=0, value=source) # kappa_series = kappa_series.write(index=0, value=kappa) # # Main optimization loop # for current_step in tf.range(self.steps - 1): # with outer_tape.stop_recording(): # with tf.GradientTape() as g: # g.watch(source) # g.watch(kappa) # y_pred = self.physical_model.forward(self.source_link(source), self.kappa_link(kappa), psf) # flux_term = tf.square( # tf.reduce_sum(y_pred, axis=(1, 2, 3)) - tf.reduce_sum(lensed_image, axis=(1, 2, 3))) # log_likelihood = 0.5 * tf.reduce_sum( # tf.square(y_pred - mask * lensed_image) / noise_rms[:, None, None, None] ** 2, axis=(1, 2, 3)) # cost = tf.reduce_mean(log_likelihood + self.flux_lagrange_multiplier * flux_term) # source_grad, kappa_grad = g.gradient(cost, [source, kappa]) # source_grad, kappa_grad = self.grad_update(source_grad, kappa_grad, current_step) # source, kappa, states = self.time_step(lensed_image, source, kappa, source_grad, kappa_grad, states) # source_series = source_series.write(index=current_step + 1, value=source) # kappa_series = kappa_series.write(index=current_step + 1, value=kappa) # chi_squared_series = chi_squared_series.write(index=current_step, # value=log_likelihood / self.pixels ** 2) # renormalize chi squared here # # last step score # log_likelihood = self.physical_model.log_likelihood(y_true=lensed_image, source=self.source_link(source), # kappa=self.kappa_link(kappa), psf=psf, noise_rms=noise_rms) # chi_squared_series = chi_squared_series.write(index=self.steps - 1, value=log_likelihood) # return source_series.stack(), kappa_series.stack(), chi_squared_series.stack() with h5py.File( os.path.join( os.getenv("CENSAI_PATH"), "results", args.experiment_name + "_" + model_name + f"_{THIS_WORKER:02d}.h5"), 'w') as hf: for i, dataset in enumerate([train_dataset, val_dataset, test_dataset]): g = hf.create_group(f'{dataset_names[i]}') data_len = dataset_shapes[i] // N_WORKERS g.create_dataset(name="lens", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="psf", shape=[ data_len, physical_params['psf pixels'], physical_params['psf pixels'], 1 ], dtype=np.float32) g.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32) g.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32) g.create_dataset( name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) g.create_dataset( name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) g.create_dataset(name="lens_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="lens_pred_reoptimized", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="source_pred", shape=[ data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1 ], dtype=np.float32) g.create_dataset( name="source_pred_reoptimized", shape=[data_len, phys.src_pixels, phys.src_pixels, 1]) g.create_dataset(name="kappa_pred", shape=[ data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1 ], dtype=np.float32) g.create_dataset( name="kappa_pred_reoptimized", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) g.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32) g.create_dataset(name="chi_squared_reoptimized", shape=[data_len], dtype=np.float32) g.create_dataset(name="chi_squared_reoptimized_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) g.create_dataset(name="source_optim_mse", shape=[data_len], dtype=np.float32) g.create_dataset(name="source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) g.create_dataset(name="kappa_optim_mse", shape=[data_len], dtype=np.float32) g.create_dataset(name="kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum2", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum_reoptimized", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum2", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum_reoptimized", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_coherence_spectrum_reoptimized", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_frequencies", shape=[args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_frequencies", shape=[args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_frequencies", shape=[args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32) g.create_dataset(name="source_fov", shape=[1], dtype=np.float32) g.create_dataset(name="lens_fov", shape=[1], dtype=np.float32) dataset = dataset.skip(data_len * (THIS_WORKER - 1)).take(data_len) for batch, (lens, source, kappa, noise_rms, psf, fwhm) in enumerate( dataset.batch(1).prefetch( tf.data.experimental.AUTOTUNE)): checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial( ) # reset model weights # Compute predictions for kappa and source source_pred, kappa_pred, chi_squared = rim.predict( lens, noise_rms, psf) lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) # Re-optimize weights of the model STEPS = args.re_optimize_steps learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=args.learning_rate, decay_rate=args.decay_rate, decay_steps=args.decay_steps, staircase=args.staircase) optim = tf.keras.optimizers.RMSprop( learning_rate=learning_rate_schedule) chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) source_mse = tf.TensorArray(DTYPE, size=STEPS) kappa_mse = tf.TensorArray(DTYPE, size=STEPS) best = chi_squared[-1, 0] # best = abs(2*chi_squared[-1, 0] - 1) # best_chisq = 2*chi_squared[-1, 0] source_best = source_pred[-1] kappa_best = kappa_pred[-1] # source_mean = source_pred[-1] # kappa_mean = rim.kappa_link(kappa_pred[-1]) # source_std = tf.zeros_like(source_mean) # kappa_std = tf.zeros_like(kappa_mean) # counter = 0 for current_step in tqdm(range(STEPS)): with tf.GradientTape() as tape: tape.watch(unet.trainable_variables) # s, k, chi_sq = call_with_mask(rim, lens, noise_rms, psf, mask, tape) s, k, chi_sq = rim.call(lens, noise_rms, psf, outer_tape=tape) cost = tf.reduce_mean(chi_sq) # mean over time steps cost += tf.reduce_sum(rim.unet.losses) log_likelihood = chi_sq[-1] chi_squared_series = chi_squared_series.write( index=current_step, value=log_likelihood) source_o = s[-1] kappa_o = k[-1] source_mse = source_mse.write( index=current_step, value=tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2)) kappa_mse = kappa_mse.write( index=current_step, value=tf.reduce_mean( (kappa_o - rim.kappa_inverse_link(kappa))**2)) if chi_sq[-1, 0] < args.converged_chisq: source_best = rim.source_link(source_o) kappa_best = rim.kappa_link(kappa_o) best = chi_sq[-1, 0] break if chi_sq[-1, 0] < best: source_best = rim.source_link(source_o) kappa_best = rim.kappa_link(kappa_o) best = chi_sq[-1, 0] source_mse_best = tf.reduce_mean( (source_best - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_mean( (kappa_best - rim.kappa_inverse_link(kappa))**2) # if counter > 0: # # Welford's online algorithm # # source # delta = source_o - source_mean # source_mean = (counter * source_mean + (counter + 1) * source_o)/(counter + 1) # delta2 = source_o - source_mean # source_std += delta * delta2 # # kappa # delta = rim.kappa_link(kappa_o) - kappa_mean # kappa_mean = (counter * kappa_mean + (counter + 1) * rim.kappa_link(kappa_o)) / (counter + 1) # delta2 = rim.kappa_link(kappa_o) - kappa_mean # kappa_std += delta * delta2 # if best_chisq < args.converged_chisq: # counter += 1 # if counter == args.window: # break # if 2*chi_sq[-1, 0] < best_chisq: # best_chisq = 2*chi_sq[-1, 0] # if abs(2*chi_sq[-1, 0] - 1) < best: # source_best = rim.source_link(source_o) # kappa_best = rim.kappa_link(kappa_o) # best = abs(2 * chi_squared[-1, 0] - 1) # source_mse_best = tf.reduce_mean((source_best - rim.source_inverse_link(source)) ** 2) # kappa_mse_best = tf.reduce_mean((kappa_best - rim.kappa_inverse_link(kappa)) ** 2) grads = tape.gradient(cost, unet.trainable_variables) optim.apply_gradients(zip(grads, unet.trainable_variables)) source_o = source_best kappa_o = kappa_best y_pred = phys.forward(source_o, kappa_o, psf) chi_sq_series = tf.transpose(chi_squared_series.stack(), perm=[1, 0]) source_mse = source_mse.stack()[None, ...] kappa_mse = kappa_mse.stack()[None, ...] # kappa_std /= float(args.window) # source_std /= float(args.window) # Compute Power spectrum of converged predictions _ps_lens = ps_lens.cross_correlation_coefficient( lens[..., 0], lens_pred[..., 0]) _ps_lens3 = ps_lens.cross_correlation_coefficient( lens[..., 0], y_pred[..., 0]) _ps_kappa = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0]) _ps_kappa2 = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_o[..., 0])) _ps_source = ps_source.cross_correlation_coefficient( source[..., 0], source_pred[-1][..., 0]) _ps_source3 = ps_source.cross_correlation_coefficient( source[..., 0], source_o[..., 0]) # save results g["lens"][batch] = lens.numpy().astype(np.float32) g["psf"][batch] = psf.numpy().astype(np.float32) g["psf_fwhm"][batch] = fwhm.numpy().astype(np.float32) g["noise_rms"][batch] = noise_rms.numpy().astype(np.float32) g["source"][batch] = source.numpy().astype(np.float32) g["kappa"][batch] = kappa.numpy().astype(np.float32) g["lens_pred"][batch] = lens_pred.numpy().astype(np.float32) g["lens_pred_reoptimized"][batch] = y_pred.numpy().astype( np.float32) g["source_pred"][batch] = tf.transpose( source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) g["source_pred_reoptimized"][batch] = source_o.numpy().astype( np.float32) g["kappa_pred"][batch] = tf.transpose( kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) g["kappa_pred_reoptimized"][batch] = kappa_o.numpy().astype( np.float32) g["chi_squared"][batch] = tf.transpose( chi_squared).numpy().astype(np.float32) g["chi_squared_reoptimized"][batch] = best.numpy().astype( np.float32) g["chi_squared_reoptimized_series"][ batch] = chi_sq_series.numpy().astype(np.float32) g["source_optim_mse"][batch] = source_mse_best.numpy().astype( np.float32) g["source_optim_mse_series"][batch] = source_mse.numpy( ).astype(np.float32) g["kappa_optim_mse"][batch] = kappa_mse_best.numpy().astype( np.float32) g["kappa_optim_mse_series"][batch] = kappa_mse.numpy().astype( np.float32) g["lens_coherence_spectrum"][batch] = _ps_lens g["lens_coherence_spectrum_reoptimized"][batch] = _ps_lens3 g["source_coherence_spectrum"][batch] = _ps_source g["source_coherence_spectrum_reoptimized"][batch] = _ps_source3 g["lens_coherence_spectrum"][batch] = _ps_lens g["lens_coherence_spectrum"][batch] = _ps_lens g["kappa_coherence_spectrum"][batch] = _ps_kappa g["kappa_coherence_spectrum_reoptimized"][batch] = _ps_kappa2 if batch == 0: _, f = np.histogram(np.fft.fftfreq( phys.pixels)[:phys.pixels // 2], bins=ps_lens.bins) f = (f[:-1] + f[1:]) / 2 g["lens_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins) f = (f[:-1] + f[1:]) / 2 g["source_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins) f = (f[:-1] + f[1:]) / 2 g["kappa_frequencies"][:] = f g["kappa_fov"][0] = phys.kappa_fov g["source_fov"][0] = phys.src_fov # Create SIE test g = hf.create_group(f'SIE_test') data_len = args.sie_size // N_WORKERS sie_dataset = test_dataset.skip(data_len * (THIS_WORKER - 1)).take(data_len) g.create_dataset(name="lens", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="psf", shape=[ data_len, physical_params['psf pixels'], physical_params['psf pixels'], 1 ], dtype=np.float32) g.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32) g.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32) g.create_dataset(name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) g.create_dataset( name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) g.create_dataset(name="lens_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="lens_pred2", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset( name="source_pred", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) g.create_dataset(name="kappa_pred", shape=[ data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1 ], dtype=np.float32) g.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum2", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum2", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_frequencies", shape=[args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_frequencies", shape=[args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_frequencies", shape=[args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="einstein_radius", shape=[data_len], dtype=np.float32) g.create_dataset(name="position", shape=[data_len, 2], dtype=np.float32) g.create_dataset(name="orientation", shape=[data_len], dtype=np.float32) g.create_dataset(name="ellipticity", shape=[data_len], dtype=np.float32) g.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32) g.create_dataset(name="source_fov", shape=[1], dtype=np.float32) g.create_dataset(name="lens_fov", shape=[1], dtype=np.float32) for batch, (_, source, _, noise_rms, psf, fwhm) in enumerate( sie_dataset.take(data_len).batch(args.batch_size).prefetch( tf.data.experimental.AUTOTUNE)): batch_size = source.shape[0] # Create some SIE kappa maps _r = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0, maxval=args.max_shift) _theta = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=-np.pi, maxval=np.pi) x0 = _r * tf.math.cos(_theta) y0 = _r * tf.math.sin(_theta) ellipticity = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0., maxval=args.max_ellipticity) phi = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=-np.pi, maxval=np.pi) einstein_radius = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=args.min_theta_e, maxval=args.max_theta_e) kappa = phys_sie.kappa_field(x0=x0, y0=y0, e=ellipticity, phi=phi, r_ein=einstein_radius) lens = phys.noisy_forward(source, kappa, noise_rms=noise_rms, psf=psf) # Compute predictions for kappa and source source_pred, kappa_pred, chi_squared = rim.predict( lens, noise_rms, psf) lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) # Compute Power spectrum of converged predictions _ps_lens = ps_lens.cross_correlation_coefficient( lens[..., 0], lens_pred[..., 0]) _ps_kappa = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0]) _ps_source = ps_source.cross_correlation_coefficient( source[..., 0], source_pred[-1][..., 0]) # save results i_begin = batch * args.batch_size i_end = i_begin + batch_size g["lens"][i_begin:i_end] = lens.numpy().astype(np.float32) g["psf"][i_begin:i_end] = psf.numpy().astype(np.float32) g["psf_fwhm"][i_begin:i_end] = fwhm.numpy().astype(np.float32) g["noise_rms"][i_begin:i_end] = noise_rms.numpy().astype( np.float32) g["source"][i_begin:i_end] = source.numpy().astype(np.float32) g["kappa"][i_begin:i_end] = kappa.numpy().astype(np.float32) g["lens_pred"][i_begin:i_end] = lens_pred.numpy().astype( np.float32) g["source_pred"][i_begin:i_end] = tf.transpose( source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) g["kappa_pred"][i_begin:i_end] = tf.transpose( kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) g["chi_squared"][i_begin:i_end] = 2 * tf.transpose( chi_squared).numpy().astype(np.float32) g["lens_coherence_spectrum"][i_begin:i_end] = _ps_lens.numpy( ).astype(np.float32) g["source_coherence_spectrum"][i_begin:i_end] = _ps_source.numpy( ).astype(np.float32) g["kappa_coherence_spectrum"][i_begin:i_end] = _ps_kappa.numpy( ).astype(np.float32) g["einstein_radius"][ i_begin:i_end] = einstein_radius[:, 0, 0, 0].numpy().astype(np.float32) g["position"][i_begin:i_end] = tf.stack( [x0[:, 0, 0, 0], y0[:, 0, 0, 0]], axis=1).numpy().astype(np.float32) g["ellipticity"][i_begin:i_end] = ellipticity[:, 0, 0, 0].numpy().astype( np.float32) g["orientation"][i_begin:i_end] = phi[:, 0, 0, 0].numpy().astype(np.float32) if batch == 0: _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels // 2], bins=ps_lens.bins) f = (f[:-1] + f[1:]) / 2 g["lens_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins) f = (f[:-1] + f[1:]) / 2 g["source_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins) f = (f[:-1] + f[1:]) / 2 g["kappa_frequencies"][:] = f g["kappa_fov"][0] = phys.kappa_fov g["source_fov"][0] = phys.src_fov
def distributed_strategy(args): psf_pixels = 20 pixels = 128 model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model) ps_observation = PowerSpectrum(bins=args.observation_coherence_bins, pixels=pixels) ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=pixels) ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=pixels) phys = PhysicalModel( pixels=pixels, kappa_pixels=pixels, src_pixels=pixels, image_fov=7.68, kappa_fov=7.68, src_fov=3., method="fft", ) with open(os.path.join(model, "unet_hparams.json")) as f: unet_params = json.load(f) unet_params["kernel_l2_amp"] = args.l2_amp unet = Model(**unet_params) ckpt = tf.train.Checkpoint(net=unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1) checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial() with open(os.path.join(model, "rim_hparams.json")) as f: rim_params = json.load(f) rim_params["source_link"] = "relu" rim = RIM(phys, unet, **rim_params) kvae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.kappa_vae) with open(os.path.join(kvae_path, "model_hparams.json"), "r") as f: kappa_vae_hparams = json.load(f) kappa_vae = VAE(**kappa_vae_hparams) ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae) checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, kvae_path, 1) checkpoint_manager1.checkpoint.restore( checkpoint_manager1.latest_checkpoint).expect_partial() svae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.source_vae) with open(os.path.join(svae_path, "model_hparams.json"), "r") as f: source_vae_hparams = json.load(f) source_vae = VAE(**source_vae_hparams) ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae) checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, svae_path, 1) checkpoint_manager2.checkpoint.restore( checkpoint_manager2.latest_checkpoint).expect_partial() model_name = os.path.split(model)[-1] wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum( tf.sqrt(k), axis=(1, 2, 3), keepdims=True)) with h5py.File( os.path.join( os.getenv("CENSAI_PATH"), "results", args.experiment_name + "_" + model_name + f"_{THIS_WORKER:02d}.h5"), 'w') as hf: data_len = args.size // N_WORKERS hf.create_dataset(name="observation", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="psf", shape=[data_len, psf_pixels, psf_pixels, 1], dtype=np.float32) hf.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32) hf.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32) hf.create_dataset( name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred_reoptimized", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred_reoptimized", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset(name="kappa_pred", shape=[ data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1 ], dtype=np.float32) hf.create_dataset( name="kappa_pred_reoptimized", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized", shape=[data_len, rim.steps], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized_series", shape=[data_len, rim.steps, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="sampled_chi_squared_reoptimized_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="source_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="sampled_source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="sampled_kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="latent_kappa_gt_distance_init", shape=[data_len, kappa_vae.latent_size], dtype=np.float32) hf.create_dataset(name="latent_source_gt_distance_init", shape=[data_len, source_vae.latent_size], dtype=np.float32) hf.create_dataset(name="latent_kappa_gt_distance_end", shape=[data_len, kappa_vae.latent_size], dtype=np.float32) hf.create_dataset(name="latent_source_gt_distance_end", shape=[data_len, source_vae.latent_size], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum_reoptimized", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum_reoptimized", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum_reoptimized", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_frequencies", shape=[args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_frequencies", shape=[args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_frequencies", shape=[args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="source_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="observation_fov", shape=[1], dtype=np.float32) for i in range(data_len): checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial( ) # reset model weights # Produce an observation kappa = 10**kappa_vae.sample(1) source = tf.nn.relu(source_vae.sample(1)) source /= tf.reduce_max(source, axis=(1, 2, 3), keepdims=True) noise_rms = 10**tf.random.uniform(shape=[1], minval=-2.5, maxval=-1) fwhm = tf.random.uniform(shape=[1], minval=0.06, maxval=0.3) psf = phys.psf_models(fwhm, cutout_size=psf_pixels) observation = phys.noisy_forward(source, kappa, noise_rms, psf) # RIM predictions for kappa and source source_pred, kappa_pred, chi_squared = rim.predict( observation, noise_rms, psf) observation_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) source_o = source_pred[-1] kappa_o = kappa_pred[-1] # Latent code of model predictions z_source, _ = source_vae.encoder(source_o) z_kappa, _ = kappa_vae.encoder(log_10(kappa_o)) # Ground truth latent code for oracle metrics z_source_gt, _ = source_vae.encoder(source) z_kappa_gt, _ = kappa_vae.encoder(log_10(kappa)) # Re-optimize weights of the model STEPS = args.re_optimize_steps learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=args.learning_rate, decay_rate=args.decay_rate, decay_steps=args.decay_steps, staircase=args.staircase) optim = tf.keras.optimizers.RMSprop( learning_rate=learning_rate_schedule) chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) source_mse = tf.TensorArray(DTYPE, size=STEPS) kappa_mse = tf.TensorArray(DTYPE, size=STEPS) sampled_chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) sampled_source_mse = tf.TensorArray(DTYPE, size=STEPS) sampled_kappa_mse = tf.TensorArray(DTYPE, size=STEPS) best = chi_squared source_best = source_pred[-1] kappa_best = kappa_pred[-1] source_mse_best = tf.reduce_mean((source_best - source)**2) kappa_mse_best = tf.reduce_mean((kappa_best - log_10(kappa))**2) # ===================== Optimization ============================== for current_step in tqdm(range(STEPS)): # ===================== VAE SAMPLING ============================== # L1 distance with ground truth in latent space -- this is changed by an user defined value when using real data # z_source_std = tf.abs(z_source - z_source_gt) # z_kappa_std = tf.abs(z_kappa - z_kappa_gt) z_source_std = args.source_vae_ball_size z_kappa_std = args.kappa_vae_ball_size # Sample latent code, then decode and forward z_s = tf.random.normal( shape=[args.sample_size, source_vae.latent_size], mean=z_source, stddev=z_source_std) z_k = tf.random.normal( shape=[args.sample_size, kappa_vae.latent_size], mean=z_kappa, stddev=z_kappa_std) sampled_source = tf.nn.relu(source_vae.decode(z_s)) sampled_source /= tf.reduce_max(sampled_source, axis=(1, 2, 3), keepdims=True) sampled_kappa = kappa_vae.decode(z_k) # output in log_10 space sampled_observation = phys.noisy_forward( sampled_source, 10**sampled_kappa, noise_rms, tf.tile(psf, [args.sample_size, 1, 1, 1])) with tf.GradientTape() as tape: tape.watch(unet.trainable_variables) s, k, chi_sq = rim.call( sampled_observation, noise_rms, tf.tile(psf, [args.sample_size, 1, 1, 1]), outer_tape=tape) _kappa_mse = tf.reduce_sum(wk(10**sampled_kappa) * (k - sampled_kappa)**2, axis=(2, 3, 4)) cost = tf.reduce_mean(_kappa_mse) cost += tf.reduce_mean((s - sampled_source)**2) cost += tf.reduce_sum(rim.unet.losses) # weight decay grads = tape.gradient(cost, unet.trainable_variables) optim.apply_gradients(zip(grads, unet.trainable_variables)) # Record performance on sampled dataset sampled_chi_squared_series = sampled_chi_squared_series.write( index=current_step, value=tf.squeeze(tf.reduce_mean(chi_sq[-1]))) sampled_source_mse = sampled_source_mse.write( index=current_step, value=tf.reduce_mean((s[-1] - sampled_source)**2)) sampled_kappa_mse = sampled_kappa_mse.write( index=current_step, value=tf.reduce_mean((k[-1] - sampled_kappa)**2)) # Record model prediction on data s, k, chi_sq = rim.call(observation, noise_rms, psf) chi_squared_series = chi_squared_series.write( index=current_step, value=tf.squeeze(chi_sq)) source_o = s[-1] kappa_o = k[-1] # oracle metrics, remove when using real data source_mse = source_mse.write(index=current_step, value=tf.reduce_mean( (source_o - source)**2)) kappa_mse = kappa_mse.write(index=current_step, value=tf.reduce_mean( (kappa_o - log_10(kappa))**2)) if abs(chi_sq[-1, 0] - 1) < abs(best[-1, 0] - 1): source_best = tf.nn.relu(source_o) kappa_best = 10**kappa_o best = chi_sq source_mse_best = tf.reduce_mean((source_best - source)**2) kappa_mse_best = tf.reduce_mean( (kappa_best - log_10(kappa))**2) source_o = source_best kappa_o = kappa_best y_pred = phys.forward(source_o, kappa_o, psf) chi_sq_series = tf.transpose(chi_squared_series.stack()) source_mse = source_mse.stack() kappa_mse = kappa_mse.stack() sampled_chi_squared_series = sampled_chi_squared_series.stack() sampled_source_mse = sampled_source_mse.stack() sampled_kappa_mse = sampled_kappa_mse.stack() # Latent code of optimized model predictions z_source_opt, _ = source_vae.encoder(tf.nn.relu(source_o)) z_kappa_opt, _ = kappa_vae.encoder(log_10(kappa_o)) # Compute Power spectrum of converged predictions _ps_observation = ps_observation.cross_correlation_coefficient( observation[..., 0], observation_pred[..., 0]) _ps_observation2 = ps_observation.cross_correlation_coefficient( observation[..., 0], y_pred[..., 0]) _ps_kappa = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0]) _ps_kappa2 = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_o[..., 0])) _ps_source = ps_source.cross_correlation_coefficient( source[..., 0], source_pred[-1][..., 0]) _ps_source2 = ps_source.cross_correlation_coefficient( source[..., 0], source_o[..., 0]) # save results hf["observation"][i] = observation.numpy().astype(np.float32) hf["psf"][i] = psf.numpy().astype(np.float32) hf["psf_fwhm"][i] = fwhm.numpy().astype(np.float32) hf["noise_rms"][i] = noise_rms.numpy().astype(np.float32) hf["source"][i] = source.numpy().astype(np.float32) hf["kappa"][i] = kappa.numpy().astype(np.float32) hf["observation_pred"][i] = observation_pred.numpy().astype( np.float32) hf["observation_pred_reoptimized"][i] = y_pred.numpy().astype( np.float32) hf["source_pred"][i] = tf.transpose( source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["source_pred_reoptimized"][i] = source_o.numpy().astype( np.float32) hf["kappa_pred"][i] = tf.transpose( kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["kappa_pred_reoptimized"][i] = kappa_o.numpy().astype( np.float32) hf["chi_squared"][i] = tf.squeeze(chi_squared).numpy().astype( np.float32) hf["chi_squared_reoptimized"][i] = tf.squeeze(best).numpy().astype( np.float32) hf["chi_squared_reoptimized_series"][i] = chi_sq_series.numpy( ).astype(np.float32) hf["sampled_chi_squared_reoptimized_series"][ i] = 2 * sampled_chi_squared_series.numpy().astype(np.float32) hf["source_optim_mse"][i] = source_mse_best.numpy().astype( np.float32) hf["source_optim_mse_series"][i] = source_mse.numpy().astype( np.float32) hf["sampled_source_optim_mse_series"][ i] = sampled_source_mse.numpy().astype(np.float32) hf["kappa_optim_mse"][i] = kappa_mse_best.numpy().astype( np.float32) hf["kappa_optim_mse_series"][i] = kappa_mse.numpy().astype( np.float32) hf["sampled_kappa_optim_mse_series"][i] = sampled_kappa_mse.numpy( ).astype(np.float32) hf["latent_source_gt_distance_init"][i] = tf.abs( z_source - z_source_gt).numpy().squeeze().astype(np.float32) hf["latent_kappa_gt_distance_init"][i] = tf.abs( z_kappa - z_kappa_gt).numpy().squeeze().astype(np.float32) hf["latent_source_gt_distance_end"][i] = tf.abs( z_source_opt - z_source_gt).numpy().squeeze().astype( np.float32) hf["latent_kappa_gt_distance_end"][i] = tf.abs( z_kappa_opt - z_kappa_gt).numpy().squeeze().astype(np.float32) hf["observation_coherence_spectrum"][i] = _ps_observation hf["observation_coherence_spectrum_reoptimized"][ i] = _ps_observation2 hf["source_coherence_spectrum"][i] = _ps_source hf["source_coherence_spectrum_reoptimized"][i] = _ps_source2 hf["kappa_coherence_spectrum"][i] = _ps_kappa hf["kappa_coherence_spectrum_reoptimized"][i] = _ps_kappa2 if i == 0: _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels // 2], bins=ps_observation.bins) f = (f[:-1] + f[1:]) / 2 hf["observation_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins) f = (f[:-1] + f[1:]) / 2 hf["source_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins) f = (f[:-1] + f[1:]) / 2 hf["kappa_frequencies"][:] = f hf["kappa_fov"][0] = phys.kappa_fov hf["source_fov"][0] = phys.src_fov
def __init__(self, observation, noise_rms, psf, phys: PhysicalModel, rim: RIM, source_vae: VAE, kappa_vae: VAE, n_samples=100, sigma_source=0.5, sigma_kappa=0.5): """ Make a copy of initial parameters \varphi^{(0)} and compute the Fisher diagonal F_{ii} """ wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum( tf.sqrt(k), axis=(1, 2, 3), keepdims=True)) # Baseline prediction from observation source_pred, kappa_pred, chi_squared = rim.predict( observation, noise_rms, psf) # Latent code of model predictions z_source, _ = source_vae.encoder(source_pred[-1]) z_kappa, _ = kappa_vae.encoder(log_10(kappa_pred[-1])) # Deepcopy of the initial parameters self.initial_params = [ deepcopy(w) for w in rim.unet.trainable_variables ] self.fisher_diagonal = [tf.zeros_like(w) for w in self.initial_params] for n in range(n_samples): # Sample latent code around the prediction mean z_s = tf.random.normal(shape=[1, source_vae.latent_size], mean=z_source, stddev=sigma_source) z_k = tf.random.normal(shape=[1, kappa_vae.latent_size], mean=z_kappa, stddev=sigma_kappa) # Decode sampled_source = tf.nn.relu(source_vae.decode(z_s)) sampled_source /= tf.reduce_max(sampled_source, axis=(1, 2, 3), keepdims=True) sampled_kappa = kappa_vae.decode(z_k) # output in log_10 space # Simulate observation sampled_observation = phys.noisy_forward(sampled_source, 10**sampled_kappa, noise_rms, psf) # Compute the gradient of the MSE with tf.GradientTape() as tape: tape.watch(rim.unet.trainable_variables) s, k, chi_squared = rim.call(sampled_observation, noise_rms, psf) # Remove the temperature from the loss when computing the Fisher: sum instead of mean, and weighted sum is renormalized by number of pixels _kappa_mse = phys.kappa_pixels**2 * tf.reduce_sum( wk(10**sampled_kappa) * (k - sampled_kappa)**2, axis=(2, 3, 4)) cost = tf.reduce_sum(_kappa_mse) cost += tf.reduce_sum((s - sampled_source)**2) grad = tape.gradient(cost, rim.unet.trainable_variables) # Square the derivative relative to initial parameters and add to total self.fisher_diagonal = [ F + g**2 / n_samples for F, g in zip(self.fisher_diagonal, grad) ]