Ejemplo n.º 1
0
def test_log_likelihood():
    phys = PhysicalModel(pixels=64, src_pixels=64)
    kappa = tf.random.normal([1, 64, 64, 1])
    source = tf.random.normal([1, 64, 64, 1])
    im_lensed = phys.forward(source, kappa)
    assert im_lensed.shape == [1, 64, 64, 1]
    cost = phys.log_likelihood(source, kappa, im_lensed)
Ejemplo n.º 2
0
def test_lagrange_multiplier_for_lens_intensity():
    phys = PhysicalModel(pixels=128)
    phys_a = AnalyticalPhysicalModel(pixels=128)
    kappa = phys_a.kappa_field(2.0, e=0.2)
    x = np.linspace(-1, 1, 128) * phys.src_fov / 2
    xx, yy = np.meshgrid(x, x)
    rho = xx**2 + yy**2
    source = tf.math.exp(-0.5 * rho / 0.5**2)[tf.newaxis, ..., tf.newaxis]
    source = tf.cast(source, tf.float32)

    y_true = phys.forward(source, kappa)
    y_pred = phys.forward(0.001 * source,
                          kappa)  # rescale it, say it has different units
    lam_lagrange = tf.reduce_sum(y_true * y_pred, axis=(
        1, 2, 3)) / tf.reduce_sum(y_pred**2, axis=(1, 2, 3))
    lam_tests = tf.squeeze(
        tf.cast(tf.linspace(lam_lagrange / 10, lam_lagrange * 10, 1000),
                tf.float32))[..., tf.newaxis, tf.newaxis, tf.newaxis]
    log_likelihood_best = 0.5 * tf.reduce_mean(
        (lam_lagrange * y_pred - y_true)**2 / phys.noise_rms**2,
        axis=(1, 2, 3))
    log_likilhood_test = 0.5 * tf.reduce_mean(
        (lam_tests * y_pred - y_true)**2 / phys.noise_rms**2, axis=(1, 2, 3))
    return log_likilhood_test, log_likelihood_best, tf.squeeze(
        lam_tests), lam_lagrange
Ejemplo n.º 3
0
def test_interpolated_kappa():
    import tensorflow_addons as tfa
    phys = PhysicalModel(pixels=128,
                         src_pixels=32,
                         image_fov=7.68,
                         kappa_fov=5)
    phys_a = AnalyticalPhysicalModel(pixels=128, image_fov=7.68)
    kappa = phys_a.kappa_field(r_ein=2., e=0.2)
    kappa += phys_a.kappa_field(r_ein=1., x0=2., y0=2.)
    true_lens = phys.lens_source_func(kappa, w=0.2)
    true_kappa = kappa

    # Test interpolation of alpha angles on a finer grid
    # phys = PhysicalModel(pixels=128, src_pixels=32, kappa_pixels=32)
    phys_a = AnalyticalPhysicalModel(pixels=32, image_fov=7.68)
    kappa = phys_a.kappa_field(r_ein=2., e=0.2)
    kappa += phys_a.kappa_field(r_ein=1., x0=2., y0=2.)

    # kappa2 = phys_a.kappa_field(r_ein=2., e=0.2)
    # kappa2 += phys_a.kappa_field(r_ein=1., x0=2., y0=2.)
    #
    # kappa = tf.concat([kappa, kappa2], axis=1)

    # Test interpolated kappa lens
    x = np.linspace(-1, 1, 128) * phys.kappa_fov / 2
    x, y = np.meshgrid(x, x)
    x = tf.constant(x[np.newaxis, ..., np.newaxis], tf.float32)
    y = tf.constant(y[np.newaxis, ..., np.newaxis], tf.float32)
    dx = phys.kappa_fov / (32 - 1)
    xmin = -0.5 * phys.kappa_fov
    ymin = -0.5 * phys.kappa_fov
    i_coord = (x - xmin) / dx
    j_coord = (y - ymin) / dx
    wrap = tf.concat([i_coord, j_coord], axis=-1)
    # test_kappa1 = tfa.image.resampler(kappa, wrap)  # bilinear interpolation of source on wrap grid
    # test_lens1 = phys.lens_source_func(test_kappa1, w=0.2)
    phys2 = PhysicalModel(pixels=128,
                          kappa_pixels=32,
                          method="fft",
                          image_fov=7.68,
                          kappa_fov=5)
    test_lens1 = phys2.lens_source_func(kappa, w=0.2)

    # Test interpolated alpha angles lens
    phys2 = PhysicalModel(pixels=32,
                          src_pixels=32,
                          image_fov=7.68,
                          kappa_fov=5)
    alpha1, alpha2 = phys2.deflection_angle(kappa)
    alpha = tf.concat([alpha1, alpha2], axis=-1)
    alpha = tfa.image.resampler(alpha, wrap)
    test_lens2 = phys.lens_source_func_given_alpha(alpha, w=0.2)

    return true_lens, test_lens1, test_lens2
Ejemplo n.º 4
0
def test_lens_source_conv2():
    pixels = 64
    src_pixels = 32
    phys = PhysicalModel(pixels=pixels,
                         src_pixels=src_pixels,
                         kappa_fov=16,
                         image_fov=16)
    phys_analytic = AnalyticalPhysicalModel(pixels=pixels, image_fov=16)
    source = tf.random.normal([1, src_pixels, src_pixels, 1])
    kappa = phys_analytic.kappa_field(7, 0.1, 0, 0, 0)
    lens = phys.lens_source(source, kappa)
    return lens
Ejemplo n.º 5
0
def test_alpha_method_fft():
    pixels = 64
    phys = PhysicalModel(pixels=pixels, method="fft")
    phys_analytic = AnalyticalPhysicalModel(pixels=pixels, image_fov=7)
    phys2 = PhysicalModel(pixels=pixels, method="conv2d")

    # test out noise
    kappa = tf.random.uniform(shape=[1, pixels, pixels, 1])
    alphax, alphay = phys.deflection_angle(kappa)
    alphax2, alphay2 = phys2.deflection_angle(kappa)

    # assert np.allclose(alphax, alphax2, atol=1e-4)
    # assert np.allclose(alphay, alphay2, atol=1e-4)

    # test out an analytical profile
    kappa = phys_analytic.kappa_field(2, 0.4, 0, 0.1, 0.5)
    alphax, alphay = phys.deflection_angle(kappa)

    alphax2, alphay2 = phys2.deflection_angle(kappa)
    #
    # assert np.allclose(alphax, alphax2, atol=1e-4)
    # assert np.allclose(alphay, alphay2, atol=1e-4)
    im1 = phys_analytic.lens_source_func_given_alpha(
        tf.concat([alphax, alphay], axis=-1))
    im2 = phys_analytic.lens_source_func_given_alpha(
        tf.concat([alphax2, alphay2], axis=-1))
    return alphax, alphax2, im1, im2
Ejemplo n.º 6
0
def test_lens_func_given_alpha():
    phys = PhysicalModel(pixels=128)
    phys_a = AnalyticalPhysicalModel(pixels=128)
    alpha = phys_a.analytical_deflection_angles(x0=0.5,
                                                y0=0.5,
                                                e=0.4,
                                                phi=0.,
                                                r_ein=1.)
    lens_true = phys_a.lens_source_func(x0=0.5,
                                        y0=0.5,
                                        e=0.4,
                                        phi=0.,
                                        r_ein=1.,
                                        xs=0.5,
                                        ys=0.5)
    lens_pred = phys_a.lens_source_func_given_alpha(alpha, xs=0.5, ys=0.5)
    lens_pred2 = phys.lens_source_func_given_alpha(alpha, xs=0.5, ys=0.5)
    fig = raytracer_residual_plot(alpha[0], alpha[0], lens_true[0],
                                  lens_pred2[0])
Ejemplo n.º 7
0
def distributed_strategy(args):
    kappa_gen = NISGenerator( # only used to generate pixelated kappa fields
        kappa_fov=args.kappa_fov,
        src_fov=args.source_fov,
        pixels=args.kappa_pixels,
        z_source=args.z_source,
        z_lens=args.z_lens
    )

    min_theta_e = 0.1 * args.image_fov if args.min_theta_e is None else args.min_theta_e
    max_theta_e = 0.45 * args.image_fov if args.max_theta_e is None else args.max_theta_e

    cosmos_files = glob.glob(os.path.join(args.cosmos_dir, "*.tfrecords"))
    cosmos = tf.data.TFRecordDataset(cosmos_files, compression_type=args.compression_type)
    cosmos = cosmos.map(decode_image).map(preprocess_image)
    if args.shuffle_cosmos:
        cosmos = cosmos.shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True)
    cosmos = cosmos.batch(args.batch_size)

    window = tukey(args.src_pixels, alpha=args.tukey_alpha)
    window = np.outer(window, window)
    phys = PhysicalModel(
        image_fov=args.image_fov,
        kappa_fov=args.kappa_fov,
        src_fov=args.source_fov,
        pixels=args.lens_pixels,
        kappa_pixels=args.kappa_pixels,
        src_pixels=args.src_pixels,
        method="conv2d"
    )
    noise_a = (args.noise_rms_min - args.noise_rms_mean) / args.noise_rms_std
    noise_b = (args.noise_rms_max - args.noise_rms_mean) / args.noise_rms_std
    psf_a = (args.psf_fwhm_min - args.psf_fwhm_mean) / args.psf_fwhm_std
    psf_b = (args.psf_fwhm_max - args.psf_fwhm_mean) / args.psf_fwhm_std

    options = tf.io.TFRecordOptions(compression_type=args.compression_type)
    with tf.io.TFRecordWriter(os.path.join(args.output_dir, f"data_{THIS_WORKER}.tfrecords"), options) as writer:
        print(f"Started worker {THIS_WORKER} at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}")
        for i in range((THIS_WORKER - 1) * args.batch_size, args.len_dataset, N_WORKERS * args.batch_size):
            for galaxies in cosmos:
                break
            galaxies = window[np.newaxis, ..., np.newaxis] * galaxies

            noise_rms = truncnorm.rvs(noise_a, noise_b, loc=args.noise_rms_mean, scale=args.noise_rms_std, size=args.batch_size)
            fwhm = truncnorm.rvs(psf_a, psf_b, loc=args.psf_fwhm_mean, scale=args.psf_fwhm_std, size=args.batch_size)
            psf = phys.psf_models(fwhm, cutout_size=args.psf_cutout_size)

            batch_size = galaxies.shape[0]
            _r = tf.random.uniform(shape=[batch_size, 1, 1], minval=0, maxval=args.max_shift)
            _theta = tf.random.uniform(shape=[batch_size, 1, 1], minval=-np.pi, maxval=np.pi)
            x0 = _r * tf.math.cos(_theta)
            y0 = _r * tf.math.sin(_theta)
            ellipticity = tf.random.uniform(shape=[batch_size, 1, 1], minval=0., maxval=args.max_ellipticity)
            phi = tf.random.uniform(shape=[batch_size, 1, 1], minval=-np.pi, maxval=np.pi)
            einstein_radius = tf.random.uniform(shape=[batch_size, 1, 1], minval=min_theta_e, maxval=max_theta_e)

            kappa = kappa_gen.kappa_field(x0, y0, ellipticity, phi, einstein_radius)

            lensed_images = phys.noisy_forward(galaxies, kappa, noise_rms=noise_rms, psf=psf)

            records = encode_examples(
                kappa=kappa,
                galaxies=galaxies,
                lensed_images=lensed_images,
                z_source=args.z_source,
                z_lens=args.z_lens,
                image_fov=phys.image_fov,
                kappa_fov=phys.kappa_fov,
                source_fov=args.source_fov,
                noise_rms=noise_rms,
                psf=psf,
                fwhm=fwhm
            )
            for record in records:
                writer.write(record)
    print(f"Finished work at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}")
