예제 #1
0
def test_lagrange_multiplier_for_lens_intensity():
    phys = PhysicalModel(pixels=128)
    phys_a = AnalyticalPhysicalModel(pixels=128)
    kappa = phys_a.kappa_field(2.0, e=0.2)
    x = np.linspace(-1, 1, 128) * phys.src_fov / 2
    xx, yy = np.meshgrid(x, x)
    rho = xx**2 + yy**2
    source = tf.math.exp(-0.5 * rho / 0.5**2)[tf.newaxis, ..., tf.newaxis]
    source = tf.cast(source, tf.float32)

    y_true = phys.forward(source, kappa)
    y_pred = phys.forward(0.001 * source,
                          kappa)  # rescale it, say it has different units
    lam_lagrange = tf.reduce_sum(y_true * y_pred, axis=(
        1, 2, 3)) / tf.reduce_sum(y_pred**2, axis=(1, 2, 3))
    lam_tests = tf.squeeze(
        tf.cast(tf.linspace(lam_lagrange / 10, lam_lagrange * 10, 1000),
                tf.float32))[..., tf.newaxis, tf.newaxis, tf.newaxis]
    log_likelihood_best = 0.5 * tf.reduce_mean(
        (lam_lagrange * y_pred - y_true)**2 / phys.noise_rms**2,
        axis=(1, 2, 3))
    log_likilhood_test = 0.5 * tf.reduce_mean(
        (lam_tests * y_pred - y_true)**2 / phys.noise_rms**2, axis=(1, 2, 3))
    return log_likilhood_test, log_likelihood_best, tf.squeeze(
        lam_tests), lam_lagrange
예제 #2
0
def test_log_likelihood():
    phys = PhysicalModel(pixels=64, src_pixels=64)
    kappa = tf.random.normal([1, 64, 64, 1])
    source = tf.random.normal([1, 64, 64, 1])
    im_lensed = phys.forward(source, kappa)
    assert im_lensed.shape == [1, 64, 64, 1]
    cost = phys.log_likelihood(source, kappa, im_lensed)
