def train_specified_manifold_flow(args, dataset, model, simulator): """ FOM training """ trainer = ForwardTrainer(model) if simulator.parameter_dim( ) is None else ConditionalForwardTrainer( model) if args.scandal is None else SCANDALForwardTrainer(model) common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs( args, dataset) logger.info("Starting training MF with specified manifold on NLL") learning_curves = trainer.train( loss_functions=[losses.mse, losses.nll] + scandal_loss, loss_labels=["MSE", "NLL"] + scandal_label, loss_weights=[ 0.0, args.nllfactor * nat_to_bit_per_dim(args.modellatentdim) ] + scandal_weight, epochs=args.epochs, callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)) ], forward_kwargs={"mode": "mf"}, initial_epoch=args.startepoch, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T return learning_curves
def train_pie(args, dataset, model, simulator): """ PIE training """ trainer = ForwardTrainer(model) if simulator.parameter_dim( ) is None else ConditionalForwardTrainer( model) if args.scandal is None else SCANDALForwardTrainer(model) common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs( args, dataset) callbacks_ = [ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)) ] if simulator.is_image(): callbacks_.append( callbacks.plot_sample_images( create_filename("training_plot", None, args), context=None if simulator.parameter_dim() is None else torch.zeros(30, simulator.parameter_dim()))) callbacks_.append( callbacks.plot_reco_images( create_filename("training_plot", "reco_epoch", args))) logger.info("Starting training PIE on NLL") learning_curves = trainer.train( loss_functions=[losses.nll] + scandal_loss, loss_labels=["NLL"] + scandal_label, loss_weights=[args.nllfactor * nat_to_bit_per_dim(args.datadim)] + scandal_weight, epochs=args.epochs, callbacks=callbacks_, forward_kwargs={"mode": "pie"}, initial_epoch=args.startepoch, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T return learning_curves
def objective(trial): global counter counter += 1 # Hyperparameters margs = pick_parameters(args, trial, counter) logger.info(f"Starting run {counter} / {args.trials}") logger.info(f"Hyperparams:") logger.info(f" outer layers: {margs.outerlayers}") logger.info(f" inner layers: {margs.innerlayers}") logger.info(f" linear transform: {margs.lineartransform}") logger.info(f" spline range: {margs.splinerange}") logger.info(f" spline bins: {margs.splinebins}") logger.info(f" batchnorm: {margs.batchnorm}") logger.info(f" dropout: {margs.dropout}") logger.info(f" batch size: {margs.batchsize}") logger.info(f" MSE factor: {margs.msefactor}") logger.info(f" latent L2 reg: {margs.uvl2reg}") logger.info(f" weight decay: {margs.weightdecay}") logger.info(f" gradient clipping: {margs.clip}") # Bug fix related to some num_workers > 1 and CUDA. Bad things happen otherwise! torch.multiprocessing.set_start_method("spawn", force=True) # Load data simulator = load_simulator(margs) dataset = simulator.load_dataset(train=True, dataset_dir=create_filename( "dataset", None, args), limit_samplesize=margs.samplesize) # Create model model = create_model(margs, simulator) # Train trainer1 = ForwardTrainer(model) if simulator.parameter_dim( ) is None else ConditionalForwardTrainer(model) trainer2 = ForwardTrainer(model) if simulator.parameter_dim( ) is None else ConditionalForwardTrainer(model) common_kwargs, _, _, _ = train.make_training_kwargs(margs, dataset) logger.info("Starting training MF, phase 1: manifold training") np.random.seed(123) _, val_losses = trainer1.train( loss_functions=[losses.mse, losses.hiddenl2reg], loss_labels=["MSE", "L2_lat"], loss_weights=[ margs.msefactor, 0.0 if margs.uvl2reg is None else margs.uvl2reg ], epochs=margs.epochs, parameters=(list(model.outer_transform.parameters()) + list(model.encoder.parameters()) if args.algorithm == "emf" else model.outer_transform.parameters()), forward_kwargs={ "mode": "projection", "return_hidden": True }, **common_kwargs, ) logger.info("Starting training MF, phase 2: density training") np.random.seed(123) _ = trainer2.train( loss_functions=[losses.nll], loss_labels=["NLL"], loss_weights=[args.nllfactor], epochs=args.densityepochs, parameters=model.inner_transform.parameters(), forward_kwargs={"mode": "mf-fixed-manifold"}, **common_kwargs, ) # Save torch.save(model.state_dict(), create_filename("model", None, margs)) # Evaluate reco error logger.info("Evaluating reco error") model.eval() np.random.seed(123) x, params = next( iter( trainer1.make_dataloader( simulator.load_dataset(train=True, dataset_dir=create_filename( "dataset", None, args), limit_samplesize=args.samplesize), args.validationsplit, 1000, 0)[1])) x = x.to(device=trainer1.device, dtype=trainer1.dtype) params = None if simulator.parameter_dim() is None else params.to( device=trainer1.device, dtype=trainer1.dtype) x_reco, _, _ = model(x, context=params, mode="projection") reco_error = torch.mean(torch.sum((x - x_reco)**2, dim=1)**0.5).detach().cpu().numpy() # Generate samples logger.info("Evaluating sample closure") x_gen = evaluate.sample_from_model(margs, model, simulator) distances_gen = simulator.distance_from_manifold(x_gen) mean_gen_distance = np.mean(distances_gen) # Report results logger.info("Results:") logger.info(" reco err: %s", reco_error) logger.info(" gen distance: %s", mean_gen_distance) return margs.metricrecoerrorfactor * reco_error + margs.metricdistancefactor * mean_gen_distance
def train_generative_adversarial_manifold_flow_alternating( args, dataset, model, simulator): """ MFMF-OTA training """ assert not args.specified gen_trainer = AdversarialTrainer(model) if simulator.parameter_dim( ) is None else ConditionalAdversarialTrainer(model) likelihood_trainer = ForwardTrainer(model) if simulator.parameter_dim( ) is None else ConditionalForwardTrainer( model) if args.scandal is None else SCANDALForwardTrainer(model) metatrainer = AlternatingTrainer(model, gen_trainer, likelihood_trainer) meta_kwargs = { "dataset": dataset, "initial_lr": args.lr, "scheduler": optim.lr_scheduler.CosineAnnealingLR, "validation_split": args.validationsplit } if args.weightdecay is not None: meta_kwargs["optimizer_kwargs"] = { "weight_decay": float(args.weightdecay) } _, scandal_loss, scandal_label, scandal_weight = make_training_kwargs( args, dataset) phase1_kwargs = {"clip_gradient": args.clip} phase2_kwargs = { "forward_kwargs": { "mode": "mf-fixed-manifold" }, "clip_gradient": args.clip } phase1_parameters = list(model.parameters()) phase2_parameters = list(model.inner_transform.parameters()) logger.info( "Starting training GAMF, alternating between Sinkhorn divergence and log likelihood" ) learning_curves_ = metatrainer.train( loss_functions=[losses.make_sinkhorn_divergence(), losses.nll] + scandal_loss, loss_function_trainers=[0, 1] + [1] if args.scandal is not None else [], loss_labels=["GED", "NLL"] + scandal_label, loss_weights=[ args.sinkhornfactor, args.nllfactor * nat_to_bit_per_dim(args.modellatentdim) ] + scandal_weight, batch_sizes=[args.genbatchsize, args.batchsize], epochs=args.epochs // 2, parameters=[phase1_parameters, phase2_parameters], callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)) ], trainer_kwargs=[phase1_kwargs, phase2_kwargs], subsets=args.subsets, subset_callbacks=[callbacks.print_mf_weight_statistics()] if args.debug else None, **meta_kwargs, ) learning_curves = np.vstack(learning_curves_).T return learning_curves
def train_manifold_flow_sequential(args, dataset, model, simulator): """ Sequential MFMF-M/D training """ assert not args.specified if simulator.parameter_dim() is None: trainer1 = ForwardTrainer(model) trainer2 = ForwardTrainer(model) else: trainer1 = ConditionalForwardTrainer(model) if args.scandal is None: trainer2 = ConditionalForwardTrainer(model) else: trainer2 = SCANDALForwardTrainer(model) common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs( args, dataset) callbacks1 = [ callbacks.save_model_after_every_epoch( create_filename("checkpoint", "A", args)), callbacks.print_mf_latent_statistics(), callbacks.print_mf_weight_statistics() ] callbacks2 = [ callbacks.save_model_after_every_epoch( create_filename("checkpoint", "B", args)), callbacks.print_mf_latent_statistics(), callbacks.print_mf_weight_statistics() ] if simulator.is_image(): callbacks1.append( callbacks.plot_sample_images( create_filename("training_plot", "sample_epoch_A", args), context=None if simulator.parameter_dim() is None else torch.zeros(30, simulator.parameter_dim()), )) callbacks2.append( callbacks.plot_sample_images( create_filename("training_plot", "sample_epoch_B", args), context=None if simulator.parameter_dim() is None else torch.zeros(30, simulator.parameter_dim()), )) callbacks1.append( callbacks.plot_reco_images( create_filename("training_plot", "reco_epoch_A", args))) callbacks2.append( callbacks.plot_reco_images( create_filename("training_plot", "reco_epoch_B", args))) logger.info("Starting training MF, phase 1: manifold training") learning_curves = trainer1.train( loss_functions=[losses.smooth_l1_loss if args.l1 else losses.mse] + ([] if args.uvl2reg is None else [losses.hiddenl2reg]), loss_labels=["L1" if args.l1 else "MSE"] + ([] if args.uvl2reg is None else ["L2_lat"]), loss_weights=[args.msefactor] + ([] if args.uvl2reg is None else [args.uvl2reg]), epochs=args.epochs // 2, parameters=list(model.outer_transform.parameters()) + list(model.encoder.parameters()) if args.algorithm == "emf" else list( model.outer_transform.parameters()), callbacks=callbacks1, forward_kwargs={ "mode": "projection", "return_hidden": args.uvl2reg is not None }, initial_epoch=args.startepoch, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T logger.info("Starting training MF, phase 2: density training") learning_curves_ = trainer2.train( loss_functions=[losses.nll] + scandal_loss, loss_labels=["NLL"] + scandal_label, loss_weights=[ args.nllfactor * nat_to_bit_per_dim(args.modellatentdim) ] + scandal_weight, epochs=args.epochs - (args.epochs // 2), parameters=list(model.inner_transform.parameters()), callbacks=callbacks2, forward_kwargs={"mode": "mf-fixed-manifold"}, initial_epoch=args.startepoch - args.epochs // 2, **common_kwargs, ) learning_curves = np.vstack( (learning_curves, np.vstack(learning_curves_).T)) return learning_curves
def train_manifold_flow_alternating(args, dataset, model, simulator): """ MFMF-A training """ assert not args.specified trainer1 = ForwardTrainer(model) if simulator.parameter_dim( ) is None else ConditionalForwardTrainer(model) trainer2 = ForwardTrainer(model) if simulator.parameter_dim( ) is None else ConditionalForwardTrainer( model) if args.scandal is None else SCANDALForwardTrainer(model) metatrainer = AlternatingTrainer(model, trainer1, trainer2) meta_kwargs = { "dataset": dataset, "initial_lr": args.lr, "scheduler": optim.lr_scheduler.CosineAnnealingLR, "validation_split": args.validationsplit } if args.weightdecay is not None: meta_kwargs["optimizer_kwargs"] = { "weight_decay": float(args.weightdecay) } _, scandal_loss, scandal_label, scandal_weight = make_training_kwargs( args, dataset) phase1_kwargs = { "forward_kwargs": { "mode": "projection" }, "clip_gradient": args.clip } phase2_kwargs = { "forward_kwargs": { "mode": "mf-fixed-manifold" }, "clip_gradient": args.clip } phase1_parameters = list(model.outer_transform.parameters()) + list( model.encoder.parameters( )) if args.algorithm == "emf" else model.outer_transform.parameters() phase2_parameters = list(model.inner_transform.parameters()) logger.info( "Starting training MF, alternating between reconstruction error and log likelihood" ) learning_curves_ = metatrainer.train( loss_functions=[ losses.smooth_l1_loss if args.l1 else losses.mse, losses.nll ] + scandal_loss, loss_function_trainers=[0, 1] + [1] if args.scandal is not None else [], loss_labels=["L1" if args.l1 else "MSE", "NLL"] + scandal_label, loss_weights=[ args.msefactor, args.nllfactor * nat_to_bit_per_dim(args.modellatentdim) ] + scandal_weight, epochs=args.epochs // 2, subsets=args.subsets, batch_sizes=[args.batchsize, args.batchsize], parameters=[phase1_parameters, phase2_parameters], callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", None, args)) ], trainer_kwargs=[phase1_kwargs, phase2_kwargs], **meta_kwargs, ) learning_curves = np.vstack(learning_curves_).T return learning_curves
def train_manifold_flow(args, dataset, model, simulator): """ MFMF-S training """ assert not args.specified trainer = ForwardTrainer(model) if simulator.parameter_dim( ) is None else ConditionalForwardTrainer( model) if args.scandal is None else SCANDALForwardTrainer(model) common_kwargs, scandal_loss, scandal_label, scandal_weight = make_training_kwargs( args, dataset) logger.info( "Starting training MF, phase 1: pretraining on reconstruction error") learning_curves = trainer.train( loss_functions=[losses.mse], loss_labels=["MSE"], loss_weights=[args.msefactor], epochs=args.epochs // 3, callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", "A", args)) ], forward_kwargs={"mode": "projection"}, initial_epoch=args.startepoch, **common_kwargs, ) learning_curves = np.vstack(learning_curves).T logger.info("Starting training MF, phase 2: mixed training") learning_curves_ = trainer.train( loss_functions=[losses.mse, losses.nll] + scandal_loss, loss_labels=["MSE", "NLL"] + scandal_label, loss_weights=[ args.msefactor, args.addnllfactor * nat_to_bit_per_dim(args.modellatentdim) ] + scandal_weight, epochs=args.epochs - 2 * (args.epochs // 3), parameters=list(model.parameters()), callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", "B", args)) ], forward_kwargs={"mode": "mf"}, initial_epoch=args.startepoch - (args.epochs // 3), **common_kwargs, ) learning_curves_ = np.vstack(learning_curves_).T learning_curves = learning_curves_ if learning_curves is None else np.vstack( (learning_curves, learning_curves_)) logger.info( "Starting training MF, phase 3: training only inner flow on NLL") learning_curves_ = trainer.train( loss_functions=[losses.mse, losses.nll] + scandal_loss, loss_labels=["MSE", "NLL"] + scandal_label, loss_weights=[ 0.0, args.nllfactor * nat_to_bit_per_dim(args.modellatentdim) ] + scandal_weight, epochs=args.epochs // 3, parameters=list(model.inner_transform.parameters()), callbacks=[ callbacks.save_model_after_every_epoch( create_filename("checkpoint", "C", args)) ], forward_kwargs={"mode": "mf-fixed-manifold"}, initial_epoch=args.startepoch - (args.epochs - (args.epochs // 3)), **common_kwargs, ) learning_curves_ = np.vstack(learning_curves_).T learning_curves = np.vstack( (learning_curves, np.vstack(learning_curves_).T)) return learning_curves
def objective(trial): global counter counter += 1 # Hyperparameters margs = pick_parameters(args, trial, counter) logger.info(f"Starting run {counter} / {args.trials}") logger.info(f"Hyperparams:") logger.info(f" outer layers: {margs.outerlayers}") logger.info(f" linlayers: {margs.linlayers}") logger.info(f" linchannelfactor: {margs.linchannelfactor}") logger.info(f" inner layers: {margs.innerlayers}") logger.info(f" linear transform: {margs.lineartransform}") logger.info(f" spline range: {margs.splinerange}") logger.info(f" spline bins: {margs.splinebins}") logger.info(f" batchnorm: {margs.batchnorm}") logger.info(f" actnorm: {margs.actnorm}") logger.info(f" dropout: {margs.dropout}") logger.info(f" batch size: {margs.batchsize}") logger.info(f" MSE factor: {margs.msefactor}") logger.info(f" latent L2 reg: {margs.uvl2reg}") logger.info(f" weight decay: {margs.weightdecay}") logger.info(f" gradient clipping: {margs.clip}") # Bug fix related to some num_workers > 1 and CUDA. Bad things happen otherwise! torch.multiprocessing.set_start_method("spawn", force=True) # Load data simulator = load_simulator(margs) dataset = simulator.load_dataset(train=True, dataset_dir=create_filename( "dataset", None, margs), limit_samplesize=margs.samplesize) # Create model model = create_model(margs, simulator) # Train try: trainer = ForwardTrainer(model) if simulator.parameter_dim( ) is None else ConditionalForwardTrainer(model) common_kwargs, _, _, _ = train.make_training_kwargs(margs, dataset) logger.info("Starting training MF: manifold training") np.random.seed(123) _, val_losses = trainer.train( loss_functions=[losses.mse, losses.hiddenl2reg], loss_labels=["MSE", "L2_lat"], loss_weights=[ margs.msefactor, 0.0 if margs.uvl2reg is None else margs.uvl2reg ], epochs=margs.epochs, parameters=(list(model.outer_transform.parameters()) + list(model.encoder.parameters()) if args.algorithm == "emf" else model.outer_transform.parameters()), forward_kwargs={ "mode": "projection", "return_hidden": True }, **common_kwargs, ) # Save torch.save(model.state_dict(), create_filename("model", None, margs)) # Evaluate reco error logger.info("Evaluating reco error") model.eval() torch.cuda.empty_cache() np.random.seed(123) dataloader = trainer.make_dataloader( simulator.load_dataset(train=True, dataset_dir=create_filename( "dataset", None, margs), limit_samplesize=margs.samplesize), args.validationsplit, 20, 4)[1] reco_errors = [] x_plot, x_reco_plot = None, None for x, params in dataloader: x = x.to(device=trainer.device, dtype=trainer.dtype) params = None if simulator.parameter_dim( ) is None else params.to(device=trainer.device, dtype=trainer.dtype) x_reco, _, _ = model(x, context=params, mode="projection") reco_errors.append((torch.sum( (x - x_reco)**2, dim=1)**0.5).detach().cpu().numpy()) if x_plot is None: x_plot = x.detach().cpu().numpy() x_reco_plot = x_reco.detach().cpu().numpy() reco_error = np.mean(reco_errors) if not np.isfinite(reco_error): raise RuntimeError() # Report results logger.info("Results:") logger.info(" reco err: %s", reco_error) # Plot reco error x = np.clip(np.transpose(x_plot, [0, 2, 3, 1]) / 256.0, 0.0, 1.0) x_reco = np.clip( np.transpose(x_reco_plot, [0, 2, 3, 1]) / 256.0, 0.0, 1.0) plt.figure(figsize=(6 * 3.0, 5 * 3.0)) for i in range(15): plt.subplot(5, 6, 2 * i + 1) plt.imshow(x[i]) plt.gca().get_xaxis().set_visible(False) plt.gca().get_yaxis().set_visible(False) plt.subplot(5, 6, 2 * i + 2) plt.imshow(x_reco[i]) plt.gca().get_xaxis().set_visible(False) plt.gca().get_yaxis().set_visible(False) plt.tight_layout() filename = create_filename("training_plot", "reco", margs) plt.savefig(filename.format("")) except RuntimeError as e: logger.info("Error during training, returning 1e9\n %s", e) return 1e9 return reco_error