lambda x: torch.mean(F.relu(-x + dxudmin)), weight=args.Q_con_fdu, name='dist_influence_lb') disturbances_max_influence_ub = Objective( [f'fD_dynamics'], lambda x: torch.mean(F.relu(x - dxudmax)), weight=args.Q_con_fdu, name='dist_influence_ub') constraints += [ disturbances_max_influence_lb, disturbances_max_influence_ub ] ########################################## ########## OPTIMIZE SOLUTION ############ ########################################## model = Problem(objectives, constraints, components).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) visualizer = VisualizerOpen(dataset, dynamics_model, args.verbosity, args.savedir, training_visuals=args.train_visuals, trace_movie=args.trace_movie) # simulator = OpenLoopSimulator(model=model, dataset=dataset, eval_sim=not args.skip_eval_sim) simulator = MHOpenLoopSimulator(model=model, dataset=dataset, eval_sim=not args.skip_eval_sim) trainer = Trainer(model, dataset, optimizer, logger=logger,
args.nsteps, dynamics_model.fy.out_features, xmax=(0.8, 0.7), xmin=0.2, min_period=1, max_period=20, name="Y_ctrl_", ) noise_generator = NoiseGenerator( ratio=0.05, keys=["Y_pred_dynamics"], name="_noise" ) objectives, constraints = get_objective_terms(args, policy) model = Problem( objectives, constraints, [signal_generator, estimator, policy, dynamics_model], ) model = model.to(device) # train only policy component freeze_weight(model, module_names=args.freeze) unfreeze_weight(model, module_names=args.unfreeze) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) plot_keys = ["Y_pred", "U_pred"] # variables to be plotted visualizer = VisualizerClosedLoop( dataset, policy, plot_keys, args.verbosity, savedir=args.savedir ) policy.input_keys[0] = "Yp" # hack for policy input key compatibility w/ simulator
# observation_lower_bound_penalty = Objective(['Y_pred_dynamics_noise', 'Y_minf'], # lambda x, xmin: torch.mean(F.relu(-x[:, :, :1] + xmin)), # weight=args.Q_con_y, name='observation_lower_bound').to(device) # observation_upper_bound_penalty = Objective(['Y_pred_dynamics_noise', 'Y_maxf'], # lambda x, xmax: torch.mean(F.relu(x[:, :, :1] - xmax)), # weight=args.Q_con_y, name='observation_upper_bound').to(device) objectives = [regularization, reference_loss] constraints = [observation_lower_bound_penalty, observation_upper_bound_penalty, inputs_lower_bound_penalty, inputs_upper_bound_penalty] ########################################## ########## OPTIMIZE SOLUTION ############ ########################################## model = Problem(objectives, constraints, components).to(device) freeze_weight(model, module_names=args.freeze) unfreeze_weight(model, module_names=args.unfreeze) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) plot_keys = ['Y_pred', 'U_pred', 'x0_estim'] # variables to be plotted visualizer = VisualizerClosedLoop(dataset, policy, plot_keys, args.verbosity, savedir=args.savedir) emulator = dynamics_model # TODO: hacky solution for policy input keys compatibility with simulator policy.input_keys[0] = 'Yp' simulator = ClosedLoopSimulator(model=model, dataset=dataset, emulator=emulator, policy=policy) trainer = Trainer(model, dataset, optimizer, logger=logger, visualizer=visualizer, simulator=simulator, epochs=args.epochs, patience=args.patience, warmup=args.warmup) best_model = trainer.train() trainer.evaluate(best_model) logger.log_metrics({'alive': 0.0})
get_parser ) if __name__ == "__main__": args = get_parser().parse_args() print({k: str(getattr(args, k)) for k in vars(args) if getattr(args, k)}) device = f"cuda:{args.gpu}" if args.gpu is not None else "cpu" logger = get_logger(args) dataset = load_dataset(args, device, "openloop") print(dataset.dims) estimator, dynamics_model = get_model_components(args, dataset) objectives, constraints = get_objective_terms(args, dataset, estimator, dynamics_model) model = Problem(objectives, constraints, [estimator, dynamics_model]) model = model.to(device) simulator = OpenLoopSimulator(model=model, dataset=dataset, eval_sim=not args.skip_eval_sim) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) trainer = Trainer( model, dataset, optimizer, logger=logger, simulator=simulator, epochs=args.epochs, eval_metric=args.eval_metric, patience=args.patience, warmup=args.warmup, )