예제 #3
0
def main(args):
    if args.seed is not None:
        tf.random.set_seed(args.seed)
        np.random.seed(args.seed)
    if args.json_override is not None:
        if isinstance(args.json_override, list):
            files = args.json_override
        else:
            files = [args.json_override,]
        for file in files:
            with open(file, "r") as f:
                json_override = json.load(f)
            args_dict = vars(args)
            args_dict.update(json_override)

    files = []
    for dataset in args.datasets:
        files.extend(glob.glob(os.path.join(dataset, "*.tfrecords")))
    np.random.shuffle(files)

    if args.val_datasets is not None:
        """    
        In this conditional, we assume total items might be a subset of the dataset size.
        Thus we want to reshuffle at each epoch to get a different realisation of the dataset. 
        In case total_items == true dataset size, this means we only change ordering of items each epochs.
        
        Also, validation is not a split of the training data, but a saved dataset on disk. 
        """
        files = tf.data.Dataset.from_tensor_slices(files).shuffle(len(files), reshuffle_each_iteration=True)
        dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type),
                                   block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE)
        # Read off global parameters from first example in dataset
        for physical_params in dataset.map(decode_physical_model_info):
            break
        # preprocessing
        dataset = dataset.map(decode_train)
        if args.cache_file is not None:
            dataset = dataset.cache(args.cache_file)
        train_dataset = dataset.shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True).take(args.total_items).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE)

        val_files = []
        for dataset in args.val_datasets:
            val_files.extend(glob.glob(os.path.join(dataset, "*.tfrecords")))
        val_files = tf.data.Dataset.from_tensor_slices(val_files).shuffle(len(files), reshuffle_each_iteration=True)
        val_dataset = val_files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE)
        val_dataset = val_dataset.map(decode_train).shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True).\
            take(math.ceil((1 - args.train_split) * args.total_items)).\
            batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    else:
        """
        Here, we split the dataset, so we assume total_items is the true dataset size. Any extra items will be discarded. 
        This is to make sure validation set is never seen by the model, so shuffling occurs after the split.
        """
        files = tf.data.Dataset.from_tensor_slices(files)
        dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type),
                                   block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE)
        # Read off global parameters from first example in dataset
        for physical_params in dataset.map(decode_physical_model_info):
            break
        # preprocessing
        dataset = dataset.map(decode_train)
        if args.cache_file is not None:
            dataset = dataset.cache(args.cache_file)
        train_dataset = dataset.take(math.floor(args.train_split * args.total_items)).shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE)
        val_dataset = dataset.skip(math.floor(args.train_split * args.total_items)).take(math.ceil((1 - args.train_split) * args.total_items)).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE)

    train_dataset = STRATEGY.experimental_distribute_dataset(train_dataset)
    val_dataset = STRATEGY.experimental_distribute_dataset(val_dataset)
    with STRATEGY.scope():  # Replicate ops accross gpus
        if args.raytracer is not None:
            with open(os.path.join(args.raytracer, "ray_tracer_hparams.json"), "r") as f:
                raytracer_hparams = json.load(f)
            raytracer = RayTracer(**raytracer_hparams)
            # load last checkpoint in the checkpoint directory
            checkpoint = tf.train.Checkpoint(net=raytracer)
            manager = tf.train.CheckpointManager(checkpoint, directory=args.raytracer, max_to_keep=3)
            checkpoint.restore(manager.latest_checkpoint).expect_partial()
        else:
            raytracer = None
        phys = PhysicalModel(
            pixels=physical_params["pixels"].numpy(),
            kappa_pixels=physical_params["kappa pixels"].numpy(),
            src_pixels=physical_params["src pixels"].numpy(),
            image_fov=physical_params["image fov"].numpy(),
            kappa_fov=physical_params["kappa fov"].numpy(),
            src_fov=physical_params["source fov"].numpy(),
            method=args.forward_method,
            raytracer=raytracer,
        )

        unet = Model(
            filters=args.filters,
            filter_scaling=args.filter_scaling,
            kernel_size=args.kernel_size,
            layers=args.layers,
            block_conv_layers=args.block_conv_layers,
            strides=args.strides,
            bottleneck_kernel_size=args.bottleneck_kernel_size,
            resampling_kernel_size=args.resampling_kernel_size,
            input_kernel_size=args.input_kernel_size,
            gru_kernel_size=args.gru_kernel_size,
            upsampling_interpolation=args.upsampling_interpolation,
            kernel_l2_amp=args.kernel_l2_amp,
            bias_l2_amp=args.bias_l2_amp,
            kernel_l1_amp=args.kernel_l1_amp,
            bias_l1_amp=args.bias_l1_amp,
            activation=args.activation,
            initializer=args.initializer,
            batch_norm=args.batch_norm,
            dropout_rate=args.dropout_rate,
            filter_cap=args.filter_cap
        )
        rim = RIM(
            physical_model=phys,
            unet=unet,
            steps=args.steps,
            adam=args.adam,
            kappalog=args.kappalog,
            source_link=args.source_link,
            kappa_normalize=args.kappa_normalize,
            flux_lagrange_multiplier=args.flux_lagrange_multiplier
        )
        learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=args.initial_learning_rate,
            decay_rate=args.decay_rate,
            decay_steps=args.decay_steps,
            staircase=args.staircase
        )
        optim = tf.keras.optimizers.deserialize(
            {
                "class_name": args.optimizer,
                'config': {"learning_rate": learning_rate_schedule}
            }
        )
        # weights for time steps in the loss function
        if args.time_weights == "uniform":
            wt = tf.ones(shape=(args.steps), dtype=DTYPE) / args.steps
        elif args.time_weights == "linear":
            wt = 2 * (tf.range(args.steps, dtype=DTYPE) + 1) / args.steps / (args.steps + 1)
        elif args.time_weights == "quadratic":
            wt = 6 * (tf.range(args.steps, dtype=DTYPE) + 1)**2 / args.steps / (args.steps + 1) / (2 * args.steps + 1)
        else:
            raise ValueError("time_weights must be in ['uniform', 'linear', 'quadratic']")
        wt = wt[..., tf.newaxis]  # [steps, batch]

        if args.kappa_residual_weights == "uniform":
            wk = tf.keras.layers.Lambda(lambda k: tf.ones_like(k, dtype=DTYPE) / tf.cast(tf.math.reduce_prod(k.shape[1:]), DTYPE))
        elif args.kappa_residual_weights == "linear":
            wk = tf.keras.layers.Lambda(lambda k: k / tf.reduce_sum(k, axis=(1, 2, 3), keepdims=True))
        elif args.kappa_residual_weights == "sqrt":
            wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum(tf.sqrt(k), axis=(1, 2, 3), keepdims=True))
        elif args.kappa_residual_weights == "quadratic":
            wk = tf.keras.layers.Lambda(lambda k: tf.square(k) / tf.reduce_sum(tf.square(k), axis=(1, 2, 3), keepdims=True))
        else:
            raise ValueError("kappa_residual_weights must be in ['uniform', 'linear', 'quadratic', 'sqrt']")

        if args.source_residual_weights == "uniform":
            ws = tf.keras.layers.Lambda(lambda s: tf.ones_like(s, dtype=DTYPE) / tf.cast(tf.math.reduce_prod(s.shape[1:]), DTYPE))
        elif args.source_residual_weights == "linear":
            ws = tf.keras.layers.Lambda(lambda s: s / tf.reduce_sum(s, axis=(1, 2, 3), keepdims=True))
        elif args.source_residual_weights == "quadratic":
            ws = tf.keras.layers.Lambda(lambda s: tf.square(s) / tf.reduce_sum(tf.square(s), axis=(1, 2, 3), keepdims=True))
        elif args.source_residual_weights == "sqrt":
            ws = tf.keras.layers.Lambda(lambda s: tf.sqrt(s) / tf.reduce_sum(tf.sqrt(s), axis=(1, 2, 3), keepdims=True))
        else:
            raise ValueError("kappa_residual_weights must be in ['uniform', 'linear', 'quadratic', 'sqrt']")

    # ==== Take care of where to write logs and stuff =================================================================
    if args.model_id.lower() != "none":
        if args.logname is not None:
            logname = args.model_id + "_" + args.logname
            model_id = args.model_id
        else:
            logname = args.model_id + "_" + datetime.now().strftime("%y%m%d%H%M%S")
            model_id = args.model_id
    elif args.logname is not None:
        logname = args.logname
        model_id = logname
    else:
        logname = args.logname_prefixe + "_" + datetime.now().strftime("%y%m%d%H%M%S")
        model_id = logname
    if args.logdir.lower() != "none":
        logdir = os.path.join(args.logdir, logname)
        if not os.path.isdir(logdir):
            os.mkdir(logdir)
        writer = tf.summary.create_file_writer(logdir)
    else:
        writer = nullwriter()
    # ===== Make sure directory and checkpoint manager are created to save model ===================================
    if args.model_dir.lower() != "none":
        checkpoints_dir = os.path.join(args.model_dir, logname)
        old_checkpoints_dir = os.path.join(args.model_dir, model_id)  # in case they differ we load model from a different directory
        if not os.path.isdir(checkpoints_dir):
            os.mkdir(checkpoints_dir)
            with open(os.path.join(checkpoints_dir, "script_params.json"), "w") as f:
                json.dump(vars(args), f, indent=4)
            with open(os.path.join(checkpoints_dir, "unet_hparams.json"), "w") as f:
                hparams_dict = {key: vars(args)[key] for key in UNET_MODEL_HPARAMS}
                json.dump(hparams_dict, f, indent=4)
            with open(os.path.join(checkpoints_dir, "rim_hparams.json"), "w") as f:
                hparams_dict = {key: vars(args)[key] for key in RIM_HPARAMS}
                json.dump(hparams_dict, f, indent=4)
        ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optim, net=rim.unet)
        checkpoint_manager = tf.train.CheckpointManager(ckpt, old_checkpoints_dir, max_to_keep=args.max_to_keep)
        save_checkpoint = True
        # ======= Load model if model_id is provided ===============================================================
        if args.model_id.lower() != "none":
            checkpoint_manager.checkpoint.restore(checkpoint_manager.latest_checkpoint)
        if old_checkpoints_dir != checkpoints_dir:  # save progress in another directory.
            if args.reset_optimizer_states:
                optim = tf.keras.optimizers.deserialize(
                    {
                        "class_name": args.optimizer,
                        'config': {"learning_rate": learning_rate_schedule}
                    }
                )
                ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optim, net=rim.unet)
            checkpoint_manager = tf.train.CheckpointManager(ckpt, checkpoints_dir, max_to_keep=args.max_to_keep)

    else:
        save_checkpoint = False
    # =================================================================================================================

    def train_step(X, source, kappa, noise_rms, psf):
        with tf.GradientTape() as tape:
            tape.watch(rim.unet.trainable_variables)
            if args.unroll_time_steps:
                source_series, kappa_series, chi_squared = rim.call_function(X, noise_rms, psf)
            else:
                source_series, kappa_series, chi_squared = rim.call(X, noise_rms, psf, outer_tape=tape)
            # mean over image residuals (in model prediction space)
            source_cost1 = tf.reduce_sum(ws(source) * tf.square(source_series - rim.source_inverse_link(source)), axis=(2, 3, 4))
            kappa_cost1 = tf.reduce_sum(wk(kappa) * tf.square(kappa_series - rim.kappa_inverse_link(kappa)), axis=(2, 3, 4))
            # weighted mean over time steps
            source_cost = tf.reduce_sum(wt * source_cost1, axis=0)
            kappa_cost = tf.reduce_sum(wt * kappa_cost1, axis=0)
            # final cost is mean over global batch size
            cost = tf.reduce_sum(kappa_cost + source_cost) / args.batch_size
        gradient = tape.gradient(cost, rim.unet.trainable_variables)
        gradient = [tf.clip_by_norm(grad, 5.) for grad in gradient]
        optim.apply_gradients(zip(gradient, rim.unet.trainable_variables))
        # Update metrics with "converged" score
        chi_squared = tf.reduce_sum(chi_squared[-1]) / args.batch_size
        source_cost = tf.reduce_sum(source_cost1[-1]) / args.batch_size
        kappa_cost = tf.reduce_sum(kappa_cost1[-1]) / args.batch_size
        return cost, chi_squared, source_cost, kappa_cost

    @tf.function
    def distributed_train_step(X, source, kappa, noise_rms, psf):
        per_replica_losses, per_replica_chi_squared, per_replica_source_cost, per_replica_kappa_cost = STRATEGY.run(train_step, args=(X, source, kappa, noise_rms, psf))
        # Replica losses are aggregated by summing them
        global_loss = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        global_chi_squared = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_chi_squared, axis=None)
        global_source_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_source_cost, axis=None)
        global_kappa_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_kappa_cost, axis=None)
        return global_loss, global_chi_squared, global_source_cost, global_kappa_cost

    def test_step(X, source, kappa, noise_rms, psf):
        source_series, kappa_series, chi_squared = rim.call(X, noise_rms, psf)
        # mean over image residuals (in model prediction space)
        source_cost1 = tf.reduce_sum(ws(source) * tf.square(source_series - rim.source_inverse_link(source)), axis=(2, 3, 4))
        kappa_cost1 = tf.reduce_sum(wk(kappa) * tf.square(kappa_series - rim.kappa_inverse_link(kappa)), axis=(2, 3, 4))
        # weighted mean over time steps
        source_cost = tf.reduce_sum(wt * source_cost1, axis=0)
        kappa_cost = tf.reduce_sum(wt * kappa_cost1, axis=0)
        # final cost is mean over global batch size
        cost = tf.reduce_sum(kappa_cost + source_cost) / args.batch_size
        # Update metrics with "converged" score
        chi_squared = tf.reduce_sum(chi_squared[-1]) / args.batch_size
        source_cost = tf.reduce_sum(source_cost1[-1]) / args.batch_size
        kappa_cost = tf.reduce_sum(kappa_cost1[-1]) / args.batch_size
        return cost, chi_squared, source_cost, kappa_cost

    @tf.function
    def distributed_test_step(X, source, kappa, noise_rms, psf):
        per_replica_losses, per_replica_chi_squared, per_replica_source_cost, per_replica_kappa_cost = STRATEGY.run(test_step, args=(X, source, kappa, noise_rms, psf))
        # Replica losses are aggregated by summing them
        global_loss = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        global_chi_squared = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_chi_squared, axis=None)
        global_source_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_source_cost, axis=None)
        global_kappa_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_kappa_cost, axis=None)
        return global_loss, global_chi_squared, global_source_cost, global_kappa_cost

    # ====== Training loop ============================================================================================
    epoch_loss = tf.metrics.Mean()
    time_per_step = tf.metrics.Mean()
    val_loss = tf.metrics.Mean()
    epoch_chi_squared = tf.metrics.Mean()
    epoch_source_loss = tf.metrics.Mean()
    epoch_kappa_loss = tf.metrics.Mean()
    val_chi_squared = tf.metrics.Mean()
    val_source_loss = tf.metrics.Mean()
    val_kappa_loss = tf.metrics.Mean()
    history = {  # recorded at the end of an epoch only
        "train_cost": [],
        "train_chi_squared": [],
        "train_source_cost": [],
        "train_kappa_cost": [],
        "val_cost": [],
        "val_chi_squared": [],
        "val_source_cost": [],
        "val_kappa_cost": [],
        "learning_rate": [],
        "time_per_step": [],
        "step": [],
        "wall_time": []
    }
    best_loss = np.inf
    patience = args.patience
    step = 0
    global_start = time.time()
    estimated_time_for_epoch = 0
    out_of_time = False
    lastest_checkpoint = 1
    for epoch in range(args.epochs):
        if (time.time() - global_start) > args.max_time*3600 - estimated_time_for_epoch:
            break
        epoch_start = time.time()
        epoch_loss.reset_states()
        epoch_chi_squared.reset_states()
        epoch_source_loss.reset_states()
        epoch_kappa_loss.reset_states()
        time_per_step.reset_states()
        with writer.as_default():
            for batch, (X, source, kappa, noise_rms, psf) in enumerate(train_dataset):
                start = time.time()
                cost, chi_squared, source_cost, kappa_cost = distributed_train_step(X, source, kappa, noise_rms, psf)
        # ========== Summary and logs ==================================================================================
                _time = time.time() - start
                time_per_step.update_state([_time])
                epoch_loss.update_state([cost])
                epoch_chi_squared.update_state([chi_squared])
                epoch_source_loss.update_state([source_cost])
                epoch_kappa_loss.update_state([kappa_cost])
                step += 1
            # last batch we make a summary of residuals
            if args.n_residuals > 0:
                source_pred, kappa_pred, chi_squared = rim.predict(X, noise_rms, psf)
                lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf)
            for res_idx in range(min(args.n_residuals, args.batch_size)):
                try:
                    tf.summary.image(f"Residuals {res_idx}",
                                     plot_to_image(
                                         residual_plot(
                                             X[res_idx],
                                             source[res_idx],
                                             kappa[res_idx],
                                             lens_pred[res_idx],
                                             source_pred[-1][res_idx],
                                             kappa_pred[-1][res_idx],
                                             chi_squared[-1][res_idx]
                                         )), step=step)
                except ValueError:
                    continue

            # ========== Validation set ===================
            val_loss.reset_states()
            val_chi_squared.reset_states()
            val_source_loss.reset_states()
            val_kappa_loss.reset_states()
            for X, source, kappa, noise_rms, psf in val_dataset:
                cost, chi_squared, source_cost, kappa_cost = distributed_test_step(X, source, kappa, noise_rms, psf)
                val_loss.update_state([cost])
                val_chi_squared.update_state([chi_squared])
                val_source_loss.update_state([source_cost])
                val_kappa_loss.update_state([kappa_cost])

            if args.n_residuals > 0 and math.ceil((1 - args.train_split) * args.total_items) > 0:  # validation set not empty set not empty
                source_pred, kappa_pred, chi_squared = rim.predict(X, noise_rms, psf)
                lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf)
            for res_idx in range(min(args.n_residuals, args.batch_size, math.ceil((1 - args.train_split) * args.total_items))):
                try:
                    tf.summary.image(f"Val Residuals {res_idx}",
                                     plot_to_image(
                                         residual_plot(
                                             X[res_idx],  # rescale intensity like it is done in the likelihood
                                             source[res_idx],
                                             kappa[res_idx],
                                             lens_pred[res_idx],
                                             source_pred[-1][res_idx],
                                             kappa_pred[-1][res_idx],
                                             chi_squared[-1][res_idx]
                                         )), step=step)
                except ValueError:
                    continue
            val_cost = val_loss.result().numpy()
            train_cost = epoch_loss.result().numpy()
            val_chi_sq = val_chi_squared.result().numpy()
            train_chi_sq = epoch_chi_squared.result().numpy()
            val_kappa_cost = val_kappa_loss.result().numpy()
            train_kappa_cost = epoch_kappa_loss.result().numpy()
            val_source_cost = val_source_loss.result().numpy()
            train_source_cost = epoch_source_loss.result().numpy()
            tf.summary.scalar("Time per step", time_per_step.result(), step=step)
            tf.summary.scalar("Chi Squared", train_chi_sq, step=step)
            tf.summary.scalar("Kappa cost", train_kappa_cost, step=step)
            tf.summary.scalar("Val Kappa cost", val_kappa_cost, step=step)
            tf.summary.scalar("Source cost", train_source_cost, step=step)
            tf.summary.scalar("Val Source cost", val_source_cost, step=step)
            tf.summary.scalar("MSE", train_cost, step=step)
            tf.summary.scalar("Val MSE", val_cost, step=step)
            tf.summary.scalar("Learning Rate", optim.lr(step), step=step)
            tf.summary.scalar("Val Chi Squared", val_chi_sq, step=step)
        print(f"epoch {epoch} | train loss {train_cost:.3e} | val loss {val_cost:.3e} "
              f"| lr {optim.lr(step).numpy():.2e} | time per step {time_per_step.result().numpy():.2e} s"
              f"| kappa cost {train_kappa_cost:.2e} | source cost {train_source_cost:.2e} | chi sq {train_chi_sq:.2e}")
        history["train_cost"].append(train_cost)
        history["val_cost"].append(val_cost)
        history["learning_rate"].append(optim.lr(step).numpy())
        history["train_chi_squared"].append(train_chi_sq)
        history["val_chi_squared"].append(val_chi_sq)
        history["time_per_step"].append(time_per_step.result().numpy())
        history["train_kappa_cost"].append(train_kappa_cost)
        history["train_source_cost"].append(train_source_cost)
        history["val_kappa_cost"].append(val_kappa_cost)
        history["val_source_cost"].append(val_source_cost)
        history["step"].append(step)
        history["wall_time"].append(time.time() - global_start)

        cost = train_cost if args.track_train else val_cost
        if np.isnan(cost):
            print("Training broke the Universe")
            break
        if cost < (1 - args.tolerance) * best_loss:
            best_loss = cost
            patience = args.patience
        else:
            patience -= 1
        if (time.time() - global_start) > args.max_time * 3600:
            out_of_time = True
        if save_checkpoint:
            checkpoint_manager.checkpoint.step.assign_add(1) # a bit of a hack
            if epoch % args.checkpoints == 0 or patience == 0 or epoch == args.epochs - 1 or out_of_time:
                with open(os.path.join(checkpoints_dir, "score_sheet.txt"), mode="a") as f:
                    np.savetxt(f, np.array([[lastest_checkpoint, cost]]))
                lastest_checkpoint += 1
                checkpoint_manager.save()
                print("Saved checkpoint for step {}: {}".format(int(checkpoint_manager.checkpoint.step), checkpoint_manager.latest_checkpoint))
        if patience == 0:
            print("Reached patience")
            break
        if out_of_time:
            break
        if epoch > 0:  # First epoch is always very slow and not a good estimate of an epoch time.
            estimated_time_for_epoch = time.time() - epoch_start
        if optim.lr(step).numpy() < 1e-8:
            print("Reached learning rate limit")
            break
    print(f"Finished training after {(time.time() - global_start)/3600:.3f} hours.")
    return history, best_loss
