def train_manifold_flow_sequential(args, dataset, model, simulator): """ MFMF-A training """ assert not args.specified trainer = ManifoldFlowTrainer(model) if simulator.parameter_dim( ) is None else ConditionalManifoldFlowTrainer(model) common_kwargs = { "dataset": dataset, "batch_size": args.batchsize, "initial_lr": args.lr, "scheduler": optim.lr_scheduler.CosineAnnealingLR, "clip_gradient": args.clip, "validation_split": args.validationsplit, } if args.weightdecay is not None: common_kwargs["optimizer_kwargs"] = { "weight_decay": float(args.weightdecay) } logger.info("Starting training MF, phase 1: manifold training") learning_curves = trainer.train( loss_functions=[losses.mse], loss_labels=["MSE"], loss_weights=[args.msefactor], epochs=args.epochs // 2, parameters=list(model.outer_transform.parameters()) + list(model.encoder.parameters()) if args.algorithm == "emf" else model.outer_transform.parameters(), callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_A{}.pt") ], forward_kwargs={"mode": "projection"}, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T logger.info("Starting training MF, phase 2: density training") learning_curves_ = trainer.train( loss_functions=[losses.nll], loss_labels=["NLL"], loss_weights=[args.nllfactor], epochs=args.epochs - (args.epochs // 2), parameters=model.inner_transform.parameters(), callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_B{}.pt") ], forward_kwargs={"mode": "mf-fixed-manifold"}, **common_kwargs, ) learning_curves_ = np.vstack(learning_curves_).T learning_curves = learning_curves_ if learning_curves is None else np.vstack( (learning_curves, learning_curves_)) return learning_curves
def train_pie(args, dataset, model, simulator): """ PIE training """ trainer = ForwardTrainer(model) if simulator.parameter_dim( ) is None else ConditionalForwardTrainer( model) if args.scandal is None else SCANDALForwardTrainer(model) common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs( args, dataset) callbacks_ = [ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)) ] if simulator.is_image(): callbacks_.append( callbacks.plot_sample_images( create_filename("training_plot", None, args), context=None if simulator.parameter_dim() is None else torch.zeros(30, simulator.parameter_dim()))) callbacks_.append( callbacks.plot_reco_images( create_filename("training_plot", "reco_epoch", args))) logger.info("Starting training PIE on NLL") learning_curves = trainer.train( loss_functions=[losses.nll] + scandal_loss, loss_labels=["NLL"] + scandal_label, loss_weights=[args.nllfactor * nat_to_bit_per_dim(args.datadim)] + scandal_weight, epochs=args.epochs, callbacks=callbacks_, forward_kwargs={"mode": "pie"}, initial_epoch=args.startepoch, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T return learning_curves
def train_generative_adversarial_manifold_flow(args, dataset, model, simulator): """ MFMF-OT training """ gen_trainer = AdversarialTrainer(model) if simulator.parameter_dim( ) is None else ConditionalAdversarialTrainer(model) common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs( args, dataset) common_kwargs["batch_size"] = args.genbatchsize logger.info("Starting training GAMF: Sinkhorn-GAN") callbacks_ = [ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)) ] if args.debug: callbacks_.append(callbacks.print_mf_weight_statistics()) learning_curves_ = gen_trainer.train( loss_functions=[losses.make_sinkhorn_divergence()], loss_labels=["GED"], loss_weights=[args.sinkhornfactor], epochs=args.epochs, callbacks=callbacks_, compute_loss_variance=True, initial_epoch=args.startepoch, **common_kwargs, ) learning_curves = np.vstack(learning_curves_).T return learning_curves
def train_specified_manifold_flow(args, dataset, model, simulator): """ FOM training """ trainer = ForwardTrainer(model) if simulator.parameter_dim( ) is None else ConditionalForwardTrainer( model) if args.scandal is None else SCANDALForwardTrainer(model) common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs( args, dataset) logger.info("Starting training MF with specified manifold on NLL") learning_curves = trainer.train( loss_functions=[losses.mse, losses.nll] + scandal_loss, loss_labels=["MSE", "NLL"] + scandal_label, loss_weights=[ 0.0, args.nllfactor * nat_to_bit_per_dim(args.modellatentdim) ] + scandal_weight, epochs=args.epochs, callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)) ], forward_kwargs={"mode": "mf"}, initial_epoch=args.startepoch, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T return learning_curves
def train_pie(args, dataset, model, simulator): """ PIE training """ trainer = ManifoldFlowTrainer(model) if simulator.parameter_dim( ) is None else ConditionalManifoldFlowTrainer(model) logger.info("Starting training PIE on NLL") common_kwargs = { "dataset": dataset, "batch_size": args.batchsize, "initial_lr": args.lr, "scheduler": optim.lr_scheduler.CosineAnnealingLR, "clip_gradient": args.clip, "validation_split": args.validationsplit, } if args.weightdecay is not None: common_kwargs["optimizer_kwargs"] = { "weight_decay": float(args.weightdecay) } learning_curves = trainer.train( loss_functions=[losses.nll], loss_labels=["NLL"], loss_weights=[args.nllfactor], epochs=args.epochs, callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_{}.pt") ], forward_kwargs={"mode": "pie"}, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T return learning_curves
def train_generative_adversarial_manifold_flow_alternating( args, dataset, model, simulator): """ MFMF-OTA training """ assert not args.specified gen_trainer = GenerativeTrainer(model) if simulator.parameter_dim( ) is None else ConditionalGenerativeTrainer(model) likelihood_trainer = ManifoldFlowTrainer(model) if simulator.parameter_dim( ) is None else ConditionalManifoldFlowTrainer(model) metatrainer = AlternatingTrainer(model, gen_trainer, likelihood_trainer) meta_kwargs = { "dataset": dataset, "initial_lr": args.lr, "scheduler": optim.lr_scheduler.CosineAnnealingLR, "validation_split": args.validationsplit } if args.weightdecay is not None: meta_kwargs["optimizer_kwargs"] = { "weight_decay": float(args.weightdecay) } phase1_kwargs = {"clip_gradient": args.clip} phase2_kwargs = { "forward_kwargs": { "mode": "mf-fixed-manifold" }, "clip_gradient": args.clip } phase1_parameters = model.parameters() phase2_parameters = model.inner_transform.parameters() logger.info( "Starting training GAMF, alternating between Sinkhorn divergence and log likelihood" ) learning_curves_ = metatrainer.train( loss_functions=[losses.make_sinkhorn_divergence(), losses.nll], loss_function_trainers=[0, 1], loss_labels=["GED", "NLL"], loss_weights=[args.sinkhornfactor, args.nllfactor], batch_sizes=[args.genbatchsize, args.batchsize], epochs=args.epochs // 2, parameters=[phase1_parameters, phase2_parameters], callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_{}.pt") ], trainer_kwargs=[phase1_kwargs, phase2_kwargs], subsets=args.subsets, subset_callbacks=[callbacks.print_mf_weight_statistics()] if args.debug else None, **meta_kwargs, ) learning_curves = np.vstack(learning_curves_).T return learning_curves
def train_generative_adversarial_manifold_flow(args, dataset, model, simulator): """ MFMF-OT training """ gen_trainer = GenerativeTrainer(model) if simulator.parameter_dim( ) is None else ConditionalGenerativeTrainer(model) common_kwargs = { "dataset": dataset, "initial_lr": args.lr, "scheduler": optim.lr_scheduler.CosineAnnealingLR, "clip_gradient": args.clip, "validation_split": args.validationsplit, } if args.weightdecay is not None: common_kwargs["optimizer_kwargs"] = { "weight_decay": float(args.weightdecay) } logger.info("Starting training GAMF: Sinkhorn-GAN") callbacks_ = [ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_{}.pt") ] if args.debug: callbacks_.append(callbacks.print_mf_weight_statistics()) learning_curves_ = gen_trainer.train( loss_functions=[losses.make_sinkhorn_divergence()], loss_labels=["GED"], loss_weights=[args.sinkhornfactor], epochs=args.epochs, callbacks=callbacks_, batch_size=args.genbatchsize, compute_loss_variance=True, **common_kwargs, ) learning_curves = np.vstack(learning_curves_).T return learning_curves
def train_dough(args, dataset, model, simulator): """ PIE with variable epsilons training """ trainer = VariableDimensionManifoldFlowTrainer( model) if simulator.parameter_dim( ) is None else ConditionalVariableDimensionManifoldFlowTrainer(model) common_kwargs = { "dataset": dataset, "batch_size": args.batchsize, "initial_lr": args.lr, "scheduler": optim.lr_scheduler.CosineAnnealingLR, "clip_gradient": args.clip, "validation_split": args.validationsplit, } if args.weightdecay is not None: common_kwargs["optimizer_kwargs"] = { "weight_decay": float(args.weightdecay) } logger.info( "Starting training dough, phase 1: NLL without latent regularization") learning_curves = trainer.train( loss_functions=[losses.nll], loss_labels=["NLL"], loss_weights=[args.nllfactor], epochs=args.epochs, callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_{}.pt") ], l1=args.doughl1reg, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T return learning_curves
def train_manifold_flow_sequential(args, dataset, model, simulator): """ Sequential MFMF-M/D training """ assert not args.specified if simulator.parameter_dim() is None: trainer1 = ForwardTrainer(model) trainer2 = ForwardTrainer(model) else: trainer1 = ConditionalForwardTrainer(model) if args.scandal is None: trainer2 = ConditionalForwardTrainer(model) else: trainer2 = SCANDALForwardTrainer(model) common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs( args, dataset) callbacks1 = [ callbacks.save_model_after_every_epoch( create_filename("checkpoint", "A", args)), callbacks.print_mf_latent_statistics(), callbacks.print_mf_weight_statistics() ] callbacks2 = [ callbacks.save_model_after_every_epoch( create_filename("checkpoint", "B", args)), callbacks.print_mf_latent_statistics(), callbacks.print_mf_weight_statistics() ] if simulator.is_image(): callbacks1.append( callbacks.plot_sample_images( create_filename("training_plot", "sample_epoch_A", args), context=None if simulator.parameter_dim() is None else torch.zeros(30, simulator.parameter_dim()), )) callbacks2.append( callbacks.plot_sample_images( create_filename("training_plot", "sample_epoch_B", args), context=None if simulator.parameter_dim() is None else torch.zeros(30, simulator.parameter_dim()), )) callbacks1.append( callbacks.plot_reco_images( create_filename("training_plot", "reco_epoch_A", args))) callbacks2.append( callbacks.plot_reco_images( create_filename("training_plot", "reco_epoch_B", args))) logger.info("Starting training MF, phase 1: manifold training") learning_curves = trainer1.train( loss_functions=[losses.smooth_l1_loss if args.l1 else losses.mse] + ([] if args.uvl2reg is None else [losses.hiddenl2reg]), loss_labels=["L1" if args.l1 else "MSE"] + ([] if args.uvl2reg is None else ["L2_lat"]), loss_weights=[args.msefactor] + ([] if args.uvl2reg is None else [args.uvl2reg]), epochs=args.epochs // 2, parameters=list(model.outer_transform.parameters()) + list(model.encoder.parameters()) if args.algorithm == "emf" else list( model.outer_transform.parameters()), callbacks=callbacks1, forward_kwargs={ "mode": "projection", "return_hidden": args.uvl2reg is not None }, initial_epoch=args.startepoch, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T logger.info("Starting training MF, phase 2: density training") learning_curves_ = trainer2.train( loss_functions=[losses.nll] + scandal_loss, loss_labels=["NLL"] + scandal_label, loss_weights=[ args.nllfactor * nat_to_bit_per_dim(args.modellatentdim) ] + scandal_weight, epochs=args.epochs - (args.epochs // 2), parameters=list(model.inner_transform.parameters()), callbacks=callbacks2, forward_kwargs={"mode": "mf-fixed-manifold"}, initial_epoch=args.startepoch - args.epochs // 2, **common_kwargs, ) learning_curves = np.vstack( (learning_curves, np.vstack(learning_curves_).T)) return learning_curves
def train_manifold_flow_alternating(args, dataset, model, simulator): """ MFMF-A training """ assert not args.specified trainer1 = ForwardTrainer(model) if simulator.parameter_dim( ) is None else ConditionalForwardTrainer(model) trainer2 = ForwardTrainer(model) if simulator.parameter_dim( ) is None else ConditionalForwardTrainer( model) if args.scandal is None else SCANDALForwardTrainer(model) metatrainer = AlternatingTrainer(model, trainer1, trainer2) meta_kwargs = { "dataset": dataset, "initial_lr": args.lr, "scheduler": optim.lr_scheduler.CosineAnnealingLR, "validation_split": args.validationsplit } if args.weightdecay is not None: meta_kwargs["optimizer_kwargs"] = { "weight_decay": float(args.weightdecay) } _, scandal_loss, scandal_label, scandal_weight = make_training_kwargs( args, dataset) phase1_kwargs = { "forward_kwargs": { "mode": "projection" }, "clip_gradient": args.clip } phase2_kwargs = { "forward_kwargs": { "mode": "mf-fixed-manifold" }, "clip_gradient": args.clip } phase1_parameters = list(model.outer_transform.parameters()) + list( model.encoder.parameters( )) if args.algorithm == "emf" else model.outer_transform.parameters() phase2_parameters = list(model.inner_transform.parameters()) logger.info( "Starting training MF, alternating between reconstruction error and log likelihood" ) learning_curves_ = metatrainer.train( loss_functions=[ losses.smooth_l1_loss if args.l1 else losses.mse, losses.nll ] + scandal_loss, loss_function_trainers=[0, 1] + [1] if args.scandal is not None else [], loss_labels=["L1" if args.l1 else "MSE", "NLL"] + scandal_label, loss_weights=[ args.msefactor, args.nllfactor * nat_to_bit_per_dim(args.modellatentdim) ] + scandal_weight, epochs=args.epochs // 2, subsets=args.subsets, batch_sizes=[args.batchsize, args.batchsize], parameters=[phase1_parameters, phase2_parameters], callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)) ], trainer_kwargs=[phase1_kwargs, phase2_kwargs], **meta_kwargs, ) learning_curves = np.vstack(learning_curves_).T return learning_curves
def train_manifold_flow(args, dataset, model, simulator): """ MFMF-S training """ assert not args.specified trainer = ForwardTrainer(model) if simulator.parameter_dim( ) is None else ConditionalForwardTrainer( model) if args.scandal is None else SCANDALForwardTrainer(model) common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs( args, dataset) logger.info( "Starting training MF, phase 1: pretraining on reconstruction error") learning_curves = trainer.train( loss_functions=[losses.mse], loss_labels=["MSE"], loss_weights=[args.msefactor], epochs=args.epochs // 3, callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", "A", args)) ], forward_kwargs={"mode": "projection"}, initial_epoch=args.startepoch, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T logger.info("Starting training MF, phase 2: mixed training") learning_curves_ = trainer.train( loss_functions=[losses.mse, losses.nll] + scandal_loss, loss_labels=["MSE", "NLL"] + scandal_label, loss_weights=[ args.msefactor, args.addnllfactor * nat_to_bit_per_dim(args.modellatentdim) ] + scandal_weight, epochs=args.epochs - 2 * (args.epochs // 3), parameters=list(model.parameters()), callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", "B", args)) ], forward_kwargs={"mode": "mf"}, initial_epoch=args.startepoch - (args.epochs // 3), **common_kwargs, ) learning_curves_ = np.vstack(learning_curves_).T learning_curves = learning_curves_ if learning_curves is None else np.vstack( (learning_curves, learning_curves_)) logger.info( "Starting training MF, phase 3: training only inner flow on NLL") learning_curves_ = trainer.train( loss_functions=[losses.mse, losses.nll] + scandal_loss, loss_labels=["MSE", "NLL"] + scandal_label, loss_weights=[ 0.0, args.nllfactor * nat_to_bit_per_dim(args.modellatentdim) ] + scandal_weight, epochs=args.epochs // 3, parameters=list(model.inner_transform.parameters()), callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", "C", args)) ], forward_kwargs={"mode": "mf-fixed-manifold"}, initial_epoch=args.startepoch - (args.epochs - (args.epochs // 3)), **common_kwargs, ) learning_curves_ = np.vstack(learning_curves_).T learning_curves = np.vstack( (learning_curves, np.vstack(learning_curves_).T)) return learning_curves
def train_slice_of_pie(args, dataset, model, simulator): """ SLICE training """ trainer = ManifoldFlowTrainer(model) if simulator.parameter_dim( ) is None else ConditionalManifoldFlowTrainer(model) common_kwargs = { "dataset": dataset, "batch_size": args.batchsize, "initial_lr": args.lr, "scheduler": optim.lr_scheduler.CosineAnnealingLR, "clip_gradient": args.clip, "validation_split": args.validationsplit, } if args.weightdecay is not None: common_kwargs["optimizer_kwargs"] = { "weight_decay": float(args.weightdecay) } if args.nopretraining or args.epochs // 3 < 1: logger.info("Skipping pretraining phase") learning_curves = np.zeros((0, 2)) else: logger.info( "Starting training slice of PIE, phase 1: pretraining on reconstruction error" ) learning_curves = trainer.train( loss_functions=[losses.mse], loss_labels=["MSE"], loss_weights=[args.initialmsefactor], epochs=args.epochs // 3, callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_A{}.pt") ], forward_kwargs={"mode": "projection"}, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T logger.info("Starting training slice of PIE, phase 2: mixed training") learning_curves_ = trainer.train( loss_functions=[losses.mse, losses.nll], loss_labels=["MSE", "NLL"], loss_weights=[args.initialmsefactor, args.initialnllfactor], epochs=args.epochs - (1 if args.nopretraining else 2) * (args.epochs // 3), parameters=model.inner_transform.parameters(), callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_B{}.pt") ], forward_kwargs={"mode": "slice"}, **common_kwargs, ) learning_curves_ = np.vstack(learning_curves_).T learning_curves = np.vstack((learning_curves, learning_curves_)) logger.info( "Starting training slice of PIE, phase 3: training only inner flow on NLL" ) learning_curves_ = trainer.train( loss_functions=[losses.mse, losses.nll], loss_labels=["MSE", "NLL"], loss_weights=[args.msefactor, args.nllfactor], epochs=args.epochs // 3, parameters=model.inner_transform.parameters(), callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_C{}.pt") ], forward_kwargs={"mode": "slice"}, **common_kwargs, ) learning_curves_ = np.vstack(learning_curves_).T learning_curves = np.vstack((learning_curves, learning_curves_)) return learning_curves
def train_manifold_flow_alternating(args, dataset, model, simulator): """ MFMF-A training """ assert not args.specified trainer = ManifoldFlowTrainer(model) if simulator.parameter_dim( ) is None else ConditionalManifoldFlowTrainer(model) metatrainer = AlternatingTrainer(model, trainer, trainer) meta_kwargs = { "dataset": dataset, "initial_lr": args.lr, "scheduler": optim.lr_scheduler.CosineAnnealingLR } if args.weightdecay is not None: meta_kwargs["optimizer_kwargs"] = { "weight_decay": float(args.weightdecay) } phase1_kwargs = { "forward_kwargs": { "mode": "projection" }, "clip_gradient": args.clip, "validation_split": args.validationsplit } phase2_kwargs = { "forward_kwargs": { "mode": "mf-fixed-manifold" }, "clip_gradient": args.clip, "validation_split": args.validationsplit } phase1_parameters = (list(model.outer_transform.parameters()) + list(model.encoder.parameters()) if args.algorithm == "emf" else model.outer_transform.parameters()) phase2_parameters = model.inner_transform.parameters() logger.info( "Starting training MF, alternating between reconstruction error and log likelihood" ) learning_curves_ = metatrainer.train( loss_functions=[losses.mse, losses.nll], loss_function_trainers=[0, 1], loss_labels=["MSE", "NLL"], loss_weights=[args.msefactor, args.nllfactor], epochs=args.epochs // 2, subsets=args.subsets, batch_sizes=[args.batchsize, args.batchsize], parameters=[phase1_parameters, phase2_parameters], callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_{}.pt") ], trainer_kwargs=[phase1_kwargs, phase2_kwargs], **meta_kwargs, ) learning_curves = np.vstack(learning_curves_).T return learning_curves
def train_manifold_flow(args, dataset, model, simulator): """ MFMF-S training """ trainer = ManifoldFlowTrainer(model) if simulator.parameter_dim( ) is None else ConditionalManifoldFlowTrainer(model) common_kwargs = { "dataset": dataset, "batch_size": args.batchsize, "initial_lr": args.lr, "scheduler": optim.lr_scheduler.CosineAnnealingLR, "clip_gradient": args.clip, "validation_split": args.validationsplit, } if args.weightdecay is not None: common_kwargs["optimizer_kwargs"] = { "weight_decay": float(args.weightdecay) } if args.specified: logger.info("Starting training MF with specified manifold on NLL") learning_curves = trainer.train( loss_functions=[losses.mse, losses.nll], loss_labels=["MSE", "NLL"], loss_weights=[0.0, args.nllfactor], epochs=args.epochs, callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_{}.pt") ], forward_kwargs={"mode": "mf"}, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T else: if args.nopretraining or args.epochs // args.prepostfraction < 1: logger.info("Skipping pretraining phase") learning_curves = None elif args.prepie: logger.info( "Starting training MF, phase 1: pretraining on PIE likelihood") learning_curves = trainer.train( loss_functions=[losses.nll], loss_labels=["NLL"], loss_weights=[args.nllfactor], epochs=args.epochs // args.prepostfraction, callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_A{}.pt") ], forward_kwargs={"mode": "pie"}, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T else: logger.info( "Starting training MF, phase 1: pretraining on reconstruction error" ) learning_curves = trainer.train( loss_functions=[losses.mse], loss_labels=["MSE"], loss_weights=[args.msefactor], epochs=args.epochs // args.prepostfraction, callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_A{}.pt") ], forward_kwargs={"mode": "projection"}, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T logger.info("Starting training MF, phase 2: mixed training") learning_curves_ = trainer.train( loss_functions=[losses.mse, losses.nll], loss_labels=["MSE", "NLL"], loss_weights=[args.msefactor, args.addnllfactor], epochs=args.epochs - (2 - int(args.nopretraining) - int(args.noposttraining)) * (args.epochs // args.prepostfraction), parameters=model.parameters(), callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_B{}.pt") ], forward_kwargs={"mode": "mf"}, **common_kwargs, ) learning_curves_ = np.vstack(learning_curves_).T learning_curves = learning_curves_ if learning_curves is None else np.vstack( (learning_curves, learning_curves_)) if args.nopretraining or args.epochs // args.prepostfraction < 1: logger.info("Skipping inner flow phase") else: logger.info( "Starting training MF, phase 3: training only inner flow on NLL" ) learning_curves_ = trainer.train( loss_functions=[losses.mse, losses.nll], loss_labels=["MSE", "NLL"], loss_weights=[0.0, args.nllfactor], epochs=args.epochs // args.prepostfraction, parameters=model.inner_transform.parameters(), callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)[:-3] + "_epoch_C{}.pt") ], forward_kwargs={"mode": "mf-fixed-manifold"}, **common_kwargs, ) learning_curves_ = np.vstack(learning_curves_).T learning_curves = np.vstack((learning_curves, learning_curves_)) return learning_curves