Example #1
0
def train_pie(args, dataset, model, simulator):
    """ PIE training """
    trainer = ForwardTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalForwardTrainer(
        model) if args.scandal is None else SCANDALForwardTrainer(model)
    common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs(
        args, dataset)
    callbacks_ = [
        callbacks.save_model_after_every_epoch(
            create_filename("checkpoint", None, args))
    ]
    if simulator.is_image():
        callbacks_.append(
            callbacks.plot_sample_images(
                create_filename("training_plot", None, args),
                context=None if simulator.parameter_dim() is None else
                torch.zeros(30, simulator.parameter_dim())))
        callbacks_.append(
            callbacks.plot_reco_images(
                create_filename("training_plot", "reco_epoch", args)))

    logger.info("Starting training PIE on NLL")
    learning_curves = trainer.train(
        loss_functions=[losses.nll] + scandal_loss,
        loss_labels=["NLL"] + scandal_label,
        loss_weights=[args.nllfactor * nat_to_bit_per_dim(args.datadim)] +
        scandal_weight,
        epochs=args.epochs,
        callbacks=callbacks_,
        forward_kwargs={"mode": "pie"},
        initial_epoch=args.startepoch,
        **common_kwargs,
    )
    learning_curves = np.vstack(learning_curves).T
    return learning_curves
Example #2
0
def train_flow(args, dataset, model, simulator):
    """ AF training """

    trainer = ManifoldFlowTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalManifoldFlowTrainer(model)
    logger.info("Starting training standard flow on NLL")

    common_kwargs = {
        "dataset": dataset,
        "batch_size": args.batchsize,
        "initial_lr": args.lr,
        "scheduler": optim.lr_scheduler.CosineAnnealingLR,
        "clip_gradient": args.clip,
        "validation_split": args.validationsplit,
    }
    if args.weightdecay is not None:
        common_kwargs["optimizer_kwargs"] = {
            "weight_decay": float(args.weightdecay)
        }
    callbacks_ = [
        callbacks.save_model_after_every_epoch(
            create_filename("checkpoint", None, args)[:-3] + "_epoch_{}.pt")
    ]
    if simulator.is_image():
        callbacks_.append(
            callbacks.plot_sample_images(
                create_filename("training_plot", None, args)))

    learning_curves = trainer.train(loss_functions=[losses.nll],
                                    loss_labels=["NLL"],
                                    loss_weights=[args.nllfactor],
                                    epochs=args.epochs,
                                    callbacks=callbacks_,
                                    **common_kwargs)

    learning_curves = np.vstack(learning_curves).T
    return learning_curves
Example #3
0
def train_manifold_flow_sequential(args, dataset, model, simulator):
    """ Sequential MFMF-M/D training """

    assert not args.specified

    if simulator.parameter_dim() is None:
        trainer1 = ForwardTrainer(model)
        trainer2 = ForwardTrainer(model)
    else:
        trainer1 = ConditionalForwardTrainer(model)
        if args.scandal is None:
            trainer2 = ConditionalForwardTrainer(model)
        else:
            trainer2 = SCANDALForwardTrainer(model)

    common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs(
        args, dataset)

    callbacks1 = [
        callbacks.save_model_after_every_epoch(
            create_filename("checkpoint", "A", args)),
        callbacks.print_mf_latent_statistics(),
        callbacks.print_mf_weight_statistics()
    ]
    callbacks2 = [
        callbacks.save_model_after_every_epoch(
            create_filename("checkpoint", "B", args)),
        callbacks.print_mf_latent_statistics(),
        callbacks.print_mf_weight_statistics()
    ]
    if simulator.is_image():
        callbacks1.append(
            callbacks.plot_sample_images(
                create_filename("training_plot", "sample_epoch_A", args),
                context=None if simulator.parameter_dim() is None else
                torch.zeros(30, simulator.parameter_dim()),
            ))
        callbacks2.append(
            callbacks.plot_sample_images(
                create_filename("training_plot", "sample_epoch_B", args),
                context=None if simulator.parameter_dim() is None else
                torch.zeros(30, simulator.parameter_dim()),
            ))
        callbacks1.append(
            callbacks.plot_reco_images(
                create_filename("training_plot", "reco_epoch_A", args)))
        callbacks2.append(
            callbacks.plot_reco_images(
                create_filename("training_plot", "reco_epoch_B", args)))

    logger.info("Starting training MF, phase 1: manifold training")
    learning_curves = trainer1.train(
        loss_functions=[losses.smooth_l1_loss if args.l1 else losses.mse] +
        ([] if args.uvl2reg is None else [losses.hiddenl2reg]),
        loss_labels=["L1" if args.l1 else "MSE"] +
        ([] if args.uvl2reg is None else ["L2_lat"]),
        loss_weights=[args.msefactor] +
        ([] if args.uvl2reg is None else [args.uvl2reg]),
        epochs=args.epochs // 2,
        parameters=list(model.outer_transform.parameters()) +
        list(model.encoder.parameters()) if args.algorithm == "emf" else list(
            model.outer_transform.parameters()),
        callbacks=callbacks1,
        forward_kwargs={
            "mode": "projection",
            "return_hidden": args.uvl2reg is not None
        },
        initial_epoch=args.startepoch,
        **common_kwargs,
    )
    learning_curves = np.vstack(learning_curves).T

    logger.info("Starting training MF, phase 2: density training")
    learning_curves_ = trainer2.train(
        loss_functions=[losses.nll] + scandal_loss,
        loss_labels=["NLL"] + scandal_label,
        loss_weights=[
            args.nllfactor * nat_to_bit_per_dim(args.modellatentdim)
        ] + scandal_weight,
        epochs=args.epochs - (args.epochs // 2),
        parameters=list(model.inner_transform.parameters()),
        callbacks=callbacks2,
        forward_kwargs={"mode": "mf-fixed-manifold"},
        initial_epoch=args.startepoch - args.epochs // 2,
        **common_kwargs,
    )
    learning_curves = np.vstack(
        (learning_curves, np.vstack(learning_curves_).T))

    return learning_curves