예제 #4
0
def main(args):
    files = glob.glob(os.path.join(args.dataset, "*.tfrecords"))
    files = tf.data.Dataset.from_tensor_slices(files)
    dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
        x, compression_type=args.compression_type),
                               block_length=1,
                               num_parallel_calls=tf.data.AUTOTUNE)
    for physical_params in dataset.map(decode_physical_model_info):
        break
    dataset = dataset.map(decode_train)

    # files = glob.glob(os.path.join(args.source_dataset, "*.tfrecords"))
    # files = tf.data.Dataset.from_tensor_slices(files)
    # source_dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type),
    #                            block_length=1, num_parallel_calls=tf.data.AUTOTUNE)
    # source_dataset = source_dataset.map(decode_image).map(preprocess_image).shuffle(10000).batch(args.sample_size)

    with open(os.path.join(args.kappa_vae, "model_hparams.json"), "r") as f:
        kappa_vae_hparams = json.load(f)
    kappa_vae = VAE(**kappa_vae_hparams)
    ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae)
    checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, args.kappa_vae, 1)
    checkpoint_manager1.checkpoint.restore(
        checkpoint_manager1.latest_checkpoint).expect_partial()

    with open(os.path.join(args.source_vae, "model_hparams.json"), "r") as f:
        source_vae_hparams = json.load(f)
    source_vae = VAE(**source_vae_hparams)
    ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae)
    checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, args.source_vae, 1)
    checkpoint_manager2.checkpoint.restore(
        checkpoint_manager2.latest_checkpoint).expect_partial()

    phys = PhysicalModel(pixels=physical_params["pixels"].numpy(),
                         kappa_pixels=physical_params["kappa pixels"].numpy(),
                         src_pixels=physical_params["src pixels"].numpy(),
                         image_fov=physical_params["image fov"].numpy(),
                         kappa_fov=physical_params["kappa fov"].numpy(),
                         src_fov=physical_params["source fov"].numpy(),
                         method="fft")

    # simulate observations
    kappa = 10**kappa_vae.sample(args.sample_size)
    source = preprocess_image(source_vae.sample(args.sample_size))
    # for source in source_dataset:
    #     break
    fwhm = tf.random.normal(shape=[args.sample_size],
                            mean=1.5 * phys.image_fov / phys.pixels,
                            stddev=0.5 * phys.image_fov / phys.pixels)
    # noise_rms = tf.random.normal(shape=[args.sample_size], mean=args.noise_mean, stddev=args.noise_std)
    psf = phys.psf_models(fwhm, cutout_size=20)
    y_vae = phys.forward(source, kappa, psf)

    with h5py.File(
            os.path.join(os.getenv("CENSAI_PATH"), "results",
                         args.output_name + ".h5"), 'w') as hf:
        # rank these observations against the dataset with L2 norm
        for i in tqdm(range(args.sample_size)):
            distances = []
            for y_d, _, _, _, _ in dataset:
                distances.append(
                    tf.sqrt(tf.reduce_sum(
                        (y_d - y_vae[i][None, ...])**2)).numpy().astype(
                            np.float32))
            k_indices = np.argsort(distances)[:args.k]

            # save results
            g = hf.create_group(f"sample_{i:02d}")
            g.create_dataset(name="matched_source",
                             shape=[args.k, phys.src_pixels, phys.src_pixels],
                             dtype=np.float32)
            g.create_dataset(
                name="matched_kappa",
                shape=[args.k, phys.kappa_pixels, phys.kappa_pixels],
                dtype=np.float32)
            g.create_dataset(name="matched_obs",
                             shape=[args.k, phys.pixels, phys.pixels],
                             dtype=np.float32)
            g.create_dataset(name="matched_psf",
                             shape=[args.k, 20, 20],
                             dtype=np.float32)
            g.create_dataset(name="matched_noise_rms",
                             shape=[args.k],
                             dtype=np.float32)
            g.create_dataset(name="obs_L2_distance",
                             shape=[args.k],
                             dtype=np.float32)
            g["vae_source"] = source[i, ..., 0].numpy().astype(np.float32)
            g["vae_kappa"] = kappa[i, ..., 0].numpy().astype(np.float32)
            g["vae_obs"] = y_vae[i, ..., 0].numpy().astype(np.float32)
            g["vae_psf"] = psf[i, ..., 0].numpy().astype(np.float32)

            for rank, j in enumerate(k_indices):
                # fetch back the matched observation
                for y_d, source_d, kappa_d, noise_rms_d, psf_d in dataset.skip(
                        j):
                    break
                # g["vae_noise_rms"] = noise_rms[i].numpy().astype(np.float32)
                g["matched_source"][rank] = source_d[..., 0].numpy().astype(
                    np.float32)
                g["matched_kappa"][rank] = kappa_d[..., 0].numpy().astype(
                    np.float32)
                g["matched_obs"][rank] = y_d[..., 0].numpy().astype(np.float32)
                g["matched_noise_rms"][rank] = noise_rms_d.numpy().astype(
                    np.float32)
                g["matched_psf"][rank] = psf_d[...,
                                               0].numpy().astype(np.float32)
                g["obs_L2_distance"][rank] = distances[j]
