Ejemplo n.º 1
0
def train_generative_adversarial_manifold_flow_alternating(
        args, dataset, model, simulator):
    """ MFMF-OTA training """

    assert not args.specified

    gen_trainer = GenerativeTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalGenerativeTrainer(model)
    likelihood_trainer = ManifoldFlowTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalManifoldFlowTrainer(model)
    metatrainer = AlternatingTrainer(model, gen_trainer, likelihood_trainer)

    meta_kwargs = {
        "dataset": dataset,
        "initial_lr": args.lr,
        "scheduler": optim.lr_scheduler.CosineAnnealingLR,
        "validation_split": args.validationsplit
    }
    if args.weightdecay is not None:
        meta_kwargs["optimizer_kwargs"] = {
            "weight_decay": float(args.weightdecay)
        }

    phase1_kwargs = {"clip_gradient": args.clip}
    phase2_kwargs = {
        "forward_kwargs": {
            "mode": "mf-fixed-manifold"
        },
        "clip_gradient": args.clip
    }

    phase1_parameters = model.parameters()
    phase2_parameters = model.inner_transform.parameters()

    logger.info(
        "Starting training GAMF, alternating between Sinkhorn divergence and log likelihood"
    )
    learning_curves_ = metatrainer.train(
        loss_functions=[losses.make_sinkhorn_divergence(), losses.nll],
        loss_function_trainers=[0, 1],
        loss_labels=["GED", "NLL"],
        loss_weights=[args.sinkhornfactor, args.nllfactor],
        batch_sizes=[args.genbatchsize, args.batchsize],
        epochs=args.epochs // 2,
        parameters=[phase1_parameters, phase2_parameters],
        callbacks=[
            callbacks.save_model_after_every_epoch(
                create_filename("checkpoint", None, args)[:-3] +
                "_epoch_{}.pt")
        ],
        trainer_kwargs=[phase1_kwargs, phase2_kwargs],
        subsets=args.subsets,
        subset_callbacks=[callbacks.print_mf_weight_statistics()]
        if args.debug else None,
        **meta_kwargs,
    )
    learning_curves = np.vstack(learning_curves_).T

    return learning_curves
Ejemplo n.º 2
0
def train_manifold_flow_alternating(args, dataset, model, simulator):
    """ MFMF-A training """

    assert not args.specified

    trainer1 = ForwardTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalForwardTrainer(model)
    trainer2 = ForwardTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalForwardTrainer(
        model) if args.scandal is None else SCANDALForwardTrainer(model)
    metatrainer = AlternatingTrainer(model, trainer1, trainer2)

    meta_kwargs = {
        "dataset": dataset,
        "initial_lr": args.lr,
        "scheduler": optim.lr_scheduler.CosineAnnealingLR,
        "validation_split": args.validationsplit
    }
    if args.weightdecay is not None:
        meta_kwargs["optimizer_kwargs"] = {
            "weight_decay": float(args.weightdecay)
        }
    _, scandal_loss, scandal_label, scandal_weight = make_training_kwargs(
        args, dataset)

    phase1_kwargs = {
        "forward_kwargs": {
            "mode": "projection"
        },
        "clip_gradient": args.clip
    }
    phase2_kwargs = {
        "forward_kwargs": {
            "mode": "mf-fixed-manifold"
        },
        "clip_gradient": args.clip
    }

    phase1_parameters = list(model.outer_transform.parameters()) + list(
        model.encoder.parameters(
        )) if args.algorithm == "emf" else model.outer_transform.parameters()
    phase2_parameters = list(model.inner_transform.parameters())

    logger.info(
        "Starting training MF, alternating between reconstruction error and log likelihood"
    )
    learning_curves_ = metatrainer.train(
        loss_functions=[
            losses.smooth_l1_loss if args.l1 else losses.mse, losses.nll
        ] + scandal_loss,
        loss_function_trainers=[0, 1] +
        [1] if args.scandal is not None else [],
        loss_labels=["L1" if args.l1 else "MSE", "NLL"] + scandal_label,
        loss_weights=[
            args.msefactor,
            args.nllfactor * nat_to_bit_per_dim(args.modellatentdim)
        ] + scandal_weight,
        epochs=args.epochs // 2,
        subsets=args.subsets,
        batch_sizes=[args.batchsize, args.batchsize],
        parameters=[phase1_parameters, phase2_parameters],
        callbacks=[
            callbacks.save_model_after_every_epoch(
                create_filename("checkpoint", None, args))
        ],
        trainer_kwargs=[phase1_kwargs, phase2_kwargs],
        **meta_kwargs,
    )
    learning_curves = np.vstack(learning_curves_).T

    return learning_curves
Ejemplo n.º 3
0
def train_manifold_flow_alternating(args, dataset, model, simulator):
    """ MFMF-A training """

    assert not args.specified

    trainer = ManifoldFlowTrainer(model) if simulator.parameter_dim(
    ) is None else ConditionalManifoldFlowTrainer(model)
    metatrainer = AlternatingTrainer(model, trainer, trainer)

    meta_kwargs = {
        "dataset": dataset,
        "initial_lr": args.lr,
        "scheduler": optim.lr_scheduler.CosineAnnealingLR
    }
    if args.weightdecay is not None:
        meta_kwargs["optimizer_kwargs"] = {
            "weight_decay": float(args.weightdecay)
        }

    phase1_kwargs = {
        "forward_kwargs": {
            "mode": "projection"
        },
        "clip_gradient": args.clip,
        "validation_split": args.validationsplit
    }
    phase2_kwargs = {
        "forward_kwargs": {
            "mode": "mf-fixed-manifold"
        },
        "clip_gradient": args.clip,
        "validation_split": args.validationsplit
    }

    phase1_parameters = (list(model.outer_transform.parameters()) +
                         list(model.encoder.parameters()) if args.algorithm
                         == "emf" else model.outer_transform.parameters())
    phase2_parameters = model.inner_transform.parameters()

    logger.info(
        "Starting training MF, alternating between reconstruction error and log likelihood"
    )
    learning_curves_ = metatrainer.train(
        loss_functions=[losses.mse, losses.nll],
        loss_function_trainers=[0, 1],
        loss_labels=["MSE", "NLL"],
        loss_weights=[args.msefactor, args.nllfactor],
        epochs=args.epochs // 2,
        subsets=args.subsets,
        batch_sizes=[args.batchsize, args.batchsize],
        parameters=[phase1_parameters, phase2_parameters],
        callbacks=[
            callbacks.save_model_after_every_epoch(
                create_filename("checkpoint", None, args)[:-3] +
                "_epoch_{}.pt")
        ],
        trainer_kwargs=[phase1_kwargs, phase2_kwargs],
        **meta_kwargs,
    )
    learning_curves = np.vstack(learning_curves_).T

    return learning_curves