Ejemplo n.º 8
0
def main(args):
    if args.seed is not None:
        tf.random.set_seed(args.seed)
        np.random.seed(args.seed)
    if args.json_override is not None:
        if isinstance(args.json_override, list):
            files = args.json_override
        else:
            files = [args.json_override,]
        for file in files:
            with open(file, "r") as f:
                json_override = json.load(f)
            args_dict = vars(args)
            args_dict.update(json_override)

    files = []
    for dataset in args.datasets:
        files.extend(glob.glob(os.path.join(dataset, "*.tfrecords")))
    np.random.shuffle(files)

    if args.val_datasets is not None:
        """    
        In this conditional, we assume total items might be a subset of the dataset size.
        Thus we want to reshuffle at each epoch to get a different realisation of the dataset. 
        In case total_items == true dataset size, this means we only change ordering of items each epochs.
        
        Also, validation is not a split of the training data, but a saved dataset on disk. 
        """
        files = tf.data.Dataset.from_tensor_slices(files).shuffle(len(files), reshuffle_each_iteration=True)
        dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type),
                                   block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE)
        # Read off global parameters from first example in dataset
        for physical_params in dataset.map(decode_physical_model_info):
            break
        # preprocessing
        dataset = dataset.map(decode_train)
        if args.cache_file is not None:
            dataset = dataset.cache(args.cache_file)
        train_dataset = dataset.shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True).take(args.total_items).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE)

        val_files = []
        for dataset in args.val_datasets:
            val_files.extend(glob.glob(os.path.join(dataset, "*.tfrecords")))
        val_files = tf.data.Dataset.from_tensor_slices(val_files).shuffle(len(files), reshuffle_each_iteration=True)
        val_dataset = val_files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE)
        val_dataset = val_dataset.map(decode_train).shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True).\
            take(math.ceil((1 - args.train_split) * args.total_items)).\
            batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    else:
        """
        Here, we split the dataset, so we assume total_items is the true dataset size. Any extra items will be discarded. 
        This is to make sure validation set is never seen by the model, so shuffling occurs after the split.
        """
        files = tf.data.Dataset.from_tensor_slices(files)
        dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type),
                                   block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE)
        # Read off global parameters from first example in dataset
        for physical_params in dataset.map(decode_physical_model_info):
            break
        # preprocessing
        dataset = dataset.map(decode_train)
        if args.cache_file is not None:
            dataset = dataset.cache(args.cache_file)
        train_dataset = dataset.take(math.floor(args.train_split * args.total_items)).shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE)
        val_dataset = dataset.skip(math.floor(args.train_split * args.total_items)).take(math.ceil((1 - args.train_split) * args.total_items)).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE)

    train_dataset = STRATEGY.experimental_distribute_dataset(train_dataset)
    val_dataset = STRATEGY.experimental_distribute_dataset(val_dataset)
    with STRATEGY.scope():  # Replicate ops accross gpus
        if args.raytracer is not None:
            with open(os.path.join(args.raytracer, "ray_tracer_hparams.json"), "r") as f:
                raytracer_hparams = json.load(f)
            raytracer = RayTracer(**raytracer_hparams)
            # load last checkpoint in the checkpoint directory
            checkpoint = tf.train.Checkpoint(net=raytracer)
            manager = tf.train.CheckpointManager(checkpoint, directory=args.raytracer, max_to_keep=3)
            checkpoint.restore(manager.latest_checkpoint).expect_partial()
        else:
            raytracer = None
        phys = PhysicalModel(
            pixels=physical_params["pixels"].numpy(),
            kappa_pixels=physical_params["kappa pixels"].numpy(),
            src_pixels=physical_params["src pixels"].numpy(),
            image_fov=physical_params["image fov"].numpy(),
            kappa_fov=physical_params["kappa fov"].numpy(),
            src_fov=physical_params["source fov"].numpy(),
            method=args.forward_method,
            raytracer=raytracer,
        )

        unet = Model(
            filters=args.filters,
            filter_scaling=args.filter_scaling,
            kernel_size=args.kernel_size,
            layers=args.layers,
            block_conv_layers=args.block_conv_layers,
            strides=args.strides,
            bottleneck_kernel_size=args.bottleneck_kernel_size,
            resampling_kernel_size=args.resampling_kernel_size,
            input_kernel_size=args.input_kernel_size,
            gru_kernel_size=args.gru_kernel_size,
            upsampling_interpolation=args.upsampling_interpolation,
            kernel_l2_amp=args.kernel_l2_amp,
            bias_l2_amp=args.bias_l2_amp,
            kernel_l1_amp=args.kernel_l1_amp,
            bias_l1_amp=args.bias_l1_amp,
            activation=args.activation,
            initializer=args.initializer,
            batch_norm=args.batch_norm,
            dropout_rate=args.dropout_rate,
            filter_cap=args.filter_cap
        )
        rim = RIM(
            physical_model=phys,
            unet=unet,
            steps=args.steps,
            adam=args.adam,
            kappalog=args.kappalog,
            source_link=args.source_link,
            kappa_normalize=args.kappa_normalize,
            flux_lagrange_multiplier=args.flux_lagrange_multiplier
        )
        learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=args.initial_learning_rate,
            decay_rate=args.decay_rate,
            decay_steps=args.decay_steps,
            staircase=args.staircase
        )
        optim = tf.keras.optimizers.deserialize(
            {
                "class_name": args.optimizer,
                'config': {"learning_rate": learning_rate_schedule}
            }
        )
        # weights for time steps in the loss function
        if args.time_weights == "uniform":
            wt = tf.ones(shape=(args.steps), dtype=DTYPE) / args.steps
        elif args.time_weights == "linear":
            wt = 2 * (tf.range(args.steps, dtype=DTYPE) + 1) / args.steps / (args.steps + 1)
        elif args.time_weights == "quadratic":
            wt = 6 * (tf.range(args.steps, dtype=DTYPE) + 1)**2 / args.steps / (args.steps + 1) / (2 * args.steps + 1)
        else:
            raise ValueError("time_weights must be in ['uniform', 'linear', 'quadratic']")
        wt = wt[..., tf.newaxis]  # [steps, batch]

        if args.kappa_residual_weights == "uniform":
            wk = tf.keras.layers.Lambda(lambda k: tf.ones_like(k, dtype=DTYPE) / tf.cast(tf.math.reduce_prod(k.shape[1:]), DTYPE))
        elif args.kappa_residual_weights == "linear":
            wk = tf.keras.layers.Lambda(lambda k: k / tf.reduce_sum(k, axis=(1, 2, 3), keepdims=True))
        elif args.kappa_residual_weights == "sqrt":
            wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum(tf.sqrt(k), axis=(1, 2, 3), keepdims=True))
        elif args.kappa_residual_weights == "quadratic":
            wk = tf.keras.layers.Lambda(lambda k: tf.square(k) / tf.reduce_sum(tf.square(k), axis=(1, 2, 3), keepdims=True))
        else:
            raise ValueError("kappa_residual_weights must be in ['uniform', 'linear', 'quadratic', 'sqrt']")

        if args.source_residual_weights == "uniform":
            ws = tf.keras.layers.Lambda(lambda s: tf.ones_like(s, dtype=DTYPE) / tf.cast(tf.math.reduce_prod(s.shape[1:]), DTYPE))
        elif args.source_residual_weights == "linear":
            ws = tf.keras.layers.Lambda(lambda s: s / tf.reduce_sum(s, axis=(1, 2, 3), keepdims=True))
        elif args.source_residual_weights == "quadratic":
            ws = tf.keras.layers.Lambda(lambda s: tf.square(s) / tf.reduce_sum(tf.square(s), axis=(1, 2, 3), keepdims=True))
        elif args.source_residual_weights == "sqrt":
            ws = tf.keras.layers.Lambda(lambda s: tf.sqrt(s) / tf.reduce_sum(tf.sqrt(s), axis=(1, 2, 3), keepdims=True))
        else:
            raise ValueError("kappa_residual_weights must be in ['uniform', 'linear', 'quadratic', 'sqrt']")

    # ==== Take care of where to write logs and stuff =================================================================
    if args.model_id.lower() != "none":
        if args.logname is not None:
            logname = args.model_id + "_" + args.logname
            model_id = args.model_id
        else:
            logname = args.model_id + "_" + datetime.now().strftime("%y%m%d%H%M%S")
            model_id = args.model_id
    elif args.logname is not None:
        logname = args.logname
        model_id = logname
    else:
        logname = args.logname_prefixe + "_" + datetime.now().strftime("%y%m%d%H%M%S")
        model_id = logname
    if args.logdir.lower() != "none":
        logdir = os.path.join(args.logdir, logname)
        if not os.path.isdir(logdir):
            os.mkdir(logdir)
        writer = tf.summary.create_file_writer(logdir)
    else:
        writer = nullwriter()
    # ===== Make sure directory and checkpoint manager are created to save model ===================================
    if args.model_dir.lower() != "none":
        checkpoints_dir = os.path.join(args.model_dir, logname)
        old_checkpoints_dir = os.path.join(args.model_dir, model_id)  # in case they differ we load model from a different directory
        if not os.path.isdir(checkpoints_dir):
            os.mkdir(checkpoints_dir)
            with open(os.path.join(checkpoints_dir, "script_params.json"), "w") as f:
                json.dump(vars(args), f, indent=4)
            with open(os.path.join(checkpoints_dir, "unet_hparams.json"), "w") as f:
                hparams_dict = {key: vars(args)[key] for key in UNET_MODEL_HPARAMS}
                json.dump(hparams_dict, f, indent=4)
            with open(os.path.join(checkpoints_dir, "rim_hparams.json"), "w") as f:
                hparams_dict = {key: vars(args)[key] for key in RIM_HPARAMS}
                json.dump(hparams_dict, f, indent=4)
        ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optim, net=rim.unet)
        checkpoint_manager = tf.train.CheckpointManager(ckpt, old_checkpoints_dir, max_to_keep=args.max_to_keep)
        save_checkpoint = True
        # ======= Load model if model_id is provided ===============================================================
        if args.model_id.lower() != "none":
            checkpoint_manager.checkpoint.restore(checkpoint_manager.latest_checkpoint)
        if old_checkpoints_dir != checkpoints_dir:  # save progress in another directory.
            if args.reset_optimizer_states:
                optim = tf.keras.optimizers.deserialize(
                    {
                        "class_name": args.optimizer,
                        'config': {"learning_rate": learning_rate_schedule}
                    }
                )
                ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optim, net=rim.unet)
            checkpoint_manager = tf.train.CheckpointManager(ckpt, checkpoints_dir, max_to_keep=args.max_to_keep)

    else:
        save_checkpoint = False
    # =================================================================================================================

    def train_step(X, source, kappa, noise_rms, psf):
        with tf.GradientTape() as tape:
            tape.watch(rim.unet.trainable_variables)
            if args.unroll_time_steps:
                source_series, kappa_series, chi_squared = rim.call_function(X, noise_rms, psf)
            else:
                source_series, kappa_series, chi_squared = rim.call(X, noise_rms, psf, outer_tape=tape)
            # mean over image residuals (in model prediction space)
            source_cost1 = tf.reduce_sum(ws(source) * tf.square(source_series - rim.source_inverse_link(source)), axis=(2, 3, 4))
            kappa_cost1 = tf.reduce_sum(wk(kappa) * tf.square(kappa_series - rim.kappa_inverse_link(kappa)), axis=(2, 3, 4))
            # weighted mean over time steps
            source_cost = tf.reduce_sum(wt * source_cost1, axis=0)
            kappa_cost = tf.reduce_sum(wt * kappa_cost1, axis=0)
            # final cost is mean over global batch size
            cost = tf.reduce_sum(kappa_cost + source_cost) / args.batch_size
        gradient = tape.gradient(cost, rim.unet.trainable_variables)
        gradient = [tf.clip_by_norm(grad, 5.) for grad in gradient]
        optim.apply_gradients(zip(gradient, rim.unet.trainable_variables))
        # Update metrics with "converged" score
        chi_squared = tf.reduce_sum(chi_squared[-1]) / args.batch_size
        source_cost = tf.reduce_sum(source_cost1[-1]) / args.batch_size
        kappa_cost = tf.reduce_sum(kappa_cost1[-1]) / args.batch_size
        return cost, chi_squared, source_cost, kappa_cost

    @tf.function
    def distributed_train_step(X, source, kappa, noise_rms, psf):
        per_replica_losses, per_replica_chi_squared, per_replica_source_cost, per_replica_kappa_cost = STRATEGY.run(train_step, args=(X, source, kappa, noise_rms, psf))
        # Replica losses are aggregated by summing them
        global_loss = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        global_chi_squared = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_chi_squared, axis=None)
        global_source_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_source_cost, axis=None)
        global_kappa_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_kappa_cost, axis=None)
        return global_loss, global_chi_squared, global_source_cost, global_kappa_cost

    def test_step(X, source, kappa, noise_rms, psf):
        source_series, kappa_series, chi_squared = rim.call(X, noise_rms, psf)
        # mean over image residuals (in model prediction space)
        source_cost1 = tf.reduce_sum(ws(source) * tf.square(source_series - rim.source_inverse_link(source)), axis=(2, 3, 4))
        kappa_cost1 = tf.reduce_sum(wk(kappa) * tf.square(kappa_series - rim.kappa_inverse_link(kappa)), axis=(2, 3, 4))
        # weighted mean over time steps
        source_cost = tf.reduce_sum(wt * source_cost1, axis=0)
        kappa_cost = tf.reduce_sum(wt * kappa_cost1, axis=0)
        # final cost is mean over global batch size
        cost = tf.reduce_sum(kappa_cost + source_cost) / args.batch_size
        # Update metrics with "converged" score
        chi_squared = tf.reduce_sum(chi_squared[-1]) / args.batch_size
        source_cost = tf.reduce_sum(source_cost1[-1]) / args.batch_size
        kappa_cost = tf.reduce_sum(kappa_cost1[-1]) / args.batch_size
        return cost, chi_squared, source_cost, kappa_cost

    @tf.function
    def distributed_test_step(X, source, kappa, noise_rms, psf):
        per_replica_losses, per_replica_chi_squared, per_replica_source_cost, per_replica_kappa_cost = STRATEGY.run(test_step, args=(X, source, kappa, noise_rms, psf))
        # Replica losses are aggregated by summing them
        global_loss = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        global_chi_squared = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_chi_squared, axis=None)
        global_source_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_source_cost, axis=None)
        global_kappa_cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM, per_replica_kappa_cost, axis=None)
        return global_loss, global_chi_squared, global_source_cost, global_kappa_cost

    # ====== Training loop ============================================================================================
    epoch_loss = tf.metrics.Mean()
    time_per_step = tf.metrics.Mean()
    val_loss = tf.metrics.Mean()
    epoch_chi_squared = tf.metrics.Mean()
    epoch_source_loss = tf.metrics.Mean()
    epoch_kappa_loss = tf.metrics.Mean()
    val_chi_squared = tf.metrics.Mean()
    val_source_loss = tf.metrics.Mean()
    val_kappa_loss = tf.metrics.Mean()
    history = {  # recorded at the end of an epoch only
        "train_cost": [],
        "train_chi_squared": [],
        "train_source_cost": [],
        "train_kappa_cost": [],
        "val_cost": [],
        "val_chi_squared": [],
        "val_source_cost": [],
        "val_kappa_cost": [],
        "learning_rate": [],
        "time_per_step": [],
        "step": [],
        "wall_time": []
    }
    best_loss = np.inf
    patience = args.patience
    step = 0
    global_start = time.time()
    estimated_time_for_epoch = 0
    out_of_time = False
    lastest_checkpoint = 1
    for epoch in range(args.epochs):
        if (time.time() - global_start) > args.max_time*3600 - estimated_time_for_epoch:
            break
        epoch_start = time.time()
        epoch_loss.reset_states()
        epoch_chi_squared.reset_states()
        epoch_source_loss.reset_states()
        epoch_kappa_loss.reset_states()
        time_per_step.reset_states()
        with writer.as_default():
            for batch, (X, source, kappa, noise_rms, psf) in enumerate(train_dataset):
                start = time.time()
                cost, chi_squared, source_cost, kappa_cost = distributed_train_step(X, source, kappa, noise_rms, psf)
        # ========== Summary and logs ==================================================================================
                _time = time.time() - start
                time_per_step.update_state([_time])
                epoch_loss.update_state([cost])
                epoch_chi_squared.update_state([chi_squared])
                epoch_source_loss.update_state([source_cost])
                epoch_kappa_loss.update_state([kappa_cost])
                step += 1
            # last batch we make a summary of residuals
            if args.n_residuals > 0:
                source_pred, kappa_pred, chi_squared = rim.predict(X, noise_rms, psf)
                lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf)
            for res_idx in range(min(args.n_residuals, args.batch_size)):
                try:
                    tf.summary.image(f"Residuals {res_idx}",
                                     plot_to_image(
                                         residual_plot(
                                             X[res_idx],
                                             source[res_idx],
                                             kappa[res_idx],
                                             lens_pred[res_idx],
                                             source_pred[-1][res_idx],
                                             kappa_pred[-1][res_idx],
                                             chi_squared[-1][res_idx]
                                         )), step=step)
                except ValueError:
                    continue

            # ========== Validation set ===================
            val_loss.reset_states()
            val_chi_squared.reset_states()
            val_source_loss.reset_states()
            val_kappa_loss.reset_states()
            for X, source, kappa, noise_rms, psf in val_dataset:
                cost, chi_squared, source_cost, kappa_cost = distributed_test_step(X, source, kappa, noise_rms, psf)
                val_loss.update_state([cost])
                val_chi_squared.update_state([chi_squared])
                val_source_loss.update_state([source_cost])
                val_kappa_loss.update_state([kappa_cost])

            if args.n_residuals > 0 and math.ceil((1 - args.train_split) * args.total_items) > 0:  # validation set not empty set not empty
                source_pred, kappa_pred, chi_squared = rim.predict(X, noise_rms, psf)
                lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf)
            for res_idx in range(min(args.n_residuals, args.batch_size, math.ceil((1 - args.train_split) * args.total_items))):
                try:
                    tf.summary.image(f"Val Residuals {res_idx}",
                                     plot_to_image(
                                         residual_plot(
                                             X[res_idx],  # rescale intensity like it is done in the likelihood
                                             source[res_idx],
                                             kappa[res_idx],
                                             lens_pred[res_idx],
                                             source_pred[-1][res_idx],
                                             kappa_pred[-1][res_idx],
                                             chi_squared[-1][res_idx]
                                         )), step=step)
                except ValueError:
                    continue
            val_cost = val_loss.result().numpy()
            train_cost = epoch_loss.result().numpy()
            val_chi_sq = val_chi_squared.result().numpy()
            train_chi_sq = epoch_chi_squared.result().numpy()
            val_kappa_cost = val_kappa_loss.result().numpy()
            train_kappa_cost = epoch_kappa_loss.result().numpy()
            val_source_cost = val_source_loss.result().numpy()
            train_source_cost = epoch_source_loss.result().numpy()
            tf.summary.scalar("Time per step", time_per_step.result(), step=step)
            tf.summary.scalar("Chi Squared", train_chi_sq, step=step)
            tf.summary.scalar("Kappa cost", train_kappa_cost, step=step)
            tf.summary.scalar("Val Kappa cost", val_kappa_cost, step=step)
            tf.summary.scalar("Source cost", train_source_cost, step=step)
            tf.summary.scalar("Val Source cost", val_source_cost, step=step)
            tf.summary.scalar("MSE", train_cost, step=step)
            tf.summary.scalar("Val MSE", val_cost, step=step)
            tf.summary.scalar("Learning Rate", optim.lr(step), step=step)
            tf.summary.scalar("Val Chi Squared", val_chi_sq, step=step)
        print(f"epoch {epoch} | train loss {train_cost:.3e} | val loss {val_cost:.3e} "
              f"| lr {optim.lr(step).numpy():.2e} | time per step {time_per_step.result().numpy():.2e} s"
              f"| kappa cost {train_kappa_cost:.2e} | source cost {train_source_cost:.2e} | chi sq {train_chi_sq:.2e}")
        history["train_cost"].append(train_cost)
        history["val_cost"].append(val_cost)
        history["learning_rate"].append(optim.lr(step).numpy())
        history["train_chi_squared"].append(train_chi_sq)
        history["val_chi_squared"].append(val_chi_sq)
        history["time_per_step"].append(time_per_step.result().numpy())
        history["train_kappa_cost"].append(train_kappa_cost)
        history["train_source_cost"].append(train_source_cost)
        history["val_kappa_cost"].append(val_kappa_cost)
        history["val_source_cost"].append(val_source_cost)
        history["step"].append(step)
        history["wall_time"].append(time.time() - global_start)

        cost = train_cost if args.track_train else val_cost
        if np.isnan(cost):
            print("Training broke the Universe")
            break
        if cost < (1 - args.tolerance) * best_loss:
            best_loss = cost
            patience = args.patience
        else:
            patience -= 1
        if (time.time() - global_start) > args.max_time * 3600:
            out_of_time = True
        if save_checkpoint:
            checkpoint_manager.checkpoint.step.assign_add(1) # a bit of a hack
            if epoch % args.checkpoints == 0 or patience == 0 or epoch == args.epochs - 1 or out_of_time:
                with open(os.path.join(checkpoints_dir, "score_sheet.txt"), mode="a") as f:
                    np.savetxt(f, np.array([[lastest_checkpoint, cost]]))
                lastest_checkpoint += 1
                checkpoint_manager.save()
                print("Saved checkpoint for step {}: {}".format(int(checkpoint_manager.checkpoint.step), checkpoint_manager.latest_checkpoint))
        if patience == 0:
            print("Reached patience")
            break
        if out_of_time:
            break
        if epoch > 0:  # First epoch is always very slow and not a good estimate of an epoch time.
            estimated_time_for_epoch = time.time() - epoch_start
        if optim.lr(step).numpy() < 1e-8:
            print("Reached learning rate limit")
            break
    print(f"Finished training after {(time.time() - global_start)/3600:.3f} hours.")
    return history, best_loss
