Exemple #1
0
def train_specified_manifold_flow(args, dataset, model, simulator):
    """ FOM training """

    trainer = ForwardTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalForwardTrainer(
        model) if args.scandal is None else SCANDALForwardTrainer(model)
    common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs(
        args, dataset)

    logger.info("Starting training MF with specified manifold on NLL")
    learning_curves = trainer.train(
        loss_functions=[losses.mse, losses.nll] + scandal_loss,
        loss_labels=["MSE", "NLL"] + scandal_label,
        loss_weights=[
            0.0, args.nllfactor * nat_to_bit_per_dim(args.modellatentdim)
        ] + scandal_weight,
        epochs=args.epochs,
        callbacks=[
            callbacks.save_model_after_every_epoch(
                create_filename("checkpoint", None, args))
        ],
        forward_kwargs={"mode": "mf"},
        initial_epoch=args.startepoch,
        **common_kwargs,
    )
    learning_curves = np.vstack(learning_curves).T

    return learning_curves
Exemple #2
0
def train_pie(args, dataset, model, simulator):
    """ PIE training """
    trainer = ForwardTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalForwardTrainer(
        model) if args.scandal is None else SCANDALForwardTrainer(model)
    common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs(
        args, dataset)
    callbacks_ = [
        callbacks.save_model_after_every_epoch(
            create_filename("checkpoint", None, args))
    ]
    if simulator.is_image():
        callbacks_.append(
            callbacks.plot_sample_images(
                create_filename("training_plot", None, args),
                context=None if simulator.parameter_dim() is None else
                torch.zeros(30, simulator.parameter_dim())))
        callbacks_.append(
            callbacks.plot_reco_images(
                create_filename("training_plot", "reco_epoch", args)))

    logger.info("Starting training PIE on NLL")
    learning_curves = trainer.train(
        loss_functions=[losses.nll] + scandal_loss,
        loss_labels=["NLL"] + scandal_label,
        loss_weights=[args.nllfactor * nat_to_bit_per_dim(args.datadim)] +
        scandal_weight,
        epochs=args.epochs,
        callbacks=callbacks_,
        forward_kwargs={"mode": "pie"},
        initial_epoch=args.startepoch,
        **common_kwargs,
    )
    learning_curves = np.vstack(learning_curves).T
    return learning_curves
Exemple #3
0
    def objective(trial):
        global counter

        counter += 1

        # Hyperparameters
        margs = pick_parameters(args, trial, counter)

        logger.info(f"Starting run {counter} / {args.trials}")
        logger.info(f"Hyperparams:")
        logger.info(f"  outer layers:      {margs.outerlayers}")
        logger.info(f"  inner layers:      {margs.innerlayers}")
        logger.info(f"  linear transform:  {margs.lineartransform}")
        logger.info(f"  spline range:      {margs.splinerange}")
        logger.info(f"  spline bins:       {margs.splinebins}")
        logger.info(f"  batchnorm:         {margs.batchnorm}")
        logger.info(f"  dropout:           {margs.dropout}")
        logger.info(f"  batch size:        {margs.batchsize}")
        logger.info(f"  MSE factor:        {margs.msefactor}")
        logger.info(f"  latent L2 reg:     {margs.uvl2reg}")
        logger.info(f"  weight decay:      {margs.weightdecay}")
        logger.info(f"  gradient clipping: {margs.clip}")

        # Bug fix related to some num_workers > 1 and CUDA. Bad things happen otherwise!
        torch.multiprocessing.set_start_method("spawn", force=True)

        # Load data
        simulator = load_simulator(margs)
        dataset = simulator.load_dataset(train=True,
                                         dataset_dir=create_filename(
                                             "dataset", None, args),
                                         limit_samplesize=margs.samplesize)

        # Create model
        model = create_model(margs, simulator)

        # Train
        trainer1 = ForwardTrainer(model) if simulator.parameter_dim(
        ) is None else ConditionalForwardTrainer(model)
        trainer2 = ForwardTrainer(model) if simulator.parameter_dim(
        ) is None else ConditionalForwardTrainer(model)
        common_kwargs, _, _, _ = train.make_training_kwargs(margs, dataset)

        logger.info("Starting training MF, phase 1: manifold training")
        np.random.seed(123)
        _, val_losses = trainer1.train(
            loss_functions=[losses.mse, losses.hiddenl2reg],
            loss_labels=["MSE", "L2_lat"],
            loss_weights=[
                margs.msefactor,
                0.0 if margs.uvl2reg is None else margs.uvl2reg
            ],
            epochs=margs.epochs,
            parameters=(list(model.outer_transform.parameters()) +
                        list(model.encoder.parameters()) if args.algorithm
                        == "emf" else model.outer_transform.parameters()),
            forward_kwargs={
                "mode": "projection",
                "return_hidden": True
            },
            **common_kwargs,
        )

        logger.info("Starting training MF, phase 2: density training")
        np.random.seed(123)
        _ = trainer2.train(
            loss_functions=[losses.nll],
            loss_labels=["NLL"],
            loss_weights=[args.nllfactor],
            epochs=args.densityepochs,
            parameters=model.inner_transform.parameters(),
            forward_kwargs={"mode": "mf-fixed-manifold"},
            **common_kwargs,
        )

        # Save
        torch.save(model.state_dict(), create_filename("model", None, margs))

        # Evaluate reco error
        logger.info("Evaluating reco error")
        model.eval()
        np.random.seed(123)
        x, params = next(
            iter(
                trainer1.make_dataloader(
                    simulator.load_dataset(train=True,
                                           dataset_dir=create_filename(
                                               "dataset", None, args),
                                           limit_samplesize=args.samplesize),
                    args.validationsplit, 1000, 0)[1]))
        x = x.to(device=trainer1.device, dtype=trainer1.dtype)
        params = None if simulator.parameter_dim() is None else params.to(
            device=trainer1.device, dtype=trainer1.dtype)
        x_reco, _, _ = model(x, context=params, mode="projection")
        reco_error = torch.mean(torch.sum((x - x_reco)**2,
                                          dim=1)**0.5).detach().cpu().numpy()

        # Generate samples
        logger.info("Evaluating sample closure")
        x_gen = evaluate.sample_from_model(margs, model, simulator)
        distances_gen = simulator.distance_from_manifold(x_gen)
        mean_gen_distance = np.mean(distances_gen)

        # Report results
        logger.info("Results:")
        logger.info("  reco err:     %s", reco_error)
        logger.info("  gen distance: %s", mean_gen_distance)

        return margs.metricrecoerrorfactor * reco_error + margs.metricdistancefactor * mean_gen_distance
