Ejemplo n.º 1
0
def main(args):
    options = tf.io.TFRecordOptions(compression_type=args.compression_type)
    with open(os.path.join(args.first_stage_model_id, "model_hparams.json"),
              "r") as f:
        vae_hparams = json.load(f)
    # load weights
    vae = VAE(**vae_hparams)
    ckpt1 = tf.train.Checkpoint(net=vae)
    checkpoint_manager1 = tf.train.CheckpointManager(ckpt1,
                                                     args.first_stage_model_id,
                                                     1)
    checkpoint_manager1.checkpoint.restore(
        checkpoint_manager1.latest_checkpoint).expect_partial()

    n_batch = args.total_items // args.batch_size
    batch_per_record = args.n_records // n_batch
    last_record_n_batch = batch_per_record + n_batch % args.n_records

    for record in range(args.n_records - 1):
        with tf.io.TFRecordWriter(
                os.path.join(args.output_dir, f"data_{record:02d}.tfrecords"),
                options) as writer:
            for batch in range(batch_per_record):
                z = tf.random.normal(shape=[args.batch_size, vae.latent_size])
                kappa_batch = vae.decode(z)
                for kappa in kappa_batch:
                    features = {
                        "kappa": _bytes_feature(kappa.numpy().tobytes()),
                        "kappa pixels": _int64_feature(kappa.shape[0]),
                    }

                    record = tf.train.Example(features=tf.train.Features(
                        feature=features)).SerializeToString()
                    writer.write(record)

    with tf.io.TFRecordWriter(
            os.path.join(args.output_dir,
                         f"data_{args.n_record-1:02d}.tfrecords"),
            options) as writer:
        for batch in range(last_record_n_batch):
            z = tf.random.normal(shape=[args.batch_size, vae.latent_size])
            kappa_batch = vae.decode(z)
            for kappa in kappa_batch:
                features = {
                    "kappa": _bytes_feature(kappa.numpy().tobytes()),
                    "kappa pixels": _int64_feature(kappa.shape[0]),
                }

                record = tf.train.Example(features=tf.train.Features(
                    feature=features)).SerializeToString()
                writer.write(record)
def main(args):
    if THIS_WORKER > 1:
        time.sleep(5)
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)
    if args.seed is not None:
        tf.random.set_seed(args.seed)

    # Load first stage
    with open(os.path.join(args.kappa_first_stage_vae, "model_hparams.json"),
              "r") as f:
        kappa_vae_hparams = json.load(f)
    kappa_vae = VAE(**kappa_vae_hparams)
    ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae)
    checkpoint_manager1 = tf.train.CheckpointManager(
        ckpt1, args.kappa_first_stage_vae, 1)
    checkpoint_manager1.checkpoint.restore(
        checkpoint_manager1.latest_checkpoint).expect_partial()
    kappa_vae.trainable = False

    options = tf.io.TFRecordOptions(compression_type=args.compression_type)
    with tf.io.TFRecordWriter(
            os.path.join(args.output_dir, f"data_{THIS_WORKER}.tfrecords"),
            options) as writer:
        print(
            f"Started worker {THIS_WORKER} at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}"
        )
        for _ in range((THIS_WORKER - 1) * args.batch_size, args.len_dataset,
                       N_WORKERS * args.batch_size):
            kappa = 10**kappa_vae.sample(args.batch_size)
            # Most important info to records are kappa and kappa fov, the rest are just fill-ins
            # to match same description as previous TNG tfrecords
            records = encode_examples(kappa=kappa,
                                      einstein_radius_init=[0.] *
                                      args.batch_size,
                                      einstein_radius=[0.] * args.batch_size,
                                      rescalings=[0.] * args.batch_size,
                                      z_source=0.,
                                      z_lens=0.,
                                      kappa_fov=args.kappa_fov,
                                      sigma_crit=0.,
                                      kappa_ids=[0] * args.batch_size)
            for record in records:
                writer.write(record)