Ejemplo n.º 9
0
def main(args):
    files = glob.glob(os.path.join(args.dataset, "*.tfrecords"))
    files = tf.data.Dataset.from_tensor_slices(files)
    dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
        x, compression_type=args.compression_type),
                               block_length=1,
                               num_parallel_calls=tf.data.AUTOTUNE)
    for physical_params in dataset.map(decode_physical_model_info):
        break
    dataset = dataset.map(decode_train)

    # files = glob.glob(os.path.join(args.source_dataset, "*.tfrecords"))
    # files = tf.data.Dataset.from_tensor_slices(files)
    # source_dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type),
    #                            block_length=1, num_parallel_calls=tf.data.AUTOTUNE)
    # source_dataset = source_dataset.map(decode_image).map(preprocess_image).shuffle(10000).batch(args.sample_size)

    with open(os.path.join(args.kappa_vae, "model_hparams.json"), "r") as f:
        kappa_vae_hparams = json.load(f)
    kappa_vae = VAE(**kappa_vae_hparams)
    ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae)
    checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, args.kappa_vae, 1)
    checkpoint_manager1.checkpoint.restore(
        checkpoint_manager1.latest_checkpoint).expect_partial()

    with open(os.path.join(args.source_vae, "model_hparams.json"), "r") as f:
        source_vae_hparams = json.load(f)
    source_vae = VAE(**source_vae_hparams)
    ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae)
    checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, args.source_vae, 1)
    checkpoint_manager2.checkpoint.restore(
        checkpoint_manager2.latest_checkpoint).expect_partial()

    phys = PhysicalModel(pixels=physical_params["pixels"].numpy(),
                         kappa_pixels=physical_params["kappa pixels"].numpy(),
                         src_pixels=physical_params["src pixels"].numpy(),
                         image_fov=physical_params["image fov"].numpy(),
                         kappa_fov=physical_params["kappa fov"].numpy(),
                         src_fov=physical_params["source fov"].numpy(),
                         method="fft")

    # simulate observations
    kappa = 10**kappa_vae.sample(args.sample_size)
    source = preprocess_image(source_vae.sample(args.sample_size))
    # for source in source_dataset:
    #     break
    fwhm = tf.random.normal(shape=[args.sample_size],
                            mean=1.5 * phys.image_fov / phys.pixels,
                            stddev=0.5 * phys.image_fov / phys.pixels)
    # noise_rms = tf.random.normal(shape=[args.sample_size], mean=args.noise_mean, stddev=args.noise_std)
    psf = phys.psf_models(fwhm, cutout_size=20)
    y_vae = phys.forward(source, kappa, psf)

    with h5py.File(
            os.path.join(os.getenv("CENSAI_PATH"), "results",
                         args.output_name + ".h5"), 'w') as hf:
        # rank these observations against the dataset with L2 norm
        for i in tqdm(range(args.sample_size)):
            distances = []
            for y_d, _, _, _, _ in dataset:
                distances.append(
                    tf.sqrt(tf.reduce_sum(
                        (y_d - y_vae[i][None, ...])**2)).numpy().astype(
                            np.float32))
            k_indices = np.argsort(distances)[:args.k]

            # save results
            g = hf.create_group(f"sample_{i:02d}")
            g.create_dataset(name="matched_source",
                             shape=[args.k, phys.src_pixels, phys.src_pixels],
                             dtype=np.float32)
            g.create_dataset(
                name="matched_kappa",
                shape=[args.k, phys.kappa_pixels, phys.kappa_pixels],
                dtype=np.float32)
            g.create_dataset(name="matched_obs",
                             shape=[args.k, phys.pixels, phys.pixels],
                             dtype=np.float32)
            g.create_dataset(name="matched_psf",
                             shape=[args.k, 20, 20],
                             dtype=np.float32)
            g.create_dataset(name="matched_noise_rms",
                             shape=[args.k],
                             dtype=np.float32)
            g.create_dataset(name="obs_L2_distance",
                             shape=[args.k],
                             dtype=np.float32)
            g["vae_source"] = source[i, ..., 0].numpy().astype(np.float32)
            g["vae_kappa"] = kappa[i, ..., 0].numpy().astype(np.float32)
            g["vae_obs"] = y_vae[i, ..., 0].numpy().astype(np.float32)
            g["vae_psf"] = psf[i, ..., 0].numpy().astype(np.float32)

            for rank, j in enumerate(k_indices):
                # fetch back the matched observation
                for y_d, source_d, kappa_d, noise_rms_d, psf_d in dataset.skip(
                        j):
                    break
                # g["vae_noise_rms"] = noise_rms[i].numpy().astype(np.float32)
                g["matched_source"][rank] = source_d[..., 0].numpy().astype(
                    np.float32)
                g["matched_kappa"][rank] = kappa_d[..., 0].numpy().astype(
                    np.float32)
                g["matched_obs"][rank] = y_d[..., 0].numpy().astype(np.float32)
                g["matched_noise_rms"][rank] = noise_rms_d.numpy().astype(
                    np.float32)
                g["matched_psf"][rank] = psf_d[...,
                                               0].numpy().astype(np.float32)
                g["obs_L2_distance"][rank] = distances[j]
Ejemplo n.º 10
0
def test_deflection_angle_conv2():
    phys = PhysicalModel(pixels=64, src_pixels=64)
    kappa = tf.random.normal([1, 64, 64, 1])
    phys.deflection_angle(kappa)
Ejemplo n.º 11
0
def distributed_strategy(args):
    kappa_datasets = []
    for path in args.kappa_datasets:
        files = glob.glob(os.path.join(path, "*.tfrecords"))
        files = tf.data.Dataset.from_tensor_slices(files).shuffle(
            len(files), reshuffle_each_iteration=True)
        dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
            x, compression_type=args.compression_type),
                                   block_length=args.block_length,
                                   num_parallel_calls=tf.data.AUTOTUNE)
        kappa_datasets.append(
            dataset.shuffle(args.buffer_size, reshuffle_each_iteration=True))
    kappa_dataset = tf.data.experimental.sample_from_datasets(
        kappa_datasets, weights=args.kappa_datasets_weights)
    # Read off global parameters from first example in dataset
    for example in kappa_dataset.map(decode_kappa_info):
        kappa_fov = example["kappa fov"].numpy()
        kappa_pixels = example["kappa pixels"].numpy()
        break
    kappa_dataset = kappa_dataset.map(decode_kappa).batch(args.batch_size)

    cosmos_datasets = []
    for path in args.cosmos_datasets:
        files = glob.glob(os.path.join(path, "*.tfrecords"))
        files = tf.data.Dataset.from_tensor_slices(files).shuffle(
            len(files), reshuffle_each_iteration=True)
        dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
            x, compression_type=args.compression_type),
                                   block_length=args.block_length,
                                   num_parallel_calls=tf.data.AUTOTUNE)
        cosmos_datasets.append(
            dataset.shuffle(args.buffer_size, reshuffle_each_iteration=True))
    cosmos_dataset = tf.data.experimental.sample_from_datasets(
        cosmos_datasets, weights=args.cosmos_datasets_weights)
    # Read off global parameters from first example in dataset
    for src_pixels in cosmos_dataset.map(decode_cosmos_info):
        src_pixels = src_pixels.numpy()
        break
    cosmos_dataset = cosmos_dataset.map(decode_cosmos).map(
        preprocess_cosmos).batch(args.batch_size)

    window = tukey(src_pixels, alpha=args.tukey_alpha)
    window = np.outer(window, window)[np.newaxis, ..., np.newaxis]
    window = tf.constant(window, dtype=DTYPE)

    phys = PhysicalModel(image_fov=kappa_fov,
                         src_fov=args.source_fov,
                         pixels=args.lens_pixels,
                         kappa_pixels=kappa_pixels,
                         src_pixels=src_pixels,
                         kappa_fov=kappa_fov,
                         method="conv2d")

    noise_a = (args.noise_rms_min - args.noise_rms_mean) / args.noise_rms_std
    noise_b = (args.noise_rms_max - args.noise_rms_mean) / args.noise_rms_std
    psf_a = (args.psf_fwhm_min - args.psf_fwhm_mean) / args.psf_fwhm_std
    psf_b = (args.psf_fwhm_max - args.psf_fwhm_mean) / args.psf_fwhm_std

    options = tf.io.TFRecordOptions(compression_type=args.compression_type)
    with tf.io.TFRecordWriter(
            os.path.join(args.output_dir, f"data_{THIS_WORKER}.tfrecords"),
            options) as writer:
        print(
            f"Started worker {THIS_WORKER} at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}"
        )
        for i in range((THIS_WORKER - 1) * args.batch_size, args.len_dataset,
                       N_WORKERS * args.batch_size):
            for galaxies in cosmos_dataset:  # select a random batch from our dataset that is reshuffled each iterations
                break
            for kappa in kappa_dataset:
                break
            galaxies = window * galaxies
            noise_rms = truncnorm.rvs(noise_a,
                                      noise_b,
                                      loc=args.noise_rms_mean,
                                      scale=args.noise_rms_std,
                                      size=args.batch_size)
            fwhm = truncnorm.rvs(psf_a,
                                 psf_b,
                                 loc=args.psf_fwhm_mean,
                                 scale=args.psf_fwhm_std,
                                 size=args.batch_size)
            psf = phys.psf_models(fwhm, cutout_size=args.psf_cutout_size)
            lensed_images = phys.noisy_forward(galaxies,
                                               kappa,
                                               noise_rms=noise_rms,
                                               psf=psf)
            records = encode_examples(kappa=kappa,
                                      galaxies=galaxies,
                                      lensed_images=lensed_images,
                                      z_source=args.z_source,
                                      z_lens=args.z_lens,
                                      image_fov=phys.image_fov,
                                      kappa_fov=phys.kappa_fov,
                                      source_fov=args.source_fov,
                                      noise_rms=noise_rms,
                                      psf=psf,
                                      fwhm=fwhm)
            for record in records:
                writer.write(record)
    print(f"Finished work at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}")
