] ########################################## ########## OPTIMIZE SOLUTION ############ ########################################## model = Problem(objectives, constraints, components).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) visualizer = VisualizerOpen(dataset, dynamics_model, args.verbosity, args.savedir, training_visuals=args.train_visuals, trace_movie=args.trace_movie) # simulator = OpenLoopSimulator(model=model, dataset=dataset, eval_sim=not args.skip_eval_sim) simulator = MHOpenLoopSimulator(model=model, dataset=dataset, eval_sim=not args.skip_eval_sim) trainer = Trainer(model, dataset, optimizer, logger=logger, visualizer=visualizer, simulator=simulator, epochs=args.epochs, eval_metric=args.eval_metric, patience=args.patience, warmup=args.warmup) best_model = trainer.train() trainer.evaluate(best_model) logger.clean_up()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) plot_keys = ["Y_pred", "U_pred"] # variables to be plotted visualizer = VisualizerClosedLoop( dataset, policy, plot_keys, args.verbosity, savedir=args.savedir ) policy.input_keys[0] = "Yp" # hack for policy input key compatibility w/ simulator simulator = ClosedLoopSimulator( model=model, dataset=dataset, emulator=dynamics_model, policy=policy ) trainer = Trainer( model, dataset, optimizer, logger=logger, visualizer=visualizer, simulator=simulator, epochs=args.epochs, patience=args.patience, warmup=args.warmup, ) # Train control policy best_model = trainer.train() best_outputs = trainer.evaluate(best_model) plots = visualizer.eval(best_outputs) # Logger logger.log_artifacts(plots) logger.clean_up()