Exemple #4
0
def train_generative_adversarial_manifold_flow_alternating(
        args, dataset, model, simulator):
    """ MFMF-OTA training """

    assert not args.specified

    gen_trainer = AdversarialTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalAdversarialTrainer(model)
    likelihood_trainer = ForwardTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalForwardTrainer(
        model) if args.scandal is None else SCANDALForwardTrainer(model)
    metatrainer = AlternatingTrainer(model, gen_trainer, likelihood_trainer)

    meta_kwargs = {
        "dataset": dataset,
        "initial_lr": args.lr,
        "scheduler": optim.lr_scheduler.CosineAnnealingLR,
        "validation_split": args.validationsplit
    }
    if args.weightdecay is not None:
        meta_kwargs["optimizer_kwargs"] = {
            "weight_decay": float(args.weightdecay)
        }
    _, scandal_loss, scandal_label, scandal_weight = make_training_kwargs(
        args, dataset)

    phase1_kwargs = {"clip_gradient": args.clip}
    phase2_kwargs = {
        "forward_kwargs": {
            "mode": "mf-fixed-manifold"
        },
        "clip_gradient": args.clip
    }

    phase1_parameters = list(model.parameters())
    phase2_parameters = list(model.inner_transform.parameters())

    logger.info(
        "Starting training GAMF, alternating between Sinkhorn divergence and log likelihood"
    )
    learning_curves_ = metatrainer.train(
        loss_functions=[losses.make_sinkhorn_divergence(), losses.nll] +
        scandal_loss,
        loss_function_trainers=[0, 1] +
        [1] if args.scandal is not None else [],
        loss_labels=["GED", "NLL"] + scandal_label,
        loss_weights=[
            args.sinkhornfactor,
            args.nllfactor * nat_to_bit_per_dim(args.modellatentdim)
        ] + scandal_weight,
        batch_sizes=[args.genbatchsize, args.batchsize],
        epochs=args.epochs // 2,
        parameters=[phase1_parameters, phase2_parameters],
        callbacks=[
            callbacks.save_model_after_every_epoch(
                create_filename("checkpoint", None, args))
        ],
        trainer_kwargs=[phase1_kwargs, phase2_kwargs],
        subsets=args.subsets,
        subset_callbacks=[callbacks.print_mf_weight_statistics()]
        if args.debug else None,
        **meta_kwargs,
    )
    learning_curves = np.vstack(learning_curves_).T

    return learning_curves