Ejemplo n.º 12
0
def test_analytic2():
    phys = AnalyticalPhysicalModelv2(pixels=64)
    # try to oversample kappa map for comparison
    # phys2 = AnalyticalPhysicalModelv2(pixels=256)
    # phys_ref = PhysicalModel(pixels=256, src_pixels=64, method="fft", kappa_fov=7.68)
    phys_ref = PhysicalModel(pixels=64, method="fft")

    # make some dummy coordinates
    x = np.linspace(-1, 1, 64) * 3.0 / 2
    xx, yy = np.meshgrid(x, x)

    # lens params
    r_ein = 2
    x0 = 0.1
    y0 = 0.1
    q = 0.98
    phi = 0.  #np.pi/3

    gamma = 0.
    phi_gamma = 0.

    # source params
    xs = 0.1
    ys = 0.1
    qs = 0.9
    phi_s = 0.  #np.pi/4
    r_eff = 0.4
    n = 1.

    source = phys.sersic_source(xx, yy, xs, ys, qs, phi_s, n, r_eff)[None, ...,
                                                                     None]
    # kappa = phys2.kappa_field(r_ein, q, phi, x0, y0)
    kappa = phys.kappa_field(r_ein, q, phi, x0, y0)
    lens_a = phys.lens_source_sersic_func(r_ein, q, phi, x0, y0, gamma,
                                          phi_gamma, xs, ys, qs, phi_s, n,
                                          r_eff)
    lens_b = phys_ref.lens_source(source, kappa)
    # lens_b = tf.image.resize(lens_b, [64, 64])
    lens_c = phys.lens_source(source, r_ein, q, phi, x0, y0, gamma, phi_gamma)
    # alpha1t, alpha2t = tf.split(phys.analytical_deflection_angles(r_ein, q, phi, x0, y0), 2, axis=-1)
    alpha1t, alpha2t = tf.split(phys.approximate_deflection_angles(
        r_ein, q, phi, x0, y0),
                                2,
                                axis=-1)
    alpha1, alpha2 = phys_ref.deflection_angle(kappa)
    # alpha1 = tf.image.resize(alpha1, [64, 64])
    # alpha2 = tf.image.resize(alpha2, [64, 64])
    beta1 = phys.theta1 - alpha1
    beta2 = phys.theta2 - alpha2
    lens_d = phys.sersic_source(beta1, beta2, xs, ys, q, phi, n, r_eff)

    # plt.imshow(source[0, ..., 0], cmap="twilight", origin="lower")
    # plt.colorbar()
    # plt.show()
    # plt.imshow(np.log10(kappa[0, ..., 0]), cmap="twilight", origin="lower")
    # plt.colorbar()
    # plt.show()
    plt.imshow(lens_a[0, ..., 0], cmap="twilight", origin="lower")
    plt.colorbar()
    plt.show()
    plt.imshow(lens_b[0, ..., 0], cmap="twilight", origin="lower")
    plt.colorbar()
    plt.show()
    # plt.imshow(lens_c[0, ..., 0], cmap="twilight", origin="lower")
    # plt.colorbar()
    # plt.show()
    # plt.imshow(lens_d[0, ..., 0], cmap="twilight", origin="lower")
    # plt.colorbar()
    # plt.show()
    plt.imshow(((lens_a - lens_b))[0, ..., 0],
               cmap="seismic",
               vmin=-0.1,
               vmax=0.1,
               origin="lower")
    plt.colorbar()
    plt.show()
    # plt.imshow(((lens_a - lens_c))[0, ..., 0], cmap="seismic", vmin=-0.1, vmax=0.1, origin="lower")
    # plt.colorbar()
    # plt.show()
    # plt.imshow(((lens_a - lens_d))[0, ..., 0], cmap="seismic", vmin=-0.1, vmax=0.1, origin="lower")
    # plt.colorbar()
    # plt.show()
    plt.imshow(((alpha1t - alpha1))[0, ..., 0],
               cmap="seismic",
               vmin=-1,
               vmax=1,
               origin="lower")
    plt.colorbar()
    plt.show()

    pass
