def train_generative_adversarial_manifold_flow(args, dataset, model, simulator): """ MFMF-OT training """ gen_trainer = AdversarialTrainer(model) if simulator.parameter_dim( ) is None else ConditionalAdversarialTrainer(model) common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs( args, dataset) common_kwargs["batch_size"] = args.genbatchsize logger.info("Starting training GAMF: Sinkhorn-GAN") callbacks_ = [ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)) ] if args.debug: callbacks_.append(callbacks.print_mf_weight_statistics()) learning_curves_ = gen_trainer.train( loss_functions=[losses.make_sinkhorn_divergence()], loss_labels=["GED"], loss_weights=[args.sinkhornfactor], epochs=args.epochs, callbacks=callbacks_, compute_loss_variance=True, initial_epoch=args.startepoch, **common_kwargs, ) learning_curves = np.vstack(learning_curves_).T return learning_curves
def train_generative_adversarial_manifold_flow_alternating( args, dataset, model, simulator): """ MFMF-OTA training """ assert not args.specified gen_trainer = GenerativeTrainer(model) if simulator.parameter_dim( ) is None else ConditionalGenerativeTrainer(model) likelihood_trainer = ManifoldFlowTrainer(model) if simulator.parameter_dim( ) is None else ConditionalManifoldFlowTrainer(model) metatrainer = AlternatingTrainer(model, gen_trainer, likelihood_trainer) meta_kwargs = { "dataset": dataset, "initial_lr": args.lr, "scheduler": optim.lr_scheduler.CosineAnnealingLR, "validation_split": args.validationsplit } if args.weightdecay is not None: meta_kwargs["optimizer_kwargs"] = { "weight_decay": float(args.weightdecay) } phase1_kwargs = {"clip_gradient": args.clip} phase2_kwargs = { "forward_kwargs": { "mode": "mf-fixed-manifold" }, "clip_gradient": args.clip } phase1_parameters = model.parameters() phase2_parameters = model.inner_transform.parameters() logger.info( "Starting training GAMF, alternating between Sinkhorn divergence and log likelihood" ) learning_curves_ = metatrainer.train( loss_functions=[losses.make_sinkhorn_divergence(), losses.nll], loss_function_trainers=[0, 1], loss_labels=["GED", "NLL"], loss_weights=[args.sinkhornfactor, args.nllfactor], batch_sizes=[args.genbatchsize, args.batchsize], epochs=args.epochs // 2, parameters=[phase1_parameters, phase2_parameters], callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_{}.pt") ], trainer_kwargs=[phase1_kwargs, phase2_kwargs], subsets=args.subsets, subset_callbacks=[callbacks.print_mf_weight_statistics()] if args.debug else None, **meta_kwargs, ) learning_curves = np.vstack(learning_curves_).T return learning_curves
def train_generative_adversarial_manifold_flow(args, dataset, model, simulator): """ MFMF-OT training """ gen_trainer = GenerativeTrainer(model) if simulator.parameter_dim( ) is None else ConditionalGenerativeTrainer(model) common_kwargs = { "dataset": dataset, "initial_lr": args.lr, "scheduler": optim.lr_scheduler.CosineAnnealingLR, "clip_gradient": args.clip, "validation_split": args.validationsplit, } if args.weightdecay is not None: common_kwargs["optimizer_kwargs"] = { "weight_decay": float(args.weightdecay) } logger.info("Starting training GAMF: Sinkhorn-GAN") callbacks_ = [ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_{}.pt") ] if args.debug: callbacks_.append(callbacks.print_mf_weight_statistics()) learning_curves_ = gen_trainer.train( loss_functions=[losses.make_sinkhorn_divergence()], loss_labels=["GED"], loss_weights=[args.sinkhornfactor], epochs=args.epochs, callbacks=callbacks_, batch_size=args.genbatchsize, compute_loss_variance=True, **common_kwargs, ) learning_curves = np.vstack(learning_curves_).T return learning_curves