예제 #1
0
def train_generative_adversarial_manifold_flow(args, dataset, model,
                                               simulator):
    """ MFMF-OT training """

    gen_trainer = AdversarialTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalAdversarialTrainer(model)
    common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs(
        args, dataset)
    common_kwargs["batch_size"] = args.genbatchsize

    logger.info("Starting training GAMF: Sinkhorn-GAN")

    callbacks_ = [
        callbacks.save_model_after_every_epoch(
            create_filename("checkpoint", None, args))
    ]
    if args.debug:
        callbacks_.append(callbacks.print_mf_weight_statistics())

    learning_curves_ = gen_trainer.train(
        loss_functions=[losses.make_sinkhorn_divergence()],
        loss_labels=["GED"],
        loss_weights=[args.sinkhornfactor],
        epochs=args.epochs,
        callbacks=callbacks_,
        compute_loss_variance=True,
        initial_epoch=args.startepoch,
        **common_kwargs,
    )

    learning_curves = np.vstack(learning_curves_).T
    return learning_curves
예제 #2
0
def train_generative_adversarial_manifold_flow_alternating(
        args, dataset, model, simulator):
    """ MFMF-OTA training """

    assert not args.specified

    gen_trainer = GenerativeTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalGenerativeTrainer(model)
    likelihood_trainer = ManifoldFlowTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalManifoldFlowTrainer(model)
    metatrainer = AlternatingTrainer(model, gen_trainer, likelihood_trainer)

    meta_kwargs = {
        "dataset": dataset,
        "initial_lr": args.lr,
        "scheduler": optim.lr_scheduler.CosineAnnealingLR,
        "validation_split": args.validationsplit
    }
    if args.weightdecay is not None:
        meta_kwargs["optimizer_kwargs"] = {
            "weight_decay": float(args.weightdecay)
        }

    phase1_kwargs = {"clip_gradient": args.clip}
    phase2_kwargs = {
        "forward_kwargs": {
            "mode": "mf-fixed-manifold"
        },
        "clip_gradient": args.clip
    }

    phase1_parameters = model.parameters()
    phase2_parameters = model.inner_transform.parameters()

    logger.info(
        "Starting training GAMF, alternating between Sinkhorn divergence and log likelihood"
    )
    learning_curves_ = metatrainer.train(
        loss_functions=[losses.make_sinkhorn_divergence(), losses.nll],
        loss_function_trainers=[0, 1],
        loss_labels=["GED", "NLL"],
        loss_weights=[args.sinkhornfactor, args.nllfactor],
        batch_sizes=[args.genbatchsize, args.batchsize],
        epochs=args.epochs // 2,
        parameters=[phase1_parameters, phase2_parameters],
        callbacks=[
            callbacks.save_model_after_every_epoch(
                create_filename("checkpoint", None, args)[:-3] +
                "_epoch_{}.pt")
        ],
        trainer_kwargs=[phase1_kwargs, phase2_kwargs],
        subsets=args.subsets,
        subset_callbacks=[callbacks.print_mf_weight_statistics()]
        if args.debug else None,
        **meta_kwargs,
    )
    learning_curves = np.vstack(learning_curves_).T

    return learning_curves
예제 #3
0
def train_generative_adversarial_manifold_flow(args, dataset, model,
                                               simulator):
    """ MFMF-OT training """

    gen_trainer = GenerativeTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalGenerativeTrainer(model)
    common_kwargs = {
        "dataset": dataset,
        "initial_lr": args.lr,
        "scheduler": optim.lr_scheduler.CosineAnnealingLR,
        "clip_gradient": args.clip,
        "validation_split": args.validationsplit,
    }
    if args.weightdecay is not None:
        common_kwargs["optimizer_kwargs"] = {
            "weight_decay": float(args.weightdecay)
        }

    logger.info("Starting training GAMF: Sinkhorn-GAN")

    callbacks_ = [
        callbacks.save_model_after_every_epoch(
            create_filename("checkpoint", None, args)[:-3] + "_epoch_{}.pt")
    ]
    if args.debug:
        callbacks_.append(callbacks.print_mf_weight_statistics())

    learning_curves_ = gen_trainer.train(
        loss_functions=[losses.make_sinkhorn_divergence()],
        loss_labels=["GED"],
        loss_weights=[args.sinkhornfactor],
        epochs=args.epochs,
        callbacks=callbacks_,
        batch_size=args.genbatchsize,
        compute_loss_variance=True,
        **common_kwargs,
    )

    learning_curves = np.vstack(learning_curves_).T
    return learning_curves