def test_log_likelihood(): phys = PhysicalModel(pixels=64, src_pixels=64) kappa = tf.random.normal([1, 64, 64, 1]) source = tf.random.normal([1, 64, 64, 1]) im_lensed = phys.forward(source, kappa) assert im_lensed.shape == [1, 64, 64, 1] cost = phys.log_likelihood(source, kappa, im_lensed)
def test_lagrange_multiplier_for_lens_intensity(): phys = PhysicalModel(pixels=128) phys_a = AnalyticalPhysicalModel(pixels=128) kappa = phys_a.kappa_field(2.0, e=0.2) x = np.linspace(-1, 1, 128) * phys.src_fov / 2 xx, yy = np.meshgrid(x, x) rho = xx**2 + yy**2 source = tf.math.exp(-0.5 * rho / 0.5**2)[tf.newaxis, ..., tf.newaxis] source = tf.cast(source, tf.float32) y_true = phys.forward(source, kappa) y_pred = phys.forward(0.001 * source, kappa) # rescale it, say it has different units lam_lagrange = tf.reduce_sum(y_true * y_pred, axis=( 1, 2, 3)) / tf.reduce_sum(y_pred**2, axis=(1, 2, 3)) lam_tests = tf.squeeze( tf.cast(tf.linspace(lam_lagrange / 10, lam_lagrange * 10, 1000), tf.float32))[..., tf.newaxis, tf.newaxis, tf.newaxis] log_likelihood_best = 0.5 * tf.reduce_mean( (lam_lagrange * y_pred - y_true)**2 / phys.noise_rms**2, axis=(1, 2, 3)) log_likilhood_test = 0.5 * tf.reduce_mean( (lam_tests * y_pred - y_true)**2 / phys.noise_rms**2, axis=(1, 2, 3)) return log_likilhood_test, log_likelihood_best, tf.squeeze( lam_tests), lam_lagrange
def test_interpolated_kappa(): import tensorflow_addons as tfa phys = PhysicalModel(pixels=128, src_pixels=32, image_fov=7.68, kappa_fov=5) phys_a = AnalyticalPhysicalModel(pixels=128, image_fov=7.68) kappa = phys_a.kappa_field(r_ein=2., e=0.2) kappa += phys_a.kappa_field(r_ein=1., x0=2., y0=2.) true_lens = phys.lens_source_func(kappa, w=0.2) true_kappa = kappa # Test interpolation of alpha angles on a finer grid # phys = PhysicalModel(pixels=128, src_pixels=32, kappa_pixels=32) phys_a = AnalyticalPhysicalModel(pixels=32, image_fov=7.68) kappa = phys_a.kappa_field(r_ein=2., e=0.2) kappa += phys_a.kappa_field(r_ein=1., x0=2., y0=2.) # kappa2 = phys_a.kappa_field(r_ein=2., e=0.2) # kappa2 += phys_a.kappa_field(r_ein=1., x0=2., y0=2.) # # kappa = tf.concat([kappa, kappa2], axis=1) # Test interpolated kappa lens x = np.linspace(-1, 1, 128) * phys.kappa_fov / 2 x, y = np.meshgrid(x, x) x = tf.constant(x[np.newaxis, ..., np.newaxis], tf.float32) y = tf.constant(y[np.newaxis, ..., np.newaxis], tf.float32) dx = phys.kappa_fov / (32 - 1) xmin = -0.5 * phys.kappa_fov ymin = -0.5 * phys.kappa_fov i_coord = (x - xmin) / dx j_coord = (y - ymin) / dx wrap = tf.concat([i_coord, j_coord], axis=-1) # test_kappa1 = tfa.image.resampler(kappa, wrap) # bilinear interpolation of source on wrap grid # test_lens1 = phys.lens_source_func(test_kappa1, w=0.2) phys2 = PhysicalModel(pixels=128, kappa_pixels=32, method="fft", image_fov=7.68, kappa_fov=5) test_lens1 = phys2.lens_source_func(kappa, w=0.2) # Test interpolated alpha angles lens phys2 = PhysicalModel(pixels=32, src_pixels=32, image_fov=7.68, kappa_fov=5) alpha1, alpha2 = phys2.deflection_angle(kappa) alpha = tf.concat([alpha1, alpha2], axis=-1) alpha = tfa.image.resampler(alpha, wrap) test_lens2 = phys.lens_source_func_given_alpha(alpha, w=0.2) return true_lens, test_lens1, test_lens2
def test_lens_source_conv2(): pixels = 64 src_pixels = 32 phys = PhysicalModel(pixels=pixels, src_pixels=src_pixels, kappa_fov=16, image_fov=16) phys_analytic = AnalyticalPhysicalModel(pixels=pixels, image_fov=16) source = tf.random.normal([1, src_pixels, src_pixels, 1]) kappa = phys_analytic.kappa_field(7, 0.1, 0, 0, 0) lens = phys.lens_source(source, kappa) return lens
def test_alpha_method_fft(): pixels = 64 phys = PhysicalModel(pixels=pixels, method="fft") phys_analytic = AnalyticalPhysicalModel(pixels=pixels, image_fov=7) phys2 = PhysicalModel(pixels=pixels, method="conv2d") # test out noise kappa = tf.random.uniform(shape=[1, pixels, pixels, 1]) alphax, alphay = phys.deflection_angle(kappa) alphax2, alphay2 = phys2.deflection_angle(kappa) # assert np.allclose(alphax, alphax2, atol=1e-4) # assert np.allclose(alphay, alphay2, atol=1e-4) # test out an analytical profile kappa = phys_analytic.kappa_field(2, 0.4, 0, 0.1, 0.5) alphax, alphay = phys.deflection_angle(kappa) alphax2, alphay2 = phys2.deflection_angle(kappa) # # assert np.allclose(alphax, alphax2, atol=1e-4) # assert np.allclose(alphay, alphay2, atol=1e-4) im1 = phys_analytic.lens_source_func_given_alpha( tf.concat([alphax, alphay], axis=-1)) im2 = phys_analytic.lens_source_func_given_alpha( tf.concat([alphax2, alphay2], axis=-1)) return alphax, alphax2, im1, im2
def test_lens_func_given_alpha(): phys = PhysicalModel(pixels=128) phys_a = AnalyticalPhysicalModel(pixels=128) alpha = phys_a.analytical_deflection_angles(x0=0.5, y0=0.5, e=0.4, phi=0., r_ein=1.) lens_true = phys_a.lens_source_func(x0=0.5, y0=0.5, e=0.4, phi=0., r_ein=1., xs=0.5, ys=0.5) lens_pred = phys_a.lens_source_func_given_alpha(alpha, xs=0.5, ys=0.5) lens_pred2 = phys.lens_source_func_given_alpha(alpha, xs=0.5, ys=0.5) fig = raytracer_residual_plot(alpha[0], alpha[0], lens_true[0], lens_pred2[0])
def distributed_strategy(args): kappa_gen = NISGenerator( # only used to generate pixelated kappa fields kappa_fov=args.kappa_fov, src_fov=args.source_fov, pixels=args.kappa_pixels, z_source=args.z_source, z_lens=args.z_lens ) min_theta_e = 0.1 * args.image_fov if args.min_theta_e is None else args.min_theta_e max_theta_e = 0.45 * args.image_fov if args.max_theta_e is None else args.max_theta_e cosmos_files = glob.glob(os.path.join(args.cosmos_dir, "*.tfrecords")) cosmos = tf.data.TFRecordDataset(cosmos_files, compression_type=args.compression_type) cosmos = cosmos.map(decode_image).map(preprocess_image) if args.shuffle_cosmos: cosmos = cosmos.shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True) cosmos = cosmos.batch(args.batch_size) window = tukey(args.src_pixels, alpha=args.tukey_alpha) window = np.outer(window, window) phys = PhysicalModel( image_fov=args.image_fov, kappa_fov=args.kappa_fov, src_fov=args.source_fov, pixels=args.lens_pixels, kappa_pixels=args.kappa_pixels, src_pixels=args.src_pixels, method="conv2d" ) noise_a = (args.noise_rms_min - args.noise_rms_mean) / args.noise_rms_std noise_b = (args.noise_rms_max - args.noise_rms_mean) / args.noise_rms_std psf_a = (args.psf_fwhm_min - args.psf_fwhm_mean) / args.psf_fwhm_std psf_b = (args.psf_fwhm_max - args.psf_fwhm_mean) / args.psf_fwhm_std options = tf.io.TFRecordOptions(compression_type=args.compression_type) with tf.io.TFRecordWriter(os.path.join(args.output_dir, f"data_{THIS_WORKER}.tfrecords"), options) as writer: print(f"Started worker {THIS_WORKER} at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}") for i in range((THIS_WORKER - 1) * args.batch_size, args.len_dataset, N_WORKERS * args.batch_size): for galaxies in cosmos: break galaxies = window[np.newaxis, ..., np.newaxis] * galaxies noise_rms = truncnorm.rvs(noise_a, noise_b, loc=args.noise_rms_mean, scale=args.noise_rms_std, size=args.batch_size) fwhm = truncnorm.rvs(psf_a, psf_b, loc=args.psf_fwhm_mean, scale=args.psf_fwhm_std, size=args.batch_size) psf = phys.psf_models(fwhm, cutout_size=args.psf_cutout_size) batch_size = galaxies.shape[0] _r = tf.random.uniform(shape=[batch_size, 1, 1], minval=0, maxval=args.max_shift) _theta = tf.random.uniform(shape=[batch_size, 1, 1], minval=-np.pi, maxval=np.pi) x0 = _r * tf.math.cos(_theta) y0 = _r * tf.math.sin(_theta) ellipticity = tf.random.uniform(shape=[batch_size, 1, 1], minval=0., maxval=args.max_ellipticity) phi = tf.random.uniform(shape=[batch_size, 1, 1], minval=-np.pi, maxval=np.pi) einstein_radius = tf.random.uniform(shape=[batch_size, 1, 1], minval=min_theta_e, maxval=max_theta_e) kappa = kappa_gen.kappa_field(x0, y0, ellipticity, phi, einstein_radius) lensed_images = phys.noisy_forward(galaxies, kappa, noise_rms=noise_rms, psf=psf) records = encode_examples( kappa=kappa, galaxies=galaxies, lensed_images=lensed_images, z_source=args.z_source, z_lens=args.z_lens, image_fov=phys.image_fov, kappa_fov=phys.kappa_fov, source_fov=args.source_fov, noise_rms=noise_rms, psf=psf, fwhm=fwhm ) for record in records: writer.write(record) print(f"Finished work at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}")
def main(args): if args.seed is not None: tf.random.set_seed(args.seed) np.random.seed(args.seed) if args.json_override is not None: if isinstance(args.json_override, list): files = args.json_override else: files = [args.json_override,] for file in files: with open(file, "r") as f: json_override = json.load(f) args_dict = vars(args) args_dict.update(json_override) files = [] for dataset in args.datasets: files.extend(glob.glob(os.path.join(dataset, "*.tfrecords"))) np.random.shuffle(files) if args.val_datasets is not None: """ In this conditional, we assume total items might be a subset of the dataset size. Thus we want to reshuffle at each epoch to get a different realisation of the dataset. In case total_items == true dataset size, this means we only change ordering of items each epochs. Also, validation is not a split of the training data, but a saved dataset on disk. """ files = tf.data.Dataset.from_tensor_slices(files).shuffle(len(files), reshuffle_each_iteration=True) dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE) # Read off global parameters from first example in dataset for physical_params in dataset.map(decode_physical_model_info): break # preprocessing dataset = dataset.map(decode_train) if args.cache_file is not None: dataset = dataset.cache(args.cache_file) train_dataset = dataset.shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True).take(args.total_items).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE) val_files = [] for dataset in args.val_datasets: val_files.extend(glob.glob(os.path.join(dataset, "*.tfrecords"))) val_files = tf.data.Dataset.from_tensor_slices(val_files).shuffle(len(files), reshuffle_each_iteration=True) val_dataset = val_files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE) val_dataset = val_dataset.map(decode_train).shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True).\ take(math.ceil((1 - args.train_split) * args.total_items)).\ batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE) else: """ Here, we split the dataset, so we assume total_items is the true dataset size. Any extra items will be discarded. This is to make sure validation set is never seen by the model, so shuffling occurs after the split. """ files = tf.data.Dataset.from_tensor_slices(files) dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE) # Read off global parameters from first example in dataset for physical_params in dataset.map(decode_physical_model_info): break # preprocessing dataset = dataset.map(decode_train) if args.cache_file is not None: dataset = dataset.cache(args.cache_file) train_dataset = dataset.take(math.floor(args.train_split * args.total_items)).shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE) val_dataset = dataset.skip(math.floor(args.train_split * args.total_items)).take(math.ceil((1 - args.train_split) * args.total_items)).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE) train_dataset = STRATEGY.experimental_distribute_dataset(train_dataset) val_dataset = STRATEGY.experimental_distribute_dataset(val_dataset) with STRATEGY.scope(): # Replicate ops accross gpus if args.raytracer is not None: with open(os.path.join(args.raytracer, "ray_tracer_hparams.json"), "r") as f: raytracer_hparams = json.load(f) raytracer = RayTracer(**raytracer_hparams) # load last checkpoint in the checkpoint directory checkpoint = tf.train.Checkpoint(net=raytracer) manager = tf.train.CheckpointManager(checkpoint, directory=args.raytracer, max_to_keep=3) checkpoint.restore(manager.latest_checkpoint).expect_partial() else: raytracer = None phys = PhysicalModel( pixels=physical_params["pixels"].numpy(), kappa_pixels=physical_params["kappa pixels"].numpy(), src_pixels=physical_params["src pixels"].numpy(), image_fov=physical_params["image fov"].numpy(), kappa_fov=physical_params["kappa fov"].numpy(), src_fov=physical_params["source fov"].numpy(), method=args.forward_method, raytracer=raytracer, ) unet = Model( filters=args.filters, filter_scaling=args.filter_scaling, kernel_size=args.kernel_size, layers=args.layers, block_conv_layers=args.block_conv_layers, strides=args.strides, bottleneck_kernel_size=args.bottleneck_kernel_size, resampling_kernel_size=args.resampling_kernel_size, input_kernel_size=args.input_kernel_size, gru_kernel_size=args.gru_kernel_size, upsampling_interpolation=args.upsampling_interpolation, kernel_l2_amp=args.kernel_l2_amp, bias_l2_amp=args.bias_l2_amp, kernel_l1_amp=args.kernel_l1_amp, bias_l1_amp=args.bias_l1_amp, activation=args.activation, initializer=args.initializer, batch_norm=args.batch_norm, dropout_rate=args.dropout_rate, filter_cap=args.filter_cap ) rim = RIM( physical_model=phys, unet=unet, steps=args.steps, adam=args.adam, kappalog=args.kappalog, source_link=args.source_link, kappa_normalize=args.kappa_normalize, flux_lagrange_multiplier=args.flux_lagrange_multiplier ) learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=args.initial_learning_rate, decay_rate=args.decay_rate, decay_steps=args.decay_steps, staircase=args.staircase ) optim = tf.keras.optimizers.deserialize( { "class_name": args.optimizer, 'config': {"learning_rate": learning_rate_schedule} } ) # weights for time steps in the loss function if args.time_weights == "uniform": wt = tf.ones(shape=(args.steps), dtype=DTYPE) / args.steps elif args.time_weights == "linear": wt = 2 * (tf.range(args.steps, dtype=DTYPE) + 1) / args.steps / (args.steps + 1) elif args.time_weights == "quadratic": wt = 6 * (tf.range(args.steps, dtype=DTYPE) + 1)**2 / args.steps / (args.steps + 1) / (2 * args.steps + 1) else: raise ValueError("time_weights must be in ['uniform', 'linear', 'quadratic']") wt = wt[..., tf.newaxis] # [steps, batch] if args.kappa_residual_weights == "uniform": wk = tf.keras.layers.Lambda(lambda k: tf.ones_like(k, dtype=DTYPE) / tf.cast(tf.math.reduce_prod(k.shape[1:]), DTYPE)) elif args.kappa_residual_weights == "linear": wk = tf.keras.layers.Lambda(lambda k: k / tf.reduce_sum(k, axis=(1, 2, 3), keepdims=True)) elif args.kappa_residual_weights == "sqrt": wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum(tf.sqrt(k), axis=(1, 2, 3), keepdims=True)) elif args.kappa_residual_weights == "quadratic": wk = tf.keras.layers.Lambda(lambda k: tf.square(k) / tf.reduce_sum(tf.square(k), axis=(1, 2, 3), keepdims=True)) else: raise ValueError("kappa_residual_weights must be in ['uniform', 'linear', 'quadratic', 'sqrt']") if args.source_residual_weights == "uniform": ws = tf.keras.layers.Lambda(lambda s: tf.ones_like(s, dtype=DTYPE) / tf.cast(tf.math.reduce_prod(s.shape[1:]), DTYPE)) elif args.source_residual_weights == "linear": ws = tf.keras.layers.Lambda(lambda s: s / tf.reduce_sum(s, axis=(1, 2, 3), keepdims=True)) elif args.source_residual_weights == "quadratic": ws = tf.keras.layers.Lambda(lambda s: tf.square(s) / tf.reduce_sum(tf.square(s), axis=(1, 2, 3), keepdims=True)) elif args.source_residual_weights == "sqrt": ws = tf.keras.layers.Lambda(lambda s: tf.sqrt(s) / tf.reduce_sum(tf.sqrt(s), axis=(1, 2, 3), keepdims=True)) else: raise ValueError("kappa_residual_weights must be in ['uniform', 'linear', 'quadratic', 'sqrt']") # ==== Take care of where to write logs and stuff ================================================================= if args.model_id.lower() != "none": if args.logname is not None: logname = args.model_id + "_" + args.logname model_id = args.model_id else: logname = args.model_id + "_" + datetime.now().strftime("%y%m%d%H%M%S") model_id = args.model_id elif args.logname is not None: logname = args.logname model_id = logname else: logname = args.logname_prefixe + "_" + datetime.now().strftime("%y%m%d%H%M%S") model_id = logname if args.logdir.lower() != "none": logdir = os.path.join(args.logdir, logname) if not os.path.isdir(logdir): os.mkdir(logdir) writer = tf.summary.create_file_writer(logdir) else: writer = nullwriter() # ===== Make sure directory and checkpoint manager are created to save model =================================== if args.model_dir.lower() != "none": checkpoints_dir = os.path.join(args.model_dir, logname) old_checkpoints_dir = os.path.join(args.model_dir, model_id) # in case they differ we load model from a different directory if not os.path.isdir(checkpoints_dir): os.mkdir(checkpoints_dir) with open(os.path.join(checkpoints_dir, "script_params.json"), "w") as f: json.dump(vars(args), f, indent=4) with open(os.path.join(checkpoints_dir, "unet_hparams.json"), "w") as f: hparams_dict = {key: vars(args)[key] for key in UNET_MODEL_HPARAMS} json.dump(hparams_dict, f, indent=4) with open(os.path.join(checkpoints_dir, "rim_hparams.json"), "w") as f: hparams_dict = {key: vars(args)[key] for key in RIM_HPARAMS} json.dump(hparams_dict, f, indent=4) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optim, net=rim.unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, old_checkpoints_dir, max_to_keep=args.max_to_keep) save_checkpoint = True # ======= Load model if model_id is provided =============================================================== if args.model_id.lower() != "none": checkpoint_manager.checkpoint.restore(checkpoint_manager.latest_checkpoint) if old_checkpoints_dir != checkpoints_dir: # save progress in another directory. if args.reset_optimizer_states: optim = tf.keras.optimizers.deserialize( { "class_name": args.optimizer, 'config': {"learning_rate": learning_rate_schedule} } ) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optim, net=rim.unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, checkpoints_dir, max_to_keep=args.max_to_keep) else: save_checkpoint = False # ================================================================================================================= def train_step(X, source, kappa, noise_rms, psf): with tf.GradientTape() as tape: tape.watch(rim.unet.trainable_variables) if args.unroll_time_steps: source_series, kappa_series, chi_squared = rim.call_function(X, noise_rms, psf) else: source_series, kappa_series, chi_squared = rim.call(X, noise_rms, psf, outer_tape=tape) # mean over image residuals (in model prediction space) source_cost1 = tf.reduce_sum(ws(source) * tf.square(source_series - rim.source_inverse_link(source)), axis=(2, 3, 4)) kappa_cost1 = tf.reduce_sum(wk(kappa) * tf.square(kappa_series - rim.kappa_inverse_link(kappa)), axis=(2, 3, 4)) # weighted mean over time steps source_cost = tf.reduce_sum(wt * source_cost1, axis=0) kappa_cost = tf.reduce_sum(wt * kappa_cost1, axis=0) # final cost is mean over global batch size cost = tf.reduce_sum(kappa_cost + source_cost) / args.batch_size gradient = tape.gradient(cost, rim.unet.trainable_variables) gradient = [tf.clip_by_norm(grad, 5.) for grad in gradient] optim.apply_gradients(zip(gradient, rim.unet.trainable_variables)) # Update metrics with "converged" score chi_squared = tf.reduce_sum(chi_squared[-1]) / args.batch_size source_cost = tf.reduce_sum(source_cost1[-1]) / args.batch_size kappa_cost = tf.reduce_sum(kappa_cost1[-1]) / args.batch_size return cost, chi_squared, source_cost, kappa_cost @tf.function def distributed_train_step(X, source, kappa, noise_rms, psf): per_replica_losses, per_replica_chi_squared, per_replica_source_cost, per_replica_kappa_cost = STRATEGY.run(train_step, args=(X, source, kappa, noise_rms, psf)) # Replica losses are aggregated by summing them global_loss = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) global_chi_squared = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_chi_squared, axis=None) global_source_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_source_cost, axis=None) global_kappa_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_kappa_cost, axis=None) return global_loss, global_chi_squared, global_source_cost, global_kappa_cost def test_step(X, source, kappa, noise_rms, psf): source_series, kappa_series, chi_squared = rim.call(X, noise_rms, psf) # mean over image residuals (in model prediction space) source_cost1 = tf.reduce_sum(ws(source) * tf.square(source_series - rim.source_inverse_link(source)), axis=(2, 3, 4)) kappa_cost1 = tf.reduce_sum(wk(kappa) * tf.square(kappa_series - rim.kappa_inverse_link(kappa)), axis=(2, 3, 4)) # weighted mean over time steps source_cost = tf.reduce_sum(wt * source_cost1, axis=0) kappa_cost = tf.reduce_sum(wt * kappa_cost1, axis=0) # final cost is mean over global batch size cost = tf.reduce_sum(kappa_cost + source_cost) / args.batch_size # Update metrics with "converged" score chi_squared = tf.reduce_sum(chi_squared[-1]) / args.batch_size source_cost = tf.reduce_sum(source_cost1[-1]) / args.batch_size kappa_cost = tf.reduce_sum(kappa_cost1[-1]) / args.batch_size return cost, chi_squared, source_cost, kappa_cost @tf.function def distributed_test_step(X, source, kappa, noise_rms, psf): per_replica_losses, per_replica_chi_squared, per_replica_source_cost, per_replica_kappa_cost = STRATEGY.run(test_step, args=(X, source, kappa, noise_rms, psf)) # Replica losses are aggregated by summing them global_loss = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) global_chi_squared = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_chi_squared, axis=None) global_source_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_source_cost, axis=None) global_kappa_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_kappa_cost, axis=None) return global_loss, global_chi_squared, global_source_cost, global_kappa_cost # ====== Training loop ============================================================================================ epoch_loss = tf.metrics.Mean() time_per_step = tf.metrics.Mean() val_loss = tf.metrics.Mean() epoch_chi_squared = tf.metrics.Mean() epoch_source_loss = tf.metrics.Mean() epoch_kappa_loss = tf.metrics.Mean() val_chi_squared = tf.metrics.Mean() val_source_loss = tf.metrics.Mean() val_kappa_loss = tf.metrics.Mean() history = { # recorded at the end of an epoch only "train_cost": [], "train_chi_squared": [], "train_source_cost": [], "train_kappa_cost": [], "val_cost": [], "val_chi_squared": [], "val_source_cost": [], "val_kappa_cost": [], "learning_rate": [], "time_per_step": [], "step": [], "wall_time": [] } best_loss = np.inf patience = args.patience step = 0 global_start = time.time() estimated_time_for_epoch = 0 out_of_time = False lastest_checkpoint = 1 for epoch in range(args.epochs): if (time.time() - global_start) > args.max_time*3600 - estimated_time_for_epoch: break epoch_start = time.time() epoch_loss.reset_states() epoch_chi_squared.reset_states() epoch_source_loss.reset_states() epoch_kappa_loss.reset_states() time_per_step.reset_states() with writer.as_default(): for batch, (X, source, kappa, noise_rms, psf) in enumerate(train_dataset): start = time.time() cost, chi_squared, source_cost, kappa_cost = distributed_train_step(X, source, kappa, noise_rms, psf) # ========== Summary and logs ================================================================================== _time = time.time() - start time_per_step.update_state([_time]) epoch_loss.update_state([cost]) epoch_chi_squared.update_state([chi_squared]) epoch_source_loss.update_state([source_cost]) epoch_kappa_loss.update_state([kappa_cost]) step += 1 # last batch we make a summary of residuals if args.n_residuals > 0: source_pred, kappa_pred, chi_squared = rim.predict(X, noise_rms, psf) lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) for res_idx in range(min(args.n_residuals, args.batch_size)): try: tf.summary.image(f"Residuals {res_idx}", plot_to_image( residual_plot( X[res_idx], source[res_idx], kappa[res_idx], lens_pred[res_idx], source_pred[-1][res_idx], kappa_pred[-1][res_idx], chi_squared[-1][res_idx] )), step=step) except ValueError: continue # ========== Validation set =================== val_loss.reset_states() val_chi_squared.reset_states() val_source_loss.reset_states() val_kappa_loss.reset_states() for X, source, kappa, noise_rms, psf in val_dataset: cost, chi_squared, source_cost, kappa_cost = distributed_test_step(X, source, kappa, noise_rms, psf) val_loss.update_state([cost]) val_chi_squared.update_state([chi_squared]) val_source_loss.update_state([source_cost]) val_kappa_loss.update_state([kappa_cost]) if args.n_residuals > 0 and math.ceil((1 - args.train_split) * args.total_items) > 0: # validation set not empty set not empty source_pred, kappa_pred, chi_squared = rim.predict(X, noise_rms, psf) lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) for res_idx in range(min(args.n_residuals, args.batch_size, math.ceil((1 - args.train_split) * args.total_items))): try: tf.summary.image(f"Val Residuals {res_idx}", plot_to_image( residual_plot( X[res_idx], # rescale intensity like it is done in the likelihood source[res_idx], kappa[res_idx], lens_pred[res_idx], source_pred[-1][res_idx], kappa_pred[-1][res_idx], chi_squared[-1][res_idx] )), step=step) except ValueError: continue val_cost = val_loss.result().numpy() train_cost = epoch_loss.result().numpy() val_chi_sq = val_chi_squared.result().numpy() train_chi_sq = epoch_chi_squared.result().numpy() val_kappa_cost = val_kappa_loss.result().numpy() train_kappa_cost = epoch_kappa_loss.result().numpy() val_source_cost = val_source_loss.result().numpy() train_source_cost = epoch_source_loss.result().numpy() tf.summary.scalar("Time per step", time_per_step.result(), step=step) tf.summary.scalar("Chi Squared", train_chi_sq, step=step) tf.summary.scalar("Kappa cost", train_kappa_cost, step=step) tf.summary.scalar("Val Kappa cost", val_kappa_cost, step=step) tf.summary.scalar("Source cost", train_source_cost, step=step) tf.summary.scalar("Val Source cost", val_source_cost, step=step) tf.summary.scalar("MSE", train_cost, step=step) tf.summary.scalar("Val MSE", val_cost, step=step) tf.summary.scalar("Learning Rate", optim.lr(step), step=step) tf.summary.scalar("Val Chi Squared", val_chi_sq, step=step) print(f"epoch {epoch} | train loss {train_cost:.3e} | val loss {val_cost:.3e} " f"| lr {optim.lr(step).numpy():.2e} | time per step {time_per_step.result().numpy():.2e} s" f"| kappa cost {train_kappa_cost:.2e} | source cost {train_source_cost:.2e} | chi sq {train_chi_sq:.2e}") history["train_cost"].append(train_cost) history["val_cost"].append(val_cost) history["learning_rate"].append(optim.lr(step).numpy()) history["train_chi_squared"].append(train_chi_sq) history["val_chi_squared"].append(val_chi_sq) history["time_per_step"].append(time_per_step.result().numpy()) history["train_kappa_cost"].append(train_kappa_cost) history["train_source_cost"].append(train_source_cost) history["val_kappa_cost"].append(val_kappa_cost) history["val_source_cost"].append(val_source_cost) history["step"].append(step) history["wall_time"].append(time.time() - global_start) cost = train_cost if args.track_train else val_cost if np.isnan(cost): print("Training broke the Universe") break if cost < (1 - args.tolerance) * best_loss: best_loss = cost patience = args.patience else: patience -= 1 if (time.time() - global_start) > args.max_time * 3600: out_of_time = True if save_checkpoint: checkpoint_manager.checkpoint.step.assign_add(1) # a bit of a hack if epoch % args.checkpoints == 0 or patience == 0 or epoch == args.epochs - 1 or out_of_time: with open(os.path.join(checkpoints_dir, "score_sheet.txt"), mode="a") as f: np.savetxt(f, np.array([[lastest_checkpoint, cost]])) lastest_checkpoint += 1 checkpoint_manager.save() print("Saved checkpoint for step {}: {}".format(int(checkpoint_manager.checkpoint.step), checkpoint_manager.latest_checkpoint)) if patience == 0: print("Reached patience") break if out_of_time: break if epoch > 0: # First epoch is always very slow and not a good estimate of an epoch time. estimated_time_for_epoch = time.time() - epoch_start if optim.lr(step).numpy() < 1e-8: print("Reached learning rate limit") break print(f"Finished training after {(time.time() - global_start)/3600:.3f} hours.") return history, best_loss
def main(args): files = glob.glob(os.path.join(args.dataset, "*.tfrecords")) files = tf.data.Dataset.from_tensor_slices(files) dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type), block_length=1, num_parallel_calls=tf.data.AUTOTUNE) for physical_params in dataset.map(decode_physical_model_info): break dataset = dataset.map(decode_train) # files = glob.glob(os.path.join(args.source_dataset, "*.tfrecords")) # files = tf.data.Dataset.from_tensor_slices(files) # source_dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type), # block_length=1, num_parallel_calls=tf.data.AUTOTUNE) # source_dataset = source_dataset.map(decode_image).map(preprocess_image).shuffle(10000).batch(args.sample_size) with open(os.path.join(args.kappa_vae, "model_hparams.json"), "r") as f: kappa_vae_hparams = json.load(f) kappa_vae = VAE(**kappa_vae_hparams) ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae) checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, args.kappa_vae, 1) checkpoint_manager1.checkpoint.restore( checkpoint_manager1.latest_checkpoint).expect_partial() with open(os.path.join(args.source_vae, "model_hparams.json"), "r") as f: source_vae_hparams = json.load(f) source_vae = VAE(**source_vae_hparams) ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae) checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, args.source_vae, 1) checkpoint_manager2.checkpoint.restore( checkpoint_manager2.latest_checkpoint).expect_partial() phys = PhysicalModel(pixels=physical_params["pixels"].numpy(), kappa_pixels=physical_params["kappa pixels"].numpy(), src_pixels=physical_params["src pixels"].numpy(), image_fov=physical_params["image fov"].numpy(), kappa_fov=physical_params["kappa fov"].numpy(), src_fov=physical_params["source fov"].numpy(), method="fft") # simulate observations kappa = 10**kappa_vae.sample(args.sample_size) source = preprocess_image(source_vae.sample(args.sample_size)) # for source in source_dataset: # break fwhm = tf.random.normal(shape=[args.sample_size], mean=1.5 * phys.image_fov / phys.pixels, stddev=0.5 * phys.image_fov / phys.pixels) # noise_rms = tf.random.normal(shape=[args.sample_size], mean=args.noise_mean, stddev=args.noise_std) psf = phys.psf_models(fwhm, cutout_size=20) y_vae = phys.forward(source, kappa, psf) with h5py.File( os.path.join(os.getenv("CENSAI_PATH"), "results", args.output_name + ".h5"), 'w') as hf: # rank these observations against the dataset with L2 norm for i in tqdm(range(args.sample_size)): distances = [] for y_d, _, _, _, _ in dataset: distances.append( tf.sqrt(tf.reduce_sum( (y_d - y_vae[i][None, ...])**2)).numpy().astype( np.float32)) k_indices = np.argsort(distances)[:args.k] # save results g = hf.create_group(f"sample_{i:02d}") g.create_dataset(name="matched_source", shape=[args.k, phys.src_pixels, phys.src_pixels], dtype=np.float32) g.create_dataset( name="matched_kappa", shape=[args.k, phys.kappa_pixels, phys.kappa_pixels], dtype=np.float32) g.create_dataset(name="matched_obs", shape=[args.k, phys.pixels, phys.pixels], dtype=np.float32) g.create_dataset(name="matched_psf", shape=[args.k, 20, 20], dtype=np.float32) g.create_dataset(name="matched_noise_rms", shape=[args.k], dtype=np.float32) g.create_dataset(name="obs_L2_distance", shape=[args.k], dtype=np.float32) g["vae_source"] = source[i, ..., 0].numpy().astype(np.float32) g["vae_kappa"] = kappa[i, ..., 0].numpy().astype(np.float32) g["vae_obs"] = y_vae[i, ..., 0].numpy().astype(np.float32) g["vae_psf"] = psf[i, ..., 0].numpy().astype(np.float32) for rank, j in enumerate(k_indices): # fetch back the matched observation for y_d, source_d, kappa_d, noise_rms_d, psf_d in dataset.skip( j): break # g["vae_noise_rms"] = noise_rms[i].numpy().astype(np.float32) g["matched_source"][rank] = source_d[..., 0].numpy().astype( np.float32) g["matched_kappa"][rank] = kappa_d[..., 0].numpy().astype( np.float32) g["matched_obs"][rank] = y_d[..., 0].numpy().astype(np.float32) g["matched_noise_rms"][rank] = noise_rms_d.numpy().astype( np.float32) g["matched_psf"][rank] = psf_d[..., 0].numpy().astype(np.float32) g["obs_L2_distance"][rank] = distances[j]
def test_deflection_angle_conv2(): phys = PhysicalModel(pixels=64, src_pixels=64) kappa = tf.random.normal([1, 64, 64, 1]) phys.deflection_angle(kappa)
def distributed_strategy(args): kappa_datasets = [] for path in args.kappa_datasets: files = glob.glob(os.path.join(path, "*.tfrecords")) files = tf.data.Dataset.from_tensor_slices(files).shuffle( len(files), reshuffle_each_iteration=True) dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE) kappa_datasets.append( dataset.shuffle(args.buffer_size, reshuffle_each_iteration=True)) kappa_dataset = tf.data.experimental.sample_from_datasets( kappa_datasets, weights=args.kappa_datasets_weights) # Read off global parameters from first example in dataset for example in kappa_dataset.map(decode_kappa_info): kappa_fov = example["kappa fov"].numpy() kappa_pixels = example["kappa pixels"].numpy() break kappa_dataset = kappa_dataset.map(decode_kappa).batch(args.batch_size) cosmos_datasets = [] for path in args.cosmos_datasets: files = glob.glob(os.path.join(path, "*.tfrecords")) files = tf.data.Dataset.from_tensor_slices(files).shuffle( len(files), reshuffle_each_iteration=True) dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE) cosmos_datasets.append( dataset.shuffle(args.buffer_size, reshuffle_each_iteration=True)) cosmos_dataset = tf.data.experimental.sample_from_datasets( cosmos_datasets, weights=args.cosmos_datasets_weights) # Read off global parameters from first example in dataset for src_pixels in cosmos_dataset.map(decode_cosmos_info): src_pixels = src_pixels.numpy() break cosmos_dataset = cosmos_dataset.map(decode_cosmos).map( preprocess_cosmos).batch(args.batch_size) window = tukey(src_pixels, alpha=args.tukey_alpha) window = np.outer(window, window)[np.newaxis, ..., np.newaxis] window = tf.constant(window, dtype=DTYPE) phys = PhysicalModel(image_fov=kappa_fov, src_fov=args.source_fov, pixels=args.lens_pixels, kappa_pixels=kappa_pixels, src_pixels=src_pixels, kappa_fov=kappa_fov, method="conv2d") noise_a = (args.noise_rms_min - args.noise_rms_mean) / args.noise_rms_std noise_b = (args.noise_rms_max - args.noise_rms_mean) / args.noise_rms_std psf_a = (args.psf_fwhm_min - args.psf_fwhm_mean) / args.psf_fwhm_std psf_b = (args.psf_fwhm_max - args.psf_fwhm_mean) / args.psf_fwhm_std options = tf.io.TFRecordOptions(compression_type=args.compression_type) with tf.io.TFRecordWriter( os.path.join(args.output_dir, f"data_{THIS_WORKER}.tfrecords"), options) as writer: print( f"Started worker {THIS_WORKER} at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}" ) for i in range((THIS_WORKER - 1) * args.batch_size, args.len_dataset, N_WORKERS * args.batch_size): for galaxies in cosmos_dataset: # select a random batch from our dataset that is reshuffled each iterations break for kappa in kappa_dataset: break galaxies = window * galaxies noise_rms = truncnorm.rvs(noise_a, noise_b, loc=args.noise_rms_mean, scale=args.noise_rms_std, size=args.batch_size) fwhm = truncnorm.rvs(psf_a, psf_b, loc=args.psf_fwhm_mean, scale=args.psf_fwhm_std, size=args.batch_size) psf = phys.psf_models(fwhm, cutout_size=args.psf_cutout_size) lensed_images = phys.noisy_forward(galaxies, kappa, noise_rms=noise_rms, psf=psf) records = encode_examples(kappa=kappa, galaxies=galaxies, lensed_images=lensed_images, z_source=args.z_source, z_lens=args.z_lens, image_fov=phys.image_fov, kappa_fov=phys.kappa_fov, source_fov=args.source_fov, noise_rms=noise_rms, psf=psf, fwhm=fwhm) for record in records: writer.write(record) print(f"Finished work at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}")
def test_analytic2(): phys = AnalyticalPhysicalModelv2(pixels=64) # try to oversample kappa map for comparison # phys2 = AnalyticalPhysicalModelv2(pixels=256) # phys_ref = PhysicalModel(pixels=256, src_pixels=64, method="fft", kappa_fov=7.68) phys_ref = PhysicalModel(pixels=64, method="fft") # make some dummy coordinates x = np.linspace(-1, 1, 64) * 3.0 / 2 xx, yy = np.meshgrid(x, x) # lens params r_ein = 2 x0 = 0.1 y0 = 0.1 q = 0.98 phi = 0. #np.pi/3 gamma = 0. phi_gamma = 0. # source params xs = 0.1 ys = 0.1 qs = 0.9 phi_s = 0. #np.pi/4 r_eff = 0.4 n = 1. source = phys.sersic_source(xx, yy, xs, ys, qs, phi_s, n, r_eff)[None, ..., None] # kappa = phys2.kappa_field(r_ein, q, phi, x0, y0) kappa = phys.kappa_field(r_ein, q, phi, x0, y0) lens_a = phys.lens_source_sersic_func(r_ein, q, phi, x0, y0, gamma, phi_gamma, xs, ys, qs, phi_s, n, r_eff) lens_b = phys_ref.lens_source(source, kappa) # lens_b = tf.image.resize(lens_b, [64, 64]) lens_c = phys.lens_source(source, r_ein, q, phi, x0, y0, gamma, phi_gamma) # alpha1t, alpha2t = tf.split(phys.analytical_deflection_angles(r_ein, q, phi, x0, y0), 2, axis=-1) alpha1t, alpha2t = tf.split(phys.approximate_deflection_angles( r_ein, q, phi, x0, y0), 2, axis=-1) alpha1, alpha2 = phys_ref.deflection_angle(kappa) # alpha1 = tf.image.resize(alpha1, [64, 64]) # alpha2 = tf.image.resize(alpha2, [64, 64]) beta1 = phys.theta1 - alpha1 beta2 = phys.theta2 - alpha2 lens_d = phys.sersic_source(beta1, beta2, xs, ys, q, phi, n, r_eff) # plt.imshow(source[0, ..., 0], cmap="twilight", origin="lower") # plt.colorbar() # plt.show() # plt.imshow(np.log10(kappa[0, ..., 0]), cmap="twilight", origin="lower") # plt.colorbar() # plt.show() plt.imshow(lens_a[0, ..., 0], cmap="twilight", origin="lower") plt.colorbar() plt.show() plt.imshow(lens_b[0, ..., 0], cmap="twilight", origin="lower") plt.colorbar() plt.show() # plt.imshow(lens_c[0, ..., 0], cmap="twilight", origin="lower") # plt.colorbar() # plt.show() # plt.imshow(lens_d[0, ..., 0], cmap="twilight", origin="lower") # plt.colorbar() # plt.show() plt.imshow(((lens_a - lens_b))[0, ..., 0], cmap="seismic", vmin=-0.1, vmax=0.1, origin="lower") plt.colorbar() plt.show() # plt.imshow(((lens_a - lens_c))[0, ..., 0], cmap="seismic", vmin=-0.1, vmax=0.1, origin="lower") # plt.colorbar() # plt.show() # plt.imshow(((lens_a - lens_d))[0, ..., 0], cmap="seismic", vmin=-0.1, vmax=0.1, origin="lower") # plt.colorbar() # plt.show() plt.imshow(((alpha1t - alpha1))[0, ..., 0], cmap="seismic", vmin=-1, vmax=1, origin="lower") plt.colorbar() plt.show() pass
def distributed_strategy(args): psf_pixels = 20 pixels = 128 model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model) ps_observation = PowerSpectrum(bins=args.observation_coherence_bins, pixels=pixels) ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=pixels) ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=pixels) phys = PhysicalModel( pixels=pixels, kappa_pixels=pixels, src_pixels=pixels, image_fov=7.68, kappa_fov=7.68, src_fov=3., method="fft", ) with open(os.path.join(model, "unet_hparams.json")) as f: unet_params = json.load(f) unet_params["kernel_l2_amp"] = args.l2_amp unet = Model(**unet_params) ckpt = tf.train.Checkpoint(net=unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1) checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial() with open(os.path.join(model, "rim_hparams.json")) as f: rim_params = json.load(f) rim_params["source_link"] = "relu" rim = RIM(phys, unet, **rim_params) kvae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.kappa_vae) with open(os.path.join(kvae_path, "model_hparams.json"), "r") as f: kappa_vae_hparams = json.load(f) kappa_vae = VAE(**kappa_vae_hparams) ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae) checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, kvae_path, 1) checkpoint_manager1.checkpoint.restore( checkpoint_manager1.latest_checkpoint).expect_partial() svae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.source_vae) with open(os.path.join(svae_path, "model_hparams.json"), "r") as f: source_vae_hparams = json.load(f) source_vae = VAE(**source_vae_hparams) ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae) checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, svae_path, 1) checkpoint_manager2.checkpoint.restore( checkpoint_manager2.latest_checkpoint).expect_partial() model_name = os.path.split(model)[-1] wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum( tf.sqrt(k), axis=(1, 2, 3), keepdims=True)) with h5py.File( os.path.join( os.getenv("CENSAI_PATH"), "results", args.experiment_name + "_" + model_name + f"_{THIS_WORKER:02d}.h5"), 'w') as hf: data_len = args.size // N_WORKERS hf.create_dataset(name="observation", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="psf", shape=[data_len, psf_pixels, psf_pixels, 1], dtype=np.float32) hf.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32) hf.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32) hf.create_dataset( name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred_reoptimized", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred_reoptimized", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset(name="kappa_pred", shape=[ data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1 ], dtype=np.float32) hf.create_dataset( name="kappa_pred_reoptimized", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized", shape=[data_len, rim.steps], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized_series", shape=[data_len, rim.steps, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="sampled_chi_squared_reoptimized_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="source_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="sampled_source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="sampled_kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="latent_kappa_gt_distance_init", shape=[data_len, kappa_vae.latent_size], dtype=np.float32) hf.create_dataset(name="latent_source_gt_distance_init", shape=[data_len, source_vae.latent_size], dtype=np.float32) hf.create_dataset(name="latent_kappa_gt_distance_end", shape=[data_len, kappa_vae.latent_size], dtype=np.float32) hf.create_dataset(name="latent_source_gt_distance_end", shape=[data_len, source_vae.latent_size], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum_reoptimized", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum_reoptimized", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum_reoptimized", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_frequencies", shape=[args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_frequencies", shape=[args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_frequencies", shape=[args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="source_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="observation_fov", shape=[1], dtype=np.float32) for i in range(data_len): checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial( ) # reset model weights # Produce an observation kappa = 10**kappa_vae.sample(1) source = tf.nn.relu(source_vae.sample(1)) source /= tf.reduce_max(source, axis=(1, 2, 3), keepdims=True) noise_rms = 10**tf.random.uniform(shape=[1], minval=-2.5, maxval=-1) fwhm = tf.random.uniform(shape=[1], minval=0.06, maxval=0.3) psf = phys.psf_models(fwhm, cutout_size=psf_pixels) observation = phys.noisy_forward(source, kappa, noise_rms, psf) # RIM predictions for kappa and source source_pred, kappa_pred, chi_squared = rim.predict( observation, noise_rms, psf) observation_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) source_o = source_pred[-1] kappa_o = kappa_pred[-1] # Latent code of model predictions z_source, _ = source_vae.encoder(source_o) z_kappa, _ = kappa_vae.encoder(log_10(kappa_o)) # Ground truth latent code for oracle metrics z_source_gt, _ = source_vae.encoder(source) z_kappa_gt, _ = kappa_vae.encoder(log_10(kappa)) # Re-optimize weights of the model STEPS = args.re_optimize_steps learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=args.learning_rate, decay_rate=args.decay_rate, decay_steps=args.decay_steps, staircase=args.staircase) optim = tf.keras.optimizers.RMSprop( learning_rate=learning_rate_schedule) chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) source_mse = tf.TensorArray(DTYPE, size=STEPS) kappa_mse = tf.TensorArray(DTYPE, size=STEPS) sampled_chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) sampled_source_mse = tf.TensorArray(DTYPE, size=STEPS) sampled_kappa_mse = tf.TensorArray(DTYPE, size=STEPS) best = chi_squared source_best = source_pred[-1] kappa_best = kappa_pred[-1] source_mse_best = tf.reduce_mean((source_best - source)**2) kappa_mse_best = tf.reduce_mean((kappa_best - log_10(kappa))**2) # ===================== Optimization ============================== for current_step in tqdm(range(STEPS)): # ===================== VAE SAMPLING ============================== # L1 distance with ground truth in latent space -- this is changed by an user defined value when using real data # z_source_std = tf.abs(z_source - z_source_gt) # z_kappa_std = tf.abs(z_kappa - z_kappa_gt) z_source_std = args.source_vae_ball_size z_kappa_std = args.kappa_vae_ball_size # Sample latent code, then decode and forward z_s = tf.random.normal( shape=[args.sample_size, source_vae.latent_size], mean=z_source, stddev=z_source_std) z_k = tf.random.normal( shape=[args.sample_size, kappa_vae.latent_size], mean=z_kappa, stddev=z_kappa_std) sampled_source = tf.nn.relu(source_vae.decode(z_s)) sampled_source /= tf.reduce_max(sampled_source, axis=(1, 2, 3), keepdims=True) sampled_kappa = kappa_vae.decode(z_k) # output in log_10 space sampled_observation = phys.noisy_forward( sampled_source, 10**sampled_kappa, noise_rms, tf.tile(psf, [args.sample_size, 1, 1, 1])) with tf.GradientTape() as tape: tape.watch(unet.trainable_variables) s, k, chi_sq = rim.call( sampled_observation, noise_rms, tf.tile(psf, [args.sample_size, 1, 1, 1]), outer_tape=tape) _kappa_mse = tf.reduce_sum(wk(10**sampled_kappa) * (k - sampled_kappa)**2, axis=(2, 3, 4)) cost = tf.reduce_mean(_kappa_mse) cost += tf.reduce_mean((s - sampled_source)**2) cost += tf.reduce_sum(rim.unet.losses) # weight decay grads = tape.gradient(cost, unet.trainable_variables) optim.apply_gradients(zip(grads, unet.trainable_variables)) # Record performance on sampled dataset sampled_chi_squared_series = sampled_chi_squared_series.write( index=current_step, value=tf.squeeze(tf.reduce_mean(chi_sq[-1]))) sampled_source_mse = sampled_source_mse.write( index=current_step, value=tf.reduce_mean((s[-1] - sampled_source)**2)) sampled_kappa_mse = sampled_kappa_mse.write( index=current_step, value=tf.reduce_mean((k[-1] - sampled_kappa)**2)) # Record model prediction on data s, k, chi_sq = rim.call(observation, noise_rms, psf) chi_squared_series = chi_squared_series.write( index=current_step, value=tf.squeeze(chi_sq)) source_o = s[-1] kappa_o = k[-1] # oracle metrics, remove when using real data source_mse = source_mse.write(index=current_step, value=tf.reduce_mean( (source_o - source)**2)) kappa_mse = kappa_mse.write(index=current_step, value=tf.reduce_mean( (kappa_o - log_10(kappa))**2)) if abs(chi_sq[-1, 0] - 1) < abs(best[-1, 0] - 1): source_best = tf.nn.relu(source_o) kappa_best = 10**kappa_o best = chi_sq source_mse_best = tf.reduce_mean((source_best - source)**2) kappa_mse_best = tf.reduce_mean( (kappa_best - log_10(kappa))**2) source_o = source_best kappa_o = kappa_best y_pred = phys.forward(source_o, kappa_o, psf) chi_sq_series = tf.transpose(chi_squared_series.stack()) source_mse = source_mse.stack() kappa_mse = kappa_mse.stack() sampled_chi_squared_series = sampled_chi_squared_series.stack() sampled_source_mse = sampled_source_mse.stack() sampled_kappa_mse = sampled_kappa_mse.stack() # Latent code of optimized model predictions z_source_opt, _ = source_vae.encoder(tf.nn.relu(source_o)) z_kappa_opt, _ = kappa_vae.encoder(log_10(kappa_o)) # Compute Power spectrum of converged predictions _ps_observation = ps_observation.cross_correlation_coefficient( observation[..., 0], observation_pred[..., 0]) _ps_observation2 = ps_observation.cross_correlation_coefficient( observation[..., 0], y_pred[..., 0]) _ps_kappa = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0]) _ps_kappa2 = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_o[..., 0])) _ps_source = ps_source.cross_correlation_coefficient( source[..., 0], source_pred[-1][..., 0]) _ps_source2 = ps_source.cross_correlation_coefficient( source[..., 0], source_o[..., 0]) # save results hf["observation"][i] = observation.numpy().astype(np.float32) hf["psf"][i] = psf.numpy().astype(np.float32) hf["psf_fwhm"][i] = fwhm.numpy().astype(np.float32) hf["noise_rms"][i] = noise_rms.numpy().astype(np.float32) hf["source"][i] = source.numpy().astype(np.float32) hf["kappa"][i] = kappa.numpy().astype(np.float32) hf["observation_pred"][i] = observation_pred.numpy().astype( np.float32) hf["observation_pred_reoptimized"][i] = y_pred.numpy().astype( np.float32) hf["source_pred"][i] = tf.transpose( source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["source_pred_reoptimized"][i] = source_o.numpy().astype( np.float32) hf["kappa_pred"][i] = tf.transpose( kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["kappa_pred_reoptimized"][i] = kappa_o.numpy().astype( np.float32) hf["chi_squared"][i] = tf.squeeze(chi_squared).numpy().astype( np.float32) hf["chi_squared_reoptimized"][i] = tf.squeeze(best).numpy().astype( np.float32) hf["chi_squared_reoptimized_series"][i] = chi_sq_series.numpy( ).astype(np.float32) hf["sampled_chi_squared_reoptimized_series"][ i] = 2 * sampled_chi_squared_series.numpy().astype(np.float32) hf["source_optim_mse"][i] = source_mse_best.numpy().astype( np.float32) hf["source_optim_mse_series"][i] = source_mse.numpy().astype( np.float32) hf["sampled_source_optim_mse_series"][ i] = sampled_source_mse.numpy().astype(np.float32) hf["kappa_optim_mse"][i] = kappa_mse_best.numpy().astype( np.float32) hf["kappa_optim_mse_series"][i] = kappa_mse.numpy().astype( np.float32) hf["sampled_kappa_optim_mse_series"][i] = sampled_kappa_mse.numpy( ).astype(np.float32) hf["latent_source_gt_distance_init"][i] = tf.abs( z_source - z_source_gt).numpy().squeeze().astype(np.float32) hf["latent_kappa_gt_distance_init"][i] = tf.abs( z_kappa - z_kappa_gt).numpy().squeeze().astype(np.float32) hf["latent_source_gt_distance_end"][i] = tf.abs( z_source_opt - z_source_gt).numpy().squeeze().astype( np.float32) hf["latent_kappa_gt_distance_end"][i] = tf.abs( z_kappa_opt - z_kappa_gt).numpy().squeeze().astype(np.float32) hf["observation_coherence_spectrum"][i] = _ps_observation hf["observation_coherence_spectrum_reoptimized"][ i] = _ps_observation2 hf["source_coherence_spectrum"][i] = _ps_source hf["source_coherence_spectrum_reoptimized"][i] = _ps_source2 hf["kappa_coherence_spectrum"][i] = _ps_kappa hf["kappa_coherence_spectrum_reoptimized"][i] = _ps_kappa2 if i == 0: _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels // 2], bins=ps_observation.bins) f = (f[:-1] + f[1:]) / 2 hf["observation_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins) f = (f[:-1] + f[1:]) / 2 hf["source_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins) f = (f[:-1] + f[1:]) / 2 hf["kappa_frequencies"][:] = f hf["kappa_fov"][0] = phys.kappa_fov hf["source_fov"][0] = phys.src_fov
def distributed_strategy(args): tf.random.set_seed(args.seed) np.random.seed(args.seed) model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model) files = glob.glob( os.path.join(os.getenv('CENSAI_PATH'), "data", args.train_dataset, "*.tfrecords")) files = tf.data.Dataset.from_tensor_slices(files) train_dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type).shuffle(len(files)), block_length=1, num_parallel_calls=tf.data.AUTOTUNE) # Read off global parameters from first example in dataset for physical_params in train_dataset.map(decode_physical_model_info): break train_dataset = train_dataset.map(decode_results).shuffle( buffer_size=args.buffer_size) files = glob.glob( os.path.join(os.getenv('CENSAI_PATH'), "data", args.val_dataset, "*.tfrecords")) files = tf.data.Dataset.from_tensor_slices(files) val_dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type).shuffle(len(files)), block_length=1, num_parallel_calls=tf.data.AUTOTUNE) val_dataset = val_dataset.map(decode_results).shuffle( buffer_size=args.buffer_size) files = glob.glob( os.path.join(os.getenv('CENSAI_PATH'), "data", args.test_dataset, "*.tfrecords")) files = tf.data.Dataset.from_tensor_slices(files) test_dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type).shuffle(len(files)), block_length=1, num_parallel_calls=tf.data.AUTOTUNE) test_dataset = test_dataset.map(decode_results).shuffle( buffer_size=args.buffer_size) ps_lens = PowerSpectrum(bins=args.lens_coherence_bins, pixels=physical_params["pixels"].numpy()) ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=physical_params["src pixels"].numpy()) ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=physical_params["kappa pixels"].numpy()) phys = PhysicalModel( pixels=physical_params["pixels"].numpy(), kappa_pixels=physical_params["kappa pixels"].numpy(), src_pixels=physical_params["src pixels"].numpy(), image_fov=physical_params["image fov"].numpy(), kappa_fov=physical_params["kappa fov"].numpy(), src_fov=physical_params["source fov"].numpy(), method="fft", ) phys_sie = AnalyticalPhysicalModel( pixels=physical_params["pixels"].numpy(), image_fov=physical_params["image fov"].numpy(), src_fov=physical_params["source fov"].numpy()) with open(os.path.join(model, "unet_hparams.json")) as f: unet_params = json.load(f) unet_params["kernel_l2_amp"] = args.l2_amp unet = Model(**unet_params) ckpt = tf.train.Checkpoint(net=unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1) checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial() with open(os.path.join(model, "rim_hparams.json")) as f: rim_params = json.load(f) rim = RIM(phys, unet, **rim_params) dataset_names = [args.train_dataset, args.val_dataset, args.test_dataset] dataset_shapes = [args.train_size, args.val_size, args.test_size] model_name = os.path.split(model)[-1] # from censai.utils import nulltape # def call_with_mask(self, lensed_image, noise_rms, psf, mask, outer_tape=nulltape): # """ # Used in training. Return linked kappa and source maps. # """ # batch_size = lensed_image.shape[0] # source, kappa, source_grad, kappa_grad, states = self.initial_states(batch_size) # initiate all tensors to 0 # source, kappa, states = self.time_step(lensed_image, source, kappa, source_grad, kappa_grad, # states) # Use lens to make an initial guess with Unet # source_series = tf.TensorArray(DTYPE, size=self.steps) # kappa_series = tf.TensorArray(DTYPE, size=self.steps) # chi_squared_series = tf.TensorArray(DTYPE, size=self.steps) # # record initial guess # source_series = source_series.write(index=0, value=source) # kappa_series = kappa_series.write(index=0, value=kappa) # # Main optimization loop # for current_step in tf.range(self.steps - 1): # with outer_tape.stop_recording(): # with tf.GradientTape() as g: # g.watch(source) # g.watch(kappa) # y_pred = self.physical_model.forward(self.source_link(source), self.kappa_link(kappa), psf) # flux_term = tf.square( # tf.reduce_sum(y_pred, axis=(1, 2, 3)) - tf.reduce_sum(lensed_image, axis=(1, 2, 3))) # log_likelihood = 0.5 * tf.reduce_sum( # tf.square(y_pred - mask * lensed_image) / noise_rms[:, None, None, None] ** 2, axis=(1, 2, 3)) # cost = tf.reduce_mean(log_likelihood + self.flux_lagrange_multiplier * flux_term) # source_grad, kappa_grad = g.gradient(cost, [source, kappa]) # source_grad, kappa_grad = self.grad_update(source_grad, kappa_grad, current_step) # source, kappa, states = self.time_step(lensed_image, source, kappa, source_grad, kappa_grad, states) # source_series = source_series.write(index=current_step + 1, value=source) # kappa_series = kappa_series.write(index=current_step + 1, value=kappa) # chi_squared_series = chi_squared_series.write(index=current_step, # value=log_likelihood / self.pixels ** 2) # renormalize chi squared here # # last step score # log_likelihood = self.physical_model.log_likelihood(y_true=lensed_image, source=self.source_link(source), # kappa=self.kappa_link(kappa), psf=psf, noise_rms=noise_rms) # chi_squared_series = chi_squared_series.write(index=self.steps - 1, value=log_likelihood) # return source_series.stack(), kappa_series.stack(), chi_squared_series.stack() with h5py.File( os.path.join( os.getenv("CENSAI_PATH"), "results", args.experiment_name + "_" + model_name + f"_{THIS_WORKER:02d}.h5"), 'w') as hf: for i, dataset in enumerate([train_dataset, val_dataset, test_dataset]): g = hf.create_group(f'{dataset_names[i]}') data_len = dataset_shapes[i] // N_WORKERS g.create_dataset(name="lens", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="psf", shape=[ data_len, physical_params['psf pixels'], physical_params['psf pixels'], 1 ], dtype=np.float32) g.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32) g.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32) g.create_dataset( name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) g.create_dataset( name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) g.create_dataset(name="lens_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="lens_pred_reoptimized", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="source_pred", shape=[ data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1 ], dtype=np.float32) g.create_dataset( name="source_pred_reoptimized", shape=[data_len, phys.src_pixels, phys.src_pixels, 1]) g.create_dataset(name="kappa_pred", shape=[ data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1 ], dtype=np.float32) g.create_dataset( name="kappa_pred_reoptimized", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) g.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32) g.create_dataset(name="chi_squared_reoptimized", shape=[data_len], dtype=np.float32) g.create_dataset(name="chi_squared_reoptimized_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) g.create_dataset(name="source_optim_mse", shape=[data_len], dtype=np.float32) g.create_dataset(name="source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) g.create_dataset(name="kappa_optim_mse", shape=[data_len], dtype=np.float32) g.create_dataset(name="kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum2", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum_reoptimized", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum2", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum_reoptimized", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_coherence_spectrum_reoptimized", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_frequencies", shape=[args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_frequencies", shape=[args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_frequencies", shape=[args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32) g.create_dataset(name="source_fov", shape=[1], dtype=np.float32) g.create_dataset(name="lens_fov", shape=[1], dtype=np.float32) dataset = dataset.skip(data_len * (THIS_WORKER - 1)).take(data_len) for batch, (lens, source, kappa, noise_rms, psf, fwhm) in enumerate( dataset.batch(1).prefetch( tf.data.experimental.AUTOTUNE)): checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial( ) # reset model weights # Compute predictions for kappa and source source_pred, kappa_pred, chi_squared = rim.predict( lens, noise_rms, psf) lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) # Re-optimize weights of the model STEPS = args.re_optimize_steps learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=args.learning_rate, decay_rate=args.decay_rate, decay_steps=args.decay_steps, staircase=args.staircase) optim = tf.keras.optimizers.RMSprop( learning_rate=learning_rate_schedule) chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) source_mse = tf.TensorArray(DTYPE, size=STEPS) kappa_mse = tf.TensorArray(DTYPE, size=STEPS) best = chi_squared[-1, 0] # best = abs(2*chi_squared[-1, 0] - 1) # best_chisq = 2*chi_squared[-1, 0] source_best = source_pred[-1] kappa_best = kappa_pred[-1] # source_mean = source_pred[-1] # kappa_mean = rim.kappa_link(kappa_pred[-1]) # source_std = tf.zeros_like(source_mean) # kappa_std = tf.zeros_like(kappa_mean) # counter = 0 for current_step in tqdm(range(STEPS)): with tf.GradientTape() as tape: tape.watch(unet.trainable_variables) # s, k, chi_sq = call_with_mask(rim, lens, noise_rms, psf, mask, tape) s, k, chi_sq = rim.call(lens, noise_rms, psf, outer_tape=tape) cost = tf.reduce_mean(chi_sq) # mean over time steps cost += tf.reduce_sum(rim.unet.losses) log_likelihood = chi_sq[-1] chi_squared_series = chi_squared_series.write( index=current_step, value=log_likelihood) source_o = s[-1] kappa_o = k[-1] source_mse = source_mse.write( index=current_step, value=tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2)) kappa_mse = kappa_mse.write( index=current_step, value=tf.reduce_mean( (kappa_o - rim.kappa_inverse_link(kappa))**2)) if chi_sq[-1, 0] < args.converged_chisq: source_best = rim.source_link(source_o) kappa_best = rim.kappa_link(kappa_o) best = chi_sq[-1, 0] break if chi_sq[-1, 0] < best: source_best = rim.source_link(source_o) kappa_best = rim.kappa_link(kappa_o) best = chi_sq[-1, 0] source_mse_best = tf.reduce_mean( (source_best - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_mean( (kappa_best - rim.kappa_inverse_link(kappa))**2) # if counter > 0: # # Welford's online algorithm # # source # delta = source_o - source_mean # source_mean = (counter * source_mean + (counter + 1) * source_o)/(counter + 1) # delta2 = source_o - source_mean # source_std += delta * delta2 # # kappa # delta = rim.kappa_link(kappa_o) - kappa_mean # kappa_mean = (counter * kappa_mean + (counter + 1) * rim.kappa_link(kappa_o)) / (counter + 1) # delta2 = rim.kappa_link(kappa_o) - kappa_mean # kappa_std += delta * delta2 # if best_chisq < args.converged_chisq: # counter += 1 # if counter == args.window: # break # if 2*chi_sq[-1, 0] < best_chisq: # best_chisq = 2*chi_sq[-1, 0] # if abs(2*chi_sq[-1, 0] - 1) < best: # source_best = rim.source_link(source_o) # kappa_best = rim.kappa_link(kappa_o) # best = abs(2 * chi_squared[-1, 0] - 1) # source_mse_best = tf.reduce_mean((source_best - rim.source_inverse_link(source)) ** 2) # kappa_mse_best = tf.reduce_mean((kappa_best - rim.kappa_inverse_link(kappa)) ** 2) grads = tape.gradient(cost, unet.trainable_variables) optim.apply_gradients(zip(grads, unet.trainable_variables)) source_o = source_best kappa_o = kappa_best y_pred = phys.forward(source_o, kappa_o, psf) chi_sq_series = tf.transpose(chi_squared_series.stack(), perm=[1, 0]) source_mse = source_mse.stack()[None, ...] kappa_mse = kappa_mse.stack()[None, ...] # kappa_std /= float(args.window) # source_std /= float(args.window) # Compute Power spectrum of converged predictions _ps_lens = ps_lens.cross_correlation_coefficient( lens[..., 0], lens_pred[..., 0]) _ps_lens3 = ps_lens.cross_correlation_coefficient( lens[..., 0], y_pred[..., 0]) _ps_kappa = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0]) _ps_kappa2 = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_o[..., 0])) _ps_source = ps_source.cross_correlation_coefficient( source[..., 0], source_pred[-1][..., 0]) _ps_source3 = ps_source.cross_correlation_coefficient( source[..., 0], source_o[..., 0]) # save results g["lens"][batch] = lens.numpy().astype(np.float32) g["psf"][batch] = psf.numpy().astype(np.float32) g["psf_fwhm"][batch] = fwhm.numpy().astype(np.float32) g["noise_rms"][batch] = noise_rms.numpy().astype(np.float32) g["source"][batch] = source.numpy().astype(np.float32) g["kappa"][batch] = kappa.numpy().astype(np.float32) g["lens_pred"][batch] = lens_pred.numpy().astype(np.float32) g["lens_pred_reoptimized"][batch] = y_pred.numpy().astype( np.float32) g["source_pred"][batch] = tf.transpose( source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) g["source_pred_reoptimized"][batch] = source_o.numpy().astype( np.float32) g["kappa_pred"][batch] = tf.transpose( kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) g["kappa_pred_reoptimized"][batch] = kappa_o.numpy().astype( np.float32) g["chi_squared"][batch] = tf.transpose( chi_squared).numpy().astype(np.float32) g["chi_squared_reoptimized"][batch] = best.numpy().astype( np.float32) g["chi_squared_reoptimized_series"][ batch] = chi_sq_series.numpy().astype(np.float32) g["source_optim_mse"][batch] = source_mse_best.numpy().astype( np.float32) g["source_optim_mse_series"][batch] = source_mse.numpy( ).astype(np.float32) g["kappa_optim_mse"][batch] = kappa_mse_best.numpy().astype( np.float32) g["kappa_optim_mse_series"][batch] = kappa_mse.numpy().astype( np.float32) g["lens_coherence_spectrum"][batch] = _ps_lens g["lens_coherence_spectrum_reoptimized"][batch] = _ps_lens3 g["source_coherence_spectrum"][batch] = _ps_source g["source_coherence_spectrum_reoptimized"][batch] = _ps_source3 g["lens_coherence_spectrum"][batch] = _ps_lens g["lens_coherence_spectrum"][batch] = _ps_lens g["kappa_coherence_spectrum"][batch] = _ps_kappa g["kappa_coherence_spectrum_reoptimized"][batch] = _ps_kappa2 if batch == 0: _, f = np.histogram(np.fft.fftfreq( phys.pixels)[:phys.pixels // 2], bins=ps_lens.bins) f = (f[:-1] + f[1:]) / 2 g["lens_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins) f = (f[:-1] + f[1:]) / 2 g["source_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins) f = (f[:-1] + f[1:]) / 2 g["kappa_frequencies"][:] = f g["kappa_fov"][0] = phys.kappa_fov g["source_fov"][0] = phys.src_fov # Create SIE test g = hf.create_group(f'SIE_test') data_len = args.sie_size // N_WORKERS sie_dataset = test_dataset.skip(data_len * (THIS_WORKER - 1)).take(data_len) g.create_dataset(name="lens", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="psf", shape=[ data_len, physical_params['psf pixels'], physical_params['psf pixels'], 1 ], dtype=np.float32) g.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32) g.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32) g.create_dataset(name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) g.create_dataset( name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) g.create_dataset(name="lens_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset(name="lens_pred2", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) g.create_dataset( name="source_pred", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) g.create_dataset(name="kappa_pred", shape=[ data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1 ], dtype=np.float32) g.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_coherence_spectrum2", shape=[data_len, args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_coherence_spectrum2", shape=[data_len, args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="lens_frequencies", shape=[args.lens_coherence_bins], dtype=np.float32) g.create_dataset(name="source_frequencies", shape=[args.source_coherence_bins], dtype=np.float32) g.create_dataset(name="kappa_frequencies", shape=[args.kappa_coherence_bins], dtype=np.float32) g.create_dataset(name="einstein_radius", shape=[data_len], dtype=np.float32) g.create_dataset(name="position", shape=[data_len, 2], dtype=np.float32) g.create_dataset(name="orientation", shape=[data_len], dtype=np.float32) g.create_dataset(name="ellipticity", shape=[data_len], dtype=np.float32) g.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32) g.create_dataset(name="source_fov", shape=[1], dtype=np.float32) g.create_dataset(name="lens_fov", shape=[1], dtype=np.float32) for batch, (_, source, _, noise_rms, psf, fwhm) in enumerate( sie_dataset.take(data_len).batch(args.batch_size).prefetch( tf.data.experimental.AUTOTUNE)): batch_size = source.shape[0] # Create some SIE kappa maps _r = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0, maxval=args.max_shift) _theta = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=-np.pi, maxval=np.pi) x0 = _r * tf.math.cos(_theta) y0 = _r * tf.math.sin(_theta) ellipticity = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0., maxval=args.max_ellipticity) phi = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=-np.pi, maxval=np.pi) einstein_radius = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=args.min_theta_e, maxval=args.max_theta_e) kappa = phys_sie.kappa_field(x0=x0, y0=y0, e=ellipticity, phi=phi, r_ein=einstein_radius) lens = phys.noisy_forward(source, kappa, noise_rms=noise_rms, psf=psf) # Compute predictions for kappa and source source_pred, kappa_pred, chi_squared = rim.predict( lens, noise_rms, psf) lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) # Compute Power spectrum of converged predictions _ps_lens = ps_lens.cross_correlation_coefficient( lens[..., 0], lens_pred[..., 0]) _ps_kappa = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0]) _ps_source = ps_source.cross_correlation_coefficient( source[..., 0], source_pred[-1][..., 0]) # save results i_begin = batch * args.batch_size i_end = i_begin + batch_size g["lens"][i_begin:i_end] = lens.numpy().astype(np.float32) g["psf"][i_begin:i_end] = psf.numpy().astype(np.float32) g["psf_fwhm"][i_begin:i_end] = fwhm.numpy().astype(np.float32) g["noise_rms"][i_begin:i_end] = noise_rms.numpy().astype( np.float32) g["source"][i_begin:i_end] = source.numpy().astype(np.float32) g["kappa"][i_begin:i_end] = kappa.numpy().astype(np.float32) g["lens_pred"][i_begin:i_end] = lens_pred.numpy().astype( np.float32) g["source_pred"][i_begin:i_end] = tf.transpose( source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) g["kappa_pred"][i_begin:i_end] = tf.transpose( kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) g["chi_squared"][i_begin:i_end] = 2 * tf.transpose( chi_squared).numpy().astype(np.float32) g["lens_coherence_spectrum"][i_begin:i_end] = _ps_lens.numpy( ).astype(np.float32) g["source_coherence_spectrum"][i_begin:i_end] = _ps_source.numpy( ).astype(np.float32) g["kappa_coherence_spectrum"][i_begin:i_end] = _ps_kappa.numpy( ).astype(np.float32) g["einstein_radius"][ i_begin:i_end] = einstein_radius[:, 0, 0, 0].numpy().astype(np.float32) g["position"][i_begin:i_end] = tf.stack( [x0[:, 0, 0, 0], y0[:, 0, 0, 0]], axis=1).numpy().astype(np.float32) g["ellipticity"][i_begin:i_end] = ellipticity[:, 0, 0, 0].numpy().astype( np.float32) g["orientation"][i_begin:i_end] = phi[:, 0, 0, 0].numpy().astype(np.float32) if batch == 0: _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels // 2], bins=ps_lens.bins) f = (f[:-1] + f[1:]) / 2 g["lens_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins) f = (f[:-1] + f[1:]) / 2 g["source_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins) f = (f[:-1] + f[1:]) / 2 g["kappa_frequencies"][:] = f g["kappa_fov"][0] = phys.kappa_fov g["source_fov"][0] = phys.src_fov
def distributed_strategy(args): model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model) path = os.getenv('CENSAI_PATH') + "/results/" dataset = [] for file in sorted(glob.glob(path + args.h5_pattern)): try: dataset.append(h5py.File(file, "r")) except: continue B = dataset[0]["source"].shape[0] data_len = len(dataset) * B // N_WORKERS ps_observation = PowerSpectrum(bins=args.observation_coherence_bins, pixels=128) ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=128) ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=128) phys = PhysicalModel( pixels=128, kappa_pixels=128, src_pixels=128, image_fov=7.69, kappa_fov=7.69, src_fov=3., method="fft", ) with open(os.path.join(model, "unet_hparams.json")) as f: unet_params = json.load(f) unet_params["kernel_l2_amp"] = args.l2_amp unet = Model(**unet_params) ckpt = tf.train.Checkpoint(net=unet) checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1) checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial() with open(os.path.join(model, "rim_hparams.json")) as f: rim_params = json.load(f) rim = RIM(phys, unet, **rim_params) kvae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.kappa_vae) with open(os.path.join(kvae_path, "model_hparams.json"), "r") as f: kappa_vae_hparams = json.load(f) kappa_vae = VAE(**kappa_vae_hparams) ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae) checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, kvae_path, 1) checkpoint_manager1.checkpoint.restore( checkpoint_manager1.latest_checkpoint).expect_partial() svae_path = os.path.join(os.getenv('CENSAI_PATH'), "models", args.source_vae) with open(os.path.join(svae_path, "model_hparams.json"), "r") as f: source_vae_hparams = json.load(f) source_vae = VAE(**source_vae_hparams) ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae) checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, svae_path, 1) checkpoint_manager2.checkpoint.restore( checkpoint_manager2.latest_checkpoint).expect_partial() wk = lambda k: tf.sqrt(k) / tf.reduce_sum( tf.sqrt(k), axis=(1, 2, 3), keepdims=True) # Freeze L5 # encoding layers # rim.unet.layers[0].trainable = False # L1 # rim.unet.layers[1].trainable = False # rim.unet.layers[2].trainable = False # rim.unet.layers[3].trainable = False # rim.unet.layers[4].trainable = False # L5 # GRU # rim.unet.layers[5].trainable = False # rim.unet.layers[6].trainable = False # rim.unet.layers[7].trainable = False # rim.unet.layers[8].trainable = False # rim.unet.layers[9].trainable = False # rim.unet.layers[15].trainable = False # bottleneck GRU # output layer # rim.unet.layers[-2].trainable = False # input layer # rim.unet.layers[-1].trainable = False # decoding layers # rim.unet.layers[10].trainable = False # L5 # rim.unet.layers[11].trainable = False # rim.unet.layers[12].trainable = False # rim.unet.layers[13].trainable = False # rim.unet.layers[14].trainable = False # L1 with h5py.File( os.path.join( os.getenv("CENSAI_PATH"), "results", args.experiment_name + "_" + args.model + "_" + args.dataset + f"_{THIS_WORKER:03d}.h5"), 'w') as hf: hf.create_dataset(name="observation", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="psf", shape=[data_len, 20, 20, 1], dtype=np.float32) hf.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32) hf.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32) hf.create_dataset( name="source", shape=[data_len, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="kappa", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset(name="observation_pred_reoptimized", shape=[data_len, phys.pixels, phys.pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred", shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1], dtype=np.float32) hf.create_dataset( name="source_pred_reoptimized", shape=[data_len, phys.src_pixels, phys.src_pixels, 1]) hf.create_dataset(name="kappa_pred", shape=[ data_len, rim.steps, phys.kappa_pixels, phys.kappa_pixels, 1 ], dtype=np.float32) hf.create_dataset( name="kappa_pred_reoptimized", shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1], dtype=np.float32) hf.create_dataset(name="chi_squared", shape=[data_len, rim.steps], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized", shape=[data_len], dtype=np.float32) hf.create_dataset(name="chi_squared_reoptimized_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="source_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="source_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse", shape=[data_len], dtype=np.float32) hf.create_dataset(name="kappa_optim_mse_series", shape=[data_len, args.re_optimize_steps], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum2", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_coherence_spectrum_reoptimized", shape=[data_len, args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum2", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_coherence_spectrum_reoptimized", shape=[data_len, args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_coherence_spectrum_reoptimized", shape=[data_len, args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="observation_frequencies", shape=[args.observation_coherence_bins], dtype=np.float32) hf.create_dataset(name="source_frequencies", shape=[args.source_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_frequencies", shape=[args.kappa_coherence_bins], dtype=np.float32) hf.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="source_fov", shape=[1], dtype=np.float32) hf.create_dataset(name="observation_fov", shape=[1], dtype=np.float32) for batch, j in enumerate( range((THIS_WORKER - 1) * data_len, THIS_WORKER * data_len)): b = j // B k = j % B observation = dataset[b]["observation"][k][None, ...] source = dataset[b]["source"][k][None, ...] kappa = dataset[b]["kappa"][k][None, ...] noise_rms = np.array([dataset[b]["noise_rms"][k]]) psf = dataset[b]["psf"][k][None, ...] fwhm = dataset[b]["psf_fwhm"][k] checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint).expect_partial( ) # reset model weights # Compute predictions for kappa and source source_pred, kappa_pred, chi_squared = rim.predict( observation, noise_rms, psf) observation_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf) # reset the seed for reproducible sampling in the VAE for EWC tf.random.set_seed(args.seed) np.random.seed(args.seed) # Initialize regularization term ewc = EWC(observation=observation, noise_rms=noise_rms, psf=psf, phys=phys, rim=rim, source_vae=source_vae, kappa_vae=kappa_vae, n_samples=args.sample_size, sigma_source=args.source_vae_ball_size, sigma_kappa=args.kappa_vae_ball_size) # Re-optimize weights of the model STEPS = args.re_optimize_steps learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=args.learning_rate, decay_rate=args.decay_rate, decay_steps=args.decay_steps, staircase=args.staircase) optim = tf.keras.optimizers.SGD( learning_rate=learning_rate_schedule) chi_squared_series = tf.TensorArray(DTYPE, size=STEPS) source_mse = tf.TensorArray(DTYPE, size=STEPS) kappa_mse = tf.TensorArray(DTYPE, size=STEPS) best = chi_squared[-1, 0] source_best = source_pred[-1] kappa_best = kappa_pred[-1] source_mse_best = tf.reduce_mean( (source_best - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_sum( wk(kappa) * (kappa_best - rim.kappa_inverse_link(kappa))**2) for current_step in tqdm(range(STEPS)): with tf.GradientTape() as tape: tape.watch(unet.trainable_variables) s, k, chi_sq = rim.call(observation, noise_rms, psf, outer_tape=tape) cost = tf.reduce_mean(chi_sq) # mean over time steps cost += tf.reduce_sum(rim.unet.losses) # L2 regularisation cost += args.lam_ewc * ewc.penalty( rim) # Elastic Weights Consolidation log_likelihood = chi_sq[-1] chi_squared_series = chi_squared_series.write( index=current_step, value=log_likelihood) source_o = s[-1] kappa_o = k[-1] source_mse = source_mse.write( index=current_step, value=tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2)) kappa_mse = kappa_mse.write( index=current_step, value=tf.reduce_sum( wk(kappa) * (kappa_o - rim.kappa_inverse_link(kappa))**2)) if 2 * chi_sq[-1, 0] < 1.0 and args.early_stopping: source_best = rim.source_link(source_o) kappa_best = rim.kappa_link(kappa_o) best = chi_sq[-1, 0] source_mse_best = tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_sum( wk(kappa) * (kappa_o - rim.kappa_inverse_link(kappa))**2) break if chi_sq[-1, 0] < best: source_best = rim.source_link(source_o) kappa_best = rim.kappa_link(kappa_o) best = chi_sq[-1, 0] source_mse_best = tf.reduce_mean( (source_o - rim.source_inverse_link(source))**2) kappa_mse_best = tf.reduce_sum( wk(kappa) * (kappa_o - rim.kappa_inverse_link(kappa))**2) grads = tape.gradient(cost, unet.trainable_variables) optim.apply_gradients(zip(grads, unet.trainable_variables)) source_o = source_best kappa_o = kappa_best y_pred = phys.forward(source_o, kappa_o, psf) chi_sq_series = tf.transpose(chi_squared_series.stack(), perm=[1, 0]) source_mse = source_mse.stack()[None, ...] kappa_mse = kappa_mse.stack()[None, ...] # Compute Power spectrum of converged predictions _ps_observation = ps_observation.cross_correlation_coefficient( observation[..., 0], observation_pred[..., 0]) _ps_observation2 = ps_observation.cross_correlation_coefficient( observation[..., 0], y_pred[..., 0]) _ps_kappa = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_pred[-1])[..., 0]) _ps_kappa2 = ps_kappa.cross_correlation_coefficient( log_10(kappa)[..., 0], log_10(kappa_o[..., 0])) _ps_source = ps_source.cross_correlation_coefficient( source[..., 0], source_pred[-1][..., 0]) _ps_source2 = ps_source.cross_correlation_coefficient( source[..., 0], source_o[..., 0]) # save results hf["observation"][batch] = observation.astype(np.float32) hf["psf"][batch] = psf.astype(np.float32) hf["psf_fwhm"][batch] = fwhm hf["noise_rms"][batch] = noise_rms.astype(np.float32) hf["source"][batch] = source.astype(np.float32) hf["kappa"][batch] = kappa.astype(np.float32) hf["observation_pred"][batch] = observation_pred.numpy().astype( np.float32) hf["observation_pred_reoptimized"][batch] = y_pred.numpy().astype( np.float32) hf["source_pred"][batch] = tf.transpose( source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["source_pred_reoptimized"][batch] = source_o.numpy().astype( np.float32) hf["kappa_pred"][batch] = tf.transpose( kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32) hf["kappa_pred_reoptimized"][batch] = kappa_o.numpy().astype( np.float32) hf["chi_squared"][batch] = 2 * tf.transpose( chi_squared).numpy().astype(np.float32) hf["chi_squared_reoptimized"][batch] = 2 * best.numpy().astype( np.float32) hf["chi_squared_reoptimized_series"][ batch] = 2 * chi_sq_series.numpy().astype(np.float32) hf["source_optim_mse"][batch] = source_mse_best.numpy().astype( np.float32) hf["source_optim_mse_series"][batch] = source_mse.numpy().astype( np.float32) hf["kappa_optim_mse"][batch] = kappa_mse_best.numpy().astype( np.float32) hf["kappa_optim_mse_series"][batch] = kappa_mse.numpy().astype( np.float32) hf["observation_coherence_spectrum"][batch] = _ps_observation hf["observation_coherence_spectrum_reoptimized"][ batch] = _ps_observation2 hf["source_coherence_spectrum"][batch] = _ps_source hf["source_coherence_spectrum_reoptimized"][batch] = _ps_source2 hf["kappa_coherence_spectrum"][batch] = _ps_kappa hf["kappa_coherence_spectrum_reoptimized"][batch] = _ps_kappa2 if batch == 0: _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels // 2], bins=ps_observation.bins) f = (f[:-1] + f[1:]) / 2 hf["observation_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.src_pixels)[:phys.src_pixels // 2], bins=ps_source.bins) f = (f[:-1] + f[1:]) / 2 hf["source_frequencies"][:] = f _, f = np.histogram(np.fft.fftfreq( phys.kappa_pixels)[:phys.kappa_pixels // 2], bins=ps_kappa.bins) f = (f[:-1] + f[1:]) / 2 hf["kappa_frequencies"][:] = f hf["kappa_fov"][0] = phys.kappa_fov hf["source_fov"][0] = phys.src_fov
def test_noisy_forward_conv2(): phys = PhysicalModel(pixels=64, src_pixels=64) source = tf.random.normal([2, 64, 64, 1]) kappa = tf.math.exp(tf.random.uniform([2, 64, 64, 1])) noise_rms = 0.1 phys.noisy_forward(source, kappa, noise_rms)
def main(args): if args.seed is not None: tf.random.set_seed(args.seed) np.random.seed(args.seed) if args.json_override is not None: if isinstance(args.json_override, list): files = args.json_override else: files = [ args.json_override, ] for file in files: with open(file, "r") as f: json_override = json.load(f) args_dict = vars(args) args_dict.update(json_override) if args.v2: # overwrite decoding procedure with version 2 from censai.data.alpha_tng_v2 import decode_train, decode_physical_info else: from censai.data.alpha_tng import decode_train, decode_physical_info # ========= Dataset================================================================================================= files = [] for dataset in args.datasets: files.extend(glob.glob(os.path.join(dataset, "*.tfrecords"))) np.random.shuffle(files) files = tf.data.Dataset.from_tensor_slices(files) dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE) # extract physical info from first example for image_fov, kappa_fov in dataset.map(decode_physical_info): break dataset = dataset.map(decode_train) if args.cache_file is not None: dataset = dataset.cache(args.cache_file) train_dataset = dataset.take(math.floor(args.train_split * args.total_items))\ .shuffle(buffer_size=args.buffer_size).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE) val_dataset = dataset.skip(math.floor(args.train_split * args.total_items))\ .take(math.ceil((1 - args.train_split) * args.total_items)).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE) train_dataset = STRATEGY.experimental_distribute_dataset(train_dataset) val_dataset = STRATEGY.experimental_distribute_dataset(val_dataset) # ========== Model and Physical Model ============================================================================= # setup an analytical physical model to compare lenses from raytracer and analytical deflection angles. phys = PhysicalModel(pixels=args.pixels, image_fov=image_fov, kappa_fov=kappa_fov, src_fov=args.source_fov, psf_sigma=args.psf_sigma) with STRATEGY.scope(): # Replicate ops accross gpus ray_tracer = RayTracer( pixels=args.pixels, filter_scaling=args.filter_scaling, layers=args.layers, block_conv_layers=args.block_conv_layers, kernel_size=args.kernel_size, filters=args.filters, strides=args.strides, bottleneck_filters=args.bottleneck_filters, resampling_kernel_size=args.resampling_kernel_size, upsampling_interpolation=args.upsampling_interpolation, kernel_regularizer_amp=args.kernel_regularizer_amp, activation=args.activation, initializer=args.initializer, kappalog=args.kappalog, normalize=args.normalize, ) lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( args.initial_learning_rate, decay_steps=args.decay_steps, decay_rate=args.decay_rate, staircase=True) optim = tf.keras.optimizers.deserialize({ "class_name": args.optimizer, 'config': { "learning_rate": lr_schedule } }) # ==== Take care of where to write logs and stuff ================================================================= if args.model_id.lower() != "none": logname = args.model_id elif args.logname is not None: logname = args.logname else: logname = args.logname_prefixe + "_" + datetime.now().strftime( "%y-%m-%d_%H-%M-%S") # setup tensorboard writer (nullwriter in case we do not want to sync) if args.logdir.lower() != "none": logdir = os.path.join(args.logdir, logname) if not os.path.isdir(logdir): os.mkdir(logdir) writer = tf.summary.create_file_writer(logdir) else: writer = nullwriter() # ===== Make sure directory and checkpoint manager are created to save model =================================== if args.model_dir.lower() != "none": checkpoints_dir = os.path.join(args.model_dir, logname) if not os.path.isdir(checkpoints_dir): os.mkdir(checkpoints_dir) # save script parameter for future reference with open(os.path.join(checkpoints_dir, "script_params.json"), "w") as f: json.dump(vars(args), f, indent=4) with open(os.path.join(checkpoints_dir, "ray_tracer_hparams.json"), "w") as f: hparams_dict = { key: vars(args)[key] for key in RAYTRACER_HPARAMS } json.dump(hparams_dict, f, indent=4) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optim, net=ray_tracer) checkpoint_manager = tf.train.CheckpointManager( ckpt, checkpoints_dir, max_to_keep=args.max_to_keep) save_checkpoint = True # ======= Load model if model_id is provided =============================================================== if args.model_id.lower() != "none": checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint) else: save_checkpoint = False # ================================================================================================================= def train_step(inputs): kappa, alpha = inputs with tf.GradientTape(watch_accessed_variables=True) as tape: tape.watch(ray_tracer.trainable_weights) cost = tf.reduce_mean(tf.square(ray_tracer(kappa) - alpha), axis=(1, 2, 3)) cost = tf.reduce_sum( cost) / args.batch_size # normalize by global batch size gradient = tape.gradient(cost, ray_tracer.trainable_weights) if args.clipping: clipped_gradient = [ tf.clip_by_value(grad, -10, 10) for grad in gradient ] else: clipped_gradient = gradient optim.apply_gradients( zip(clipped_gradient, ray_tracer.trainable_variables)) return cost @tf.function def distributed_train_step(dist_inputs): per_replica_losses = STRATEGY.run(train_step, args=(dist_inputs, )) cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) cost += tf.reduce_sum(ray_tracer.losses) return cost def test_step(inputs): kappa, alpha = inputs cost = tf.reduce_mean(tf.square(ray_tracer(kappa) - alpha), axis=(1, 2, 3)) cost = tf.reduce_sum( cost) / args.batch_size # normalize by global batch size return cost @tf.function def distributed_test_step(dist_inputs): per_replica_losses = STRATEGY.run(test_step, args=(dist_inputs, )) # Replica losses are aggregated by summing them cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) cost += tf.reduce_sum(ray_tracer.losses) return cost # ================================================================================================================= epoch_loss = tf.metrics.Mean() val_loss = tf.metrics.Mean() time_per_step = tf.metrics.Mean() history = { # recorded at the end of an epoch only "train_cost": [], "train_lens_residuals": [], "val_cost": [], "val_lens_residuals": [], "learning_rate": [], "step": [], "wall_time": [] } best_loss = np.inf global_start = time.time() estimated_time_for_epoch = 0 out_of_time = False patience = args.patience step = 0 lastest_checkpoint = 1 for epoch in range(1, args.epochs + 1): if (time.time() - global_start ) > args.max_time * 3600 - estimated_time_for_epoch: break epoch_start = time.time() epoch_loss.reset_states() time_per_step.reset_states() with writer.as_default(): for batch, distributed_inputs in enumerate(train_dataset): start = time.time() cost = distributed_train_step(distributed_inputs) # ========== Summary and logs ================================================================================== _time = time.time() - start time_per_step.update_state([_time]) epoch_loss.update_state([cost]) step += 1 # Deflection angle residual kappa_true, alpha_true = distributed_inputs alpha_pred = ray_tracer.call(kappa_true) # Lens residual lens_true = phys.lens_source_func(kappa_true, w=args.source_w) lens_pred = phys.lens_source_func_given_alpha(alpha_pred, w=args.source_w) train_chi_squared = tf.reduce_mean(tf.square(lens_true - lens_pred)) for res_idx in range( min(args.n_residuals, args.batch_size)): # Residuals in train set tf.summary.image(f"Residuals {res_idx}", plot_to_image( residual_plot(alpha_true[res_idx], alpha_pred[res_idx], lens_true[res_idx], lens_pred[res_idx])), step=step) # ========== Validation set =================== val_loss.reset_states() for distributed_inputs in val_dataset: # Cost of val set test_cost = distributed_test_step(distributed_inputs) val_loss.update_state([test_cost]) kappa_true, alpha_true = distributed_inputs alpha_pred = ray_tracer.call(kappa_true) lens_true = phys.lens_source_func(kappa_true, args.source_w) lens_pred = phys.lens_source_func_given_alpha(alpha_pred, w=args.source_w) val_chi_squared = tf.reduce_mean(tf.square(lens_true - lens_pred)) for res_idx in range(min(args.n_residuals, args.batch_size)): # Residuals in val set tf.summary.image(f"Val Residuals {res_idx}", plot_to_image( residual_plot(alpha_true[res_idx], alpha_pred[res_idx], lens_true[res_idx], lens_pred[res_idx])), step=step) train_cost = epoch_loss.result().numpy() val_cost = val_loss.result().numpy() tf.summary.scalar("Time per step", time_per_step.result().numpy(), step=step) tf.summary.scalar("Train lens residual", train_chi_squared, step=step) tf.summary.scalar("Val lens residual", val_chi_squared, step=step) tf.summary.scalar("MSE", train_cost, step=step) tf.summary.scalar("Val MSE", val_cost, step=step) tf.summary.scalar("Learning rate", optim.lr(step), step=step) print( f"epoch {epoch} | train loss {train_cost:.3e} | val loss {val_cost:.3e} | learning rate {optim.lr(step).numpy():.2e} | " f"time per step {time_per_step.result():.2e} s") history["train_cost"].append(train_cost) history["train_lens_residuals"].append(train_chi_squared.numpy()) history["val_cost"].append(val_cost) history["val_lens_residuals"].append(val_chi_squared.numpy()) history["learning_rate"].append(optim.lr(step).numpy()) history["step"].append(step) history["wall_time"].append(time.time() - global_start) cost = train_cost if args.track_train else val_cost if np.isnan(cost): print("Training broke the Universe") break if cost < best_loss * (1 - args.tolerance): best_loss = cost patience = args.patience else: patience -= 1 if (time.time() - global_start) > args.max_time * 3600: out_of_time = True if save_checkpoint: checkpoint_manager.checkpoint.step.assign_add(1) # a bit of a hack if epoch % args.checkpoints == 0 or patience == 0 or epoch == args.epochs - 1 or out_of_time: with open(os.path.join(checkpoints_dir, "score_sheet.txt"), mode="a") as f: np.savetxt(f, np.array([[lastest_checkpoint, cost]])) lastest_checkpoint += 1 checkpoint_manager.save() print("Saved checkpoint for step {}: {}".format( int(checkpoint_manager.checkpoint.step), checkpoint_manager.latest_checkpoint)) if patience == 0: print("Reached patience") break if out_of_time: break if epoch > 0: # First epoch is always very slow and not a good estimate of an epoch time. estimated_time_for_epoch = time.time() - epoch_start if optim.lr(step).numpy() < 1e-9: print("Reached learning rate limit") break return history, best_loss
def __init__(self, observation, noise_rms, psf, phys: PhysicalModel, rim: RIM, source_vae: VAE, kappa_vae: VAE, n_samples=100, sigma_source=0.5, sigma_kappa=0.5): """ Make a copy of initial parameters \varphi^{(0)} and compute the Fisher diagonal F_{ii} """ wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum( tf.sqrt(k), axis=(1, 2, 3), keepdims=True)) # Baseline prediction from observation source_pred, kappa_pred, chi_squared = rim.predict( observation, noise_rms, psf) # Latent code of model predictions z_source, _ = source_vae.encoder(source_pred[-1]) z_kappa, _ = kappa_vae.encoder(log_10(kappa_pred[-1])) # Deepcopy of the initial parameters self.initial_params = [ deepcopy(w) for w in rim.unet.trainable_variables ] self.fisher_diagonal = [tf.zeros_like(w) for w in self.initial_params] for n in range(n_samples): # Sample latent code around the prediction mean z_s = tf.random.normal(shape=[1, source_vae.latent_size], mean=z_source, stddev=sigma_source) z_k = tf.random.normal(shape=[1, kappa_vae.latent_size], mean=z_kappa, stddev=sigma_kappa) # Decode sampled_source = tf.nn.relu(source_vae.decode(z_s)) sampled_source /= tf.reduce_max(sampled_source, axis=(1, 2, 3), keepdims=True) sampled_kappa = kappa_vae.decode(z_k) # output in log_10 space # Simulate observation sampled_observation = phys.noisy_forward(sampled_source, 10**sampled_kappa, noise_rms, psf) # Compute the gradient of the MSE with tf.GradientTape() as tape: tape.watch(rim.unet.trainable_variables) s, k, chi_squared = rim.call(sampled_observation, noise_rms, psf) # Remove the temperature from the loss when computing the Fisher: sum instead of mean, and weighted sum is renormalized by number of pixels _kappa_mse = phys.kappa_pixels**2 * tf.reduce_sum( wk(10**sampled_kappa) * (k - sampled_kappa)**2, axis=(2, 3, 4)) cost = tf.reduce_sum(_kappa_mse) cost += tf.reduce_sum((s - sampled_source)**2) grad = tape.gradient(cost, rim.unet.trainable_variables) # Square the derivative relative to initial parameters and add to total self.fisher_diagonal = [ F + g**2 / n_samples for F, g in zip(self.fisher_diagonal, grad) ]