예제 #5
0
def distributed_strategy(args):
    psf_pixels = 20
    pixels = 128
    model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model)

    ps_observation = PowerSpectrum(bins=args.observation_coherence_bins,
                                   pixels=pixels)
    ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=pixels)
    ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=pixels)

    phys = PhysicalModel(
        pixels=pixels,
        kappa_pixels=pixels,
        src_pixels=pixels,
        image_fov=7.68,
        kappa_fov=7.68,
        src_fov=3.,
        method="fft",
    )

    with open(os.path.join(model, "unet_hparams.json")) as f:
        unet_params = json.load(f)
    unet_params["kernel_l2_amp"] = args.l2_amp
    unet = Model(**unet_params)
    ckpt = tf.train.Checkpoint(net=unet)
    checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1)
    checkpoint_manager.checkpoint.restore(
        checkpoint_manager.latest_checkpoint).expect_partial()
    with open(os.path.join(model, "rim_hparams.json")) as f:
        rim_params = json.load(f)
    rim_params["source_link"] = "relu"
    rim = RIM(phys, unet, **rim_params)

    kvae_path = os.path.join(os.getenv('CENSAI_PATH'), "models",
                             args.kappa_vae)
    with open(os.path.join(kvae_path, "model_hparams.json"), "r") as f:
        kappa_vae_hparams = json.load(f)
    kappa_vae = VAE(**kappa_vae_hparams)
    ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae)
    checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, kvae_path, 1)
    checkpoint_manager1.checkpoint.restore(
        checkpoint_manager1.latest_checkpoint).expect_partial()

    svae_path = os.path.join(os.getenv('CENSAI_PATH'), "models",
                             args.source_vae)
    with open(os.path.join(svae_path, "model_hparams.json"), "r") as f:
        source_vae_hparams = json.load(f)
    source_vae = VAE(**source_vae_hparams)
    ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae)
    checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, svae_path, 1)
    checkpoint_manager2.checkpoint.restore(
        checkpoint_manager2.latest_checkpoint).expect_partial()

    model_name = os.path.split(model)[-1]
    wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum(
        tf.sqrt(k), axis=(1, 2, 3), keepdims=True))
    with h5py.File(
            os.path.join(
                os.getenv("CENSAI_PATH"), "results", args.experiment_name +
                "_" + model_name + f"_{THIS_WORKER:02d}.h5"), 'w') as hf:
        data_len = args.size // N_WORKERS
        hf.create_dataset(name="observation",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(name="psf",
                          shape=[data_len, psf_pixels, psf_pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32)
        hf.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32)
        hf.create_dataset(
            name="source",
            shape=[data_len, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(
            name="kappa",
            shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(name="observation_pred",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(name="observation_pred_reoptimized",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(
            name="source_pred",
            shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(
            name="source_pred_reoptimized",
            shape=[data_len, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(name="kappa_pred",
                          shape=[
                              data_len, rim.steps, phys.kappa_pixels,
                              phys.kappa_pixels, 1
                          ],
                          dtype=np.float32)
        hf.create_dataset(
            name="kappa_pred_reoptimized",
            shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(name="chi_squared",
                          shape=[data_len, rim.steps],
                          dtype=np.float32)
        hf.create_dataset(name="chi_squared_reoptimized",
                          shape=[data_len, rim.steps],
                          dtype=np.float32)
        hf.create_dataset(name="chi_squared_reoptimized_series",
                          shape=[data_len, rim.steps, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="sampled_chi_squared_reoptimized_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="source_optim_mse",
                          shape=[data_len],
                          dtype=np.float32)
        hf.create_dataset(name="source_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="sampled_source_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_optim_mse",
                          shape=[data_len],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="sampled_kappa_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="latent_kappa_gt_distance_init",
                          shape=[data_len, kappa_vae.latent_size],
                          dtype=np.float32)
        hf.create_dataset(name="latent_source_gt_distance_init",
                          shape=[data_len, source_vae.latent_size],
                          dtype=np.float32)
        hf.create_dataset(name="latent_kappa_gt_distance_end",
                          shape=[data_len, kappa_vae.latent_size],
                          dtype=np.float32)
        hf.create_dataset(name="latent_source_gt_distance_end",
                          shape=[data_len, source_vae.latent_size],
                          dtype=np.float32)
        hf.create_dataset(name="source_coherence_spectrum",
                          shape=[data_len, args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="source_coherence_spectrum_reoptimized",
                          shape=[data_len, args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_coherence_spectrum",
                          shape=[data_len, args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_coherence_spectrum_reoptimized",
                          shape=[data_len, args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_coherence_spectrum",
                          shape=[data_len, args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_coherence_spectrum_reoptimized",
                          shape=[data_len, args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_frequencies",
                          shape=[args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="source_frequencies",
                          shape=[args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_frequencies",
                          shape=[args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32)
        hf.create_dataset(name="source_fov", shape=[1], dtype=np.float32)
        hf.create_dataset(name="observation_fov", shape=[1], dtype=np.float32)
        for i in range(data_len):
            checkpoint_manager.checkpoint.restore(
                checkpoint_manager.latest_checkpoint).expect_partial(
                )  # reset model weights

            # Produce an observation
            kappa = 10**kappa_vae.sample(1)
            source = tf.nn.relu(source_vae.sample(1))
            source /= tf.reduce_max(source, axis=(1, 2, 3), keepdims=True)
            noise_rms = 10**tf.random.uniform(shape=[1],
                                              minval=-2.5,
                                              maxval=-1)
            fwhm = tf.random.uniform(shape=[1], minval=0.06, maxval=0.3)
            psf = phys.psf_models(fwhm, cutout_size=psf_pixels)
            observation = phys.noisy_forward(source, kappa, noise_rms, psf)

            # RIM predictions for kappa and source
            source_pred, kappa_pred, chi_squared = rim.predict(
                observation, noise_rms, psf)
            observation_pred = phys.forward(source_pred[-1], kappa_pred[-1],
                                            psf)
            source_o = source_pred[-1]
            kappa_o = kappa_pred[-1]

            # Latent code of model predictions
            z_source, _ = source_vae.encoder(source_o)
            z_kappa, _ = kappa_vae.encoder(log_10(kappa_o))

            # Ground truth latent code for oracle metrics
            z_source_gt, _ = source_vae.encoder(source)
            z_kappa_gt, _ = kappa_vae.encoder(log_10(kappa))

            # Re-optimize weights of the model
            STEPS = args.re_optimize_steps
            learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=args.learning_rate,
                decay_rate=args.decay_rate,
                decay_steps=args.decay_steps,
                staircase=args.staircase)
            optim = tf.keras.optimizers.RMSprop(
                learning_rate=learning_rate_schedule)

            chi_squared_series = tf.TensorArray(DTYPE, size=STEPS)
            source_mse = tf.TensorArray(DTYPE, size=STEPS)
            kappa_mse = tf.TensorArray(DTYPE, size=STEPS)
            sampled_chi_squared_series = tf.TensorArray(DTYPE, size=STEPS)
            sampled_source_mse = tf.TensorArray(DTYPE, size=STEPS)
            sampled_kappa_mse = tf.TensorArray(DTYPE, size=STEPS)

            best = chi_squared
            source_best = source_pred[-1]
            kappa_best = kappa_pred[-1]
            source_mse_best = tf.reduce_mean((source_best - source)**2)
            kappa_mse_best = tf.reduce_mean((kappa_best - log_10(kappa))**2)

            # ===================== Optimization ==============================
            for current_step in tqdm(range(STEPS)):
                # ===================== VAE SAMPLING ==============================

                # L1 distance with ground truth in latent space -- this is changed by an user defined value when using real data
                # z_source_std = tf.abs(z_source - z_source_gt)
                # z_kappa_std = tf.abs(z_kappa - z_kappa_gt)
                z_source_std = args.source_vae_ball_size
                z_kappa_std = args.kappa_vae_ball_size

                # Sample latent code, then decode and forward
                z_s = tf.random.normal(
                    shape=[args.sample_size, source_vae.latent_size],
                    mean=z_source,
                    stddev=z_source_std)
                z_k = tf.random.normal(
                    shape=[args.sample_size, kappa_vae.latent_size],
                    mean=z_kappa,
                    stddev=z_kappa_std)
                sampled_source = tf.nn.relu(source_vae.decode(z_s))
                sampled_source /= tf.reduce_max(sampled_source,
                                                axis=(1, 2, 3),
                                                keepdims=True)
                sampled_kappa = kappa_vae.decode(z_k)  # output in log_10 space
                sampled_observation = phys.noisy_forward(
                    sampled_source, 10**sampled_kappa, noise_rms,
                    tf.tile(psf, [args.sample_size, 1, 1, 1]))
                with tf.GradientTape() as tape:
                    tape.watch(unet.trainable_variables)
                    s, k, chi_sq = rim.call(
                        sampled_observation,
                        noise_rms,
                        tf.tile(psf, [args.sample_size, 1, 1, 1]),
                        outer_tape=tape)
                    _kappa_mse = tf.reduce_sum(wk(10**sampled_kappa) *
                                               (k - sampled_kappa)**2,
                                               axis=(2, 3, 4))
                    cost = tf.reduce_mean(_kappa_mse)
                    cost += tf.reduce_mean((s - sampled_source)**2)
                    cost += tf.reduce_sum(rim.unet.losses)  # weight decay

                grads = tape.gradient(cost, unet.trainable_variables)
                optim.apply_gradients(zip(grads, unet.trainable_variables))

                # Record performance on sampled dataset
                sampled_chi_squared_series = sampled_chi_squared_series.write(
                    index=current_step,
                    value=tf.squeeze(tf.reduce_mean(chi_sq[-1])))
                sampled_source_mse = sampled_source_mse.write(
                    index=current_step,
                    value=tf.reduce_mean((s[-1] - sampled_source)**2))
                sampled_kappa_mse = sampled_kappa_mse.write(
                    index=current_step,
                    value=tf.reduce_mean((k[-1] - sampled_kappa)**2))
                # Record model prediction on data
                s, k, chi_sq = rim.call(observation, noise_rms, psf)
                chi_squared_series = chi_squared_series.write(
                    index=current_step, value=tf.squeeze(chi_sq))
                source_o = s[-1]
                kappa_o = k[-1]
                # oracle metrics, remove when using real data
                source_mse = source_mse.write(index=current_step,
                                              value=tf.reduce_mean(
                                                  (source_o - source)**2))
                kappa_mse = kappa_mse.write(index=current_step,
                                            value=tf.reduce_mean(
                                                (kappa_o - log_10(kappa))**2))

                if abs(chi_sq[-1, 0] - 1) < abs(best[-1, 0] - 1):
                    source_best = tf.nn.relu(source_o)
                    kappa_best = 10**kappa_o
                    best = chi_sq
                    source_mse_best = tf.reduce_mean((source_best - source)**2)
                    kappa_mse_best = tf.reduce_mean(
                        (kappa_best - log_10(kappa))**2)

            source_o = source_best
            kappa_o = kappa_best
            y_pred = phys.forward(source_o, kappa_o, psf)

            chi_sq_series = tf.transpose(chi_squared_series.stack())
            source_mse = source_mse.stack()
            kappa_mse = kappa_mse.stack()
            sampled_chi_squared_series = sampled_chi_squared_series.stack()
            sampled_source_mse = sampled_source_mse.stack()
            sampled_kappa_mse = sampled_kappa_mse.stack()

            # Latent code of optimized model predictions
            z_source_opt, _ = source_vae.encoder(tf.nn.relu(source_o))
            z_kappa_opt, _ = kappa_vae.encoder(log_10(kappa_o))

            # Compute Power spectrum of converged predictions
            _ps_observation = ps_observation.cross_correlation_coefficient(
                observation[..., 0], observation_pred[..., 0])
            _ps_observation2 = ps_observation.cross_correlation_coefficient(
                observation[..., 0], y_pred[..., 0])
            _ps_kappa = ps_kappa.cross_correlation_coefficient(
                log_10(kappa)[..., 0],
                log_10(kappa_pred[-1])[..., 0])
            _ps_kappa2 = ps_kappa.cross_correlation_coefficient(
                log_10(kappa)[..., 0], log_10(kappa_o[..., 0]))
            _ps_source = ps_source.cross_correlation_coefficient(
                source[..., 0], source_pred[-1][..., 0])
            _ps_source2 = ps_source.cross_correlation_coefficient(
                source[..., 0], source_o[..., 0])

            # save results
            hf["observation"][i] = observation.numpy().astype(np.float32)
            hf["psf"][i] = psf.numpy().astype(np.float32)
            hf["psf_fwhm"][i] = fwhm.numpy().astype(np.float32)
            hf["noise_rms"][i] = noise_rms.numpy().astype(np.float32)
            hf["source"][i] = source.numpy().astype(np.float32)
            hf["kappa"][i] = kappa.numpy().astype(np.float32)
            hf["observation_pred"][i] = observation_pred.numpy().astype(
                np.float32)
            hf["observation_pred_reoptimized"][i] = y_pred.numpy().astype(
                np.float32)
            hf["source_pred"][i] = tf.transpose(
                source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            hf["source_pred_reoptimized"][i] = source_o.numpy().astype(
                np.float32)
            hf["kappa_pred"][i] = tf.transpose(
                kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            hf["kappa_pred_reoptimized"][i] = kappa_o.numpy().astype(
                np.float32)
            hf["chi_squared"][i] = tf.squeeze(chi_squared).numpy().astype(
                np.float32)
            hf["chi_squared_reoptimized"][i] = tf.squeeze(best).numpy().astype(
                np.float32)
            hf["chi_squared_reoptimized_series"][i] = chi_sq_series.numpy(
            ).astype(np.float32)
            hf["sampled_chi_squared_reoptimized_series"][
                i] = 2 * sampled_chi_squared_series.numpy().astype(np.float32)
            hf["source_optim_mse"][i] = source_mse_best.numpy().astype(
                np.float32)
            hf["source_optim_mse_series"][i] = source_mse.numpy().astype(
                np.float32)
            hf["sampled_source_optim_mse_series"][
                i] = sampled_source_mse.numpy().astype(np.float32)
            hf["kappa_optim_mse"][i] = kappa_mse_best.numpy().astype(
                np.float32)
            hf["kappa_optim_mse_series"][i] = kappa_mse.numpy().astype(
                np.float32)
            hf["sampled_kappa_optim_mse_series"][i] = sampled_kappa_mse.numpy(
            ).astype(np.float32)
            hf["latent_source_gt_distance_init"][i] = tf.abs(
                z_source - z_source_gt).numpy().squeeze().astype(np.float32)
            hf["latent_kappa_gt_distance_init"][i] = tf.abs(
                z_kappa - z_kappa_gt).numpy().squeeze().astype(np.float32)
            hf["latent_source_gt_distance_end"][i] = tf.abs(
                z_source_opt - z_source_gt).numpy().squeeze().astype(
                    np.float32)
            hf["latent_kappa_gt_distance_end"][i] = tf.abs(
                z_kappa_opt - z_kappa_gt).numpy().squeeze().astype(np.float32)
            hf["observation_coherence_spectrum"][i] = _ps_observation
            hf["observation_coherence_spectrum_reoptimized"][
                i] = _ps_observation2
            hf["source_coherence_spectrum"][i] = _ps_source
            hf["source_coherence_spectrum_reoptimized"][i] = _ps_source2
            hf["kappa_coherence_spectrum"][i] = _ps_kappa
            hf["kappa_coherence_spectrum_reoptimized"][i] = _ps_kappa2

            if i == 0:
                _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels //
                                                                2],
                                    bins=ps_observation.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["observation_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(
                    phys.src_pixels)[:phys.src_pixels // 2],
                                    bins=ps_source.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["source_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(
                    phys.kappa_pixels)[:phys.kappa_pixels // 2],
                                    bins=ps_kappa.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["kappa_frequencies"][:] = f
                hf["kappa_fov"][0] = phys.kappa_fov
                hf["source_fov"][0] = phys.src_fov
예제 #6
0
def distributed_strategy(args):
    tf.random.set_seed(args.seed)
    np.random.seed(args.seed)

    model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model)
    files = glob.glob(
        os.path.join(os.getenv('CENSAI_PATH'), "data", args.train_dataset,
                     "*.tfrecords"))
    files = tf.data.Dataset.from_tensor_slices(files)
    train_dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
        x, compression_type=args.compression_type).shuffle(len(files)),
                                     block_length=1,
                                     num_parallel_calls=tf.data.AUTOTUNE)
    # Read off global parameters from first example in dataset
    for physical_params in train_dataset.map(decode_physical_model_info):
        break
    train_dataset = train_dataset.map(decode_results).shuffle(
        buffer_size=args.buffer_size)

    files = glob.glob(
        os.path.join(os.getenv('CENSAI_PATH'), "data", args.val_dataset,
                     "*.tfrecords"))
    files = tf.data.Dataset.from_tensor_slices(files)
    val_dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
        x, compression_type=args.compression_type).shuffle(len(files)),
                                   block_length=1,
                                   num_parallel_calls=tf.data.AUTOTUNE)
    val_dataset = val_dataset.map(decode_results).shuffle(
        buffer_size=args.buffer_size)

    files = glob.glob(
        os.path.join(os.getenv('CENSAI_PATH'), "data", args.test_dataset,
                     "*.tfrecords"))
    files = tf.data.Dataset.from_tensor_slices(files)
    test_dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
        x, compression_type=args.compression_type).shuffle(len(files)),
                                    block_length=1,
                                    num_parallel_calls=tf.data.AUTOTUNE)
    test_dataset = test_dataset.map(decode_results).shuffle(
        buffer_size=args.buffer_size)

    ps_lens = PowerSpectrum(bins=args.lens_coherence_bins,
                            pixels=physical_params["pixels"].numpy())
    ps_source = PowerSpectrum(bins=args.source_coherence_bins,
                              pixels=physical_params["src pixels"].numpy())
    ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins,
                             pixels=physical_params["kappa pixels"].numpy())

    phys = PhysicalModel(
        pixels=physical_params["pixels"].numpy(),
        kappa_pixels=physical_params["kappa pixels"].numpy(),
        src_pixels=physical_params["src pixels"].numpy(),
        image_fov=physical_params["image fov"].numpy(),
        kappa_fov=physical_params["kappa fov"].numpy(),
        src_fov=physical_params["source fov"].numpy(),
        method="fft",
    )

    phys_sie = AnalyticalPhysicalModel(
        pixels=physical_params["pixels"].numpy(),
        image_fov=physical_params["image fov"].numpy(),
        src_fov=physical_params["source fov"].numpy())

    with open(os.path.join(model, "unet_hparams.json")) as f:
        unet_params = json.load(f)
    unet_params["kernel_l2_amp"] = args.l2_amp
    unet = Model(**unet_params)
    ckpt = tf.train.Checkpoint(net=unet)
    checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1)
    checkpoint_manager.checkpoint.restore(
        checkpoint_manager.latest_checkpoint).expect_partial()
    with open(os.path.join(model, "rim_hparams.json")) as f:
        rim_params = json.load(f)
    rim = RIM(phys, unet, **rim_params)

    dataset_names = [args.train_dataset, args.val_dataset, args.test_dataset]
    dataset_shapes = [args.train_size, args.val_size, args.test_size]
    model_name = os.path.split(model)[-1]

    # from censai.utils import nulltape
    # def call_with_mask(self, lensed_image, noise_rms, psf, mask, outer_tape=nulltape):
    #     """
    #     Used in training. Return linked kappa and source maps.
    #     """
    #     batch_size = lensed_image.shape[0]
    #     source, kappa, source_grad, kappa_grad, states = self.initial_states(batch_size)  # initiate all tensors to 0
    #     source, kappa, states = self.time_step(lensed_image, source, kappa, source_grad, kappa_grad,
    #                                            states)  # Use lens to make an initial guess with Unet
    #     source_series = tf.TensorArray(DTYPE, size=self.steps)
    #     kappa_series = tf.TensorArray(DTYPE, size=self.steps)
    #     chi_squared_series = tf.TensorArray(DTYPE, size=self.steps)
    #     # record initial guess
    #     source_series = source_series.write(index=0, value=source)
    #     kappa_series = kappa_series.write(index=0, value=kappa)
    #     # Main optimization loop
    #     for current_step in tf.range(self.steps - 1):
    #         with outer_tape.stop_recording():
    #             with tf.GradientTape() as g:
    #                 g.watch(source)
    #                 g.watch(kappa)
    #                 y_pred = self.physical_model.forward(self.source_link(source), self.kappa_link(kappa), psf)
    #                 flux_term = tf.square(
    #                     tf.reduce_sum(y_pred, axis=(1, 2, 3)) - tf.reduce_sum(lensed_image, axis=(1, 2, 3)))
    #                 log_likelihood = 0.5 * tf.reduce_sum(
    #                     tf.square(y_pred - mask * lensed_image) / noise_rms[:, None, None, None] ** 2, axis=(1, 2, 3))
    #                 cost = tf.reduce_mean(log_likelihood + self.flux_lagrange_multiplier * flux_term)
    #             source_grad, kappa_grad = g.gradient(cost, [source, kappa])
    #             source_grad, kappa_grad = self.grad_update(source_grad, kappa_grad, current_step)
    #         source, kappa, states = self.time_step(lensed_image, source, kappa, source_grad, kappa_grad, states)
    #         source_series = source_series.write(index=current_step + 1, value=source)
    #         kappa_series = kappa_series.write(index=current_step + 1, value=kappa)
    #         chi_squared_series = chi_squared_series.write(index=current_step,
    #                                                       value=log_likelihood / self.pixels ** 2)  # renormalize chi squared here
    #     # last step score
    #     log_likelihood = self.physical_model.log_likelihood(y_true=lensed_image, source=self.source_link(source),
    #                                                         kappa=self.kappa_link(kappa), psf=psf, noise_rms=noise_rms)
    #     chi_squared_series = chi_squared_series.write(index=self.steps - 1, value=log_likelihood)
    #     return source_series.stack(), kappa_series.stack(), chi_squared_series.stack()

    with h5py.File(
            os.path.join(
                os.getenv("CENSAI_PATH"), "results", args.experiment_name +
                "_" + model_name + f"_{THIS_WORKER:02d}.h5"), 'w') as hf:
        for i, dataset in enumerate([train_dataset, val_dataset,
                                     test_dataset]):
            g = hf.create_group(f'{dataset_names[i]}')
            data_len = dataset_shapes[i] // N_WORKERS
            g.create_dataset(name="lens",
                             shape=[data_len, phys.pixels, phys.pixels, 1],
                             dtype=np.float32)
            g.create_dataset(name="psf",
                             shape=[
                                 data_len, physical_params['psf pixels'],
                                 physical_params['psf pixels'], 1
                             ],
                             dtype=np.float32)
            g.create_dataset(name="psf_fwhm",
                             shape=[data_len],
                             dtype=np.float32)
            g.create_dataset(name="noise_rms",
                             shape=[data_len],
                             dtype=np.float32)
            g.create_dataset(
                name="source",
                shape=[data_len, phys.src_pixels, phys.src_pixels, 1],
                dtype=np.float32)
            g.create_dataset(
                name="kappa",
                shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
                dtype=np.float32)
            g.create_dataset(name="lens_pred",
                             shape=[data_len, phys.pixels, phys.pixels, 1],
                             dtype=np.float32)
            g.create_dataset(name="lens_pred_reoptimized",
                             shape=[data_len, phys.pixels, phys.pixels, 1],
                             dtype=np.float32)
            g.create_dataset(name="source_pred",
                             shape=[
                                 data_len, rim.steps, phys.src_pixels,
                                 phys.src_pixels, 1
                             ],
                             dtype=np.float32)
            g.create_dataset(
                name="source_pred_reoptimized",
                shape=[data_len, phys.src_pixels, phys.src_pixels, 1])
            g.create_dataset(name="kappa_pred",
                             shape=[
                                 data_len, rim.steps, phys.kappa_pixels,
                                 phys.kappa_pixels, 1
                             ],
                             dtype=np.float32)
            g.create_dataset(
                name="kappa_pred_reoptimized",
                shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
                dtype=np.float32)
            g.create_dataset(name="chi_squared",
                             shape=[data_len, rim.steps],
                             dtype=np.float32)
            g.create_dataset(name="chi_squared_reoptimized",
                             shape=[data_len],
                             dtype=np.float32)
            g.create_dataset(name="chi_squared_reoptimized_series",
                             shape=[data_len, args.re_optimize_steps],
                             dtype=np.float32)
            g.create_dataset(name="source_optim_mse",
                             shape=[data_len],
                             dtype=np.float32)
            g.create_dataset(name="source_optim_mse_series",
                             shape=[data_len, args.re_optimize_steps],
                             dtype=np.float32)
            g.create_dataset(name="kappa_optim_mse",
                             shape=[data_len],
                             dtype=np.float32)
            g.create_dataset(name="kappa_optim_mse_series",
                             shape=[data_len, args.re_optimize_steps],
                             dtype=np.float32)
            g.create_dataset(name="lens_coherence_spectrum",
                             shape=[data_len, args.lens_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="source_coherence_spectrum",
                             shape=[data_len, args.source_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="lens_coherence_spectrum2",
                             shape=[data_len, args.lens_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="lens_coherence_spectrum_reoptimized",
                             shape=[data_len, args.lens_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="source_coherence_spectrum2",
                             shape=[data_len, args.source_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="source_coherence_spectrum_reoptimized",
                             shape=[data_len, args.source_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="kappa_coherence_spectrum",
                             shape=[data_len, args.kappa_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="kappa_coherence_spectrum_reoptimized",
                             shape=[data_len, args.kappa_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="lens_frequencies",
                             shape=[args.lens_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="source_frequencies",
                             shape=[args.source_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="kappa_frequencies",
                             shape=[args.kappa_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32)
            g.create_dataset(name="source_fov", shape=[1], dtype=np.float32)
            g.create_dataset(name="lens_fov", shape=[1], dtype=np.float32)
            dataset = dataset.skip(data_len * (THIS_WORKER - 1)).take(data_len)
            for batch, (lens, source, kappa, noise_rms, psf,
                        fwhm) in enumerate(
                            dataset.batch(1).prefetch(
                                tf.data.experimental.AUTOTUNE)):
                checkpoint_manager.checkpoint.restore(
                    checkpoint_manager.latest_checkpoint).expect_partial(
                    )  # reset model weights
                # Compute predictions for kappa and source
                source_pred, kappa_pred, chi_squared = rim.predict(
                    lens, noise_rms, psf)
                lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf)
                # Re-optimize weights of the model
                STEPS = args.re_optimize_steps
                learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                    initial_learning_rate=args.learning_rate,
                    decay_rate=args.decay_rate,
                    decay_steps=args.decay_steps,
                    staircase=args.staircase)
                optim = tf.keras.optimizers.RMSprop(
                    learning_rate=learning_rate_schedule)

                chi_squared_series = tf.TensorArray(DTYPE, size=STEPS)
                source_mse = tf.TensorArray(DTYPE, size=STEPS)
                kappa_mse = tf.TensorArray(DTYPE, size=STEPS)
                best = chi_squared[-1, 0]
                # best = abs(2*chi_squared[-1, 0] - 1)
                # best_chisq = 2*chi_squared[-1, 0]
                source_best = source_pred[-1]
                kappa_best = kappa_pred[-1]
                # source_mean = source_pred[-1]
                # kappa_mean = rim.kappa_link(kappa_pred[-1])
                # source_std = tf.zeros_like(source_mean)
                # kappa_std = tf.zeros_like(kappa_mean)
                # counter = 0
                for current_step in tqdm(range(STEPS)):
                    with tf.GradientTape() as tape:
                        tape.watch(unet.trainable_variables)
                        # s, k, chi_sq = call_with_mask(rim, lens, noise_rms, psf, mask, tape)
                        s, k, chi_sq = rim.call(lens,
                                                noise_rms,
                                                psf,
                                                outer_tape=tape)
                        cost = tf.reduce_mean(chi_sq)  # mean over time steps
                        cost += tf.reduce_sum(rim.unet.losses)

                    log_likelihood = chi_sq[-1]
                    chi_squared_series = chi_squared_series.write(
                        index=current_step, value=log_likelihood)
                    source_o = s[-1]
                    kappa_o = k[-1]
                    source_mse = source_mse.write(
                        index=current_step,
                        value=tf.reduce_mean(
                            (source_o - rim.source_inverse_link(source))**2))
                    kappa_mse = kappa_mse.write(
                        index=current_step,
                        value=tf.reduce_mean(
                            (kappa_o - rim.kappa_inverse_link(kappa))**2))
                    if chi_sq[-1, 0] < args.converged_chisq:
                        source_best = rim.source_link(source_o)
                        kappa_best = rim.kappa_link(kappa_o)
                        best = chi_sq[-1, 0]
                        break
                    if chi_sq[-1, 0] < best:
                        source_best = rim.source_link(source_o)
                        kappa_best = rim.kappa_link(kappa_o)
                        best = chi_sq[-1, 0]
                        source_mse_best = tf.reduce_mean(
                            (source_best - rim.source_inverse_link(source))**2)
                        kappa_mse_best = tf.reduce_mean(
                            (kappa_best - rim.kappa_inverse_link(kappa))**2)
                    # if counter > 0:
                    #     # Welford's online algorithm
                    #     # source
                    #     delta = source_o - source_mean
                    #     source_mean = (counter * source_mean + (counter + 1) * source_o)/(counter + 1)
                    #     delta2 = source_o - source_mean
                    #     source_std += delta * delta2
                    #     # kappa
                    #     delta = rim.kappa_link(kappa_o) - kappa_mean
                    #     kappa_mean = (counter * kappa_mean + (counter + 1) * rim.kappa_link(kappa_o)) / (counter + 1)
                    #     delta2 = rim.kappa_link(kappa_o) - kappa_mean
                    #     kappa_std += delta * delta2
                    # if best_chisq < args.converged_chisq:
                    #     counter += 1
                    #     if counter == args.window:
                    #         break
                    # if 2*chi_sq[-1, 0] < best_chisq:
                    #     best_chisq = 2*chi_sq[-1, 0]
                    # if abs(2*chi_sq[-1, 0] - 1) < best:
                    #     source_best = rim.source_link(source_o)
                    #     kappa_best = rim.kappa_link(kappa_o)
                    #     best = abs(2 * chi_squared[-1, 0] - 1)
                    #     source_mse_best = tf.reduce_mean((source_best - rim.source_inverse_link(source)) ** 2)
                    #     kappa_mse_best = tf.reduce_mean((kappa_best - rim.kappa_inverse_link(kappa)) ** 2)

                    grads = tape.gradient(cost, unet.trainable_variables)
                    optim.apply_gradients(zip(grads, unet.trainable_variables))

                source_o = source_best
                kappa_o = kappa_best
                y_pred = phys.forward(source_o, kappa_o, psf)
                chi_sq_series = tf.transpose(chi_squared_series.stack(),
                                             perm=[1, 0])
                source_mse = source_mse.stack()[None, ...]
                kappa_mse = kappa_mse.stack()[None, ...]
                # kappa_std /= float(args.window)
                # source_std /= float(args.window)

                # Compute Power spectrum of converged predictions
                _ps_lens = ps_lens.cross_correlation_coefficient(
                    lens[..., 0], lens_pred[..., 0])
                _ps_lens3 = ps_lens.cross_correlation_coefficient(
                    lens[..., 0], y_pred[..., 0])
                _ps_kappa = ps_kappa.cross_correlation_coefficient(
                    log_10(kappa)[..., 0],
                    log_10(kappa_pred[-1])[..., 0])
                _ps_kappa2 = ps_kappa.cross_correlation_coefficient(
                    log_10(kappa)[..., 0], log_10(kappa_o[..., 0]))
                _ps_source = ps_source.cross_correlation_coefficient(
                    source[..., 0], source_pred[-1][..., 0])
                _ps_source3 = ps_source.cross_correlation_coefficient(
                    source[..., 0], source_o[..., 0])

                # save results
                g["lens"][batch] = lens.numpy().astype(np.float32)
                g["psf"][batch] = psf.numpy().astype(np.float32)
                g["psf_fwhm"][batch] = fwhm.numpy().astype(np.float32)
                g["noise_rms"][batch] = noise_rms.numpy().astype(np.float32)
                g["source"][batch] = source.numpy().astype(np.float32)
                g["kappa"][batch] = kappa.numpy().astype(np.float32)
                g["lens_pred"][batch] = lens_pred.numpy().astype(np.float32)
                g["lens_pred_reoptimized"][batch] = y_pred.numpy().astype(
                    np.float32)
                g["source_pred"][batch] = tf.transpose(
                    source_pred,
                    perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
                g["source_pred_reoptimized"][batch] = source_o.numpy().astype(
                    np.float32)
                g["kappa_pred"][batch] = tf.transpose(
                    kappa_pred,
                    perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
                g["kappa_pred_reoptimized"][batch] = kappa_o.numpy().astype(
                    np.float32)
                g["chi_squared"][batch] = tf.transpose(
                    chi_squared).numpy().astype(np.float32)
                g["chi_squared_reoptimized"][batch] = best.numpy().astype(
                    np.float32)
                g["chi_squared_reoptimized_series"][
                    batch] = chi_sq_series.numpy().astype(np.float32)
                g["source_optim_mse"][batch] = source_mse_best.numpy().astype(
                    np.float32)
                g["source_optim_mse_series"][batch] = source_mse.numpy(
                ).astype(np.float32)
                g["kappa_optim_mse"][batch] = kappa_mse_best.numpy().astype(
                    np.float32)
                g["kappa_optim_mse_series"][batch] = kappa_mse.numpy().astype(
                    np.float32)
                g["lens_coherence_spectrum"][batch] = _ps_lens
                g["lens_coherence_spectrum_reoptimized"][batch] = _ps_lens3
                g["source_coherence_spectrum"][batch] = _ps_source
                g["source_coherence_spectrum_reoptimized"][batch] = _ps_source3
                g["lens_coherence_spectrum"][batch] = _ps_lens
                g["lens_coherence_spectrum"][batch] = _ps_lens
                g["kappa_coherence_spectrum"][batch] = _ps_kappa
                g["kappa_coherence_spectrum_reoptimized"][batch] = _ps_kappa2

                if batch == 0:
                    _, f = np.histogram(np.fft.fftfreq(
                        phys.pixels)[:phys.pixels // 2],
                                        bins=ps_lens.bins)
                    f = (f[:-1] + f[1:]) / 2
                    g["lens_frequencies"][:] = f
                    _, f = np.histogram(np.fft.fftfreq(
                        phys.src_pixels)[:phys.src_pixels // 2],
                                        bins=ps_source.bins)
                    f = (f[:-1] + f[1:]) / 2
                    g["source_frequencies"][:] = f
                    _, f = np.histogram(np.fft.fftfreq(
                        phys.kappa_pixels)[:phys.kappa_pixels // 2],
                                        bins=ps_kappa.bins)
                    f = (f[:-1] + f[1:]) / 2
                    g["kappa_frequencies"][:] = f
                    g["kappa_fov"][0] = phys.kappa_fov
                    g["source_fov"][0] = phys.src_fov

        # Create SIE test
        g = hf.create_group(f'SIE_test')
        data_len = args.sie_size // N_WORKERS
        sie_dataset = test_dataset.skip(data_len *
                                        (THIS_WORKER - 1)).take(data_len)
        g.create_dataset(name="lens",
                         shape=[data_len, phys.pixels, phys.pixels, 1],
                         dtype=np.float32)
        g.create_dataset(name="psf",
                         shape=[
                             data_len, physical_params['psf pixels'],
                             physical_params['psf pixels'], 1
                         ],
                         dtype=np.float32)
        g.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32)
        g.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32)
        g.create_dataset(name="source",
                         shape=[data_len, phys.src_pixels, phys.src_pixels, 1],
                         dtype=np.float32)
        g.create_dataset(
            name="kappa",
            shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
            dtype=np.float32)
        g.create_dataset(name="lens_pred",
                         shape=[data_len, phys.pixels, phys.pixels, 1],
                         dtype=np.float32)
        g.create_dataset(name="lens_pred2",
                         shape=[data_len, phys.pixels, phys.pixels, 1],
                         dtype=np.float32)
        g.create_dataset(
            name="source_pred",
            shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        g.create_dataset(name="kappa_pred",
                         shape=[
                             data_len, rim.steps, phys.kappa_pixels,
                             phys.kappa_pixels, 1
                         ],
                         dtype=np.float32)
        g.create_dataset(name="chi_squared",
                         shape=[data_len, rim.steps],
                         dtype=np.float32)
        g.create_dataset(name="lens_coherence_spectrum",
                         shape=[data_len, args.lens_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="source_coherence_spectrum",
                         shape=[data_len, args.source_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="lens_coherence_spectrum2",
                         shape=[data_len, args.lens_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="source_coherence_spectrum2",
                         shape=[data_len, args.source_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="kappa_coherence_spectrum",
                         shape=[data_len, args.kappa_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="lens_frequencies",
                         shape=[args.lens_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="source_frequencies",
                         shape=[args.source_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="kappa_frequencies",
                         shape=[args.kappa_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="einstein_radius",
                         shape=[data_len],
                         dtype=np.float32)
        g.create_dataset(name="position",
                         shape=[data_len, 2],
                         dtype=np.float32)
        g.create_dataset(name="orientation",
                         shape=[data_len],
                         dtype=np.float32)
        g.create_dataset(name="ellipticity",
                         shape=[data_len],
                         dtype=np.float32)
        g.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32)
        g.create_dataset(name="source_fov", shape=[1], dtype=np.float32)
        g.create_dataset(name="lens_fov", shape=[1], dtype=np.float32)

        for batch, (_, source, _, noise_rms, psf, fwhm) in enumerate(
                sie_dataset.take(data_len).batch(args.batch_size).prefetch(
                    tf.data.experimental.AUTOTUNE)):
            batch_size = source.shape[0]
            # Create some SIE kappa maps
            _r = tf.random.uniform(shape=[batch_size, 1, 1, 1],
                                   minval=0,
                                   maxval=args.max_shift)
            _theta = tf.random.uniform(shape=[batch_size, 1, 1, 1],
                                       minval=-np.pi,
                                       maxval=np.pi)
            x0 = _r * tf.math.cos(_theta)
            y0 = _r * tf.math.sin(_theta)
            ellipticity = tf.random.uniform(shape=[batch_size, 1, 1, 1],
                                            minval=0.,
                                            maxval=args.max_ellipticity)
            phi = tf.random.uniform(shape=[batch_size, 1, 1, 1],
                                    minval=-np.pi,
                                    maxval=np.pi)
            einstein_radius = tf.random.uniform(shape=[batch_size, 1, 1, 1],
                                                minval=args.min_theta_e,
                                                maxval=args.max_theta_e)
            kappa = phys_sie.kappa_field(x0=x0,
                                         y0=y0,
                                         e=ellipticity,
                                         phi=phi,
                                         r_ein=einstein_radius)
            lens = phys.noisy_forward(source,
                                      kappa,
                                      noise_rms=noise_rms,
                                      psf=psf)

            # Compute predictions for kappa and source
            source_pred, kappa_pred, chi_squared = rim.predict(
                lens, noise_rms, psf)
            lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf)
            # Compute Power spectrum of converged predictions
            _ps_lens = ps_lens.cross_correlation_coefficient(
                lens[..., 0], lens_pred[..., 0])
            _ps_kappa = ps_kappa.cross_correlation_coefficient(
                log_10(kappa)[..., 0],
                log_10(kappa_pred[-1])[..., 0])
            _ps_source = ps_source.cross_correlation_coefficient(
                source[..., 0], source_pred[-1][..., 0])

            # save results
            i_begin = batch * args.batch_size
            i_end = i_begin + batch_size
            g["lens"][i_begin:i_end] = lens.numpy().astype(np.float32)
            g["psf"][i_begin:i_end] = psf.numpy().astype(np.float32)
            g["psf_fwhm"][i_begin:i_end] = fwhm.numpy().astype(np.float32)
            g["noise_rms"][i_begin:i_end] = noise_rms.numpy().astype(
                np.float32)
            g["source"][i_begin:i_end] = source.numpy().astype(np.float32)
            g["kappa"][i_begin:i_end] = kappa.numpy().astype(np.float32)
            g["lens_pred"][i_begin:i_end] = lens_pred.numpy().astype(
                np.float32)
            g["source_pred"][i_begin:i_end] = tf.transpose(
                source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            g["kappa_pred"][i_begin:i_end] = tf.transpose(
                kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            g["chi_squared"][i_begin:i_end] = 2 * tf.transpose(
                chi_squared).numpy().astype(np.float32)
            g["lens_coherence_spectrum"][i_begin:i_end] = _ps_lens.numpy(
            ).astype(np.float32)
            g["source_coherence_spectrum"][i_begin:i_end] = _ps_source.numpy(
            ).astype(np.float32)
            g["kappa_coherence_spectrum"][i_begin:i_end] = _ps_kappa.numpy(
            ).astype(np.float32)
            g["einstein_radius"][
                i_begin:i_end] = einstein_radius[:, 0, 0,
                                                 0].numpy().astype(np.float32)
            g["position"][i_begin:i_end] = tf.stack(
                [x0[:, 0, 0, 0], y0[:, 0, 0, 0]],
                axis=1).numpy().astype(np.float32)
            g["ellipticity"][i_begin:i_end] = ellipticity[:, 0, 0,
                                                          0].numpy().astype(
                                                              np.float32)
            g["orientation"][i_begin:i_end] = phi[:, 0, 0,
                                                  0].numpy().astype(np.float32)

            if batch == 0:
                _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels //
                                                                2],
                                    bins=ps_lens.bins)
                f = (f[:-1] + f[1:]) / 2
                g["lens_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(
                    phys.src_pixels)[:phys.src_pixels // 2],
                                    bins=ps_source.bins)
                f = (f[:-1] + f[1:]) / 2
                g["source_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(
                    phys.kappa_pixels)[:phys.kappa_pixels // 2],
                                    bins=ps_kappa.bins)
                f = (f[:-1] + f[1:]) / 2
                g["kappa_frequencies"][:] = f
                g["kappa_fov"][0] = phys.kappa_fov
                g["source_fov"][0] = phys.src_fov