def test_interpolated_kappa(): import tensorflow_addons as tfa phys = PhysicalModel(pixels=128, src_pixels=32, image_fov=7.68, kappa_fov=5) phys_a = AnalyticalPhysicalModel(pixels=128, image_fov=7.68) kappa = phys_a.kappa_field(r_ein=2., e=0.2) kappa += phys_a.kappa_field(r_ein=1., x0=2., y0=2.) true_lens = phys.lens_source_func(kappa, w=0.2) true_kappa = kappa # Test interpolation of alpha angles on a finer grid # phys = PhysicalModel(pixels=128, src_pixels=32, kappa_pixels=32) phys_a = AnalyticalPhysicalModel(pixels=32, image_fov=7.68) kappa = phys_a.kappa_field(r_ein=2., e=0.2) kappa += phys_a.kappa_field(r_ein=1., x0=2., y0=2.) # kappa2 = phys_a.kappa_field(r_ein=2., e=0.2) # kappa2 += phys_a.kappa_field(r_ein=1., x0=2., y0=2.) # # kappa = tf.concat([kappa, kappa2], axis=1) # Test interpolated kappa lens x = np.linspace(-1, 1, 128) * phys.kappa_fov / 2 x, y = np.meshgrid(x, x) x = tf.constant(x[np.newaxis, ..., np.newaxis], tf.float32) y = tf.constant(y[np.newaxis, ..., np.newaxis], tf.float32) dx = phys.kappa_fov / (32 - 1) xmin = -0.5 * phys.kappa_fov ymin = -0.5 * phys.kappa_fov i_coord = (x - xmin) / dx j_coord = (y - ymin) / dx wrap = tf.concat([i_coord, j_coord], axis=-1) # test_kappa1 = tfa.image.resampler(kappa, wrap) # bilinear interpolation of source on wrap grid # test_lens1 = phys.lens_source_func(test_kappa1, w=0.2) phys2 = PhysicalModel(pixels=128, kappa_pixels=32, method="fft", image_fov=7.68, kappa_fov=5) test_lens1 = phys2.lens_source_func(kappa, w=0.2) # Test interpolated alpha angles lens phys2 = PhysicalModel(pixels=32, src_pixels=32, image_fov=7.68, kappa_fov=5) alpha1, alpha2 = phys2.deflection_angle(kappa) alpha = tf.concat([alpha1, alpha2], axis=-1) alpha = tfa.image.resampler(alpha, wrap) test_lens2 = phys.lens_source_func_given_alpha(alpha, w=0.2) return true_lens, test_lens1, test_lens2
def main(args): if args.seed is not None: tf.random.set_seed(args.seed) np.random.seed(args.seed) if args.json_override is not None: if isinstance(args.json_override, list): files = args.json_override else: files = [ args.json_override, ] for file in files: with open(file, "r") as f: json_override = json.load(f) args_dict = vars(args) args_dict.update(json_override) if args.v2: # overwrite decoding procedure with version 2 from censai.data.alpha_tng_v2 import decode_train, decode_physical_info else: from censai.data.alpha_tng import decode_train, decode_physical_info # ========= Dataset================================================================================================= files = [] for dataset in args.datasets: files.extend(glob.glob(os.path.join(dataset, "*.tfrecords"))) np.random.shuffle(files) files = tf.data.Dataset.from_tensor_slices(files) dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE) # extract physical info from first example for image_fov, kappa_fov in dataset.map(decode_physical_info): break dataset = dataset.map(decode_train) if args.cache_file is not None: dataset = dataset.cache(args.cache_file) train_dataset = dataset.take(math.floor(args.train_split * args.total_items))\ .shuffle(buffer_size=args.buffer_size).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE) val_dataset = dataset.skip(math.floor(args.train_split * args.total_items))\ .take(math.ceil((1 - args.train_split) * args.total_items)).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE) train_dataset = STRATEGY.experimental_distribute_dataset(train_dataset) val_dataset = STRATEGY.experimental_distribute_dataset(val_dataset) # ========== Model and Physical Model ============================================================================= # setup an analytical physical model to compare lenses from raytracer and analytical deflection angles. phys = PhysicalModel(pixels=args.pixels, image_fov=image_fov, kappa_fov=kappa_fov, src_fov=args.source_fov, psf_sigma=args.psf_sigma) with STRATEGY.scope(): # Replicate ops accross gpus ray_tracer = RayTracer( pixels=args.pixels, filter_scaling=args.filter_scaling, layers=args.layers, block_conv_layers=args.block_conv_layers, kernel_size=args.kernel_size, filters=args.filters, strides=args.strides, bottleneck_filters=args.bottleneck_filters, resampling_kernel_size=args.resampling_kernel_size, upsampling_interpolation=args.upsampling_interpolation, kernel_regularizer_amp=args.kernel_regularizer_amp, activation=args.activation, initializer=args.initializer, kappalog=args.kappalog, normalize=args.normalize, ) lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( args.initial_learning_rate, decay_steps=args.decay_steps, decay_rate=args.decay_rate, staircase=True) optim = tf.keras.optimizers.deserialize({ "class_name": args.optimizer, 'config': { "learning_rate": lr_schedule } }) # ==== Take care of where to write logs and stuff ================================================================= if args.model_id.lower() != "none": logname = args.model_id elif args.logname is not None: logname = args.logname else: logname = args.logname_prefixe + "_" + datetime.now().strftime( "%y-%m-%d_%H-%M-%S") # setup tensorboard writer (nullwriter in case we do not want to sync) if args.logdir.lower() != "none": logdir = os.path.join(args.logdir, logname) if not os.path.isdir(logdir): os.mkdir(logdir) writer = tf.summary.create_file_writer(logdir) else: writer = nullwriter() # ===== Make sure directory and checkpoint manager are created to save model =================================== if args.model_dir.lower() != "none": checkpoints_dir = os.path.join(args.model_dir, logname) if not os.path.isdir(checkpoints_dir): os.mkdir(checkpoints_dir) # save script parameter for future reference with open(os.path.join(checkpoints_dir, "script_params.json"), "w") as f: json.dump(vars(args), f, indent=4) with open(os.path.join(checkpoints_dir, "ray_tracer_hparams.json"), "w") as f: hparams_dict = { key: vars(args)[key] for key in RAYTRACER_HPARAMS } json.dump(hparams_dict, f, indent=4) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optim, net=ray_tracer) checkpoint_manager = tf.train.CheckpointManager( ckpt, checkpoints_dir, max_to_keep=args.max_to_keep) save_checkpoint = True # ======= Load model if model_id is provided =============================================================== if args.model_id.lower() != "none": checkpoint_manager.checkpoint.restore( checkpoint_manager.latest_checkpoint) else: save_checkpoint = False # ================================================================================================================= def train_step(inputs): kappa, alpha = inputs with tf.GradientTape(watch_accessed_variables=True) as tape: tape.watch(ray_tracer.trainable_weights) cost = tf.reduce_mean(tf.square(ray_tracer(kappa) - alpha), axis=(1, 2, 3)) cost = tf.reduce_sum( cost) / args.batch_size # normalize by global batch size gradient = tape.gradient(cost, ray_tracer.trainable_weights) if args.clipping: clipped_gradient = [ tf.clip_by_value(grad, -10, 10) for grad in gradient ] else: clipped_gradient = gradient optim.apply_gradients( zip(clipped_gradient, ray_tracer.trainable_variables)) return cost @tf.function def distributed_train_step(dist_inputs): per_replica_losses = STRATEGY.run(train_step, args=(dist_inputs, )) cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) cost += tf.reduce_sum(ray_tracer.losses) return cost def test_step(inputs): kappa, alpha = inputs cost = tf.reduce_mean(tf.square(ray_tracer(kappa) - alpha), axis=(1, 2, 3)) cost = tf.reduce_sum( cost) / args.batch_size # normalize by global batch size return cost @tf.function def distributed_test_step(dist_inputs): per_replica_losses = STRATEGY.run(test_step, args=(dist_inputs, )) # Replica losses are aggregated by summing them cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) cost += tf.reduce_sum(ray_tracer.losses) return cost # ================================================================================================================= epoch_loss = tf.metrics.Mean() val_loss = tf.metrics.Mean() time_per_step = tf.metrics.Mean() history = { # recorded at the end of an epoch only "train_cost": [], "train_lens_residuals": [], "val_cost": [], "val_lens_residuals": [], "learning_rate": [], "step": [], "wall_time": [] } best_loss = np.inf global_start = time.time() estimated_time_for_epoch = 0 out_of_time = False patience = args.patience step = 0 lastest_checkpoint = 1 for epoch in range(1, args.epochs + 1): if (time.time() - global_start ) > args.max_time * 3600 - estimated_time_for_epoch: break epoch_start = time.time() epoch_loss.reset_states() time_per_step.reset_states() with writer.as_default(): for batch, distributed_inputs in enumerate(train_dataset): start = time.time() cost = distributed_train_step(distributed_inputs) # ========== Summary and logs ================================================================================== _time = time.time() - start time_per_step.update_state([_time]) epoch_loss.update_state([cost]) step += 1 # Deflection angle residual kappa_true, alpha_true = distributed_inputs alpha_pred = ray_tracer.call(kappa_true) # Lens residual lens_true = phys.lens_source_func(kappa_true, w=args.source_w) lens_pred = phys.lens_source_func_given_alpha(alpha_pred, w=args.source_w) train_chi_squared = tf.reduce_mean(tf.square(lens_true - lens_pred)) for res_idx in range( min(args.n_residuals, args.batch_size)): # Residuals in train set tf.summary.image(f"Residuals {res_idx}", plot_to_image( residual_plot(alpha_true[res_idx], alpha_pred[res_idx], lens_true[res_idx], lens_pred[res_idx])), step=step) # ========== Validation set =================== val_loss.reset_states() for distributed_inputs in val_dataset: # Cost of val set test_cost = distributed_test_step(distributed_inputs) val_loss.update_state([test_cost]) kappa_true, alpha_true = distributed_inputs alpha_pred = ray_tracer.call(kappa_true) lens_true = phys.lens_source_func(kappa_true, args.source_w) lens_pred = phys.lens_source_func_given_alpha(alpha_pred, w=args.source_w) val_chi_squared = tf.reduce_mean(tf.square(lens_true - lens_pred)) for res_idx in range(min(args.n_residuals, args.batch_size)): # Residuals in val set tf.summary.image(f"Val Residuals {res_idx}", plot_to_image( residual_plot(alpha_true[res_idx], alpha_pred[res_idx], lens_true[res_idx], lens_pred[res_idx])), step=step) train_cost = epoch_loss.result().numpy() val_cost = val_loss.result().numpy() tf.summary.scalar("Time per step", time_per_step.result().numpy(), step=step) tf.summary.scalar("Train lens residual", train_chi_squared, step=step) tf.summary.scalar("Val lens residual", val_chi_squared, step=step) tf.summary.scalar("MSE", train_cost, step=step) tf.summary.scalar("Val MSE", val_cost, step=step) tf.summary.scalar("Learning rate", optim.lr(step), step=step) print( f"epoch {epoch} | train loss {train_cost:.3e} | val loss {val_cost:.3e} | learning rate {optim.lr(step).numpy():.2e} | " f"time per step {time_per_step.result():.2e} s") history["train_cost"].append(train_cost) history["train_lens_residuals"].append(train_chi_squared.numpy()) history["val_cost"].append(val_cost) history["val_lens_residuals"].append(val_chi_squared.numpy()) history["learning_rate"].append(optim.lr(step).numpy()) history["step"].append(step) history["wall_time"].append(time.time() - global_start) cost = train_cost if args.track_train else val_cost if np.isnan(cost): print("Training broke the Universe") break if cost < best_loss * (1 - args.tolerance): best_loss = cost patience = args.patience else: patience -= 1 if (time.time() - global_start) > args.max_time * 3600: out_of_time = True if save_checkpoint: checkpoint_manager.checkpoint.step.assign_add(1) # a bit of a hack if epoch % args.checkpoints == 0 or patience == 0 or epoch == args.epochs - 1 or out_of_time: with open(os.path.join(checkpoints_dir, "score_sheet.txt"), mode="a") as f: np.savetxt(f, np.array([[lastest_checkpoint, cost]])) lastest_checkpoint += 1 checkpoint_manager.save() print("Saved checkpoint for step {}: {}".format( int(checkpoint_manager.checkpoint.step), checkpoint_manager.latest_checkpoint)) if patience == 0: print("Reached patience") break if out_of_time: break if epoch > 0: # First epoch is always very slow and not a good estimate of an epoch time. estimated_time_for_epoch = time.time() - epoch_start if optim.lr(step).numpy() < 1e-9: print("Reached learning rate limit") break return history, best_loss