Ejemplo n.º 3
0
def distributed_strategy(args):

    model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model)
    path = os.getenv('CENSAI_PATH') + "/results/"
    dataset = []
    for file in sorted(glob.glob(path + args.h5_pattern)):
        try:
            dataset.append(h5py.File(file, "r"))
        except:
            continue
    B = dataset[0]["source"].shape[0]
    data_len = len(dataset) * B // N_WORKERS

    ps_observation = PowerSpectrum(bins=args.observation_coherence_bins,
                                   pixels=128)
    ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=128)
    ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=128)

    phys = PhysicalModel(
        pixels=128,
        kappa_pixels=128,
        src_pixels=128,
        image_fov=7.69,
        kappa_fov=7.69,
        src_fov=3.,
        method="fft",
    )

    with open(os.path.join(model, "unet_hparams.json")) as f:
        unet_params = json.load(f)
    unet_params["kernel_l2_amp"] = args.l2_amp
    unet = Model(**unet_params)
    ckpt = tf.train.Checkpoint(net=unet)
    checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1)
    checkpoint_manager.checkpoint.restore(
        checkpoint_manager.latest_checkpoint).expect_partial()
    with open(os.path.join(model, "rim_hparams.json")) as f:
        rim_params = json.load(f)
    rim = RIM(phys, unet, **rim_params)

    kvae_path = os.path.join(os.getenv('CENSAI_PATH'), "models",
                             args.kappa_vae)
    with open(os.path.join(kvae_path, "model_hparams.json"), "r") as f:
        kappa_vae_hparams = json.load(f)
    kappa_vae = VAE(**kappa_vae_hparams)
    ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae)
    checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, kvae_path, 1)
    checkpoint_manager1.checkpoint.restore(
        checkpoint_manager1.latest_checkpoint).expect_partial()

    svae_path = os.path.join(os.getenv('CENSAI_PATH'), "models",
                             args.source_vae)
    with open(os.path.join(svae_path, "model_hparams.json"), "r") as f:
        source_vae_hparams = json.load(f)
    source_vae = VAE(**source_vae_hparams)
    ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae)
    checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, svae_path, 1)
    checkpoint_manager2.checkpoint.restore(
        checkpoint_manager2.latest_checkpoint).expect_partial()
    wk = lambda k: tf.sqrt(k) / tf.reduce_sum(
        tf.sqrt(k), axis=(1, 2, 3), keepdims=True)

    # Freeze L5
    # encoding layers
    # rim.unet.layers[0].trainable = False # L1
    # rim.unet.layers[1].trainable = False
    # rim.unet.layers[2].trainable = False
    # rim.unet.layers[3].trainable = False
    # rim.unet.layers[4].trainable = False # L5
    # GRU
    # rim.unet.layers[5].trainable = False
    # rim.unet.layers[6].trainable = False
    # rim.unet.layers[7].trainable = False
    # rim.unet.layers[8].trainable = False
    # rim.unet.layers[9].trainable = False
    # rim.unet.layers[15].trainable = False  # bottleneck GRU
    # output layer
    # rim.unet.layers[-2].trainable = False
    # input layer
    # rim.unet.layers[-1].trainable = False
    # decoding layers
    # rim.unet.layers[10].trainable = False # L5
    # rim.unet.layers[11].trainable = False
    # rim.unet.layers[12].trainable = False
    # rim.unet.layers[13].trainable = False
    # rim.unet.layers[14].trainable = False # L1

    with h5py.File(
            os.path.join(
                os.getenv("CENSAI_PATH"), "results",
                args.experiment_name + "_" + args.model + "_" + args.dataset +
                f"_{THIS_WORKER:03d}.h5"), 'w') as hf:
        hf.create_dataset(name="observation",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(name="psf",
                          shape=[data_len, 20, 20, 1],
                          dtype=np.float32)
        hf.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32)
        hf.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32)
        hf.create_dataset(
            name="source",
            shape=[data_len, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(
            name="kappa",
            shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(name="observation_pred",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(name="observation_pred_reoptimized",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(
            name="source_pred",
            shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(
            name="source_pred_reoptimized",
            shape=[data_len, phys.src_pixels, phys.src_pixels, 1])
        hf.create_dataset(name="kappa_pred",
                          shape=[
                              data_len, rim.steps, phys.kappa_pixels,
                              phys.kappa_pixels, 1
                          ],
                          dtype=np.float32)
        hf.create_dataset(
            name="kappa_pred_reoptimized",
            shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(name="chi_squared",
                          shape=[data_len, rim.steps],
                          dtype=np.float32)
        hf.create_dataset(name="chi_squared_reoptimized",
                          shape=[data_len],
                          dtype=np.float32)
        hf.create_dataset(name="chi_squared_reoptimized_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="source_optim_mse",
                          shape=[data_len],
                          dtype=np.float32)
        hf.create_dataset(name="source_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_optim_mse",
                          shape=[data_len],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="observation_coherence_spectrum",
                          shape=[data_len, args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="source_coherence_spectrum",
                          shape=[data_len, args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_coherence_spectrum2",
                          shape=[data_len, args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_coherence_spectrum_reoptimized",
                          shape=[data_len, args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="source_coherence_spectrum2",
                          shape=[data_len, args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="source_coherence_spectrum_reoptimized",
                          shape=[data_len, args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_coherence_spectrum",
                          shape=[data_len, args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_coherence_spectrum_reoptimized",
                          shape=[data_len, args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_frequencies",
                          shape=[args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="source_frequencies",
                          shape=[args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_frequencies",
                          shape=[args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32)
        hf.create_dataset(name="source_fov", shape=[1], dtype=np.float32)
        hf.create_dataset(name="observation_fov", shape=[1], dtype=np.float32)
        for batch, j in enumerate(
                range((THIS_WORKER - 1) * data_len, THIS_WORKER * data_len)):
            b = j // B
            k = j % B
            observation = dataset[b]["observation"][k][None, ...]
            source = dataset[b]["source"][k][None, ...]
            kappa = dataset[b]["kappa"][k][None, ...]
            noise_rms = np.array([dataset[b]["noise_rms"][k]])
            psf = dataset[b]["psf"][k][None, ...]
            fwhm = dataset[b]["psf_fwhm"][k]

            checkpoint_manager.checkpoint.restore(
                checkpoint_manager.latest_checkpoint).expect_partial(
                )  # reset model weights
            # Compute predictions for kappa and source
            source_pred, kappa_pred, chi_squared = rim.predict(
                observation, noise_rms, psf)
            observation_pred = phys.forward(source_pred[-1], kappa_pred[-1],
                                            psf)
            # reset the seed for reproducible sampling in the VAE for EWC
            tf.random.set_seed(args.seed)
            np.random.seed(args.seed)
            # Initialize regularization term
            ewc = EWC(observation=observation,
                      noise_rms=noise_rms,
                      psf=psf,
                      phys=phys,
                      rim=rim,
                      source_vae=source_vae,
                      kappa_vae=kappa_vae,
                      n_samples=args.sample_size,
                      sigma_source=args.source_vae_ball_size,
                      sigma_kappa=args.kappa_vae_ball_size)
            # Re-optimize weights of the model
            STEPS = args.re_optimize_steps
            learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=args.learning_rate,
                decay_rate=args.decay_rate,
                decay_steps=args.decay_steps,
                staircase=args.staircase)
            optim = tf.keras.optimizers.SGD(
                learning_rate=learning_rate_schedule)

            chi_squared_series = tf.TensorArray(DTYPE, size=STEPS)
            source_mse = tf.TensorArray(DTYPE, size=STEPS)
            kappa_mse = tf.TensorArray(DTYPE, size=STEPS)
            best = chi_squared[-1, 0]
            source_best = source_pred[-1]
            kappa_best = kappa_pred[-1]
            source_mse_best = tf.reduce_mean(
                (source_best - rim.source_inverse_link(source))**2)
            kappa_mse_best = tf.reduce_sum(
                wk(kappa) * (kappa_best - rim.kappa_inverse_link(kappa))**2)

            for current_step in tqdm(range(STEPS)):
                with tf.GradientTape() as tape:
                    tape.watch(unet.trainable_variables)
                    s, k, chi_sq = rim.call(observation,
                                            noise_rms,
                                            psf,
                                            outer_tape=tape)
                    cost = tf.reduce_mean(chi_sq)  # mean over time steps
                    cost += tf.reduce_sum(rim.unet.losses)  # L2 regularisation
                    cost += args.lam_ewc * ewc.penalty(
                        rim)  # Elastic Weights Consolidation

                log_likelihood = chi_sq[-1]
                chi_squared_series = chi_squared_series.write(
                    index=current_step, value=log_likelihood)
                source_o = s[-1]
                kappa_o = k[-1]
                source_mse = source_mse.write(
                    index=current_step,
                    value=tf.reduce_mean(
                        (source_o - rim.source_inverse_link(source))**2))
                kappa_mse = kappa_mse.write(
                    index=current_step,
                    value=tf.reduce_sum(
                        wk(kappa) *
                        (kappa_o - rim.kappa_inverse_link(kappa))**2))
                if 2 * chi_sq[-1, 0] < 1.0 and args.early_stopping:
                    source_best = rim.source_link(source_o)
                    kappa_best = rim.kappa_link(kappa_o)
                    best = chi_sq[-1, 0]
                    source_mse_best = tf.reduce_mean(
                        (source_o - rim.source_inverse_link(source))**2)
                    kappa_mse_best = tf.reduce_sum(
                        wk(kappa) *
                        (kappa_o - rim.kappa_inverse_link(kappa))**2)
                    break
                if chi_sq[-1, 0] < best:
                    source_best = rim.source_link(source_o)
                    kappa_best = rim.kappa_link(kappa_o)
                    best = chi_sq[-1, 0]
                    source_mse_best = tf.reduce_mean(
                        (source_o - rim.source_inverse_link(source))**2)
                    kappa_mse_best = tf.reduce_sum(
                        wk(kappa) *
                        (kappa_o - rim.kappa_inverse_link(kappa))**2)

                grads = tape.gradient(cost, unet.trainable_variables)
                optim.apply_gradients(zip(grads, unet.trainable_variables))

            source_o = source_best
            kappa_o = kappa_best
            y_pred = phys.forward(source_o, kappa_o, psf)
            chi_sq_series = tf.transpose(chi_squared_series.stack(),
                                         perm=[1, 0])
            source_mse = source_mse.stack()[None, ...]
            kappa_mse = kappa_mse.stack()[None, ...]

            # Compute Power spectrum of converged predictions
            _ps_observation = ps_observation.cross_correlation_coefficient(
                observation[..., 0], observation_pred[..., 0])
            _ps_observation2 = ps_observation.cross_correlation_coefficient(
                observation[..., 0], y_pred[..., 0])
            _ps_kappa = ps_kappa.cross_correlation_coefficient(
                log_10(kappa)[..., 0],
                log_10(kappa_pred[-1])[..., 0])
            _ps_kappa2 = ps_kappa.cross_correlation_coefficient(
                log_10(kappa)[..., 0], log_10(kappa_o[..., 0]))
            _ps_source = ps_source.cross_correlation_coefficient(
                source[..., 0], source_pred[-1][..., 0])
            _ps_source2 = ps_source.cross_correlation_coefficient(
                source[..., 0], source_o[..., 0])

            # save results
            hf["observation"][batch] = observation.astype(np.float32)
            hf["psf"][batch] = psf.astype(np.float32)
            hf["psf_fwhm"][batch] = fwhm
            hf["noise_rms"][batch] = noise_rms.astype(np.float32)
            hf["source"][batch] = source.astype(np.float32)
            hf["kappa"][batch] = kappa.astype(np.float32)
            hf["observation_pred"][batch] = observation_pred.numpy().astype(
                np.float32)
            hf["observation_pred_reoptimized"][batch] = y_pred.numpy().astype(
                np.float32)
            hf["source_pred"][batch] = tf.transpose(
                source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            hf["source_pred_reoptimized"][batch] = source_o.numpy().astype(
                np.float32)
            hf["kappa_pred"][batch] = tf.transpose(
                kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            hf["kappa_pred_reoptimized"][batch] = kappa_o.numpy().astype(
                np.float32)
            hf["chi_squared"][batch] = 2 * tf.transpose(
                chi_squared).numpy().astype(np.float32)
            hf["chi_squared_reoptimized"][batch] = 2 * best.numpy().astype(
                np.float32)
            hf["chi_squared_reoptimized_series"][
                batch] = 2 * chi_sq_series.numpy().astype(np.float32)
            hf["source_optim_mse"][batch] = source_mse_best.numpy().astype(
                np.float32)
            hf["source_optim_mse_series"][batch] = source_mse.numpy().astype(
                np.float32)
            hf["kappa_optim_mse"][batch] = kappa_mse_best.numpy().astype(
                np.float32)
            hf["kappa_optim_mse_series"][batch] = kappa_mse.numpy().astype(
                np.float32)
            hf["observation_coherence_spectrum"][batch] = _ps_observation
            hf["observation_coherence_spectrum_reoptimized"][
                batch] = _ps_observation2
            hf["source_coherence_spectrum"][batch] = _ps_source
            hf["source_coherence_spectrum_reoptimized"][batch] = _ps_source2
            hf["kappa_coherence_spectrum"][batch] = _ps_kappa
            hf["kappa_coherence_spectrum_reoptimized"][batch] = _ps_kappa2

            if batch == 0:
                _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels //
                                                                2],
                                    bins=ps_observation.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["observation_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(
                    phys.src_pixels)[:phys.src_pixels // 2],
                                    bins=ps_source.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["source_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(
                    phys.kappa_pixels)[:phys.kappa_pixels // 2],
                                    bins=ps_kappa.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["kappa_frequencies"][:] = f
                hf["kappa_fov"][0] = phys.kappa_fov
                hf["source_fov"][0] = phys.src_fov
Ejemplo n.º 4
0
def main(args):
    files = []
    files.extend(glob.glob(os.path.join(args.dataset, "*.tfrecords")))
    np.random.shuffle(files)
    # Read concurrently from multiple records
    files = tf.data.Dataset.from_tensor_slices(files)
    dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
        x, compression_type=args.compression_type),
                               block_length=args.block_length,
                               num_parallel_calls=tf.data.AUTOTUNE)

    if args.type == "cosmos":
        from censai.data.cosmos import decode_shape, decode_image as decode, preprocess_image as preprocess
    elif args.type == "kappa":
        from censai.data.kappa_tng import decode_shape, decode_train as decode
        from censai.definitions import log_10 as preprocess
    # Read off global parameters from first example in dataset
    for pixels in dataset.map(decode_shape):
        break
    vars(args).update({"pixels": int(pixels)})
    dataset = dataset.map(decode).map(preprocess).shuffle(
        args.buffer_size).batch(args.batch_size).take(args.n_plots).cache(
            args.cache)

    model_list = glob.glob(
        os.path.join(os.getenv("CENSAI_PATH"), "models",
                     args.model_prefixe + "*"))
    for model in model_list:
        if "second_stage" in model:
            continue
        with open(os.path.join(model, "model_hparams.json")) as f:
            vae_hparams = json.load(f)

        # load weights
        vae = VAE(**vae_hparams)
        ckpt1 = tf.train.Checkpoint(net=vae)
        checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, model, 1)
        checkpoint_manager1.checkpoint.restore(
            checkpoint_manager1.latest_checkpoint).expect_partial()
        vae.trainable = False
        model_name = os.path.split(model)[-1]

        for batch, images in enumerate(dataset):
            y_pred = vae(images)
            fig = reconstruction_plot(images, y_pred)
            fig.suptitle(model_name)
            fig.savefig(
                os.path.join(
                    os.getenv("CENSAI_PATH"), "results",
                    "vae_reconstruction_" + model_name + "_" +
                    args.output_postfixe + f"_{batch:02d}.png"))
            fig.clf()
            y_pred = vae.sample(args.sampling_size)
            fig = sampling_plot(y_pred)
            fig.suptitle(model_name)
            fig.savefig(
                os.path.join(
                    os.getenv("CENSAI_PATH"), "results",
                    "vae_sampling_" + model_name + "_" + args.output_postfixe +
                    f"_{batch:02d}.png"))
            fig.clf()
Ejemplo n.º 5
0
def main(args):
    files = glob.glob(os.path.join(args.dataset, "*.tfrecords"))
    files = tf.data.Dataset.from_tensor_slices(files)
    dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
        x, compression_type=args.compression_type),
                               block_length=1,
                               num_parallel_calls=tf.data.AUTOTUNE)
    for physical_params in dataset.map(decode_physical_model_info):
        break
    dataset = dataset.map(decode_train)

    # files = glob.glob(os.path.join(args.source_dataset, "*.tfrecords"))
    # files = tf.data.Dataset.from_tensor_slices(files)
    # source_dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type),
    #                            block_length=1, num_parallel_calls=tf.data.AUTOTUNE)
    # source_dataset = source_dataset.map(decode_image).map(preprocess_image).shuffle(10000).batch(args.sample_size)

    with open(os.path.join(args.kappa_vae, "model_hparams.json"), "r") as f:
        kappa_vae_hparams = json.load(f)
    kappa_vae = VAE(**kappa_vae_hparams)
    ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae)
    checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, args.kappa_vae, 1)
    checkpoint_manager1.checkpoint.restore(
        checkpoint_manager1.latest_checkpoint).expect_partial()

    with open(os.path.join(args.source_vae, "model_hparams.json"), "r") as f:
        source_vae_hparams = json.load(f)
    source_vae = VAE(**source_vae_hparams)
    ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae)
    checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, args.source_vae, 1)
    checkpoint_manager2.checkpoint.restore(
        checkpoint_manager2.latest_checkpoint).expect_partial()

    phys = PhysicalModel(pixels=physical_params["pixels"].numpy(),
                         kappa_pixels=physical_params["kappa pixels"].numpy(),
                         src_pixels=physical_params["src pixels"].numpy(),
                         image_fov=physical_params["image fov"].numpy(),
                         kappa_fov=physical_params["kappa fov"].numpy(),
                         src_fov=physical_params["source fov"].numpy(),
                         method="fft")

    # simulate observations
    kappa = 10**kappa_vae.sample(args.sample_size)
    source = preprocess_image(source_vae.sample(args.sample_size))
    # for source in source_dataset:
    #     break
    fwhm = tf.random.normal(shape=[args.sample_size],
                            mean=1.5 * phys.image_fov / phys.pixels,
                            stddev=0.5 * phys.image_fov / phys.pixels)
    # noise_rms = tf.random.normal(shape=[args.sample_size], mean=args.noise_mean, stddev=args.noise_std)
    psf = phys.psf_models(fwhm, cutout_size=20)
    y_vae = phys.forward(source, kappa, psf)

    with h5py.File(
            os.path.join(os.getenv("CENSAI_PATH"), "results",
                         args.output_name + ".h5"), 'w') as hf:
        # rank these observations against the dataset with L2 norm
        for i in tqdm(range(args.sample_size)):
            distances = []
            for y_d, _, _, _, _ in dataset:
                distances.append(
                    tf.sqrt(tf.reduce_sum(
                        (y_d - y_vae[i][None, ...])**2)).numpy().astype(
                            np.float32))
            k_indices = np.argsort(distances)[:args.k]

            # save results
            g = hf.create_group(f"sample_{i:02d}")
            g.create_dataset(name="matched_source",
                             shape=[args.k, phys.src_pixels, phys.src_pixels],
                             dtype=np.float32)
            g.create_dataset(
                name="matched_kappa",
                shape=[args.k, phys.kappa_pixels, phys.kappa_pixels],
                dtype=np.float32)
            g.create_dataset(name="matched_obs",
                             shape=[args.k, phys.pixels, phys.pixels],
                             dtype=np.float32)
            g.create_dataset(name="matched_psf",
                             shape=[args.k, 20, 20],
                             dtype=np.float32)
            g.create_dataset(name="matched_noise_rms",
                             shape=[args.k],
                             dtype=np.float32)
            g.create_dataset(name="obs_L2_distance",
                             shape=[args.k],
                             dtype=np.float32)
            g["vae_source"] = source[i, ..., 0].numpy().astype(np.float32)
            g["vae_kappa"] = kappa[i, ..., 0].numpy().astype(np.float32)
            g["vae_obs"] = y_vae[i, ..., 0].numpy().astype(np.float32)
            g["vae_psf"] = psf[i, ..., 0].numpy().astype(np.float32)

            for rank, j in enumerate(k_indices):
                # fetch back the matched observation
                for y_d, source_d, kappa_d, noise_rms_d, psf_d in dataset.skip(
                        j):
                    break
                # g["vae_noise_rms"] = noise_rms[i].numpy().astype(np.float32)
                g["matched_source"][rank] = source_d[..., 0].numpy().astype(
                    np.float32)
                g["matched_kappa"][rank] = kappa_d[..., 0].numpy().astype(
                    np.float32)
                g["matched_obs"][rank] = y_d[..., 0].numpy().astype(np.float32)
                g["matched_noise_rms"][rank] = noise_rms_d.numpy().astype(
                    np.float32)
                g["matched_psf"][rank] = psf_d[...,
                                               0].numpy().astype(np.float32)
                g["obs_L2_distance"][rank] = distances[j]
Ejemplo n.º 6
0
def distributed_strategy(args):
    psf_pixels = 20
    pixels = 128
    model = os.path.join(os.getenv('CENSAI_PATH'), "models", args.model)

    ps_observation = PowerSpectrum(bins=args.observation_coherence_bins,
                                   pixels=pixels)
    ps_source = PowerSpectrum(bins=args.source_coherence_bins, pixels=pixels)
    ps_kappa = PowerSpectrum(bins=args.kappa_coherence_bins, pixels=pixels)

    phys = PhysicalModel(
        pixels=pixels,
        kappa_pixels=pixels,
        src_pixels=pixels,
        image_fov=7.68,
        kappa_fov=7.68,
        src_fov=3.,
        method="fft",
    )

    with open(os.path.join(model, "unet_hparams.json")) as f:
        unet_params = json.load(f)
    unet_params["kernel_l2_amp"] = args.l2_amp
    unet = Model(**unet_params)
    ckpt = tf.train.Checkpoint(net=unet)
    checkpoint_manager = tf.train.CheckpointManager(ckpt, model, 1)
    checkpoint_manager.checkpoint.restore(
        checkpoint_manager.latest_checkpoint).expect_partial()
    with open(os.path.join(model, "rim_hparams.json")) as f:
        rim_params = json.load(f)
    rim_params["source_link"] = "relu"
    rim = RIM(phys, unet, **rim_params)

    kvae_path = os.path.join(os.getenv('CENSAI_PATH'), "models",
                             args.kappa_vae)
    with open(os.path.join(kvae_path, "model_hparams.json"), "r") as f:
        kappa_vae_hparams = json.load(f)
    kappa_vae = VAE(**kappa_vae_hparams)
    ckpt1 = tf.train.Checkpoint(step=tf.Variable(1), net=kappa_vae)
    checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, kvae_path, 1)
    checkpoint_manager1.checkpoint.restore(
        checkpoint_manager1.latest_checkpoint).expect_partial()

    svae_path = os.path.join(os.getenv('CENSAI_PATH'), "models",
                             args.source_vae)
    with open(os.path.join(svae_path, "model_hparams.json"), "r") as f:
        source_vae_hparams = json.load(f)
    source_vae = VAE(**source_vae_hparams)
    ckpt2 = tf.train.Checkpoint(step=tf.Variable(1), net=source_vae)
    checkpoint_manager2 = tf.train.CheckpointManager(ckpt2, svae_path, 1)
    checkpoint_manager2.checkpoint.restore(
        checkpoint_manager2.latest_checkpoint).expect_partial()

    model_name = os.path.split(model)[-1]
    wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum(
        tf.sqrt(k), axis=(1, 2, 3), keepdims=True))
    with h5py.File(
            os.path.join(
                os.getenv("CENSAI_PATH"), "results", args.experiment_name +
                "_" + model_name + f"_{THIS_WORKER:02d}.h5"), 'w') as hf:
        data_len = args.size // N_WORKERS
        hf.create_dataset(name="observation",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(name="psf",
                          shape=[data_len, psf_pixels, psf_pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(name="psf_fwhm", shape=[data_len], dtype=np.float32)
        hf.create_dataset(name="noise_rms", shape=[data_len], dtype=np.float32)
        hf.create_dataset(
            name="source",
            shape=[data_len, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(
            name="kappa",
            shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(name="observation_pred",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(name="observation_pred_reoptimized",
                          shape=[data_len, phys.pixels, phys.pixels, 1],
                          dtype=np.float32)
        hf.create_dataset(
            name="source_pred",
            shape=[data_len, rim.steps, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(
            name="source_pred_reoptimized",
            shape=[data_len, phys.src_pixels, phys.src_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(name="kappa_pred",
                          shape=[
                              data_len, rim.steps, phys.kappa_pixels,
                              phys.kappa_pixels, 1
                          ],
                          dtype=np.float32)
        hf.create_dataset(
            name="kappa_pred_reoptimized",
            shape=[data_len, phys.kappa_pixels, phys.kappa_pixels, 1],
            dtype=np.float32)
        hf.create_dataset(name="chi_squared",
                          shape=[data_len, rim.steps],
                          dtype=np.float32)
        hf.create_dataset(name="chi_squared_reoptimized",
                          shape=[data_len, rim.steps],
                          dtype=np.float32)
        hf.create_dataset(name="chi_squared_reoptimized_series",
                          shape=[data_len, rim.steps, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="sampled_chi_squared_reoptimized_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="source_optim_mse",
                          shape=[data_len],
                          dtype=np.float32)
        hf.create_dataset(name="source_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="sampled_source_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_optim_mse",
                          shape=[data_len],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="sampled_kappa_optim_mse_series",
                          shape=[data_len, args.re_optimize_steps],
                          dtype=np.float32)
        hf.create_dataset(name="latent_kappa_gt_distance_init",
                          shape=[data_len, kappa_vae.latent_size],
                          dtype=np.float32)
        hf.create_dataset(name="latent_source_gt_distance_init",
                          shape=[data_len, source_vae.latent_size],
                          dtype=np.float32)
        hf.create_dataset(name="latent_kappa_gt_distance_end",
                          shape=[data_len, kappa_vae.latent_size],
                          dtype=np.float32)
        hf.create_dataset(name="latent_source_gt_distance_end",
                          shape=[data_len, source_vae.latent_size],
                          dtype=np.float32)
        hf.create_dataset(name="source_coherence_spectrum",
                          shape=[data_len, args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="source_coherence_spectrum_reoptimized",
                          shape=[data_len, args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_coherence_spectrum",
                          shape=[data_len, args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_coherence_spectrum_reoptimized",
                          shape=[data_len, args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_coherence_spectrum",
                          shape=[data_len, args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_coherence_spectrum_reoptimized",
                          shape=[data_len, args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="observation_frequencies",
                          shape=[args.observation_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="source_frequencies",
                          shape=[args.source_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_frequencies",
                          shape=[args.kappa_coherence_bins],
                          dtype=np.float32)
        hf.create_dataset(name="kappa_fov", shape=[1], dtype=np.float32)
        hf.create_dataset(name="source_fov", shape=[1], dtype=np.float32)
        hf.create_dataset(name="observation_fov", shape=[1], dtype=np.float32)
        for i in range(data_len):
            checkpoint_manager.checkpoint.restore(
                checkpoint_manager.latest_checkpoint).expect_partial(
                )  # reset model weights

            # Produce an observation
            kappa = 10**kappa_vae.sample(1)
            source = tf.nn.relu(source_vae.sample(1))
            source /= tf.reduce_max(source, axis=(1, 2, 3), keepdims=True)
            noise_rms = 10**tf.random.uniform(shape=[1],
                                              minval=-2.5,
                                              maxval=-1)
            fwhm = tf.random.uniform(shape=[1], minval=0.06, maxval=0.3)
            psf = phys.psf_models(fwhm, cutout_size=psf_pixels)
            observation = phys.noisy_forward(source, kappa, noise_rms, psf)

            # RIM predictions for kappa and source
            source_pred, kappa_pred, chi_squared = rim.predict(
                observation, noise_rms, psf)
            observation_pred = phys.forward(source_pred[-1], kappa_pred[-1],
                                            psf)
            source_o = source_pred[-1]
            kappa_o = kappa_pred[-1]

            # Latent code of model predictions
            z_source, _ = source_vae.encoder(source_o)
            z_kappa, _ = kappa_vae.encoder(log_10(kappa_o))

            # Ground truth latent code for oracle metrics
            z_source_gt, _ = source_vae.encoder(source)
            z_kappa_gt, _ = kappa_vae.encoder(log_10(kappa))

            # Re-optimize weights of the model
            STEPS = args.re_optimize_steps
            learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=args.learning_rate,
                decay_rate=args.decay_rate,
                decay_steps=args.decay_steps,
                staircase=args.staircase)
            optim = tf.keras.optimizers.RMSprop(
                learning_rate=learning_rate_schedule)

            chi_squared_series = tf.TensorArray(DTYPE, size=STEPS)
            source_mse = tf.TensorArray(DTYPE, size=STEPS)
            kappa_mse = tf.TensorArray(DTYPE, size=STEPS)
            sampled_chi_squared_series = tf.TensorArray(DTYPE, size=STEPS)
            sampled_source_mse = tf.TensorArray(DTYPE, size=STEPS)
            sampled_kappa_mse = tf.TensorArray(DTYPE, size=STEPS)

            best = chi_squared
            source_best = source_pred[-1]
            kappa_best = kappa_pred[-1]
            source_mse_best = tf.reduce_mean((source_best - source)**2)
            kappa_mse_best = tf.reduce_mean((kappa_best - log_10(kappa))**2)

            # ===================== Optimization ==============================
            for current_step in tqdm(range(STEPS)):
                # ===================== VAE SAMPLING ==============================

                # L1 distance with ground truth in latent space -- this is changed by an user defined value when using real data
                # z_source_std = tf.abs(z_source - z_source_gt)
                # z_kappa_std = tf.abs(z_kappa - z_kappa_gt)
                z_source_std = args.source_vae_ball_size
                z_kappa_std = args.kappa_vae_ball_size

                # Sample latent code, then decode and forward
                z_s = tf.random.normal(
                    shape=[args.sample_size, source_vae.latent_size],
                    mean=z_source,
                    stddev=z_source_std)
                z_k = tf.random.normal(
                    shape=[args.sample_size, kappa_vae.latent_size],
                    mean=z_kappa,
                    stddev=z_kappa_std)
                sampled_source = tf.nn.relu(source_vae.decode(z_s))
                sampled_source /= tf.reduce_max(sampled_source,
                                                axis=(1, 2, 3),
                                                keepdims=True)
                sampled_kappa = kappa_vae.decode(z_k)  # output in log_10 space
                sampled_observation = phys.noisy_forward(
                    sampled_source, 10**sampled_kappa, noise_rms,
                    tf.tile(psf, [args.sample_size, 1, 1, 1]))
                with tf.GradientTape() as tape:
                    tape.watch(unet.trainable_variables)
                    s, k, chi_sq = rim.call(
                        sampled_observation,
                        noise_rms,
                        tf.tile(psf, [args.sample_size, 1, 1, 1]),
                        outer_tape=tape)
                    _kappa_mse = tf.reduce_sum(wk(10**sampled_kappa) *
                                               (k - sampled_kappa)**2,
                                               axis=(2, 3, 4))
                    cost = tf.reduce_mean(_kappa_mse)
                    cost += tf.reduce_mean((s - sampled_source)**2)
                    cost += tf.reduce_sum(rim.unet.losses)  # weight decay

                grads = tape.gradient(cost, unet.trainable_variables)
                optim.apply_gradients(zip(grads, unet.trainable_variables))

                # Record performance on sampled dataset
                sampled_chi_squared_series = sampled_chi_squared_series.write(
                    index=current_step,
                    value=tf.squeeze(tf.reduce_mean(chi_sq[-1])))
                sampled_source_mse = sampled_source_mse.write(
                    index=current_step,
                    value=tf.reduce_mean((s[-1] - sampled_source)**2))
                sampled_kappa_mse = sampled_kappa_mse.write(
                    index=current_step,
                    value=tf.reduce_mean((k[-1] - sampled_kappa)**2))
                # Record model prediction on data
                s, k, chi_sq = rim.call(observation, noise_rms, psf)
                chi_squared_series = chi_squared_series.write(
                    index=current_step, value=tf.squeeze(chi_sq))
                source_o = s[-1]
                kappa_o = k[-1]
                # oracle metrics, remove when using real data
                source_mse = source_mse.write(index=current_step,
                                              value=tf.reduce_mean(
                                                  (source_o - source)**2))
                kappa_mse = kappa_mse.write(index=current_step,
                                            value=tf.reduce_mean(
                                                (kappa_o - log_10(kappa))**2))

                if abs(chi_sq[-1, 0] - 1) < abs(best[-1, 0] - 1):
                    source_best = tf.nn.relu(source_o)
                    kappa_best = 10**kappa_o
                    best = chi_sq
                    source_mse_best = tf.reduce_mean((source_best - source)**2)
                    kappa_mse_best = tf.reduce_mean(
                        (kappa_best - log_10(kappa))**2)

            source_o = source_best
            kappa_o = kappa_best
            y_pred = phys.forward(source_o, kappa_o, psf)

            chi_sq_series = tf.transpose(chi_squared_series.stack())
            source_mse = source_mse.stack()
            kappa_mse = kappa_mse.stack()
            sampled_chi_squared_series = sampled_chi_squared_series.stack()
            sampled_source_mse = sampled_source_mse.stack()
            sampled_kappa_mse = sampled_kappa_mse.stack()

            # Latent code of optimized model predictions
            z_source_opt, _ = source_vae.encoder(tf.nn.relu(source_o))
            z_kappa_opt, _ = kappa_vae.encoder(log_10(kappa_o))

            # Compute Power spectrum of converged predictions
            _ps_observation = ps_observation.cross_correlation_coefficient(
                observation[..., 0], observation_pred[..., 0])
            _ps_observation2 = ps_observation.cross_correlation_coefficient(
                observation[..., 0], y_pred[..., 0])
            _ps_kappa = ps_kappa.cross_correlation_coefficient(
                log_10(kappa)[..., 0],
                log_10(kappa_pred[-1])[..., 0])
            _ps_kappa2 = ps_kappa.cross_correlation_coefficient(
                log_10(kappa)[..., 0], log_10(kappa_o[..., 0]))
            _ps_source = ps_source.cross_correlation_coefficient(
                source[..., 0], source_pred[-1][..., 0])
            _ps_source2 = ps_source.cross_correlation_coefficient(
                source[..., 0], source_o[..., 0])

            # save results
            hf["observation"][i] = observation.numpy().astype(np.float32)
            hf["psf"][i] = psf.numpy().astype(np.float32)
            hf["psf_fwhm"][i] = fwhm.numpy().astype(np.float32)
            hf["noise_rms"][i] = noise_rms.numpy().astype(np.float32)
            hf["source"][i] = source.numpy().astype(np.float32)
            hf["kappa"][i] = kappa.numpy().astype(np.float32)
            hf["observation_pred"][i] = observation_pred.numpy().astype(
                np.float32)
            hf["observation_pred_reoptimized"][i] = y_pred.numpy().astype(
                np.float32)
            hf["source_pred"][i] = tf.transpose(
                source_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            hf["source_pred_reoptimized"][i] = source_o.numpy().astype(
                np.float32)
            hf["kappa_pred"][i] = tf.transpose(
                kappa_pred, perm=(1, 0, 2, 3, 4)).numpy().astype(np.float32)
            hf["kappa_pred_reoptimized"][i] = kappa_o.numpy().astype(
                np.float32)
            hf["chi_squared"][i] = tf.squeeze(chi_squared).numpy().astype(
                np.float32)
            hf["chi_squared_reoptimized"][i] = tf.squeeze(best).numpy().astype(
                np.float32)
            hf["chi_squared_reoptimized_series"][i] = chi_sq_series.numpy(
            ).astype(np.float32)
            hf["sampled_chi_squared_reoptimized_series"][
                i] = 2 * sampled_chi_squared_series.numpy().astype(np.float32)
            hf["source_optim_mse"][i] = source_mse_best.numpy().astype(
                np.float32)
            hf["source_optim_mse_series"][i] = source_mse.numpy().astype(
                np.float32)
            hf["sampled_source_optim_mse_series"][
                i] = sampled_source_mse.numpy().astype(np.float32)
            hf["kappa_optim_mse"][i] = kappa_mse_best.numpy().astype(
                np.float32)
            hf["kappa_optim_mse_series"][i] = kappa_mse.numpy().astype(
                np.float32)
            hf["sampled_kappa_optim_mse_series"][i] = sampled_kappa_mse.numpy(
            ).astype(np.float32)
            hf["latent_source_gt_distance_init"][i] = tf.abs(
                z_source - z_source_gt).numpy().squeeze().astype(np.float32)
            hf["latent_kappa_gt_distance_init"][i] = tf.abs(
                z_kappa - z_kappa_gt).numpy().squeeze().astype(np.float32)
            hf["latent_source_gt_distance_end"][i] = tf.abs(
                z_source_opt - z_source_gt).numpy().squeeze().astype(
                    np.float32)
            hf["latent_kappa_gt_distance_end"][i] = tf.abs(
                z_kappa_opt - z_kappa_gt).numpy().squeeze().astype(np.float32)
            hf["observation_coherence_spectrum"][i] = _ps_observation
            hf["observation_coherence_spectrum_reoptimized"][
                i] = _ps_observation2
            hf["source_coherence_spectrum"][i] = _ps_source
            hf["source_coherence_spectrum_reoptimized"][i] = _ps_source2
            hf["kappa_coherence_spectrum"][i] = _ps_kappa
            hf["kappa_coherence_spectrum_reoptimized"][i] = _ps_kappa2

            if i == 0:
                _, f = np.histogram(np.fft.fftfreq(phys.pixels)[:phys.pixels //
                                                                2],
                                    bins=ps_observation.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["observation_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(
                    phys.src_pixels)[:phys.src_pixels // 2],
                                    bins=ps_source.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["source_frequencies"][:] = f
                _, f = np.histogram(np.fft.fftfreq(
                    phys.kappa_pixels)[:phys.kappa_pixels // 2],
                                    bins=ps_kappa.bins)
                f = (f[:-1] + f[1:]) / 2
                hf["kappa_frequencies"][:] = f
                hf["kappa_fov"][0] = phys.kappa_fov
                hf["source_fov"][0] = phys.src_fov
Ejemplo n.º 7
0
def main(args):
    if args.seed is not None:
        tf.random.set_seed(args.seed)
        np.random.seed(args.seed)
    if args.json_override is not None:
        if isinstance(args.json_override, list):
            files = args.json_override
        else:
            files = [
                args.json_override,
            ]
        for file in files:
            with open(file, "r") as f:
                json_override = json.load(f)
            args_dict = vars(args)
            args_dict.update(json_override)

    files = []
    for dataset in args.datasets:
        files.extend(glob.glob(os.path.join(dataset, "*.tfrecords")))
    np.random.shuffle(files)
    # Read concurrently from multiple records
    files = tf.data.Dataset.from_tensor_slices(files)
    dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
        x, compression_type=args.compression_type),
                               block_length=args.block_length,
                               num_parallel_calls=tf.data.AUTOTUNE)
    # Read off global parameters from first example in dataset
    for pixels in dataset.map(decode_shape):
        break
    vars(args).update({"pixels": int(pixels)})
    dataset = dataset.map(decode_image).map(preprocess_image).batch(
        args.batch_size)
    if args.cache_file is not None:
        dataset = dataset.cache(args.cache_file).prefetch(
            tf.data.experimental.AUTOTUNE)
    else:
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    train_dataset = dataset.take(
        math.floor(args.train_split * args.total_items /
                   args.batch_size))  # dont forget to divide by batch size!
    val_dataset = dataset.skip(
        math.floor(args.train_split * args.total_items / args.batch_size))
    val_dataset = val_dataset.take(
        math.ceil((1 - args.train_split) * args.total_items / args.batch_size))

    vae = VAE(pixels=pixels,
              layers=args.layers,
              conv_layers=args.conv_layers,
              filter_scaling=args.filter_scaling,
              filters=args.filters,
              kernel_size=args.kernel_size,
              kernel_reg_amp=args.kernel_reg_amp,
              bias_reg_amp=args.bias_reg_amp,
              activation=args.activation,
              dropout_rate=args.dropout_rate,
              batch_norm=args.batch_norm,
              latent_size=args.latent_size,
              output_activation=args.output_activation)
    learning_rate_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=args.initial_learning_rate,
        decay_rate=args.decay_rate,
        decay_steps=args.decay_steps,
        staircase=args.staircase)
    beta_schedule = PolynomialSchedule(initial_value=args.beta_init,
                                       end_value=args.beta_end_value,
                                       power=args.beta_decay_power,
                                       decay_steps=args.beta_decay_steps,
                                       cyclical=args.beta_cyclical)
    skip_strength_schedule = PolynomialSchedule(
        initial_value=args.skip_strength_init,
        end_value=0.,
        power=args.skip_strength_decay_power,
        decay_steps=args.skip_strength_decay_steps)
    l2_bottleneck_schedule = PolynomialSchedule(
        initial_value=args.l2_bottleneck_init,
        end_value=0.,
        power=args.l2_bottleneck_decay_power,
        decay_steps=args.l2_bottleneck_decay_steps)
    optim = tf.keras.optimizers.deserialize({
        "class_name": args.optimizer,
        'config': {
            "learning_rate": learning_rate_schedule
        }
    })
    # ==== Take care of where to write logs and stuff =================================================================
    if args.model_id.lower() != "none":
        logname = args.model_id
    elif args.logname is not None:
        logname = args.logname
    else:
        logname = args.logname_prefixe + "_" + datetime.now().strftime(
            "%y%m%d%H%M%S")
    if args.logdir.lower() != "none":
        logdir = os.path.join(args.logdir, logname)
        if not os.path.isdir(logdir):
            os.mkdir(logdir)
        writer = tf.summary.create_file_writer(logdir)
    else:
        writer = nullwriter()
    # ===== Make sure directory and checkpoint manager are created to save model ===================================
    if args.model_dir.lower() != "none":
        checkpoints_dir = os.path.join(args.model_dir, logname)
        if not os.path.isdir(checkpoints_dir):
            os.mkdir(checkpoints_dir)
            with open(os.path.join(checkpoints_dir, "script_params.json"),
                      "w") as f:
                json.dump(vars(args), f, indent=4)
            with open(os.path.join(checkpoints_dir, "model_hparams.json"),
                      "w") as f:
                hparams_dict = {key: vars(args)[key] for key in VAE_HPARAMS}
                json.dump(hparams_dict, f, indent=4)
        ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                                   optimizer=optim,
                                   net=vae)
        checkpoint_manager = tf.train.CheckpointManager(
            ckpt, checkpoints_dir, max_to_keep=args.max_to_keep)
        save_checkpoint = True
        # ======= Load model if model_id is provided ===============================================================
        if args.model_id.lower() != "none":
            if args.load_checkpoint == "lastest":
                checkpoint_manager.checkpoint.restore(
                    checkpoint_manager.latest_checkpoint)
            elif args.load_checkpoint == "best":
                scores = np.loadtxt(
                    os.path.join(checkpoints_dir, "score_sheet.txt"))
                _checkpoint = scores[np.argmin(scores[:, 1]), 0]
                checkpoint = checkpoint_manager.checkpoints[_checkpoint]
                checkpoint_manager.checkpoint.restore(checkpoint)
            else:
                checkpoint = checkpoint_manager.checkpoints[int(
                    args.load_checkpoint)]
                checkpoint_manager.checkpoint.restore(checkpoint)
    else:
        save_checkpoint = False

    def train_step(x, step):
        with tf.GradientTape() as tape:
            tape.watch(vae.trainable_weights)
            reconstruction_loss, kl_loss, bottleneck_l2_loss = vae.cost_function_training(
                x, skip_strength_schedule(step), l2_bottleneck_schedule(step))
            cost = tf.reduce_sum(reconstruction_loss +
                                 beta_schedule(step) * kl_loss +
                                 bottleneck_l2_loss) / args.batch_size
        gradients = tape.gradient(cost, vae.trainable_weights)
        if args.clipping:
            gradients = [tf.clip_by_value(grad, -10, 10) for grad in gradients]
        optim.apply_gradients(zip(gradients, vae.trainable_weights))
        reconstruction_loss = tf.reduce_mean(reconstruction_loss)
        kl_loss = tf.reduce_mean(kl_loss)
        return cost, reconstruction_loss, kl_loss

    def test_step(x, step):
        reconstruction_loss, kl_loss, bottleneck_l2_loss = vae.cost_function_training(
            x, skip_strength_schedule(step), l2_bottleneck_schedule(step))
        cost = tf.reduce_sum(reconstruction_loss + beta_schedule(step) *
                             kl_loss + bottleneck_l2_loss) / args.batch_size
        reconstruction_loss = tf.reduce_mean(reconstruction_loss)
        kl_loss = tf.reduce_mean(kl_loss)
        return cost, reconstruction_loss, kl_loss

    # ====== Training loop ============================================================================================
    epoch_loss = tf.metrics.Mean()
    epoch_reconstruction_loss = tf.metrics.Mean()
    epoch_kl_loss = tf.metrics.Mean()
    time_per_step = tf.metrics.Mean()
    val_loss = tf.metrics.Mean()
    val_reconstruction_loss = tf.metrics.Mean()
    val_kl_loss = tf.metrics.Mean()
    history = {  # recorded at the end of an epoch only
        "train_cost": [],
        "val_cost": [],
        "learning_rate": [],
        "time_per_step": [],
        "train_reconstruction_loss": [],
        "val_reconstruction_loss": [],
        "train_kl_loss": [],
        "val_kl_loss": [],
        "step": [],
        "wall_time": []
    }
    best_loss = np.inf
    patience = args.patience
    step = 0
    global_start = time.time()
    estimated_time_for_epoch = 0
    out_of_time = False
    lastest_checkpoint = 1
    for epoch in range(args.epochs):
        if (time.time() - global_start
            ) > args.max_time * 3600 - estimated_time_for_epoch:
            break
        epoch_start = time.time()
        epoch_loss.reset_states()
        epoch_reconstruction_loss.reset_states()
        epoch_kl_loss.reset_states()
        time_per_step.reset_states()
        with writer.as_default():
            for batch, x in enumerate(train_dataset):
                start = time.time()
                cost, reconstruction_loss, kl_loss = train_step(x, step=step)
                # ========== Summary and logs ==================================================================================
                _time = time.time() - start
                time_per_step.update_state([_time])
                epoch_loss.update_state([cost])
                epoch_reconstruction_loss.update_state([reconstruction_loss])
                epoch_kl_loss.update_state([kl_loss])
                step += 1
            # last batch we make a summary of residuals
            for res_idx in range(min(args.n_residuals, args.batch_size)):
                y_true = x[res_idx, ...]
                y_pred = vae.call(y_true[None, ...])[0, ...]
                tf.summary.image(f"Residuals {res_idx}",
                                 plot_to_image(residual_plot(y_true, y_pred)),
                                 step=step)
            # ========== Validation set ===================
            val_loss.reset_states()
            val_reconstruction_loss.reset_states()
            val_kl_loss.reset_states()
            for x in val_dataset:
                cost, reconstruction_loss, kl_loss = test_step(x, step=step)
                val_loss.update_state([cost])
                val_reconstruction_loss.update_state([reconstruction_loss])
                val_kl_loss.update_state([kl_loss])
            for res_idx in range(min(args.n_residuals, args.batch_size)):
                y_true = x[res_idx, ...]
                y_pred = vae.call(y_true[None, ...])[0, ...]
                tf.summary.image(f"Val Residuals {res_idx}",
                                 plot_to_image(residual_plot(y_true, y_pred)),
                                 step=step)

            val_cost = val_loss.result().numpy()
            train_cost = epoch_loss.result().numpy()
            train_reconstruction_cost = epoch_reconstruction_loss.result(
            ).numpy()
            val_reconstruction_cost = val_reconstruction_loss.result().numpy()
            train_kl_cost = epoch_kl_loss.result().numpy()
            val_kl_cost = val_kl_loss.result().numpy()
            tf.summary.scalar("KL", train_kl_cost, step=step)
            tf.summary.scalar("Val KL", val_kl_cost, step=step)
            tf.summary.scalar("Reconstruction loss",
                              train_reconstruction_cost,
                              step=step)
            tf.summary.scalar("Val reconstruction loss",
                              val_reconstruction_cost,
                              step=step)
            tf.summary.scalar("MSE", train_cost, step=step)
            tf.summary.scalar("Val MSE", val_cost, step=step)
            tf.summary.scalar("Learning Rate", optim.lr(step), step=step)
            tf.summary.scalar("beta", beta_schedule(step), step=step)
            tf.summary.scalar("l2 bottleneck",
                              l2_bottleneck_schedule(step),
                              step=step)
        print(
            f"epoch {epoch} | train loss {train_cost:.3e} | val loss {val_cost:.3e} "
            f"| learning rate {optim.lr(step).numpy():.2e} | time per step {time_per_step.result().numpy():.2e} s"
        )
        history["train_cost"].append(train_cost)
        history["val_cost"].append(val_cost)
        history["train_reconstruction_loss"].append(train_reconstruction_cost)
        history["val_reconstruction_loss"].append(val_reconstruction_cost)
        history["train_kl_loss"].append(train_kl_cost)
        history["val_kl_loss"].append(val_kl_cost)
        history["learning_rate"].append(optim.lr(step).numpy())
        history["time_per_step"].append(time_per_step.result().numpy())
        history["step"].append(step)
        history["wall_time"].append(time.time() - global_start)

        cost = train_cost if args.track_train else val_cost
        if np.isnan(cost):
            print("Training broke the Universe")
            break
        if cost < (1 - args.tolerance) * best_loss:
            best_loss = cost
            patience = args.patience
        else:
            patience -= 1
        if (time.time() - global_start) > args.max_time * 3600:
            out_of_time = True
        if save_checkpoint:
            checkpoint_manager.checkpoint.step.assign_add(1)  # a bit of a hack
            if epoch % args.checkpoints == 0 or patience == 0 or epoch == args.epochs - 1 or out_of_time:
                with open(os.path.join(checkpoints_dir, "score_sheet.txt"),
                          mode="a") as f:
                    np.savetxt(f, np.array([[lastest_checkpoint, cost]]))
                lastest_checkpoint += 1
                checkpoint_manager.save()
                print("Saved checkpoint for step {}: {}".format(
                    int(checkpoint_manager.checkpoint.step),
                    checkpoint_manager.latest_checkpoint))
        if patience == 0:
            print("Reached patience")
            break
        if out_of_time:
            break
        if epoch > 0:  # First epoch is always very slow and not a good estimate of an epoch time.
            estimated_time_for_epoch = time.time() - epoch_start
    print(
        f"Finished training after {(time.time() - global_start)/3600:.3f} hours."
    )
    return history, best_loss
Ejemplo n.º 8
0
 def __init__(self,
              observation,
              noise_rms,
              psf,
              phys: PhysicalModel,
              rim: RIM,
              source_vae: VAE,
              kappa_vae: VAE,
              n_samples=100,
              sigma_source=0.5,
              sigma_kappa=0.5):
     """
     Make a copy of initial parameters \varphi^{(0)} and compute the Fisher diagonal F_{ii}
     """
     wk = tf.keras.layers.Lambda(lambda k: tf.sqrt(k) / tf.reduce_sum(
         tf.sqrt(k), axis=(1, 2, 3), keepdims=True))
     # Baseline prediction from observation
     source_pred, kappa_pred, chi_squared = rim.predict(
         observation, noise_rms, psf)
     # Latent code of model predictions
     z_source, _ = source_vae.encoder(source_pred[-1])
     z_kappa, _ = kappa_vae.encoder(log_10(kappa_pred[-1]))
     # Deepcopy of the initial parameters
     self.initial_params = [
         deepcopy(w) for w in rim.unet.trainable_variables
     ]
     self.fisher_diagonal = [tf.zeros_like(w) for w in self.initial_params]
     for n in range(n_samples):
         # Sample latent code around the prediction mean
         z_s = tf.random.normal(shape=[1, source_vae.latent_size],
                                mean=z_source,
                                stddev=sigma_source)
         z_k = tf.random.normal(shape=[1, kappa_vae.latent_size],
                                mean=z_kappa,
                                stddev=sigma_kappa)
         # Decode
         sampled_source = tf.nn.relu(source_vae.decode(z_s))
         sampled_source /= tf.reduce_max(sampled_source,
                                         axis=(1, 2, 3),
                                         keepdims=True)
         sampled_kappa = kappa_vae.decode(z_k)  # output in log_10 space
         # Simulate observation
         sampled_observation = phys.noisy_forward(sampled_source,
                                                  10**sampled_kappa,
                                                  noise_rms, psf)
         # Compute the gradient of the MSE
         with tf.GradientTape() as tape:
             tape.watch(rim.unet.trainable_variables)
             s, k, chi_squared = rim.call(sampled_observation, noise_rms,
                                          psf)
             # Remove the temperature from the loss when computing the Fisher: sum instead of mean, and weighted sum is renormalized by number of pixels
             _kappa_mse = phys.kappa_pixels**2 * tf.reduce_sum(
                 wk(10**sampled_kappa) * (k - sampled_kappa)**2,
                 axis=(2, 3, 4))
             cost = tf.reduce_sum(_kappa_mse)
             cost += tf.reduce_sum((s - sampled_source)**2)
         grad = tape.gradient(cost, rim.unet.trainable_variables)
         # Square the derivative relative to initial parameters and add to total
         self.fisher_diagonal = [
             F + g**2 / n_samples
             for F, g in zip(self.fisher_diagonal, grad)
         ]