Exemplo n.º 1
0
            ]

    ##########################################
    ########## OPTIMIZE SOLUTION ############
    ##########################################
    model = Problem(objectives, constraints, components).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    visualizer = VisualizerOpen(dataset,
                                dynamics_model,
                                args.verbosity,
                                args.savedir,
                                training_visuals=args.train_visuals,
                                trace_movie=args.trace_movie)
    # simulator = OpenLoopSimulator(model=model, dataset=dataset, eval_sim=not args.skip_eval_sim)
    simulator = MHOpenLoopSimulator(model=model,
                                    dataset=dataset,
                                    eval_sim=not args.skip_eval_sim)
    trainer = Trainer(model,
                      dataset,
                      optimizer,
                      logger=logger,
                      visualizer=visualizer,
                      simulator=simulator,
                      epochs=args.epochs,
                      eval_metric=args.eval_metric,
                      patience=args.patience,
                      warmup=args.warmup)
    best_model = trainer.train()
    trainer.evaluate(best_model)
    logger.clean_up()
Exemplo n.º 2
0
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    plot_keys = ["Y_pred", "U_pred"]  # variables to be plotted
    visualizer = VisualizerClosedLoop(
        dataset, policy, plot_keys, args.verbosity, savedir=args.savedir
    )

    policy.input_keys[0] = "Yp"  # hack for policy input key compatibility w/ simulator
    simulator = ClosedLoopSimulator(
        model=model, dataset=dataset, emulator=dynamics_model, policy=policy
    )
    trainer = Trainer(
        model,
        dataset,
        optimizer,
        logger=logger,
        visualizer=visualizer,
        simulator=simulator,
        epochs=args.epochs,
        patience=args.patience,
        warmup=args.warmup,
    )

    # Train control policy
    best_model = trainer.train()
    best_outputs = trainer.evaluate(best_model)
    plots = visualizer.eval(best_outputs)

    # Logger
    logger.log_artifacts(plots)
    logger.clean_up()