Exemple #5
0
def train_manifold_flow_sequential(args, dataset, model, simulator):
    """ Sequential MFMF-M/D training """

    assert not args.specified

    if simulator.parameter_dim() is None:
        trainer1 = ForwardTrainer(model)
        trainer2 = ForwardTrainer(model)
    else:
        trainer1 = ConditionalForwardTrainer(model)
        if args.scandal is None:
            trainer2 = ConditionalForwardTrainer(model)
        else:
            trainer2 = SCANDALForwardTrainer(model)

    common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs(
        args, dataset)

    callbacks1 = [
        callbacks.save_model_after_every_epoch(
            create_filename("checkpoint", "A", args)),
        callbacks.print_mf_latent_statistics(),
        callbacks.print_mf_weight_statistics()
    ]
    callbacks2 = [
        callbacks.save_model_after_every_epoch(
            create_filename("checkpoint", "B", args)),
        callbacks.print_mf_latent_statistics(),
        callbacks.print_mf_weight_statistics()
    ]
    if simulator.is_image():
        callbacks1.append(
            callbacks.plot_sample_images(
                create_filename("training_plot", "sample_epoch_A", args),
                context=None if simulator.parameter_dim() is None else
                torch.zeros(30, simulator.parameter_dim()),
            ))
        callbacks2.append(
            callbacks.plot_sample_images(
                create_filename("training_plot", "sample_epoch_B", args),
                context=None if simulator.parameter_dim() is None else
                torch.zeros(30, simulator.parameter_dim()),
            ))
        callbacks1.append(
            callbacks.plot_reco_images(
                create_filename("training_plot", "reco_epoch_A", args)))
        callbacks2.append(
            callbacks.plot_reco_images(
                create_filename("training_plot", "reco_epoch_B", args)))

    logger.info("Starting training MF, phase 1: manifold training")
    learning_curves = trainer1.train(
        loss_functions=[losses.smooth_l1_loss if args.l1 else losses.mse] +
        ([] if args.uvl2reg is None else [losses.hiddenl2reg]),
        loss_labels=["L1" if args.l1 else "MSE"] +
        ([] if args.uvl2reg is None else ["L2_lat"]),
        loss_weights=[args.msefactor] +
        ([] if args.uvl2reg is None else [args.uvl2reg]),
        epochs=args.epochs // 2,
        parameters=list(model.outer_transform.parameters()) +
        list(model.encoder.parameters()) if args.algorithm == "emf" else list(
            model.outer_transform.parameters()),
        callbacks=callbacks1,
        forward_kwargs={
            "mode": "projection",
            "return_hidden": args.uvl2reg is not None
        },
        initial_epoch=args.startepoch,
        **common_kwargs,
    )
    learning_curves = np.vstack(learning_curves).T

    logger.info("Starting training MF, phase 2: density training")
    learning_curves_ = trainer2.train(
        loss_functions=[losses.nll] + scandal_loss,
        loss_labels=["NLL"] + scandal_label,
        loss_weights=[
            args.nllfactor * nat_to_bit_per_dim(args.modellatentdim)
        ] + scandal_weight,
        epochs=args.epochs - (args.epochs // 2),
        parameters=list(model.inner_transform.parameters()),
        callbacks=callbacks2,
        forward_kwargs={"mode": "mf-fixed-manifold"},
        initial_epoch=args.startepoch - args.epochs // 2,
        **common_kwargs,
    )
    learning_curves = np.vstack(
        (learning_curves, np.vstack(learning_curves_).T))

    return learning_curves
Exemple #6
0
def train_manifold_flow_alternating(args, dataset, model, simulator):
    """ MFMF-A training """

    assert not args.specified

    trainer1 = ForwardTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalForwardTrainer(model)
    trainer2 = ForwardTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalForwardTrainer(
        model) if args.scandal is None else SCANDALForwardTrainer(model)
    metatrainer = AlternatingTrainer(model, trainer1, trainer2)

    meta_kwargs = {
        "dataset": dataset,
        "initial_lr": args.lr,
        "scheduler": optim.lr_scheduler.CosineAnnealingLR,
        "validation_split": args.validationsplit
    }
    if args.weightdecay is not None:
        meta_kwargs["optimizer_kwargs"] = {
            "weight_decay": float(args.weightdecay)
        }
    _, scandal_loss, scandal_label, scandal_weight = make_training_kwargs(
        args, dataset)

    phase1_kwargs = {
        "forward_kwargs": {
            "mode": "projection"
        },
        "clip_gradient": args.clip
    }
    phase2_kwargs = {
        "forward_kwargs": {
            "mode": "mf-fixed-manifold"
        },
        "clip_gradient": args.clip
    }

    phase1_parameters = list(model.outer_transform.parameters()) + list(
        model.encoder.parameters(
        )) if args.algorithm == "emf" else model.outer_transform.parameters()
    phase2_parameters = list(model.inner_transform.parameters())

    logger.info(
        "Starting training MF, alternating between reconstruction error and log likelihood"
    )
    learning_curves_ = metatrainer.train(
        loss_functions=[
            losses.smooth_l1_loss if args.l1 else losses.mse, losses.nll
        ] + scandal_loss,
        loss_function_trainers=[0, 1] +
        [1] if args.scandal is not None else [],
        loss_labels=["L1" if args.l1 else "MSE", "NLL"] + scandal_label,
        loss_weights=[
            args.msefactor,
            args.nllfactor * nat_to_bit_per_dim(args.modellatentdim)
        ] + scandal_weight,
        epochs=args.epochs // 2,
        subsets=args.subsets,
        batch_sizes=[args.batchsize, args.batchsize],
        parameters=[phase1_parameters, phase2_parameters],
        callbacks=[
            callbacks.save_model_after_every_epoch(
                create_filename("checkpoint", None, args))
        ],
        trainer_kwargs=[phase1_kwargs, phase2_kwargs],
        **meta_kwargs,
    )
    learning_curves = np.vstack(learning_curves_).T

    return learning_curves
Exemple #7
0
def train_manifold_flow(args, dataset, model, simulator):
    """ MFMF-S training """

    assert not args.specified

    trainer = ForwardTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalForwardTrainer(
        model) if args.scandal is None else SCANDALForwardTrainer(model)
    common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs(
        args, dataset)

    logger.info(
        "Starting training MF, phase 1: pretraining on reconstruction error")
    learning_curves = trainer.train(
        loss_functions=[losses.mse],
        loss_labels=["MSE"],
        loss_weights=[args.msefactor],
        epochs=args.epochs // 3,
        callbacks=[
            callbacks.save_model_after_every_epoch(
                create_filename("checkpoint", "A", args))
        ],
        forward_kwargs={"mode": "projection"},
        initial_epoch=args.startepoch,
        **common_kwargs,
    )
    learning_curves = np.vstack(learning_curves).T

    logger.info("Starting training MF, phase 2: mixed training")
    learning_curves_ = trainer.train(
        loss_functions=[losses.mse, losses.nll] + scandal_loss,
        loss_labels=["MSE", "NLL"] + scandal_label,
        loss_weights=[
            args.msefactor,
            args.addnllfactor * nat_to_bit_per_dim(args.modellatentdim)
        ] + scandal_weight,
        epochs=args.epochs - 2 * (args.epochs // 3),
        parameters=list(model.parameters()),
        callbacks=[
            callbacks.save_model_after_every_epoch(
                create_filename("checkpoint", "B", args))
        ],
        forward_kwargs={"mode": "mf"},
        initial_epoch=args.startepoch - (args.epochs // 3),
        **common_kwargs,
    )
    learning_curves_ = np.vstack(learning_curves_).T
    learning_curves = learning_curves_ if learning_curves is None else np.vstack(
        (learning_curves, learning_curves_))

    logger.info(
        "Starting training MF, phase 3: training only inner flow on NLL")
    learning_curves_ = trainer.train(
        loss_functions=[losses.mse, losses.nll] + scandal_loss,
        loss_labels=["MSE", "NLL"] + scandal_label,
        loss_weights=[
            0.0, args.nllfactor * nat_to_bit_per_dim(args.modellatentdim)
        ] + scandal_weight,
        epochs=args.epochs // 3,
        parameters=list(model.inner_transform.parameters()),
        callbacks=[
            callbacks.save_model_after_every_epoch(
                create_filename("checkpoint", "C", args))
        ],
        forward_kwargs={"mode": "mf-fixed-manifold"},
        initial_epoch=args.startepoch - (args.epochs - (args.epochs // 3)),
        **common_kwargs,
    )
    learning_curves_ = np.vstack(learning_curves_).T
    learning_curves = np.vstack(
        (learning_curves, np.vstack(learning_curves_).T))

    return learning_curves
    def objective(trial):
        global counter

        counter += 1

        # Hyperparameters
        margs = pick_parameters(args, trial, counter)

        logger.info(f"Starting run {counter} / {args.trials}")
        logger.info(f"Hyperparams:")
        logger.info(f"  outer layers:      {margs.outerlayers}")
        logger.info(f"  linlayers:         {margs.linlayers}")
        logger.info(f"  linchannelfactor:  {margs.linchannelfactor}")
        logger.info(f"  inner layers:      {margs.innerlayers}")
        logger.info(f"  linear transform:  {margs.lineartransform}")
        logger.info(f"  spline range:      {margs.splinerange}")
        logger.info(f"  spline bins:       {margs.splinebins}")
        logger.info(f"  batchnorm:         {margs.batchnorm}")
        logger.info(f"  actnorm:           {margs.actnorm}")
        logger.info(f"  dropout:           {margs.dropout}")
        logger.info(f"  batch size:        {margs.batchsize}")
        logger.info(f"  MSE factor:        {margs.msefactor}")
        logger.info(f"  latent L2 reg:     {margs.uvl2reg}")
        logger.info(f"  weight decay:      {margs.weightdecay}")
        logger.info(f"  gradient clipping: {margs.clip}")

        # Bug fix related to some num_workers > 1 and CUDA. Bad things happen otherwise!
        torch.multiprocessing.set_start_method("spawn", force=True)

        # Load data
        simulator = load_simulator(margs)
        dataset = simulator.load_dataset(train=True,
                                         dataset_dir=create_filename(
                                             "dataset", None, margs),
                                         limit_samplesize=margs.samplesize)

        # Create model
        model = create_model(margs, simulator)

        # Train
        try:
            trainer = ForwardTrainer(model) if simulator.parameter_dim(
            ) is None else ConditionalForwardTrainer(model)
            common_kwargs, _, _, _ = train.make_training_kwargs(margs, dataset)

            logger.info("Starting training MF: manifold training")
            np.random.seed(123)
            _, val_losses = trainer.train(
                loss_functions=[losses.mse, losses.hiddenl2reg],
                loss_labels=["MSE", "L2_lat"],
                loss_weights=[
                    margs.msefactor,
                    0.0 if margs.uvl2reg is None else margs.uvl2reg
                ],
                epochs=margs.epochs,
                parameters=(list(model.outer_transform.parameters()) +
                            list(model.encoder.parameters()) if args.algorithm
                            == "emf" else model.outer_transform.parameters()),
                forward_kwargs={
                    "mode": "projection",
                    "return_hidden": True
                },
                **common_kwargs,
            )

            # Save
            torch.save(model.state_dict(),
                       create_filename("model", None, margs))

            # Evaluate reco error
            logger.info("Evaluating reco error")
            model.eval()
            torch.cuda.empty_cache()
            np.random.seed(123)
            dataloader = trainer.make_dataloader(
                simulator.load_dataset(train=True,
                                       dataset_dir=create_filename(
                                           "dataset", None, margs),
                                       limit_samplesize=margs.samplesize),
                args.validationsplit, 20, 4)[1]
            reco_errors = []
            x_plot, x_reco_plot = None, None
            for x, params in dataloader:
                x = x.to(device=trainer.device, dtype=trainer.dtype)
                params = None if simulator.parameter_dim(
                ) is None else params.to(device=trainer.device,
                                         dtype=trainer.dtype)
                x_reco, _, _ = model(x, context=params, mode="projection")
                reco_errors.append((torch.sum(
                    (x - x_reco)**2, dim=1)**0.5).detach().cpu().numpy())
                if x_plot is None:
                    x_plot = x.detach().cpu().numpy()
                    x_reco_plot = x_reco.detach().cpu().numpy()
            reco_error = np.mean(reco_errors)

            if not np.isfinite(reco_error):
                raise RuntimeError()

            # Report results
            logger.info("Results:")
            logger.info("  reco err:     %s", reco_error)

            # Plot reco error
            x = np.clip(np.transpose(x_plot, [0, 2, 3, 1]) / 256.0, 0.0, 1.0)
            x_reco = np.clip(
                np.transpose(x_reco_plot, [0, 2, 3, 1]) / 256.0, 0.0, 1.0)
            plt.figure(figsize=(6 * 3.0, 5 * 3.0))
            for i in range(15):
                plt.subplot(5, 6, 2 * i + 1)
                plt.imshow(x[i])
                plt.gca().get_xaxis().set_visible(False)
                plt.gca().get_yaxis().set_visible(False)
                plt.subplot(5, 6, 2 * i + 2)
                plt.imshow(x_reco[i])
                plt.gca().get_xaxis().set_visible(False)
                plt.gca().get_yaxis().set_visible(False)
            plt.tight_layout()
            filename = create_filename("training_plot", "reco", margs)
            plt.savefig(filename.format(""))
        except RuntimeError as e:
            logger.info("Error during training, returning 1e9\n  %s", e)
            return 1e9

        return reco_error