Ejemplo n.º 13
0
def distributed_strategy(args):
    psf_pixels = 20
    pixels = 128
    model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model)

    ps_observation = PowerSpectrum(bins=args.observation_coherence_bins,
                                   pixels=pixels)
    ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=pixels)
    ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=pixels)

    phys = PhysicalModel(
        pixels=pixels,
        kappa_pixels=pixels,
        src_pixels=pixels,
        image_fov=7.68,
        kappa_fov=7.68,
        src_fov=3.,
        method="fft",
    )

    with open(os.path.join(model, "unet_hparams.json")) as f:
        unet_params = json.load(f)
    unet_params["kernel_l2_amp"] = args.l2_amp
    unet = Model(**unet_params)
    ckpt = tf.train.Checkpoint(net=unet)
    checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1)
    checkpoint_manager.checkpoint.restore(
        checkpoint_manager.latest_checkpoint).expect_partial()
    with open(os.path.join(model, "rim_hparams.json")) as f:
        rim_params = json.load(f)
    rim_params["source_link"] = "relu"
    rim = RIM(phys, unet, **rim_params)

    kvae_path = os.path.join(os.getenv('CENSAI_PATH'), "models",
                             args.kappa_vae)
    with open(os.path.join(kvae_path, "model_hparams.json"), "r") as f:
        kappa_vae_hparams = json.load(f)
    kappa_vae = VAE(**kappa_vae_hparams)
    ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae)
    checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, kvae_path, 1)
    checkpoint_manager1.checkpoint.restore(
        checkpoint_manager1.latest_checkpoint).expect_partial()

    svae_path = os.path.join(os.getenv('CENSAI_PATH'), "models",
                             args.source_vae)
    with open(os.path.join(svae_path, "model_hparams.json"), "r") as f:
        source_vae_hparams = json.load(f)
    source_vae = VAE(**source_vae_hparams)
    ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae)
    checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, svae_path, 1)
    checkpoint_manager2.checkpoint.restore(
        checkpoint_manager2.latest_checkpoint).expect_partial()

    model_name = os.path.split(model)[-1]
    wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum(
        tf.sqrt(k), axis=(1, 2, 3), keepdims=True))
    with h5py.File(
            os.path.join(
                os.getenv("CENSAI_PATH"), "results", args.experiment_name +
                "_" + model_name + f"_{THIS_WORKER:02d}.h5"), 'w') as hf:
        data_len = args.size // N_WORKERS
        hf.create_dataset(name="observation",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(name="psf",
                          shape=[data_len, psf_pixels, psf_pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32)
        hf.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32)
        hf.create_dataset(
            name="source",
            shape=[data_len, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(
            name="kappa",
            shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(name="observation_pred",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(name="observation_pred_reoptimized",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(
            name="source_pred",
            shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(
            name="source_pred_reoptimized",
            shape=[data_len, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(name="kappa_pred",
                          shape=[
                              data_len, rim.steps, phys.kappa_pixels,
                              phys.kappa_pixels, 1
                          ],
                          dtype=np.float32)
        hf.create_dataset(
            name="kappa_pred_reoptimized",
            shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(name="chi_squared",
                          shape=[data_len, rim.steps],
                          dtype=np.float32)
        hf.create_dataset(name="chi_squared_reoptimized",
                          shape=[data_len, rim.steps],
                          dtype=np.float32)
        hf.create_dataset(name="chi_squared_reoptimized_series",
                          shape=[data_len, rim.steps, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="sampled_chi_squared_reoptimized_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="source_optim_mse",
                          shape=[data_len],
                          dtype=np.float32)
        hf.create_dataset(name="source_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="sampled_source_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_optim_mse",
                          shape=[data_len],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="sampled_kappa_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="latent_kappa_gt_distance_init",
                          shape=[data_len, kappa_vae.latent_size],
                          dtype=np.float32)
        hf.create_dataset(name="latent_source_gt_distance_init",
                          shape=[data_len, source_vae.latent_size],
                          dtype=np.float32)
        hf.create_dataset(name="latent_kappa_gt_distance_end",
                          shape=[data_len, kappa_vae.latent_size],
                          dtype=np.float32)
        hf.create_dataset(name="latent_source_gt_distance_end",
                          shape=[data_len, source_vae.latent_size],
                          dtype=np.float32)
        hf.create_dataset(name="source_coherence_spectrum",
                          shape=[data_len, args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="source_coherence_spectrum_reoptimized",
                          shape=[data_len, args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_coherence_spectrum",
                          shape=[data_len, args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_coherence_spectrum_reoptimized",
                          shape=[data_len, args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_coherence_spectrum",
                          shape=[data_len, args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_coherence_spectrum_reoptimized",
                          shape=[data_len, args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_frequencies",
                          shape=[args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="source_frequencies",
                          shape=[args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_frequencies",
                          shape=[args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32)
        hf.create_dataset(name="source_fov", shape=[1], dtype=np.float32)
        hf.create_dataset(name="observation_fov", shape=[1], dtype=np.float32)
        for i in range(data_len):
            checkpoint_manager.checkpoint.restore(
                checkpoint_manager.latest_checkpoint).expect_partial(
                )  # reset model weights

            # Produce an observation
            kappa = 10**kappa_vae.sample(1)
            source = tf.nn.relu(source_vae.sample(1))
            source /= tf.reduce_max(source, axis=(1, 2, 3), keepdims=True)
            noise_rms = 10**tf.random.uniform(shape=[1],
                                              minval=-2.5,
                                              maxval=-1)
            fwhm = tf.random.uniform(shape=[1], minval=0.06, maxval=0.3)
            psf = phys.psf_models(fwhm, cutout_size=psf_pixels)
            observation = phys.noisy_forward(source, kappa, noise_rms, psf)

            # RIM predictions for kappa and source
            source_pred, kappa_pred, chi_squared = rim.predict(
                observation, noise_rms, psf)
            observation_pred = phys.forward(source_pred[-1], kappa_pred[-1],
                                            psf)
            source_o = source_pred[-1]
            kappa_o = kappa_pred[-1]

            # Latent code of model predictions
            z_source, _ = source_vae.encoder(source_o)
            z_kappa, _ = kappa_vae.encoder(log_10(kappa_o))

            # Ground truth latent code for oracle metrics
            z_source_gt, _ = source_vae.encoder(source)
            z_kappa_gt, _ = kappa_vae.encoder(log_10(kappa))

            # Re-optimize weights of the model
            STEPS = args.re_optimize_steps
            learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=args.learning_rate,
                decay_rate=args.decay_rate,
                decay_steps=args.decay_steps,
                staircase=args.staircase)
            optim = tf.keras.optimizers.RMSprop(
                learning_rate=learning_rate_schedule)

            chi_squared_series = tf.TensorArray(DTYPE, size=STEPS)
            source_mse = tf.TensorArray(DTYPE, size=STEPS)
            kappa_mse = tf.TensorArray(DTYPE, size=STEPS)
            sampled_chi_squared_series = tf.TensorArray(DTYPE, size=STEPS)
            sampled_source_mse = tf.TensorArray(DTYPE, size=STEPS)
            sampled_kappa_mse = tf.TensorArray(DTYPE, size=STEPS)

            best = chi_squared
            source_best = source_pred[-1]
            kappa_best = kappa_pred[-1]
            source_mse_best = tf.reduce_mean((source_best - source)**2)
            kappa_mse_best = tf.reduce_mean((kappa_best - log_10(kappa))**2)

            # ===================== Optimization ==============================
            for current_step in tqdm(range(STEPS)):
                # ===================== VAE SAMPLING ==============================

                # L1 distance with ground truth in latent space -- this is changed by an user defined value when using real data
                # z_source_std = tf.abs(z_source - z_source_gt)
                # z_kappa_std = tf.abs(z_kappa - z_kappa_gt)
                z_source_std = args.source_vae_ball_size
                z_kappa_std = args.kappa_vae_ball_size

                # Sample latent code, then decode and forward
                z_s = tf.random.normal(
                    shape=[args.sample_size, source_vae.latent_size],
                    mean=z_source,
                    stddev=z_source_std)
                z_k = tf.random.normal(
                    shape=[args.sample_size, kappa_vae.latent_size],
                    mean=z_kappa,
                    stddev=z_kappa_std)
                sampled_source = tf.nn.relu(source_vae.decode(z_s))
                sampled_source /= tf.reduce_max(sampled_source,
                                                axis=(1, 2, 3),
                                                keepdims=True)
                sampled_kappa = kappa_vae.decode(z_k)  # output in log_10 space
                sampled_observation = phys.noisy_forward(
                    sampled_source, 10**sampled_kappa, noise_rms,
                    tf.tile(psf, [args.sample_size, 1, 1, 1]))
                with tf.GradientTape() as tape:
                    tape.watch(unet.trainable_variables)
                    s, k, chi_sq = rim.call(
                        sampled_observation,
                        noise_rms,
                        tf.tile(psf, [args.sample_size, 1, 1, 1]),
                        outer_tape=tape)
                    _kappa_mse = tf.reduce_sum(wk(10**sampled_kappa) *
                                               (k - sampled_kappa)**2,
                                               axis=(2, 3, 4))
                    cost = tf.reduce_mean(_kappa_mse)
                    cost += tf.reduce_mean((s - sampled_source)**2)
                    cost += tf.reduce_sum(rim.unet.losses)  # weight decay

                grads = tape.gradient(cost, unet.trainable_variables)
                optim.apply_gradients(zip(grads, unet.trainable_variables))

                # Record performance on sampled dataset
                sampled_chi_squared_series = sampled_chi_squared_series.write(
                    index=current_step,
                    value=tf.squeeze(tf.reduce_mean(chi_sq[-1])))
                sampled_source_mse = sampled_source_mse.write(
                    index=current_step,
                    value=tf.reduce_mean((s[-1] - sampled_source)**2))
                sampled_kappa_mse = sampled_kappa_mse.write(
                    index=current_step,
                    value=tf.reduce_mean((k[-1] - sampled_kappa)**2))
                # Record model prediction on data
                s, k, chi_sq = rim.call(observation, noise_rms, psf)
                chi_squared_series = chi_squared_series.write(
                    index=current_step, value=tf.squeeze(chi_sq))
                source_o = s[-1]
                kappa_o = k[-1]
                # oracle metrics, remove when using real data
                source_mse = source_mse.write(index=current_step,
                                              value=tf.reduce_mean(
                                                  (source_o - source)**2))
                kappa_mse = kappa_mse.write(index=current_step,
                                            value=tf.reduce_mean(
                                                (kappa_o - log_10(kappa))**2))

                if abs(chi_sq[-1, 0] - 1) < abs(best[-1, 0] - 1):
                    source_best = tf.nn.relu(source_o)
                    kappa_best = 10**kappa_o
                    best = chi_sq
                    source_mse_best = tf.reduce_mean((source_best - source)**2)
                    kappa_mse_best = tf.reduce_mean(
                        (kappa_best - log_10(kappa))**2)

            source_o = source_best
            kappa_o = kappa_best
            y_pred = phys.forward(source_o, kappa_o, psf)

            chi_sq_series = tf.transpose(chi_squared_series.stack())
            source_mse = source_mse.stack()
            kappa_mse = kappa_mse.stack()
            sampled_chi_squared_series = sampled_chi_squared_series.stack()
            sampled_source_mse = sampled_source_mse.stack()
            sampled_kappa_mse = sampled_kappa_mse.stack()

            # Latent code of optimized model predictions
            z_source_opt, _ = source_vae.encoder(tf.nn.relu(source_o))
            z_kappa_opt, _ = kappa_vae.encoder(log_10(kappa_o))

            # Compute Power spectrum of converged predictions
            _ps_observation = ps_observation.cross_correlation_coefficient(
                observation[..., 0], observation_pred[..., 0])
            _ps_observation2 = ps_observation.cross_correlation_coefficient(
                observation[..., 0], y_pred[..., 0])
            _ps_kappa = ps_kappa.cross_correlation_coefficient(
                log_10(kappa)[..., 0],
                log_10(kappa_pred[-1])[..., 0])
            _ps_kappa2 = ps_kappa.cross_correlation_coefficient(
                log_10(kappa)[..., 0], log_10(kappa_o[..., 0]))
            _ps_source = ps_source.cross_correlation_coefficient(
                source[..., 0], source_pred[-1][..., 0])
            _ps_source2 = ps_source.cross_correlation_coefficient(
                source[..., 0], source_o[..., 0])

            # save results
            hf["observation"][i] = observation.numpy().astype(np.float32)
            hf["psf"][i] = psf.numpy().astype(np.float32)
            hf["psf_fwhm"][i] = fwhm.numpy().astype(np.float32)
            hf["noise_rms"][i] = noise_rms.numpy().astype(np.float32)
            hf["source"][i] = source.numpy().astype(np.float32)
            hf["kappa"][i] = kappa.numpy().astype(np.float32)
            hf["observation_pred"][i] = observation_pred.numpy().astype(
                np.float32)
            hf["observation_pred_reoptimized"][i] = y_pred.numpy().astype(
                np.float32)
            hf["source_pred"][i] = tf.transpose(
                source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            hf["source_pred_reoptimized"][i] = source_o.numpy().astype(
                np.float32)
            hf["kappa_pred"][i] = tf.transpose(
                kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            hf["kappa_pred_reoptimized"][i] = kappa_o.numpy().astype(
                np.float32)
            hf["chi_squared"][i] = tf.squeeze(chi_squared).numpy().astype(
                np.float32)
            hf["chi_squared_reoptimized"][i] = tf.squeeze(best).numpy().astype(
                np.float32)
            hf["chi_squared_reoptimized_series"][i] = chi_sq_series.numpy(
            ).astype(np.float32)
            hf["sampled_chi_squared_reoptimized_series"][
                i] = 2 * sampled_chi_squared_series.numpy().astype(np.float32)
            hf["source_optim_mse"][i] = source_mse_best.numpy().astype(
                np.float32)
            hf["source_optim_mse_series"][i] = source_mse.numpy().astype(
                np.float32)
            hf["sampled_source_optim_mse_series"][
                i] = sampled_source_mse.numpy().astype(np.float32)
            hf["kappa_optim_mse"][i] = kappa_mse_best.numpy().astype(
                np.float32)
            hf["kappa_optim_mse_series"][i] = kappa_mse.numpy().astype(
                np.float32)
            hf["sampled_kappa_optim_mse_series"][i] = sampled_kappa_mse.numpy(
            ).astype(np.float32)
            hf["latent_source_gt_distance_init"][i] = tf.abs(
                z_source - z_source_gt).numpy().squeeze().astype(np.float32)
            hf["latent_kappa_gt_distance_init"][i] = tf.abs(
                z_kappa - z_kappa_gt).numpy().squeeze().astype(np.float32)
            hf["latent_source_gt_distance_end"][i] = tf.abs(
                z_source_opt - z_source_gt).numpy().squeeze().astype(
                    np.float32)
            hf["latent_kappa_gt_distance_end"][i] = tf.abs(
                z_kappa_opt - z_kappa_gt).numpy().squeeze().astype(np.float32)
            hf["observation_coherence_spectrum"][i] = _ps_observation
            hf["observation_coherence_spectrum_reoptimized"][
                i] = _ps_observation2
            hf["source_coherence_spectrum"][i] = _ps_source
            hf["source_coherence_spectrum_reoptimized"][i] = _ps_source2
            hf["kappa_coherence_spectrum"][i] = _ps_kappa
            hf["kappa_coherence_spectrum_reoptimized"][i] = _ps_kappa2

            if i == 0:
                _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels //
                                                                2],
                                    bins=ps_observation.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["observation_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(
                    phys.src_pixels)[:phys.src_pixels // 2],
                                    bins=ps_source.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["source_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(
                    phys.kappa_pixels)[:phys.kappa_pixels // 2],
                                    bins=ps_kappa.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["kappa_frequencies"][:] = f
                hf["kappa_fov"][0] = phys.kappa_fov
                hf["source_fov"][0] = phys.src_fov
Ejemplo n.º 14
0
def distributed_strategy(args):
    tf.random.set_seed(args.seed)
    np.random.seed(args.seed)

    model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model)
    files = glob.glob(
        os.path.join(os.getenv('CENSAI_PATH'), "data", args.train_dataset,
                     "*.tfrecords"))
    files = tf.data.Dataset.from_tensor_slices(files)
    train_dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
        x, compression_type=args.compression_type).shuffle(len(files)),
                                     block_length=1,
                                     num_parallel_calls=tf.data.AUTOTUNE)
    # Read off global parameters from first example in dataset
    for physical_params in train_dataset.map(decode_physical_model_info):
        break
    train_dataset = train_dataset.map(decode_results).shuffle(
        buffer_size=args.buffer_size)

    files = glob.glob(
        os.path.join(os.getenv('CENSAI_PATH'), "data", args.val_dataset,
                     "*.tfrecords"))
    files = tf.data.Dataset.from_tensor_slices(files)
    val_dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
        x, compression_type=args.compression_type).shuffle(len(files)),
                                   block_length=1,
                                   num_parallel_calls=tf.data.AUTOTUNE)
    val_dataset = val_dataset.map(decode_results).shuffle(
        buffer_size=args.buffer_size)

    files = glob.glob(
        os.path.join(os.getenv('CENSAI_PATH'), "data", args.test_dataset,
                     "*.tfrecords"))
    files = tf.data.Dataset.from_tensor_slices(files)
    test_dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
        x, compression_type=args.compression_type).shuffle(len(files)),
                                    block_length=1,
                                    num_parallel_calls=tf.data.AUTOTUNE)
    test_dataset = test_dataset.map(decode_results).shuffle(
        buffer_size=args.buffer_size)

    ps_lens = PowerSpectrum(bins=args.lens_coherence_bins,
                            pixels=physical_params["pixels"].numpy())
    ps_source = PowerSpectrum(bins=args.source_coherence_bins,
                              pixels=physical_params["src pixels"].numpy())
    ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins,
                             pixels=physical_params["kappa pixels"].numpy())

    phys = PhysicalModel(
        pixels=physical_params["pixels"].numpy(),
        kappa_pixels=physical_params["kappa pixels"].numpy(),
        src_pixels=physical_params["src pixels"].numpy(),
        image_fov=physical_params["image fov"].numpy(),
        kappa_fov=physical_params["kappa fov"].numpy(),
        src_fov=physical_params["source fov"].numpy(),
        method="fft",
    )

    phys_sie = AnalyticalPhysicalModel(
        pixels=physical_params["pixels"].numpy(),
        image_fov=physical_params["image fov"].numpy(),
        src_fov=physical_params["source fov"].numpy())

    with open(os.path.join(model, "unet_hparams.json")) as f:
        unet_params = json.load(f)
    unet_params["kernel_l2_amp"] = args.l2_amp
    unet = Model(**unet_params)
    ckpt = tf.train.Checkpoint(net=unet)
    checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1)
    checkpoint_manager.checkpoint.restore(
        checkpoint_manager.latest_checkpoint).expect_partial()
    with open(os.path.join(model, "rim_hparams.json")) as f:
        rim_params = json.load(f)
    rim = RIM(phys, unet, **rim_params)

    dataset_names = [args.train_dataset, args.val_dataset, args.test_dataset]
    dataset_shapes = [args.train_size, args.val_size, args.test_size]
    model_name = os.path.split(model)[-1]

    # from censai.utils import nulltape
    # def call_with_mask(self, lensed_image, noise_rms, psf, mask, outer_tape=nulltape):
    #     """
    #     Used in training. Return linked kappa and source maps.
    #     """
    #     batch_size = lensed_image.shape[0]
    #     source, kappa, source_grad, kappa_grad, states = self.initial_states(batch_size)  # initiate all tensors to 0
    #     source, kappa, states = self.time_step(lensed_image, source, kappa, source_grad, kappa_grad,
    #                                            states)  # Use lens to make an initial guess with Unet
    #     source_series = tf.TensorArray(DTYPE, size=self.steps)
    #     kappa_series = tf.TensorArray(DTYPE, size=self.steps)
    #     chi_squared_series = tf.TensorArray(DTYPE, size=self.steps)
    #     # record initial guess
    #     source_series = source_series.write(index=0, value=source)
    #     kappa_series = kappa_series.write(index=0, value=kappa)
    #     # Main optimization loop
    #     for current_step in tf.range(self.steps - 1):
    #         with outer_tape.stop_recording():
    #             with tf.GradientTape() as g:
    #                 g.watch(source)
    #                 g.watch(kappa)
    #                 y_pred = self.physical_model.forward(self.source_link(source), self.kappa_link(kappa), psf)
    #                 flux_term = tf.square(
    #                     tf.reduce_sum(y_pred, axis=(1, 2, 3)) - tf.reduce_sum(lensed_image, axis=(1, 2, 3)))
    #                 log_likelihood = 0.5 * tf.reduce_sum(
    #                     tf.square(y_pred - mask * lensed_image) / noise_rms[:, None, None, None] ** 2, axis=(1, 2, 3))
    #                 cost = tf.reduce_mean(log_likelihood + self.flux_lagrange_multiplier * flux_term)
    #             source_grad, kappa_grad = g.gradient(cost, [source, kappa])
    #             source_grad, kappa_grad = self.grad_update(source_grad, kappa_grad, current_step)
    #         source, kappa, states = self.time_step(lensed_image, source, kappa, source_grad, kappa_grad, states)
    #         source_series = source_series.write(index=current_step + 1, value=source)
    #         kappa_series = kappa_series.write(index=current_step + 1, value=kappa)
    #         chi_squared_series = chi_squared_series.write(index=current_step,
    #                                                       value=log_likelihood / self.pixels ** 2)  # renormalize chi squared here
    #     # last step score
    #     log_likelihood = self.physical_model.log_likelihood(y_true=lensed_image, source=self.source_link(source),
    #                                                         kappa=self.kappa_link(kappa), psf=psf, noise_rms=noise_rms)
    #     chi_squared_series = chi_squared_series.write(index=self.steps - 1, value=log_likelihood)
    #     return source_series.stack(), kappa_series.stack(), chi_squared_series.stack()

    with h5py.File(
            os.path.join(
                os.getenv("CENSAI_PATH"), "results", args.experiment_name +
                "_" + model_name + f"_{THIS_WORKER:02d}.h5"), 'w') as hf:
        for i, dataset in enumerate([train_dataset, val_dataset,
                                     test_dataset]):
            g = hf.create_group(f'{dataset_names[i]}')
            data_len = dataset_shapes[i] // N_WORKERS
            g.create_dataset(name="lens",
                             shape=[data_len, phys.pixels, phys.pixels, 1],
                             dtype=np.float32)
            g.create_dataset(name="psf",
                             shape=[
                                 data_len, physical_params['psf pixels'],
                                 physical_params['psf pixels'], 1
                             ],
                             dtype=np.float32)
            g.create_dataset(name="psf_fwhm",
                             shape=[data_len],
                             dtype=np.float32)
            g.create_dataset(name="noise_rms",
                             shape=[data_len],
                             dtype=np.float32)
            g.create_dataset(
                name="source",
                shape=[data_len, phys.src_pixels, phys.src_pixels, 1],
                dtype=np.float32)
            g.create_dataset(
                name="kappa",
                shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
                dtype=np.float32)
            g.create_dataset(name="lens_pred",
                             shape=[data_len, phys.pixels, phys.pixels, 1],
                             dtype=np.float32)
            g.create_dataset(name="lens_pred_reoptimized",
                             shape=[data_len, phys.pixels, phys.pixels, 1],
                             dtype=np.float32)
            g.create_dataset(name="source_pred",
                             shape=[
                                 data_len, rim.steps, phys.src_pixels,
                                 phys.src_pixels, 1
                             ],
                             dtype=np.float32)
            g.create_dataset(
                name="source_pred_reoptimized",
                shape=[data_len, phys.src_pixels, phys.src_pixels, 1])
            g.create_dataset(name="kappa_pred",
                             shape=[
                                 data_len, rim.steps, phys.kappa_pixels,
                                 phys.kappa_pixels, 1
                             ],
                             dtype=np.float32)
            g.create_dataset(
                name="kappa_pred_reoptimized",
                shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
                dtype=np.float32)
            g.create_dataset(name="chi_squared",
                             shape=[data_len, rim.steps],
                             dtype=np.float32)
            g.create_dataset(name="chi_squared_reoptimized",
                             shape=[data_len],
                             dtype=np.float32)
            g.create_dataset(name="chi_squared_reoptimized_series",
                             shape=[data_len, args.re_optimize_steps],
                             dtype=np.float32)
            g.create_dataset(name="source_optim_mse",
                             shape=[data_len],
                             dtype=np.float32)
            g.create_dataset(name="source_optim_mse_series",
                             shape=[data_len, args.re_optimize_steps],
                             dtype=np.float32)
            g.create_dataset(name="kappa_optim_mse",
                             shape=[data_len],
                             dtype=np.float32)
            g.create_dataset(name="kappa_optim_mse_series",
                             shape=[data_len, args.re_optimize_steps],
                             dtype=np.float32)
            g.create_dataset(name="lens_coherence_spectrum",
                             shape=[data_len, args.lens_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="source_coherence_spectrum",
                             shape=[data_len, args.source_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="lens_coherence_spectrum2",
                             shape=[data_len, args.lens_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="lens_coherence_spectrum_reoptimized",
                             shape=[data_len, args.lens_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="source_coherence_spectrum2",
                             shape=[data_len, args.source_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="source_coherence_spectrum_reoptimized",
                             shape=[data_len, args.source_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="kappa_coherence_spectrum",
                             shape=[data_len, args.kappa_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="kappa_coherence_spectrum_reoptimized",
                             shape=[data_len, args.kappa_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="lens_frequencies",
                             shape=[args.lens_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="source_frequencies",
                             shape=[args.source_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="kappa_frequencies",
                             shape=[args.kappa_coherence_bins],
                             dtype=np.float32)
            g.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32)
            g.create_dataset(name="source_fov", shape=[1], dtype=np.float32)
            g.create_dataset(name="lens_fov", shape=[1], dtype=np.float32)
            dataset = dataset.skip(data_len * (THIS_WORKER - 1)).take(data_len)
            for batch, (lens, source, kappa, noise_rms, psf,
                        fwhm) in enumerate(
                            dataset.batch(1).prefetch(
                                tf.data.experimental.AUTOTUNE)):
                checkpoint_manager.checkpoint.restore(
                    checkpoint_manager.latest_checkpoint).expect_partial(
                    )  # reset model weights
                # Compute predictions for kappa and source
                source_pred, kappa_pred, chi_squared = rim.predict(
                    lens, noise_rms, psf)
                lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf)
                # Re-optimize weights of the model
                STEPS = args.re_optimize_steps
                learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                    initial_learning_rate=args.learning_rate,
                    decay_rate=args.decay_rate,
                    decay_steps=args.decay_steps,
                    staircase=args.staircase)
                optim = tf.keras.optimizers.RMSprop(
                    learning_rate=learning_rate_schedule)

                chi_squared_series = tf.TensorArray(DTYPE, size=STEPS)
                source_mse = tf.TensorArray(DTYPE, size=STEPS)
                kappa_mse = tf.TensorArray(DTYPE, size=STEPS)
                best = chi_squared[-1, 0]
                # best = abs(2*chi_squared[-1, 0] - 1)
                # best_chisq = 2*chi_squared[-1, 0]
                source_best = source_pred[-1]
                kappa_best = kappa_pred[-1]
                # source_mean = source_pred[-1]
                # kappa_mean = rim.kappa_link(kappa_pred[-1])
                # source_std = tf.zeros_like(source_mean)
                # kappa_std = tf.zeros_like(kappa_mean)
                # counter = 0
                for current_step in tqdm(range(STEPS)):
                    with tf.GradientTape() as tape:
                        tape.watch(unet.trainable_variables)
                        # s, k, chi_sq = call_with_mask(rim, lens, noise_rms, psf, mask, tape)
                        s, k, chi_sq = rim.call(lens,
                                                noise_rms,
                                                psf,
                                                outer_tape=tape)
                        cost = tf.reduce_mean(chi_sq)  # mean over time steps
                        cost += tf.reduce_sum(rim.unet.losses)

                    log_likelihood = chi_sq[-1]
                    chi_squared_series = chi_squared_series.write(
                        index=current_step, value=log_likelihood)
                    source_o = s[-1]
                    kappa_o = k[-1]
                    source_mse = source_mse.write(
                        index=current_step,
                        value=tf.reduce_mean(
                            (source_o - rim.source_inverse_link(source))**2))
                    kappa_mse = kappa_mse.write(
                        index=current_step,
                        value=tf.reduce_mean(
                            (kappa_o - rim.kappa_inverse_link(kappa))**2))
                    if chi_sq[-1, 0] < args.converged_chisq:
                        source_best = rim.source_link(source_o)
                        kappa_best = rim.kappa_link(kappa_o)
                        best = chi_sq[-1, 0]
                        break
                    if chi_sq[-1, 0] < best:
                        source_best = rim.source_link(source_o)
                        kappa_best = rim.kappa_link(kappa_o)
                        best = chi_sq[-1, 0]
                        source_mse_best = tf.reduce_mean(
                            (source_best - rim.source_inverse_link(source))**2)
                        kappa_mse_best = tf.reduce_mean(
                            (kappa_best - rim.kappa_inverse_link(kappa))**2)
                    # if counter > 0:
                    #     # Welford's online algorithm
                    #     # source
                    #     delta = source_o - source_mean
                    #     source_mean = (counter * source_mean + (counter + 1) * source_o)/(counter + 1)
                    #     delta2 = source_o - source_mean
                    #     source_std += delta * delta2
                    #     # kappa
                    #     delta = rim.kappa_link(kappa_o) - kappa_mean
                    #     kappa_mean = (counter * kappa_mean + (counter + 1) * rim.kappa_link(kappa_o)) / (counter + 1)
                    #     delta2 = rim.kappa_link(kappa_o) - kappa_mean
                    #     kappa_std += delta * delta2
                    # if best_chisq < args.converged_chisq:
                    #     counter += 1
                    #     if counter == args.window:
                    #         break
                    # if 2*chi_sq[-1, 0] < best_chisq:
                    #     best_chisq = 2*chi_sq[-1, 0]
                    # if abs(2*chi_sq[-1, 0] - 1) < best:
                    #     source_best = rim.source_link(source_o)
                    #     kappa_best = rim.kappa_link(kappa_o)
                    #     best = abs(2 * chi_squared[-1, 0] - 1)
                    #     source_mse_best = tf.reduce_mean((source_best - rim.source_inverse_link(source)) ** 2)
                    #     kappa_mse_best = tf.reduce_mean((kappa_best - rim.kappa_inverse_link(kappa)) ** 2)

                    grads = tape.gradient(cost, unet.trainable_variables)
                    optim.apply_gradients(zip(grads, unet.trainable_variables))

                source_o = source_best
                kappa_o = kappa_best
                y_pred = phys.forward(source_o, kappa_o, psf)
                chi_sq_series = tf.transpose(chi_squared_series.stack(),
                                             perm=[1, 0])
                source_mse = source_mse.stack()[None, ...]
                kappa_mse = kappa_mse.stack()[None, ...]
                # kappa_std /= float(args.window)
                # source_std /= float(args.window)

                # Compute Power spectrum of converged predictions
                _ps_lens = ps_lens.cross_correlation_coefficient(
                    lens[..., 0], lens_pred[..., 0])
                _ps_lens3 = ps_lens.cross_correlation_coefficient(
                    lens[..., 0], y_pred[..., 0])
                _ps_kappa = ps_kappa.cross_correlation_coefficient(
                    log_10(kappa)[..., 0],
                    log_10(kappa_pred[-1])[..., 0])
                _ps_kappa2 = ps_kappa.cross_correlation_coefficient(
                    log_10(kappa)[..., 0], log_10(kappa_o[..., 0]))
                _ps_source = ps_source.cross_correlation_coefficient(
                    source[..., 0], source_pred[-1][..., 0])
                _ps_source3 = ps_source.cross_correlation_coefficient(
                    source[..., 0], source_o[..., 0])

                # save results
                g["lens"][batch] = lens.numpy().astype(np.float32)
                g["psf"][batch] = psf.numpy().astype(np.float32)
                g["psf_fwhm"][batch] = fwhm.numpy().astype(np.float32)
                g["noise_rms"][batch] = noise_rms.numpy().astype(np.float32)
                g["source"][batch] = source.numpy().astype(np.float32)
                g["kappa"][batch] = kappa.numpy().astype(np.float32)
                g["lens_pred"][batch] = lens_pred.numpy().astype(np.float32)
                g["lens_pred_reoptimized"][batch] = y_pred.numpy().astype(
                    np.float32)
                g["source_pred"][batch] = tf.transpose(
                    source_pred,
                    perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
                g["source_pred_reoptimized"][batch] = source_o.numpy().astype(
                    np.float32)
                g["kappa_pred"][batch] = tf.transpose(
                    kappa_pred,
                    perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
                g["kappa_pred_reoptimized"][batch] = kappa_o.numpy().astype(
                    np.float32)
                g["chi_squared"][batch] = tf.transpose(
                    chi_squared).numpy().astype(np.float32)
                g["chi_squared_reoptimized"][batch] = best.numpy().astype(
                    np.float32)
                g["chi_squared_reoptimized_series"][
                    batch] = chi_sq_series.numpy().astype(np.float32)
                g["source_optim_mse"][batch] = source_mse_best.numpy().astype(
                    np.float32)
                g["source_optim_mse_series"][batch] = source_mse.numpy(
                ).astype(np.float32)
                g["kappa_optim_mse"][batch] = kappa_mse_best.numpy().astype(
                    np.float32)
                g["kappa_optim_mse_series"][batch] = kappa_mse.numpy().astype(
                    np.float32)
                g["lens_coherence_spectrum"][batch] = _ps_lens
                g["lens_coherence_spectrum_reoptimized"][batch] = _ps_lens3
                g["source_coherence_spectrum"][batch] = _ps_source
                g["source_coherence_spectrum_reoptimized"][batch] = _ps_source3
                g["lens_coherence_spectrum"][batch] = _ps_lens
                g["lens_coherence_spectrum"][batch] = _ps_lens
                g["kappa_coherence_spectrum"][batch] = _ps_kappa
                g["kappa_coherence_spectrum_reoptimized"][batch] = _ps_kappa2

                if batch == 0:
                    _, f = np.histogram(np.fft.fftfreq(
                        phys.pixels)[:phys.pixels // 2],
                                        bins=ps_lens.bins)
                    f = (f[:-1] + f[1:]) / 2
                    g["lens_frequencies"][:] = f
                    _, f = np.histogram(np.fft.fftfreq(
                        phys.src_pixels)[:phys.src_pixels // 2],
                                        bins=ps_source.bins)
                    f = (f[:-1] + f[1:]) / 2
                    g["source_frequencies"][:] = f
                    _, f = np.histogram(np.fft.fftfreq(
                        phys.kappa_pixels)[:phys.kappa_pixels // 2],
                                        bins=ps_kappa.bins)
                    f = (f[:-1] + f[1:]) / 2
                    g["kappa_frequencies"][:] = f
                    g["kappa_fov"][0] = phys.kappa_fov
                    g["source_fov"][0] = phys.src_fov

        # Create SIE test
        g = hf.create_group(f'SIE_test')
        data_len = args.sie_size // N_WORKERS
        sie_dataset = test_dataset.skip(data_len *
                                        (THIS_WORKER - 1)).take(data_len)
        g.create_dataset(name="lens",
                         shape=[data_len, phys.pixels, phys.pixels, 1],
                         dtype=np.float32)
        g.create_dataset(name="psf",
                         shape=[
                             data_len, physical_params['psf pixels'],
                             physical_params['psf pixels'], 1
                         ],
                         dtype=np.float32)
        g.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32)
        g.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32)
        g.create_dataset(name="source",
                         shape=[data_len, phys.src_pixels, phys.src_pixels, 1],
                         dtype=np.float32)
        g.create_dataset(
            name="kappa",
            shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
            dtype=np.float32)
        g.create_dataset(name="lens_pred",
                         shape=[data_len, phys.pixels, phys.pixels, 1],
                         dtype=np.float32)
        g.create_dataset(name="lens_pred2",
                         shape=[data_len, phys.pixels, phys.pixels, 1],
                         dtype=np.float32)
        g.create_dataset(
            name="source_pred",
            shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        g.create_dataset(name="kappa_pred",
                         shape=[
                             data_len, rim.steps, phys.kappa_pixels,
                             phys.kappa_pixels, 1
                         ],
                         dtype=np.float32)
        g.create_dataset(name="chi_squared",
                         shape=[data_len, rim.steps],
                         dtype=np.float32)
        g.create_dataset(name="lens_coherence_spectrum",
                         shape=[data_len, args.lens_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="source_coherence_spectrum",
                         shape=[data_len, args.source_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="lens_coherence_spectrum2",
                         shape=[data_len, args.lens_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="source_coherence_spectrum2",
                         shape=[data_len, args.source_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="kappa_coherence_spectrum",
                         shape=[data_len, args.kappa_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="lens_frequencies",
                         shape=[args.lens_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="source_frequencies",
                         shape=[args.source_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="kappa_frequencies",
                         shape=[args.kappa_coherence_bins],
                         dtype=np.float32)
        g.create_dataset(name="einstein_radius",
                         shape=[data_len],
                         dtype=np.float32)
        g.create_dataset(name="position",
                         shape=[data_len, 2],
                         dtype=np.float32)
        g.create_dataset(name="orientation",
                         shape=[data_len],
                         dtype=np.float32)
        g.create_dataset(name="ellipticity",
                         shape=[data_len],
                         dtype=np.float32)
        g.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32)
        g.create_dataset(name="source_fov", shape=[1], dtype=np.float32)
        g.create_dataset(name="lens_fov", shape=[1], dtype=np.float32)

        for batch, (_, source, _, noise_rms, psf, fwhm) in enumerate(
                sie_dataset.take(data_len).batch(args.batch_size).prefetch(
                    tf.data.experimental.AUTOTUNE)):
            batch_size = source.shape[0]
            # Create some SIE kappa maps
            _r = tf.random.uniform(shape=[batch_size, 1, 1, 1],
                                   minval=0,
                                   maxval=args.max_shift)
            _theta = tf.random.uniform(shape=[batch_size, 1, 1, 1],
                                       minval=-np.pi,
                                       maxval=np.pi)
            x0 = _r * tf.math.cos(_theta)
            y0 = _r * tf.math.sin(_theta)
            ellipticity = tf.random.uniform(shape=[batch_size, 1, 1, 1],
                                            minval=0.,
                                            maxval=args.max_ellipticity)
            phi = tf.random.uniform(shape=[batch_size, 1, 1, 1],
                                    minval=-np.pi,
                                    maxval=np.pi)
            einstein_radius = tf.random.uniform(shape=[batch_size, 1, 1, 1],
                                                minval=args.min_theta_e,
                                                maxval=args.max_theta_e)
            kappa = phys_sie.kappa_field(x0=x0,
                                         y0=y0,
                                         e=ellipticity,
                                         phi=phi,
                                         r_ein=einstein_radius)
            lens = phys.noisy_forward(source,
                                      kappa,
                                      noise_rms=noise_rms,
                                      psf=psf)

            # Compute predictions for kappa and source
            source_pred, kappa_pred, chi_squared = rim.predict(
                lens, noise_rms, psf)
            lens_pred = phys.forward(source_pred[-1], kappa_pred[-1], psf)
            # Compute Power spectrum of converged predictions
            _ps_lens = ps_lens.cross_correlation_coefficient(
                lens[..., 0], lens_pred[..., 0])
            _ps_kappa = ps_kappa.cross_correlation_coefficient(
                log_10(kappa)[..., 0],
                log_10(kappa_pred[-1])[..., 0])
            _ps_source = ps_source.cross_correlation_coefficient(
                source[..., 0], source_pred[-1][..., 0])

            # save results
            i_begin = batch * args.batch_size
            i_end = i_begin + batch_size
            g["lens"][i_begin:i_end] = lens.numpy().astype(np.float32)
            g["psf"][i_begin:i_end] = psf.numpy().astype(np.float32)
            g["psf_fwhm"][i_begin:i_end] = fwhm.numpy().astype(np.float32)
            g["noise_rms"][i_begin:i_end] = noise_rms.numpy().astype(
                np.float32)
            g["source"][i_begin:i_end] = source.numpy().astype(np.float32)
            g["kappa"][i_begin:i_end] = kappa.numpy().astype(np.float32)
            g["lens_pred"][i_begin:i_end] = lens_pred.numpy().astype(
                np.float32)
            g["source_pred"][i_begin:i_end] = tf.transpose(
                source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            g["kappa_pred"][i_begin:i_end] = tf.transpose(
                kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            g["chi_squared"][i_begin:i_end] = 2 * tf.transpose(
                chi_squared).numpy().astype(np.float32)
            g["lens_coherence_spectrum"][i_begin:i_end] = _ps_lens.numpy(
            ).astype(np.float32)
            g["source_coherence_spectrum"][i_begin:i_end] = _ps_source.numpy(
            ).astype(np.float32)
            g["kappa_coherence_spectrum"][i_begin:i_end] = _ps_kappa.numpy(
            ).astype(np.float32)
            g["einstein_radius"][
                i_begin:i_end] = einstein_radius[:, 0, 0,
                                                 0].numpy().astype(np.float32)
            g["position"][i_begin:i_end] = tf.stack(
                [x0[:, 0, 0, 0], y0[:, 0, 0, 0]],
                axis=1).numpy().astype(np.float32)
            g["ellipticity"][i_begin:i_end] = ellipticity[:, 0, 0,
                                                          0].numpy().astype(
                                                              np.float32)
            g["orientation"][i_begin:i_end] = phi[:, 0, 0,
                                                  0].numpy().astype(np.float32)

            if batch == 0:
                _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels //
                                                                2],
                                    bins=ps_lens.bins)
                f = (f[:-1] + f[1:]) / 2
                g["lens_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(
                    phys.src_pixels)[:phys.src_pixels // 2],
                                    bins=ps_source.bins)
                f = (f[:-1] + f[1:]) / 2
                g["source_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(
                    phys.kappa_pixels)[:phys.kappa_pixels // 2],
                                    bins=ps_kappa.bins)
                f = (f[:-1] + f[1:]) / 2
                g["kappa_frequencies"][:] = f
                g["kappa_fov"][0] = phys.kappa_fov
                g["source_fov"][0] = phys.src_fov
Ejemplo n.º 15
0
def distributed_strategy(args):

    model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model)
    path = os.getenv('CENSAI_PATH') + "/results/"
    dataset = []
    for file in sorted(glob.glob(path + args.h5_pattern)):
        try:
            dataset.append(h5py.File(file, "r"))
        except:
            continue
    B = dataset[0]["source"].shape[0]
    data_len = len(dataset) * B // N_WORKERS

    ps_observation = PowerSpectrum(bins=args.observation_coherence_bins,
                                   pixels=128)
    ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=128)
    ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=128)

    phys = PhysicalModel(
        pixels=128,
        kappa_pixels=128,
        src_pixels=128,
        image_fov=7.69,
        kappa_fov=7.69,
        src_fov=3.,
        method="fft",
    )

    with open(os.path.join(model, "unet_hparams.json")) as f:
        unet_params = json.load(f)
    unet_params["kernel_l2_amp"] = args.l2_amp
    unet = Model(**unet_params)
    ckpt = tf.train.Checkpoint(net=unet)
    checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1)
    checkpoint_manager.checkpoint.restore(
        checkpoint_manager.latest_checkpoint).expect_partial()
    with open(os.path.join(model, "rim_hparams.json")) as f:
        rim_params = json.load(f)
    rim = RIM(phys, unet, **rim_params)

    kvae_path = os.path.join(os.getenv('CENSAI_PATH'), "models",
                             args.kappa_vae)
    with open(os.path.join(kvae_path, "model_hparams.json"), "r") as f:
        kappa_vae_hparams = json.load(f)
    kappa_vae = VAE(**kappa_vae_hparams)
    ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae)
    checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, kvae_path, 1)
    checkpoint_manager1.checkpoint.restore(
        checkpoint_manager1.latest_checkpoint).expect_partial()

    svae_path = os.path.join(os.getenv('CENSAI_PATH'), "models",
                             args.source_vae)
    with open(os.path.join(svae_path, "model_hparams.json"), "r") as f:
        source_vae_hparams = json.load(f)
    source_vae = VAE(**source_vae_hparams)
    ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae)
    checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, svae_path, 1)
    checkpoint_manager2.checkpoint.restore(
        checkpoint_manager2.latest_checkpoint).expect_partial()
    wk = lambda k: tf.sqrt(k) / tf.reduce_sum(
        tf.sqrt(k), axis=(1, 2, 3), keepdims=True)

    # Freeze L5
    # encoding layers
    # rim.unet.layers[0].trainable = False # L1
    # rim.unet.layers[1].trainable = False
    # rim.unet.layers[2].trainable = False
    # rim.unet.layers[3].trainable = False
    # rim.unet.layers[4].trainable = False # L5
    # GRU
    # rim.unet.layers[5].trainable = False
    # rim.unet.layers[6].trainable = False
    # rim.unet.layers[7].trainable = False
    # rim.unet.layers[8].trainable = False
    # rim.unet.layers[9].trainable = False
    # rim.unet.layers[15].trainable = False  # bottleneck GRU
    # output layer
    # rim.unet.layers[-2].trainable = False
    # input layer
    # rim.unet.layers[-1].trainable = False
    # decoding layers
    # rim.unet.layers[10].trainable = False # L5
    # rim.unet.layers[11].trainable = False
    # rim.unet.layers[12].trainable = False
    # rim.unet.layers[13].trainable = False
    # rim.unet.layers[14].trainable = False # L1

    with h5py.File(
            os.path.join(
                os.getenv("CENSAI_PATH"), "results",
                args.experiment_name + "_" + args.model + "_" + args.dataset +
                f"_{THIS_WORKER:03d}.h5"), 'w') as hf:
        hf.create_dataset(name="observation",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(name="psf",
                          shape=[data_len, 20, 20, 1],
                          dtype=np.float32)
        hf.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32)
        hf.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32)
        hf.create_dataset(
            name="source",
            shape=[data_len, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(
            name="kappa",
            shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(name="observation_pred",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(name="observation_pred_reoptimized",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(
            name="source_pred",
            shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(
            name="source_pred_reoptimized",
            shape=[data_len, phys.src_pixels, phys.src_pixels, 1])
        hf.create_dataset(name="kappa_pred",
                          shape=[
                              data_len, rim.steps, phys.kappa_pixels,
                              phys.kappa_pixels, 1
                          ],
                          dtype=np.float32)
        hf.create_dataset(
            name="kappa_pred_reoptimized",
            shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(name="chi_squared",
                          shape=[data_len, rim.steps],
                          dtype=np.float32)
        hf.create_dataset(name="chi_squared_reoptimized",
                          shape=[data_len],
                          dtype=np.float32)
        hf.create_dataset(name="chi_squared_reoptimized_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="source_optim_mse",
                          shape=[data_len],
                          dtype=np.float32)
        hf.create_dataset(name="source_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_optim_mse",
                          shape=[data_len],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="observation_coherence_spectrum",
                          shape=[data_len, args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="source_coherence_spectrum",
                          shape=[data_len, args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_coherence_spectrum2",
                          shape=[data_len, args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_coherence_spectrum_reoptimized",
                          shape=[data_len, args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="source_coherence_spectrum2",
                          shape=[data_len, args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="source_coherence_spectrum_reoptimized",
                          shape=[data_len, args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_coherence_spectrum",
                          shape=[data_len, args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_coherence_spectrum_reoptimized",
                          shape=[data_len, args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_frequencies",
                          shape=[args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="source_frequencies",
                          shape=[args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_frequencies",
                          shape=[args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32)
        hf.create_dataset(name="source_fov", shape=[1], dtype=np.float32)
        hf.create_dataset(name="observation_fov", shape=[1], dtype=np.float32)
        for batch, j in enumerate(
                range((THIS_WORKER - 1) * data_len, THIS_WORKER * data_len)):
            b = j // B
            k = j % B
            observation = dataset[b]["observation"][k][None, ...]
            source = dataset[b]["source"][k][None, ...]
            kappa = dataset[b]["kappa"][k][None, ...]
            noise_rms = np.array([dataset[b]["noise_rms"][k]])
            psf = dataset[b]["psf"][k][None, ...]
            fwhm = dataset[b]["psf_fwhm"][k]

            checkpoint_manager.checkpoint.restore(
                checkpoint_manager.latest_checkpoint).expect_partial(
                )  # reset model weights
            # Compute predictions for kappa and source
            source_pred, kappa_pred, chi_squared = rim.predict(
                observation, noise_rms, psf)
            observation_pred = phys.forward(source_pred[-1], kappa_pred[-1],
                                            psf)
            # reset the seed for reproducible sampling in the VAE for EWC
            tf.random.set_seed(args.seed)
            np.random.seed(args.seed)
            # Initialize regularization term
            ewc = EWC(observation=observation,
                      noise_rms=noise_rms,
                      psf=psf,
                      phys=phys,
                      rim=rim,
                      source_vae=source_vae,
                      kappa_vae=kappa_vae,
                      n_samples=args.sample_size,
                      sigma_source=args.source_vae_ball_size,
                      sigma_kappa=args.kappa_vae_ball_size)
            # Re-optimize weights of the model
            STEPS = args.re_optimize_steps
            learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=args.learning_rate,
                decay_rate=args.decay_rate,
                decay_steps=args.decay_steps,
                staircase=args.staircase)
            optim = tf.keras.optimizers.SGD(
                learning_rate=learning_rate_schedule)

            chi_squared_series = tf.TensorArray(DTYPE, size=STEPS)
            source_mse = tf.TensorArray(DTYPE, size=STEPS)
            kappa_mse = tf.TensorArray(DTYPE, size=STEPS)
            best = chi_squared[-1, 0]
            source_best = source_pred[-1]
            kappa_best = kappa_pred[-1]
            source_mse_best = tf.reduce_mean(
                (source_best - rim.source_inverse_link(source))**2)
            kappa_mse_best = tf.reduce_sum(
                wk(kappa) * (kappa_best - rim.kappa_inverse_link(kappa))**2)

            for current_step in tqdm(range(STEPS)):
                with tf.GradientTape() as tape:
                    tape.watch(unet.trainable_variables)
                    s, k, chi_sq = rim.call(observation,
                                            noise_rms,
                                            psf,
                                            outer_tape=tape)
                    cost = tf.reduce_mean(chi_sq)  # mean over time steps
                    cost += tf.reduce_sum(rim.unet.losses)  # L2 regularisation
                    cost += args.lam_ewc * ewc.penalty(
                        rim)  # Elastic Weights Consolidation

                log_likelihood = chi_sq[-1]
                chi_squared_series = chi_squared_series.write(
                    index=current_step, value=log_likelihood)
                source_o = s[-1]
                kappa_o = k[-1]
                source_mse = source_mse.write(
                    index=current_step,
                    value=tf.reduce_mean(
                        (source_o - rim.source_inverse_link(source))**2))
                kappa_mse = kappa_mse.write(
                    index=current_step,
                    value=tf.reduce_sum(
                        wk(kappa) *
                        (kappa_o - rim.kappa_inverse_link(kappa))**2))
                if 2 * chi_sq[-1, 0] < 1.0 and args.early_stopping:
                    source_best = rim.source_link(source_o)
                    kappa_best = rim.kappa_link(kappa_o)
                    best = chi_sq[-1, 0]
                    source_mse_best = tf.reduce_mean(
                        (source_o - rim.source_inverse_link(source))**2)
                    kappa_mse_best = tf.reduce_sum(
                        wk(kappa) *
                        (kappa_o - rim.kappa_inverse_link(kappa))**2)
                    break
                if chi_sq[-1, 0] < best:
                    source_best = rim.source_link(source_o)
                    kappa_best = rim.kappa_link(kappa_o)
                    best = chi_sq[-1, 0]
                    source_mse_best = tf.reduce_mean(
                        (source_o - rim.source_inverse_link(source))**2)
                    kappa_mse_best = tf.reduce_sum(
                        wk(kappa) *
                        (kappa_o - rim.kappa_inverse_link(kappa))**2)

                grads = tape.gradient(cost, unet.trainable_variables)
                optim.apply_gradients(zip(grads, unet.trainable_variables))

            source_o = source_best
            kappa_o = kappa_best
            y_pred = phys.forward(source_o, kappa_o, psf)
            chi_sq_series = tf.transpose(chi_squared_series.stack(),
                                         perm=[1, 0])
            source_mse = source_mse.stack()[None, ...]
            kappa_mse = kappa_mse.stack()[None, ...]

            # Compute Power spectrum of converged predictions
            _ps_observation = ps_observation.cross_correlation_coefficient(
                observation[..., 0], observation_pred[..., 0])
            _ps_observation2 = ps_observation.cross_correlation_coefficient(
                observation[..., 0], y_pred[..., 0])
            _ps_kappa = ps_kappa.cross_correlation_coefficient(
                log_10(kappa)[..., 0],
                log_10(kappa_pred[-1])[..., 0])
            _ps_kappa2 = ps_kappa.cross_correlation_coefficient(
                log_10(kappa)[..., 0], log_10(kappa_o[..., 0]))
            _ps_source = ps_source.cross_correlation_coefficient(
                source[..., 0], source_pred[-1][..., 0])
            _ps_source2 = ps_source.cross_correlation_coefficient(
                source[..., 0], source_o[..., 0])

            # save results
            hf["observation"][batch] = observation.astype(np.float32)
            hf["psf"][batch] = psf.astype(np.float32)
            hf["psf_fwhm"][batch] = fwhm
            hf["noise_rms"][batch] = noise_rms.astype(np.float32)
            hf["source"][batch] = source.astype(np.float32)
            hf["kappa"][batch] = kappa.astype(np.float32)
            hf["observation_pred"][batch] = observation_pred.numpy().astype(
                np.float32)
            hf["observation_pred_reoptimized"][batch] = y_pred.numpy().astype(
                np.float32)
            hf["source_pred"][batch] = tf.transpose(
                source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            hf["source_pred_reoptimized"][batch] = source_o.numpy().astype(
                np.float32)
            hf["kappa_pred"][batch] = tf.transpose(
                kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            hf["kappa_pred_reoptimized"][batch] = kappa_o.numpy().astype(
                np.float32)
            hf["chi_squared"][batch] = 2 * tf.transpose(
                chi_squared).numpy().astype(np.float32)
            hf["chi_squared_reoptimized"][batch] = 2 * best.numpy().astype(
                np.float32)
            hf["chi_squared_reoptimized_series"][
                batch] = 2 * chi_sq_series.numpy().astype(np.float32)
            hf["source_optim_mse"][batch] = source_mse_best.numpy().astype(
                np.float32)
            hf["source_optim_mse_series"][batch] = source_mse.numpy().astype(
                np.float32)
            hf["kappa_optim_mse"][batch] = kappa_mse_best.numpy().astype(
                np.float32)
            hf["kappa_optim_mse_series"][batch] = kappa_mse.numpy().astype(
                np.float32)
            hf["observation_coherence_spectrum"][batch] = _ps_observation
            hf["observation_coherence_spectrum_reoptimized"][
                batch] = _ps_observation2
            hf["source_coherence_spectrum"][batch] = _ps_source
            hf["source_coherence_spectrum_reoptimized"][batch] = _ps_source2
            hf["kappa_coherence_spectrum"][batch] = _ps_kappa
            hf["kappa_coherence_spectrum_reoptimized"][batch] = _ps_kappa2

            if batch == 0:
                _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels //
                                                                2],
                                    bins=ps_observation.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["observation_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(
                    phys.src_pixels)[:phys.src_pixels // 2],
                                    bins=ps_source.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["source_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(
                    phys.kappa_pixels)[:phys.kappa_pixels // 2],
                                    bins=ps_kappa.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["kappa_frequencies"][:] = f
                hf["kappa_fov"][0] = phys.kappa_fov
                hf["source_fov"][0] = phys.src_fov
Ejemplo n.º 16
0
def test_noisy_forward_conv2():
    phys = PhysicalModel(pixels=64, src_pixels=64)
    source = tf.random.normal([2, 64, 64, 1])
    kappa = tf.math.exp(tf.random.uniform([2, 64, 64, 1]))
    noise_rms = 0.1
    phys.noisy_forward(source, kappa, noise_rms)
Ejemplo n.º 17
0
def main(args):
    if args.seed is not None:
        tf.random.set_seed(args.seed)
        np.random.seed(args.seed)
    if args.json_override is not None:
        if isinstance(args.json_override, list):
            files = args.json_override
        else:
            files = [
                args.json_override,
            ]
        for file in files:
            with open(file, "r") as f:
                json_override = json.load(f)
            args_dict = vars(args)
            args_dict.update(json_override)
    if args.v2:  # overwrite decoding procedure with version 2
        from censai.data.alpha_tng_v2 import decode_train, decode_physical_info
    else:
        from censai.data.alpha_tng import decode_train, decode_physical_info
    # ========= Dataset=================================================================================================
    files = []
    for dataset in args.datasets:
        files.extend(glob.glob(os.path.join(dataset, "*.tfrecords")))
    np.random.shuffle(files)
    files = tf.data.Dataset.from_tensor_slices(files)
    dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
        x, compression_type=args.compression_type),
                               block_length=args.block_length,
                               num_parallel_calls=tf.data.AUTOTUNE)
    # extract physical info from first example
    for image_fov, kappa_fov in dataset.map(decode_physical_info):
        break
    dataset = dataset.map(decode_train)
    if args.cache_file is not None:
        dataset = dataset.cache(args.cache_file)
    train_dataset = dataset.take(math.floor(args.train_split * args.total_items))\
        .shuffle(buffer_size=args.buffer_size).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    val_dataset = dataset.skip(math.floor(args.train_split * args.total_items))\
        .take(math.ceil((1 - args.train_split) * args.total_items)).batch(args.batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    train_dataset = STRATEGY.experimental_distribute_dataset(train_dataset)
    val_dataset = STRATEGY.experimental_distribute_dataset(val_dataset)

    # ========== Model and Physical Model =============================================================================
    # setup an analytical physical model to compare lenses from raytracer and analytical deflection angles.

    phys = PhysicalModel(pixels=args.pixels,
                         image_fov=image_fov,
                         kappa_fov=kappa_fov,
                         src_fov=args.source_fov,
                         psf_sigma=args.psf_sigma)
    with STRATEGY.scope():  # Replicate ops accross gpus
        ray_tracer = RayTracer(
            pixels=args.pixels,
            filter_scaling=args.filter_scaling,
            layers=args.layers,
            block_conv_layers=args.block_conv_layers,
            kernel_size=args.kernel_size,
            filters=args.filters,
            strides=args.strides,
            bottleneck_filters=args.bottleneck_filters,
            resampling_kernel_size=args.resampling_kernel_size,
            upsampling_interpolation=args.upsampling_interpolation,
            kernel_regularizer_amp=args.kernel_regularizer_amp,
            activation=args.activation,
            initializer=args.initializer,
            kappalog=args.kappalog,
            normalize=args.normalize,
        )

        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            args.initial_learning_rate,
            decay_steps=args.decay_steps,
            decay_rate=args.decay_rate,
            staircase=True)
        optim = tf.keras.optimizers.deserialize({
            "class_name": args.optimizer,
            'config': {
                "learning_rate": lr_schedule
            }
        })

    # ==== Take care of where to write logs and stuff =================================================================
    if args.model_id.lower() != "none":
        logname = args.model_id
    elif args.logname is not None:
        logname = args.logname
    else:
        logname = args.logname_prefixe + "_" + datetime.now().strftime(
            "%y-%m-%d_%H-%M-%S")
    # setup tensorboard writer (nullwriter in case we do not want to sync)
    if args.logdir.lower() != "none":
        logdir = os.path.join(args.logdir, logname)
        if not os.path.isdir(logdir):
            os.mkdir(logdir)
        writer = tf.summary.create_file_writer(logdir)
    else:
        writer = nullwriter()
    # ===== Make sure directory and checkpoint manager are created to save model ===================================
    if args.model_dir.lower() != "none":
        checkpoints_dir = os.path.join(args.model_dir, logname)
        if not os.path.isdir(checkpoints_dir):
            os.mkdir(checkpoints_dir)
            # save script parameter for future reference
            with open(os.path.join(checkpoints_dir, "script_params.json"),
                      "w") as f:
                json.dump(vars(args), f, indent=4)
            with open(os.path.join(checkpoints_dir, "ray_tracer_hparams.json"),
                      "w") as f:
                hparams_dict = {
                    key: vars(args)[key]
                    for key in RAYTRACER_HPARAMS
                }
                json.dump(hparams_dict, f, indent=4)
        ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                                   optimizer=optim,
                                   net=ray_tracer)
        checkpoint_manager = tf.train.CheckpointManager(
            ckpt, checkpoints_dir, max_to_keep=args.max_to_keep)
        save_checkpoint = True
        # ======= Load model if model_id is provided ===============================================================
        if args.model_id.lower() != "none":
            checkpoint_manager.checkpoint.restore(
                checkpoint_manager.latest_checkpoint)
    else:
        save_checkpoint = False
    # =================================================================================================================

    def train_step(inputs):
        kappa, alpha = inputs
        with tf.GradientTape(watch_accessed_variables=True) as tape:
            tape.watch(ray_tracer.trainable_weights)
            cost = tf.reduce_mean(tf.square(ray_tracer(kappa) - alpha),
                                  axis=(1, 2, 3))
            cost = tf.reduce_sum(
                cost) / args.batch_size  # normalize by global batch size
        gradient = tape.gradient(cost, ray_tracer.trainable_weights)
        if args.clipping:
            clipped_gradient = [
                tf.clip_by_value(grad, -10, 10) for grad in gradient
            ]
        else:
            clipped_gradient = gradient
        optim.apply_gradients(
            zip(clipped_gradient, ray_tracer.trainable_variables))
        return cost

    @tf.function
    def distributed_train_step(dist_inputs):
        per_replica_losses = STRATEGY.run(train_step, args=(dist_inputs, ))
        cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM,
                               per_replica_losses,
                               axis=None)
        cost += tf.reduce_sum(ray_tracer.losses)
        return cost

    def test_step(inputs):
        kappa, alpha = inputs
        cost = tf.reduce_mean(tf.square(ray_tracer(kappa) - alpha),
                              axis=(1, 2, 3))
        cost = tf.reduce_sum(
            cost) / args.batch_size  # normalize by global batch size
        return cost

    @tf.function
    def distributed_test_step(dist_inputs):
        per_replica_losses = STRATEGY.run(test_step, args=(dist_inputs, ))
        # Replica losses are aggregated by summing them
        cost = STRATEGY.reduce(tf.distribute.ReduceOp.SUM,
                               per_replica_losses,
                               axis=None)
        cost += tf.reduce_sum(ray_tracer.losses)
        return cost

    # =================================================================================================================
    epoch_loss = tf.metrics.Mean()
    val_loss = tf.metrics.Mean()
    time_per_step = tf.metrics.Mean()
    history = {  # recorded at the end of an epoch only
        "train_cost": [],
        "train_lens_residuals": [],
        "val_cost": [],
        "val_lens_residuals": [],
        "learning_rate": [],
        "step": [],
        "wall_time": []
    }
    best_loss = np.inf
    global_start = time.time()
    estimated_time_for_epoch = 0
    out_of_time = False
    patience = args.patience
    step = 0
    lastest_checkpoint = 1
    for epoch in range(1, args.epochs + 1):
        if (time.time() - global_start
            ) > args.max_time * 3600 - estimated_time_for_epoch:
            break
        epoch_start = time.time()
        epoch_loss.reset_states()
        time_per_step.reset_states()
        with writer.as_default():
            for batch, distributed_inputs in enumerate(train_dataset):
                start = time.time()
                cost = distributed_train_step(distributed_inputs)
                # ========== Summary and logs ==================================================================================
                _time = time.time() - start
                time_per_step.update_state([_time])
                epoch_loss.update_state([cost])
                step += 1

            # Deflection angle residual
            kappa_true, alpha_true = distributed_inputs
            alpha_pred = ray_tracer.call(kappa_true)
            # Lens residual
            lens_true = phys.lens_source_func(kappa_true, w=args.source_w)
            lens_pred = phys.lens_source_func_given_alpha(alpha_pred,
                                                          w=args.source_w)
            train_chi_squared = tf.reduce_mean(tf.square(lens_true -
                                                         lens_pred))
            for res_idx in range(
                    min(args.n_residuals,
                        args.batch_size)):  # Residuals in train set
                tf.summary.image(f"Residuals {res_idx}",
                                 plot_to_image(
                                     residual_plot(alpha_true[res_idx],
                                                   alpha_pred[res_idx],
                                                   lens_true[res_idx],
                                                   lens_pred[res_idx])),
                                 step=step)

        # ========== Validation set ===================
            val_loss.reset_states()
            for distributed_inputs in val_dataset:  # Cost of val set
                test_cost = distributed_test_step(distributed_inputs)
                val_loss.update_state([test_cost])

            kappa_true, alpha_true = distributed_inputs
            alpha_pred = ray_tracer.call(kappa_true)
            lens_true = phys.lens_source_func(kappa_true, args.source_w)
            lens_pred = phys.lens_source_func_given_alpha(alpha_pred,
                                                          w=args.source_w)
            val_chi_squared = tf.reduce_mean(tf.square(lens_true - lens_pred))
            for res_idx in range(min(args.n_residuals,
                                     args.batch_size)):  # Residuals in val set
                tf.summary.image(f"Val Residuals {res_idx}",
                                 plot_to_image(
                                     residual_plot(alpha_true[res_idx],
                                                   alpha_pred[res_idx],
                                                   lens_true[res_idx],
                                                   lens_pred[res_idx])),
                                 step=step)

            train_cost = epoch_loss.result().numpy()
            val_cost = val_loss.result().numpy()
            tf.summary.scalar("Time per step",
                              time_per_step.result().numpy(),
                              step=step)
            tf.summary.scalar("Train lens residual",
                              train_chi_squared,
                              step=step)
            tf.summary.scalar("Val lens residual", val_chi_squared, step=step)
            tf.summary.scalar("MSE", train_cost, step=step)
            tf.summary.scalar("Val MSE", val_cost, step=step)
            tf.summary.scalar("Learning rate", optim.lr(step), step=step)
        print(
            f"epoch {epoch} | train loss {train_cost:.3e} | val loss {val_cost:.3e} | learning rate {optim.lr(step).numpy():.2e} | "
            f"time per step {time_per_step.result():.2e} s")
        history["train_cost"].append(train_cost)
        history["train_lens_residuals"].append(train_chi_squared.numpy())
        history["val_cost"].append(val_cost)
        history["val_lens_residuals"].append(val_chi_squared.numpy())
        history["learning_rate"].append(optim.lr(step).numpy())
        history["step"].append(step)
        history["wall_time"].append(time.time() - global_start)

        cost = train_cost if args.track_train else val_cost
        if np.isnan(cost):
            print("Training broke the Universe")
            break
        if cost < best_loss * (1 - args.tolerance):
            best_loss = cost
            patience = args.patience
        else:
            patience -= 1
        if (time.time() - global_start) > args.max_time * 3600:
            out_of_time = True
        if save_checkpoint:
            checkpoint_manager.checkpoint.step.assign_add(1)  # a bit of a hack
            if epoch % args.checkpoints == 0 or patience == 0 or epoch == args.epochs - 1 or out_of_time:
                with open(os.path.join(checkpoints_dir, "score_sheet.txt"),
                          mode="a") as f:
                    np.savetxt(f, np.array([[lastest_checkpoint, cost]]))
                lastest_checkpoint += 1
                checkpoint_manager.save()
                print("Saved checkpoint for step {}: {}".format(
                    int(checkpoint_manager.checkpoint.step),
                    checkpoint_manager.latest_checkpoint))
        if patience == 0:
            print("Reached patience")
            break
        if out_of_time:
            break
        if epoch > 0:  # First epoch is always very slow and not a good estimate of an epoch time.
            estimated_time_for_epoch = time.time() - epoch_start
        if optim.lr(step).numpy() < 1e-9:
            print("Reached learning rate limit")
            break
    return history, best_loss
Ejemplo n.º 18
0
 def __init__(self,
              observation,
              noise_rms,
              psf,
              phys: PhysicalModel,
              rim: RIM,
              source_vae: VAE,
              kappa_vae: VAE,
              n_samples=100,
              sigma_source=0.5,
              sigma_kappa=0.5):
     """
     Make a copy of initial parameters \varphi^{(0)} and compute the Fisher diagonal F_{ii}
     """
     wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum(
         tf.sqrt(k), axis=(1, 2, 3), keepdims=True))
     # Baseline prediction from observation
     source_pred, kappa_pred, chi_squared = rim.predict(
         observation, noise_rms, psf)
     # Latent code of model predictions
     z_source, _ = source_vae.encoder(source_pred[-1])
     z_kappa, _ = kappa_vae.encoder(log_10(kappa_pred[-1]))
     # Deepcopy of the initial parameters
     self.initial_params = [
         deepcopy(w) for w in rim.unet.trainable_variables
     ]
     self.fisher_diagonal = [tf.zeros_like(w) for w in self.initial_params]
     for n in range(n_samples):
         # Sample latent code around the prediction mean
         z_s = tf.random.normal(shape=[1, source_vae.latent_size],
                                mean=z_source,
                                stddev=sigma_source)
         z_k = tf.random.normal(shape=[1, kappa_vae.latent_size],
                                mean=z_kappa,
                                stddev=sigma_kappa)
         # Decode
         sampled_source = tf.nn.relu(source_vae.decode(z_s))
         sampled_source /= tf.reduce_max(sampled_source,
                                         axis=(1, 2, 3),
                                         keepdims=True)
         sampled_kappa = kappa_vae.decode(z_k)  # output in log_10 space
         # Simulate observation
         sampled_observation = phys.noisy_forward(sampled_source,
                                                  10**sampled_kappa,
                                                  noise_rms, psf)
         # Compute the gradient of the MSE
         with tf.GradientTape() as tape:
             tape.watch(rim.unet.trainable_variables)
             s, k, chi_squared = rim.call(sampled_observation, noise_rms,
                                          psf)
             # Remove the temperature from the loss when computing the Fisher: sum instead of mean, and weighted sum is renormalized by number of pixels
             _kappa_mse = phys.kappa_pixels**2 * tf.reduce_sum(
                 wk(10**sampled_kappa) * (k - sampled_kappa)**2,
                 axis=(2, 3, 4))
             cost = tf.reduce_sum(_kappa_mse)
             cost += tf.reduce_sum((s - sampled_source)**2)
         grad = tape.gradient(cost, rim.unet.trainable_variables)
         # Square the derivative relative to initial parameters and add to total
         self.fisher_diagonal = [
             F + g**2 / n_samples
             for F, g in zip(self.fisher_diagonal, grad)
         ]