def main(args): options = tf.io.TFRecordOptions(compression_type=args.compression_type) with open(os.path.join(args.first_stage_model_id, "model_hparams.json"), "r") as f: vae_hparams = json.load(f) # load weights vae = VAE(**vae_hparams) ckpt1 = tf.train.Checkpoint(net=vae) checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, args.first_stage_model_id, 1) checkpoint_manager1.checkpoint.restore( checkpoint_manager1.latest_checkpoint).expect_partial() n_batch = args.total_items // args.batch_size batch_per_record = args.n_records // n_batch last_record_n_batch = batch_per_record + n_batch % args.n_records for record in range(args.n_records - 1): with tf.io.TFRecordWriter( os.path.join(args.output_dir, f"data_{record:02d}.tfrecords"), options) as writer: for batch in range(batch_per_record): z = tf.random.normal(shape=[args.batch_size, vae.latent_size]) kappa_batch = vae.decode(z) for kappa in kappa_batch: features = { "kappa": _bytes_feature(kappa.numpy().tobytes()), "kappa pixels": _int64_feature(kappa.shape[0]), } record = tf.train.Example(features=tf.train.Features( feature=features)).SerializeToString() writer.write(record) with tf.io.TFRecordWriter( os.path.join(args.output_dir, f"data_{args.n_record-1:02d}.tfrecords"), options) as writer: for batch in range(last_record_n_batch): z = tf.random.normal(shape=[args.batch_size, vae.latent_size]) kappa_batch = vae.decode(z) for kappa in kappa_batch: features = { "kappa": _bytes_feature(kappa.numpy().tobytes()), "kappa pixels": _int64_feature(kappa.shape[0]), } record = tf.train.Example(features=tf.train.Features( feature=features)).SerializeToString() writer.write(record)
def main(args): if THIS_WORKER > 1: time.sleep(5) if not os.path.isdir(args.output_dir): os.mkdir(args.output_dir) if args.seed is not None: tf.random.set_seed(args.seed) # Load first stage with open(os.path.join(args.kappa_first_stage_vae, "model_hparams.json"), "r") as f: kappa_vae_hparams = json.load(f) kappa_vae = VAE(**kappa_vae_hparams) ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae) checkpoint_manager1 = tf.train.CheckpointManager( ckpt1, args.kappa_first_stage_vae, 1) checkpoint_manager1.checkpoint.restore( checkpoint_manager1.latest_checkpoint).expect_partial() kappa_vae.trainable = False options = tf.io.TFRecordOptions(compression_type=args.compression_type) with tf.io.TFRecordWriter( os.path.join(args.output_dir, f"data_{THIS_WORKER}.tfrecords"), options) as writer: print( f"Started worker {THIS_WORKER} at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}" ) for _ in range((THIS_WORKER - 1) * args.batch_size, args.len_dataset, N_WORKERS * args.batch_size): kappa = 10**kappa_vae.sample(args.batch_size) # Most important info to records are kappa and kappa fov, the rest are just fill-ins # to match same description as previous TNG tfrecords records = encode_examples(kappa=kappa, einstein_radius_init=[0.] * args.batch_size, einstein_radius=[0.] * args.batch_size, rescalings=[0.] * args.batch_size, z_source=0., z_lens=0., kappa_fov=args.kappa_fov, sigma_crit=0., kappa_ids=[0] * args.batch_size) for record in records: writer.write(record)
def distributed_strategy(args): model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model) path = os.getenv('CENSAI_PATH') + "/results/" dataset = [] for file in sorted(glob.glob(path + args.h5_pattern)): try: dataset.append(h5py.File(file, "r")) except: continue B = dataset[0]["source"].shape[0] data_len = len(dataset) * B // N_WORKERS ps_observation = PowerSpectrum(bins=args.observation_coherence_bins, pixels=128) ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=128) ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=128) phys = PhysicalModel( pixels=128, kappa_pixels=128, src_pixels=128, image_fov=7.69, kappa_fov=7.69, src_fov=3., method="fft", ) with open(os.path.join(model, "unet_hparams.json")) as f: unet_params = json.load(f) unet_params["kernel_l2_amp"] = args.l2_amp unet = Model(**unet_params) ckpt = tf.train.Checkpoint(net=unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1) checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial() with open(os.path.join(model, "rim_hparams.json")) as f: rim_params = json.load(f) rim = RIM(phys, unet, **rim_params) kvae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.kappa_vae) with open(os.path.join(kvae_path, "model_hparams.json"), "r") as f: kappa_vae_hparams = json.load(f) kappa_vae = VAE(**kappa_vae_hparams) ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae) checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, kvae_path, 1) checkpoint_manager1.checkpoint.restore( checkpoint_manager1.latest_checkpoint).expect_partial() svae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.source_vae) with open(os.path.join(svae_path, "model_hparams.json"), "r") as f: source_vae_hparams = json.load(f) source_vae = VAE(**source_vae_hparams) ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae) checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, svae_path, 1) checkpoint_manager2.checkpoint.restore( checkpoint_manager2.latest_checkpoint).expect_partial() wk = lambda k: tf.sqrt(k) / tf.reduce_sum( tf.sqrt(k), axis=(1, 2, 3), keepdims=True) # Freeze L5 # encoding layers # rim.unet.layers[0].trainable = False # L1 # rim.unet.layers[1].trainable = False # rim.unet.layers[2].trainable = False # rim.unet.layers[3].trainable = False # rim.unet.layers[4].trainable = False # L5 # GRU # rim.unet.layers[5].trainable = False # rim.unet.layers[6].trainable = False # rim.unet.layers[7].trainable = False # rim.unet.layers[8].trainable = False # rim.unet.layers[9].trainable = False # rim.unet.layers[15].trainable = False # bottleneck GRU # output layer # rim.unet.layers[-2].trainable = False # input layer # rim.unet.layers[-1].trainable = False # decoding layers # rim.unet.layers[10].trainable = False # L5 # rim.unet.layers[11].trainable = False # rim.unet.layers[12].trainable = False # rim.unet.layers[13].trainable = False # rim.unet.layers[14].trainable = False # L1 with h5py.File( os.path.join( os.getenv("CENSAI_PATH"), "results", args.experiment_name + "_" + args.model + "_" + args.dataset + f"_{THIS_WORKER:03d}.h5"), 'w') as hf: hf.create_dataset(name="observation", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="psf", shape=[data_len, 20, 20, 1], dtype=np.float32) hf.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32) hf.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32) hf.create_dataset( name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred_reoptimized", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred_reoptimized", shape=[data_len, phys.src_pixels, phys.src_pixels, 1]) hf.create_dataset(name="kappa_pred", shape=[ data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1 ], dtype=np.float32) hf.create_dataset( name="kappa_pred_reoptimized", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized", shape=[data_len], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="source_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum2", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum_reoptimized", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum2", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum_reoptimized", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum_reoptimized", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_frequencies", shape=[args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_frequencies", shape=[args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_frequencies", shape=[args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="source_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="observation_fov", shape=[1], dtype=np.float32) for batch, j in enumerate( range((THIS_WORKER - 1) * data_len, THIS_WORKER * data_len)): b = j // B k = j % B observation = dataset[b]["observation"][k][None, ...] source = dataset[b]["source"][k][None, ...] kappa = dataset[b]["kappa"][k][None, ...] noise_rms = np.array([dataset[b]["noise_rms"][k]]) psf = dataset[b]["psf"][k][None, ...] fwhm = dataset[b]["psf_fwhm"][k] checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial( ) # reset model weights # Compute predictions for kappa and source source_pred, kappa_pred, chi_squared = rim.predict( observation, noise_rms, psf) observation_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) # reset the seed for reproducible sampling in the VAE for EWC tf.random.set_seed(args.seed) np.random.seed(args.seed) # Initialize regularization term ewc = EWC(observation=observation, noise_rms=noise_rms, psf=psf, phys=phys, rim=rim, source_vae=source_vae, kappa_vae=kappa_vae, n_samples=args.sample_size, sigma_source=args.source_vae_ball_size, sigma_kappa=args.kappa_vae_ball_size) # Re-optimize weights of the model STEPS = args.re_optimize_steps learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=args.learning_rate, decay_rate=args.decay_rate, decay_steps=args.decay_steps, staircase=args.staircase) optim = tf.keras.optimizers.SGD( learning_rate=learning_rate_schedule) chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) source_mse = tf.TensorArray(DTYPE, size=STEPS) kappa_mse = tf.TensorArray(DTYPE, size=STEPS) best = chi_squared[-1, 0] source_best = source_pred[-1] kappa_best = kappa_pred[-1] source_mse_best = tf.reduce_mean( (source_best - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_sum( wk(kappa) * (kappa_best - rim.kappa_inverse_link(kappa))**2) for current_step in tqdm(range(STEPS)): with tf.GradientTape() as tape: tape.watch(unet.trainable_variables) s, k, chi_sq = rim.call(observation, noise_rms, psf, outer_tape=tape) cost = tf.reduce_mean(chi_sq) # mean over time steps cost += tf.reduce_sum(rim.unet.losses) # L2 regularisation cost += args.lam_ewc * ewc.penalty( rim) # Elastic Weights Consolidation log_likelihood = chi_sq[-1] chi_squared_series = chi_squared_series.write( index=current_step, value=log_likelihood) source_o = s[-1] kappa_o = k[-1] source_mse = source_mse.write( index=current_step, value=tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2)) kappa_mse = kappa_mse.write( index=current_step, value=tf.reduce_sum( wk(kappa) * (kappa_o - rim.kappa_inverse_link(kappa))**2)) if 2 * chi_sq[-1, 0] < 1.0 and args.early_stopping: source_best = rim.source_link(source_o) kappa_best = rim.kappa_link(kappa_o) best = chi_sq[-1, 0] source_mse_best = tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_sum( wk(kappa) * (kappa_o - rim.kappa_inverse_link(kappa))**2) break if chi_sq[-1, 0] < best: source_best = rim.source_link(source_o) kappa_best = rim.kappa_link(kappa_o) best = chi_sq[-1, 0] source_mse_best = tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_sum( wk(kappa) * (kappa_o - rim.kappa_inverse_link(kappa))**2) grads = tape.gradient(cost, unet.trainable_variables) optim.apply_gradients(zip(grads, unet.trainable_variables)) source_o = source_best kappa_o = kappa_best y_pred = phys.forward(source_o, kappa_o, psf) chi_sq_series = tf.transpose(chi_squared_series.stack(), perm=[1, 0]) source_mse = source_mse.stack()[None, ...] kappa_mse = kappa_mse.stack()[None, ...] # Compute Power spectrum of converged predictions _ps_observation = ps_observation.cross_correlation_coefficient( observation[..., 0], observation_pred[..., 0]) _ps_observation2 = ps_observation.cross_correlation_coefficient( observation[..., 0], y_pred[..., 0]) _ps_kappa = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0]) _ps_kappa2 = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_o[..., 0])) _ps_source = ps_source.cross_correlation_coefficient( source[..., 0], source_pred[-1][..., 0]) _ps_source2 = ps_source.cross_correlation_coefficient( source[..., 0], source_o[..., 0]) # save results hf["observation"][batch] = observation.astype(np.float32) hf["psf"][batch] = psf.astype(np.float32) hf["psf_fwhm"][batch] = fwhm hf["noise_rms"][batch] = noise_rms.astype(np.float32) hf["source"][batch] = source.astype(np.float32) hf["kappa"][batch] = kappa.astype(np.float32) hf["observation_pred"][batch] = observation_pred.numpy().astype( np.float32) hf["observation_pred_reoptimized"][batch] = y_pred.numpy().astype( np.float32) hf["source_pred"][batch] = tf.transpose( source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["source_pred_reoptimized"][batch] = source_o.numpy().astype( np.float32) hf["kappa_pred"][batch] = tf.transpose( kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["kappa_pred_reoptimized"][batch] = kappa_o.numpy().astype( np.float32) hf["chi_squared"][batch] = 2 * tf.transpose( chi_squared).numpy().astype(np.float32) hf["chi_squared_reoptimized"][batch] = 2 * best.numpy().astype( np.float32) hf["chi_squared_reoptimized_series"][ batch] = 2 * chi_sq_series.numpy().astype(np.float32) hf["source_optim_mse"][batch] = source_mse_best.numpy().astype( np.float32) hf["source_optim_mse_series"][batch] = source_mse.numpy().astype( np.float32) hf["kappa_optim_mse"][batch] = kappa_mse_best.numpy().astype( np.float32) hf["kappa_optim_mse_series"][batch] = kappa_mse.numpy().astype( np.float32) hf["observation_coherence_spectrum"][batch] = _ps_observation hf["observation_coherence_spectrum_reoptimized"][ batch] = _ps_observation2 hf["source_coherence_spectrum"][batch] = _ps_source hf["source_coherence_spectrum_reoptimized"][batch] = _ps_source2 hf["kappa_coherence_spectrum"][batch] = _ps_kappa hf["kappa_coherence_spectrum_reoptimized"][batch] = _ps_kappa2 if batch == 0: _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels // 2], bins=ps_observation.bins) f = (f[:-1] + f[1:]) / 2 hf["observation_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins) f = (f[:-1] + f[1:]) / 2 hf["source_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins) f = (f[:-1] + f[1:]) / 2 hf["kappa_frequencies"][:] = f hf["kappa_fov"][0] = phys.kappa_fov hf["source_fov"][0] = phys.src_fov
def main(args): files = [] files.extend(glob.glob(os.path.join(args.dataset, "*.tfrecords"))) np.random.shuffle(files) # Read concurrently from multiple records files = tf.data.Dataset.from_tensor_slices(files) dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE) if args.type == "cosmos": from censai.data.cosmos import decode_shape, decode_image as decode, preprocess_image as preprocess elif args.type == "kappa": from censai.data.kappa_tng import decode_shape, decode_train as decode from censai.definitions import log_10 as preprocess # Read off global parameters from first example in dataset for pixels in dataset.map(decode_shape): break vars(args).update({"pixels": int(pixels)}) dataset = dataset.map(decode).map(preprocess).shuffle( args.buffer_size).batch(args.batch_size).take(args.n_plots).cache( args.cache) model_list = glob.glob( os.path.join(os.getenv("CENSAI_PATH"), "models", args.model_prefixe + "*")) for model in model_list: if "second_stage" in model: continue with open(os.path.join(model, "model_hparams.json")) as f: vae_hparams = json.load(f) # load weights vae = VAE(**vae_hparams) ckpt1 = tf.train.Checkpoint(net=vae) checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, model, 1) checkpoint_manager1.checkpoint.restore( checkpoint_manager1.latest_checkpoint).expect_partial() vae.trainable = False model_name = os.path.split(model)[-1] for batch, images in enumerate(dataset): y_pred = vae(images) fig = reconstruction_plot(images, y_pred) fig.suptitle(model_name) fig.savefig( os.path.join( os.getenv("CENSAI_PATH"), "results", "vae_reconstruction_" + model_name + "_" + args.output_postfixe + f"_{batch:02d}.png")) fig.clf() y_pred = vae.sample(args.sampling_size) fig = sampling_plot(y_pred) fig.suptitle(model_name) fig.savefig( os.path.join( os.getenv("CENSAI_PATH"), "results", "vae_sampling_" + model_name + "_" + args.output_postfixe + f"_{batch:02d}.png")) fig.clf()
def main(args): files = glob.glob(os.path.join(args.dataset, "*.tfrecords")) files = tf.data.Dataset.from_tensor_slices(files) dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type), block_length=1, num_parallel_calls=tf.data.AUTOTUNE) for physical_params in dataset.map(decode_physical_model_info): break dataset = dataset.map(decode_train) # files = glob.glob(os.path.join(args.source_dataset, "*.tfrecords")) # files = tf.data.Dataset.from_tensor_slices(files) # source_dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type), # block_length=1, num_parallel_calls=tf.data.AUTOTUNE) # source_dataset = source_dataset.map(decode_image).map(preprocess_image).shuffle(10000).batch(args.sample_size) with open(os.path.join(args.kappa_vae, "model_hparams.json"), "r") as f: kappa_vae_hparams = json.load(f) kappa_vae = VAE(**kappa_vae_hparams) ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae) checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, args.kappa_vae, 1) checkpoint_manager1.checkpoint.restore( checkpoint_manager1.latest_checkpoint).expect_partial() with open(os.path.join(args.source_vae, "model_hparams.json"), "r") as f: source_vae_hparams = json.load(f) source_vae = VAE(**source_vae_hparams) ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae) checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, args.source_vae, 1) checkpoint_manager2.checkpoint.restore( checkpoint_manager2.latest_checkpoint).expect_partial() phys = PhysicalModel(pixels=physical_params["pixels"].numpy(), kappa_pixels=physical_params["kappa pixels"].numpy(), src_pixels=physical_params["src pixels"].numpy(), image_fov=physical_params["image fov"].numpy(), kappa_fov=physical_params["kappa fov"].numpy(), src_fov=physical_params["source fov"].numpy(), method="fft") # simulate observations kappa = 10**kappa_vae.sample(args.sample_size) source = preprocess_image(source_vae.sample(args.sample_size)) # for source in source_dataset: # break fwhm = tf.random.normal(shape=[args.sample_size], mean=1.5 * phys.image_fov / phys.pixels, stddev=0.5 * phys.image_fov / phys.pixels) # noise_rms = tf.random.normal(shape=[args.sample_size], mean=args.noise_mean, stddev=args.noise_std) psf = phys.psf_models(fwhm, cutout_size=20) y_vae = phys.forward(source, kappa, psf) with h5py.File( os.path.join(os.getenv("CENSAI_PATH"), "results", args.output_name + ".h5"), 'w') as hf: # rank these observations against the dataset with L2 norm for i in tqdm(range(args.sample_size)): distances = [] for y_d, _, _, _, _ in dataset: distances.append( tf.sqrt(tf.reduce_sum( (y_d - y_vae[i][None, ...])**2)).numpy().astype( np.float32)) k_indices = np.argsort(distances)[:args.k] # save results g = hf.create_group(f"sample_{i:02d}") g.create_dataset(name="matched_source", shape=[args.k, phys.src_pixels, phys.src_pixels], dtype=np.float32) g.create_dataset( name="matched_kappa", shape=[args.k, phys.kappa_pixels, phys.kappa_pixels], dtype=np.float32) g.create_dataset(name="matched_obs", shape=[args.k, phys.pixels, phys.pixels], dtype=np.float32) g.create_dataset(name="matched_psf", shape=[args.k, 20, 20], dtype=np.float32) g.create_dataset(name="matched_noise_rms", shape=[args.k], dtype=np.float32) g.create_dataset(name="obs_L2_distance", shape=[args.k], dtype=np.float32) g["vae_source"] = source[i, ..., 0].numpy().astype(np.float32) g["vae_kappa"] = kappa[i, ..., 0].numpy().astype(np.float32) g["vae_obs"] = y_vae[i, ..., 0].numpy().astype(np.float32) g["vae_psf"] = psf[i, ..., 0].numpy().astype(np.float32) for rank, j in enumerate(k_indices): # fetch back the matched observation for y_d, source_d, kappa_d, noise_rms_d, psf_d in dataset.skip( j): break # g["vae_noise_rms"] = noise_rms[i].numpy().astype(np.float32) g["matched_source"][rank] = source_d[..., 0].numpy().astype( np.float32) g["matched_kappa"][rank] = kappa_d[..., 0].numpy().astype( np.float32) g["matched_obs"][rank] = y_d[..., 0].numpy().astype(np.float32) g["matched_noise_rms"][rank] = noise_rms_d.numpy().astype( np.float32) g["matched_psf"][rank] = psf_d[..., 0].numpy().astype(np.float32) g["obs_L2_distance"][rank] = distances[j]
def distributed_strategy(args): psf_pixels = 20 pixels = 128 model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model) ps_observation = PowerSpectrum(bins=args.observation_coherence_bins, pixels=pixels) ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=pixels) ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=pixels) phys = PhysicalModel( pixels=pixels, kappa_pixels=pixels, src_pixels=pixels, image_fov=7.68, kappa_fov=7.68, src_fov=3., method="fft", ) with open(os.path.join(model, "unet_hparams.json")) as f: unet_params = json.load(f) unet_params["kernel_l2_amp"] = args.l2_amp unet = Model(**unet_params) ckpt = tf.train.Checkpoint(net=unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1) checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial() with open(os.path.join(model, "rim_hparams.json")) as f: rim_params = json.load(f) rim_params["source_link"] = "relu" rim = RIM(phys, unet, **rim_params) kvae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.kappa_vae) with open(os.path.join(kvae_path, "model_hparams.json"), "r") as f: kappa_vae_hparams = json.load(f) kappa_vae = VAE(**kappa_vae_hparams) ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae) checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, kvae_path, 1) checkpoint_manager1.checkpoint.restore( checkpoint_manager1.latest_checkpoint).expect_partial() svae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.source_vae) with open(os.path.join(svae_path, "model_hparams.json"), "r") as f: source_vae_hparams = json.load(f) source_vae = VAE(**source_vae_hparams) ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae) checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, svae_path, 1) checkpoint_manager2.checkpoint.restore( checkpoint_manager2.latest_checkpoint).expect_partial() model_name = os.path.split(model)[-1] wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum( tf.sqrt(k), axis=(1, 2, 3), keepdims=True)) with h5py.File( os.path.join( os.getenv("CENSAI_PATH"), "results", args.experiment_name + "_" + model_name + f"_{THIS_WORKER:02d}.h5"), 'w') as hf: data_len = args.size // N_WORKERS hf.create_dataset(name="observation", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="psf", shape=[data_len, psf_pixels, psf_pixels, 1], dtype=np.float32) hf.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32) hf.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32) hf.create_dataset( name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred_reoptimized", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred_reoptimized", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset(name="kappa_pred", shape=[ data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1 ], dtype=np.float32) hf.create_dataset( name="kappa_pred_reoptimized", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized", shape=[data_len, rim.steps], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized_series", shape=[data_len, rim.steps, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="sampled_chi_squared_reoptimized_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="source_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="sampled_source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="sampled_kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="latent_kappa_gt_distance_init", shape=[data_len, kappa_vae.latent_size], dtype=np.float32) hf.create_dataset(name="latent_source_gt_distance_init", shape=[data_len, source_vae.latent_size], dtype=np.float32) hf.create_dataset(name="latent_kappa_gt_distance_end", shape=[data_len, kappa_vae.latent_size], dtype=np.float32) hf.create_dataset(name="latent_source_gt_distance_end", shape=[data_len, source_vae.latent_size], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum_reoptimized", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum_reoptimized", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum_reoptimized", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_frequencies", shape=[args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_frequencies", shape=[args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_frequencies", shape=[args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="source_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="observation_fov", shape=[1], dtype=np.float32) for i in range(data_len): checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial( ) # reset model weights # Produce an observation kappa = 10**kappa_vae.sample(1) source = tf.nn.relu(source_vae.sample(1)) source /= tf.reduce_max(source, axis=(1, 2, 3), keepdims=True) noise_rms = 10**tf.random.uniform(shape=[1], minval=-2.5, maxval=-1) fwhm = tf.random.uniform(shape=[1], minval=0.06, maxval=0.3) psf = phys.psf_models(fwhm, cutout_size=psf_pixels) observation = phys.noisy_forward(source, kappa, noise_rms, psf) # RIM predictions for kappa and source source_pred, kappa_pred, chi_squared = rim.predict( observation, noise_rms, psf) observation_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) source_o = source_pred[-1] kappa_o = kappa_pred[-1] # Latent code of model predictions z_source, _ = source_vae.encoder(source_o) z_kappa, _ = kappa_vae.encoder(log_10(kappa_o)) # Ground truth latent code for oracle metrics z_source_gt, _ = source_vae.encoder(source) z_kappa_gt, _ = kappa_vae.encoder(log_10(kappa)) # Re-optimize weights of the model STEPS = args.re_optimize_steps learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=args.learning_rate, decay_rate=args.decay_rate, decay_steps=args.decay_steps, staircase=args.staircase) optim = tf.keras.optimizers.RMSprop( learning_rate=learning_rate_schedule) chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) source_mse = tf.TensorArray(DTYPE, size=STEPS) kappa_mse = tf.TensorArray(DTYPE, size=STEPS) sampled_chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) sampled_source_mse = tf.TensorArray(DTYPE, size=STEPS) sampled_kappa_mse = tf.TensorArray(DTYPE, size=STEPS) best = chi_squared source_best = source_pred[-1] kappa_best = kappa_pred[-1] source_mse_best = tf.reduce_mean((source_best - source)**2) kappa_mse_best = tf.reduce_mean((kappa_best - log_10(kappa))**2) # ===================== Optimization ============================== for current_step in tqdm(range(STEPS)): # ===================== VAE SAMPLING ============================== # L1 distance with ground truth in latent space -- this is changed by an user defined value when using real data # z_source_std = tf.abs(z_source - z_source_gt) # z_kappa_std = tf.abs(z_kappa - z_kappa_gt) z_source_std = args.source_vae_ball_size z_kappa_std = args.kappa_vae_ball_size # Sample latent code, then decode and forward z_s = tf.random.normal( shape=[args.sample_size, source_vae.latent_size], mean=z_source, stddev=z_source_std) z_k = tf.random.normal( shape=[args.sample_size, kappa_vae.latent_size], mean=z_kappa, stddev=z_kappa_std) sampled_source = tf.nn.relu(source_vae.decode(z_s)) sampled_source /= tf.reduce_max(sampled_source, axis=(1, 2, 3), keepdims=True) sampled_kappa = kappa_vae.decode(z_k) # output in log_10 space sampled_observation = phys.noisy_forward( sampled_source, 10**sampled_kappa, noise_rms, tf.tile(psf, [args.sample_size, 1, 1, 1])) with tf.GradientTape() as tape: tape.watch(unet.trainable_variables) s, k, chi_sq = rim.call( sampled_observation, noise_rms, tf.tile(psf, [args.sample_size, 1, 1, 1]), outer_tape=tape) _kappa_mse = tf.reduce_sum(wk(10**sampled_kappa) * (k - sampled_kappa)**2, axis=(2, 3, 4)) cost = tf.reduce_mean(_kappa_mse) cost += tf.reduce_mean((s - sampled_source)**2) cost += tf.reduce_sum(rim.unet.losses) # weight decay grads = tape.gradient(cost, unet.trainable_variables) optim.apply_gradients(zip(grads, unet.trainable_variables)) # Record performance on sampled dataset sampled_chi_squared_series = sampled_chi_squared_series.write( index=current_step, value=tf.squeeze(tf.reduce_mean(chi_sq[-1]))) sampled_source_mse = sampled_source_mse.write( index=current_step, value=tf.reduce_mean((s[-1] - sampled_source)**2)) sampled_kappa_mse = sampled_kappa_mse.write( index=current_step, value=tf.reduce_mean((k[-1] - sampled_kappa)**2)) # Record model prediction on data s, k, chi_sq = rim.call(observation, noise_rms, psf) chi_squared_series = chi_squared_series.write( index=current_step, value=tf.squeeze(chi_sq)) source_o = s[-1] kappa_o = k[-1] # oracle metrics, remove when using real data source_mse = source_mse.write(index=current_step, value=tf.reduce_mean( (source_o - source)**2)) kappa_mse = kappa_mse.write(index=current_step, value=tf.reduce_mean( (kappa_o - log_10(kappa))**2)) if abs(chi_sq[-1, 0] - 1) < abs(best[-1, 0] - 1): source_best = tf.nn.relu(source_o) kappa_best = 10**kappa_o best = chi_sq source_mse_best = tf.reduce_mean((source_best - source)**2) kappa_mse_best = tf.reduce_mean( (kappa_best - log_10(kappa))**2) source_o = source_best kappa_o = kappa_best y_pred = phys.forward(source_o, kappa_o, psf) chi_sq_series = tf.transpose(chi_squared_series.stack()) source_mse = source_mse.stack() kappa_mse = kappa_mse.stack() sampled_chi_squared_series = sampled_chi_squared_series.stack() sampled_source_mse = sampled_source_mse.stack() sampled_kappa_mse = sampled_kappa_mse.stack() # Latent code of optimized model predictions z_source_opt, _ = source_vae.encoder(tf.nn.relu(source_o)) z_kappa_opt, _ = kappa_vae.encoder(log_10(kappa_o)) # Compute Power spectrum of converged predictions _ps_observation = ps_observation.cross_correlation_coefficient( observation[..., 0], observation_pred[..., 0]) _ps_observation2 = ps_observation.cross_correlation_coefficient( observation[..., 0], y_pred[..., 0]) _ps_kappa = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0]) _ps_kappa2 = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_o[..., 0])) _ps_source = ps_source.cross_correlation_coefficient( source[..., 0], source_pred[-1][..., 0]) _ps_source2 = ps_source.cross_correlation_coefficient( source[..., 0], source_o[..., 0]) # save results hf["observation"][i] = observation.numpy().astype(np.float32) hf["psf"][i] = psf.numpy().astype(np.float32) hf["psf_fwhm"][i] = fwhm.numpy().astype(np.float32) hf["noise_rms"][i] = noise_rms.numpy().astype(np.float32) hf["source"][i] = source.numpy().astype(np.float32) hf["kappa"][i] = kappa.numpy().astype(np.float32) hf["observation_pred"][i] = observation_pred.numpy().astype( np.float32) hf["observation_pred_reoptimized"][i] = y_pred.numpy().astype( np.float32) hf["source_pred"][i] = tf.transpose( source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["source_pred_reoptimized"][i] = source_o.numpy().astype( np.float32) hf["kappa_pred"][i] = tf.transpose( kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["kappa_pred_reoptimized"][i] = kappa_o.numpy().astype( np.float32) hf["chi_squared"][i] = tf.squeeze(chi_squared).numpy().astype( np.float32) hf["chi_squared_reoptimized"][i] = tf.squeeze(best).numpy().astype( np.float32) hf["chi_squared_reoptimized_series"][i] = chi_sq_series.numpy( ).astype(np.float32) hf["sampled_chi_squared_reoptimized_series"][ i] = 2 * sampled_chi_squared_series.numpy().astype(np.float32) hf["source_optim_mse"][i] = source_mse_best.numpy().astype( np.float32) hf["source_optim_mse_series"][i] = source_mse.numpy().astype( np.float32) hf["sampled_source_optim_mse_series"][ i] = sampled_source_mse.numpy().astype(np.float32) hf["kappa_optim_mse"][i] = kappa_mse_best.numpy().astype( np.float32) hf["kappa_optim_mse_series"][i] = kappa_mse.numpy().astype( np.float32) hf["sampled_kappa_optim_mse_series"][i] = sampled_kappa_mse.numpy( ).astype(np.float32) hf["latent_source_gt_distance_init"][i] = tf.abs( z_source - z_source_gt).numpy().squeeze().astype(np.float32) hf["latent_kappa_gt_distance_init"][i] = tf.abs( z_kappa - z_kappa_gt).numpy().squeeze().astype(np.float32) hf["latent_source_gt_distance_end"][i] = tf.abs( z_source_opt - z_source_gt).numpy().squeeze().astype( np.float32) hf["latent_kappa_gt_distance_end"][i] = tf.abs( z_kappa_opt - z_kappa_gt).numpy().squeeze().astype(np.float32) hf["observation_coherence_spectrum"][i] = _ps_observation hf["observation_coherence_spectrum_reoptimized"][ i] = _ps_observation2 hf["source_coherence_spectrum"][i] = _ps_source hf["source_coherence_spectrum_reoptimized"][i] = _ps_source2 hf["kappa_coherence_spectrum"][i] = _ps_kappa hf["kappa_coherence_spectrum_reoptimized"][i] = _ps_kappa2 if i == 0: _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels // 2], bins=ps_observation.bins) f = (f[:-1] + f[1:]) / 2 hf["observation_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins) f = (f[:-1] + f[1:]) / 2 hf["source_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins) f = (f[:-1] + f[1:]) / 2 hf["kappa_frequencies"][:] = f hf["kappa_fov"][0] = phys.kappa_fov hf["source_fov"][0] = phys.src_fov
def main(args): if args.seed is not None: tf.random.set_seed(args.seed) np.random.seed(args.seed) if args.json_override is not None: if isinstance(args.json_override, list): files = args.json_override else: files = [ args.json_override, ] for file in files: with open(file, "r") as f: json_override = json.load(f) args_dict = vars(args) args_dict.update(json_override) files = [] for dataset in args.datasets: files.extend(glob.glob(os.path.join(dataset, "*.tfrecords"))) np.random.shuffle(files) # Read concurrently from multiple records files = tf.data.Dataset.from_tensor_slices(files) dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE) # Read off global parameters from first example in dataset for pixels in dataset.map(decode_shape): break vars(args).update({"pixels": int(pixels)}) dataset = dataset.map(decode_image).map(preprocess_image).batch( args.batch_size) if args.cache_file is not None: dataset = dataset.cache(args.cache_file).prefetch( tf.data.experimental.AUTOTUNE) else: dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) train_dataset = dataset.take( math.floor(args.train_split * args.total_items / args.batch_size)) # dont forget to divide by batch size! val_dataset = dataset.skip( math.floor(args.train_split * args.total_items / args.batch_size)) val_dataset = val_dataset.take( math.ceil((1 - args.train_split) * args.total_items / args.batch_size)) vae = VAE(pixels=pixels, layers=args.layers, conv_layers=args.conv_layers, filter_scaling=args.filter_scaling, filters=args.filters, kernel_size=args.kernel_size, kernel_reg_amp=args.kernel_reg_amp, bias_reg_amp=args.bias_reg_amp, activation=args.activation, dropout_rate=args.dropout_rate, batch_norm=args.batch_norm, latent_size=args.latent_size, output_activation=args.output_activation) learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=args.initial_learning_rate, decay_rate=args.decay_rate, decay_steps=args.decay_steps, staircase=args.staircase) beta_schedule = PolynomialSchedule(initial_value=args.beta_init, end_value=args.beta_end_value, power=args.beta_decay_power, decay_steps=args.beta_decay_steps, cyclical=args.beta_cyclical) skip_strength_schedule = PolynomialSchedule( initial_value=args.skip_strength_init, end_value=0., power=args.skip_strength_decay_power, decay_steps=args.skip_strength_decay_steps) l2_bottleneck_schedule = PolynomialSchedule( initial_value=args.l2_bottleneck_init, end_value=0., power=args.l2_bottleneck_decay_power, decay_steps=args.l2_bottleneck_decay_steps) optim = tf.keras.optimizers.deserialize({ "class_name": args.optimizer, 'config': { "learning_rate": learning_rate_schedule } }) # ==== Take care of where to write logs and stuff ================================================================= if args.model_id.lower() != "none": logname = args.model_id elif args.logname is not None: logname = args.logname else: logname = args.logname_prefixe + "_" + datetime.now().strftime( "%y%m%d%H%M%S") if args.logdir.lower() != "none": logdir = os.path.join(args.logdir, logname) if not os.path.isdir(logdir): os.mkdir(logdir) writer = tf.summary.create_file_writer(logdir) else: writer = nullwriter() # ===== Make sure directory and checkpoint manager are created to save model =================================== if args.model_dir.lower() != "none": checkpoints_dir = os.path.join(args.model_dir, logname) if not os.path.isdir(checkpoints_dir): os.mkdir(checkpoints_dir) with open(os.path.join(checkpoints_dir, "script_params.json"), "w") as f: json.dump(vars(args), f, indent=4) with open(os.path.join(checkpoints_dir, "model_hparams.json"), "w") as f: hparams_dict = {key: vars(args)[key] for key in VAE_HPARAMS} json.dump(hparams_dict, f, indent=4) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optim, net=vae) checkpoint_manager = tf.train.CheckpointManager( ckpt, checkpoints_dir, max_to_keep=args.max_to_keep) save_checkpoint = True # ======= Load model if model_id is provided =============================================================== if args.model_id.lower() != "none": if args.load_checkpoint == "lastest": checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint) elif args.load_checkpoint == "best": scores = np.loadtxt( os.path.join(checkpoints_dir, "score_sheet.txt")) _checkpoint = scores[np.argmin(scores[:, 1]), 0] checkpoint = checkpoint_manager.checkpoints[_checkpoint] checkpoint_manager.checkpoint.restore(checkpoint) else: checkpoint = checkpoint_manager.checkpoints[int( args.load_checkpoint)] checkpoint_manager.checkpoint.restore(checkpoint) else: save_checkpoint = False def train_step(x, step): with tf.GradientTape() as tape: tape.watch(vae.trainable_weights) reconstruction_loss, kl_loss, bottleneck_l2_loss = vae.cost_function_training( x, skip_strength_schedule(step), l2_bottleneck_schedule(step)) cost = tf.reduce_sum(reconstruction_loss + beta_schedule(step) * kl_loss + bottleneck_l2_loss) / args.batch_size gradients = tape.gradient(cost, vae.trainable_weights) if args.clipping: gradients = [tf.clip_by_value(grad, -10, 10) for grad in gradients] optim.apply_gradients(zip(gradients, vae.trainable_weights)) reconstruction_loss = tf.reduce_mean(reconstruction_loss) kl_loss = tf.reduce_mean(kl_loss) return cost, reconstruction_loss, kl_loss def test_step(x, step): reconstruction_loss, kl_loss, bottleneck_l2_loss = vae.cost_function_training( x, skip_strength_schedule(step), l2_bottleneck_schedule(step)) cost = tf.reduce_sum(reconstruction_loss + beta_schedule(step) * kl_loss + bottleneck_l2_loss) / args.batch_size reconstruction_loss = tf.reduce_mean(reconstruction_loss) kl_loss = tf.reduce_mean(kl_loss) return cost, reconstruction_loss, kl_loss # ====== Training loop ============================================================================================ epoch_loss = tf.metrics.Mean() epoch_reconstruction_loss = tf.metrics.Mean() epoch_kl_loss = tf.metrics.Mean() time_per_step = tf.metrics.Mean() val_loss = tf.metrics.Mean() val_reconstruction_loss = tf.metrics.Mean() val_kl_loss = tf.metrics.Mean() history = { # recorded at the end of an epoch only "train_cost": [], "val_cost": [], "learning_rate": [], "time_per_step": [], "train_reconstruction_loss": [], "val_reconstruction_loss": [], "train_kl_loss": [], "val_kl_loss": [], "step": [], "wall_time": [] } best_loss = np.inf patience = args.patience step = 0 global_start = time.time() estimated_time_for_epoch = 0 out_of_time = False lastest_checkpoint = 1 for epoch in range(args.epochs): if (time.time() - global_start ) > args.max_time * 3600 - estimated_time_for_epoch: break epoch_start = time.time() epoch_loss.reset_states() epoch_reconstruction_loss.reset_states() epoch_kl_loss.reset_states() time_per_step.reset_states() with writer.as_default(): for batch, x in enumerate(train_dataset): start = time.time() cost, reconstruction_loss, kl_loss = train_step(x, step=step) # ========== Summary and logs ================================================================================== _time = time.time() - start time_per_step.update_state([_time]) epoch_loss.update_state([cost]) epoch_reconstruction_loss.update_state([reconstruction_loss]) epoch_kl_loss.update_state([kl_loss]) step += 1 # last batch we make a summary of residuals for res_idx in range(min(args.n_residuals, args.batch_size)): y_true = x[res_idx, ...] y_pred = vae.call(y_true[None, ...])[0, ...] tf.summary.image(f"Residuals {res_idx}", plot_to_image(residual_plot(y_true, y_pred)), step=step) # ========== Validation set =================== val_loss.reset_states() val_reconstruction_loss.reset_states() val_kl_loss.reset_states() for x in val_dataset: cost, reconstruction_loss, kl_loss = test_step(x, step=step) val_loss.update_state([cost]) val_reconstruction_loss.update_state([reconstruction_loss]) val_kl_loss.update_state([kl_loss]) for res_idx in range(min(args.n_residuals, args.batch_size)): y_true = x[res_idx, ...] y_pred = vae.call(y_true[None, ...])[0, ...] tf.summary.image(f"Val Residuals {res_idx}", plot_to_image(residual_plot(y_true, y_pred)), step=step) val_cost = val_loss.result().numpy() train_cost = epoch_loss.result().numpy() train_reconstruction_cost = epoch_reconstruction_loss.result( ).numpy() val_reconstruction_cost = val_reconstruction_loss.result().numpy() train_kl_cost = epoch_kl_loss.result().numpy() val_kl_cost = val_kl_loss.result().numpy() tf.summary.scalar("KL", train_kl_cost, step=step) tf.summary.scalar("Val KL", val_kl_cost, step=step) tf.summary.scalar("Reconstruction loss", train_reconstruction_cost, step=step) tf.summary.scalar("Val reconstruction loss", val_reconstruction_cost, step=step) tf.summary.scalar("MSE", train_cost, step=step) tf.summary.scalar("Val MSE", val_cost, step=step) tf.summary.scalar("Learning Rate", optim.lr(step), step=step) tf.summary.scalar("beta", beta_schedule(step), step=step) tf.summary.scalar("l2 bottleneck", l2_bottleneck_schedule(step), step=step) print( f"epoch {epoch} | train loss {train_cost:.3e} | val loss {val_cost:.3e} " f"| learning rate {optim.lr(step).numpy():.2e} | time per step {time_per_step.result().numpy():.2e} s" ) history["train_cost"].append(train_cost) history["val_cost"].append(val_cost) history["train_reconstruction_loss"].append(train_reconstruction_cost) history["val_reconstruction_loss"].append(val_reconstruction_cost) history["train_kl_loss"].append(train_kl_cost) history["val_kl_loss"].append(val_kl_cost) history["learning_rate"].append(optim.lr(step).numpy()) history["time_per_step"].append(time_per_step.result().numpy()) history["step"].append(step) history["wall_time"].append(time.time() - global_start) cost = train_cost if args.track_train else val_cost if np.isnan(cost): print("Training broke the Universe") break if cost < (1 - args.tolerance) * best_loss: best_loss = cost patience = args.patience else: patience -= 1 if (time.time() - global_start) > args.max_time * 3600: out_of_time = True if save_checkpoint: checkpoint_manager.checkpoint.step.assign_add(1) # a bit of a hack if epoch % args.checkpoints == 0 or patience == 0 or epoch == args.epochs - 1 or out_of_time: with open(os.path.join(checkpoints_dir, "score_sheet.txt"), mode="a") as f: np.savetxt(f, np.array([[lastest_checkpoint, cost]])) lastest_checkpoint += 1 checkpoint_manager.save() print("Saved checkpoint for step {}: {}".format( int(checkpoint_manager.checkpoint.step), checkpoint_manager.latest_checkpoint)) if patience == 0: print("Reached patience") break if out_of_time: break if epoch > 0: # First epoch is always very slow and not a good estimate of an epoch time. estimated_time_for_epoch = time.time() - epoch_start print( f"Finished training after {(time.time() - global_start)/3600:.3f} hours." ) return history, best_loss
def __init__(self, observation, noise_rms, psf, phys: PhysicalModel, rim: RIM, source_vae: VAE, kappa_vae: VAE, n_samples=100, sigma_source=0.5, sigma_kappa=0.5): """ Make a copy of initial parameters \varphi^{(0)} and compute the Fisher diagonal F_{ii} """ wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum( tf.sqrt(k), axis=(1, 2, 3), keepdims=True)) # Baseline prediction from observation source_pred, kappa_pred, chi_squared = rim.predict( observation, noise_rms, psf) # Latent code of model predictions z_source, _ = source_vae.encoder(source_pred[-1]) z_kappa, _ = kappa_vae.encoder(log_10(kappa_pred[-1])) # Deepcopy of the initial parameters self.initial_params = [ deepcopy(w) for w in rim.unet.trainable_variables ] self.fisher_diagonal = [tf.zeros_like(w) for w in self.initial_params] for n in range(n_samples): # Sample latent code around the prediction mean z_s = tf.random.normal(shape=[1, source_vae.latent_size], mean=z_source, stddev=sigma_source) z_k = tf.random.normal(shape=[1, kappa_vae.latent_size], mean=z_kappa, stddev=sigma_kappa) # Decode sampled_source = tf.nn.relu(source_vae.decode(z_s)) sampled_source /= tf.reduce_max(sampled_source, axis=(1, 2, 3), keepdims=True) sampled_kappa = kappa_vae.decode(z_k) # output in log_10 space # Simulate observation sampled_observation = phys.noisy_forward(sampled_source, 10**sampled_kappa, noise_rms, psf) # Compute the gradient of the MSE with tf.GradientTape() as tape: tape.watch(rim.unet.trainable_variables) s, k, chi_squared = rim.call(sampled_observation, noise_rms, psf) # Remove the temperature from the loss when computing the Fisher: sum instead of mean, and weighted sum is renormalized by number of pixels _kappa_mse = phys.kappa_pixels**2 * tf.reduce_sum( wk(10**sampled_kappa) * (k - sampled_kappa)**2, axis=(2, 3, 4)) cost = tf.reduce_sum(_kappa_mse) cost += tf.reduce_sum((s - sampled_source)**2) grad = tape.gradient(cost, rim.unet.trainable_variables) # Square the derivative relative to initial parameters and add to total self.fisher_diagonal = [ F + g**2 / n_samples for F, g in zip(self.fisher_diagonal, grad) ]