Esempio n. 1
0
if __name__ == "__main__":
    # Parse command line arguments
    args = get_argparser().parse_args()

    # Experiment (set seed before creating the modules)
    ex_dir = setup_experiment(
        QQubeSwingUpSim.name,
        f"{SimOpt.name}-{NES.name}-{PPO.name}_{FNNPolicy.name}")
    num_workers = 16

    # Set seed if desired
    pyrado.set_seed(args.seed, verbose=True)

    # Environments
    env_hparams = dict(dt=1 / 100.0, max_steps=600)
    env_real = QQubeSwingUpReal(**env_hparams)

    env_sim = QQubeSwingUpSim(**env_hparams)
    randomizer = DomainRandomizer(
        NormalDomainParam(name="mass_rot_pole",
                          mean=0.0,
                          std=1e6,
                          clip_lo=1e-3),
        NormalDomainParam(name="mass_pend_pole",
                          mean=0.0,
                          std=1e6,
                          clip_lo=1e-3),
        NormalDomainParam(name="length_rot_pole",
                          mean=0.0,
                          std=1e6,
                          clip_lo=1e-3),
Esempio n. 2
0
 def default_qq_real():
     return QQubeSwingUpReal(dt=1 / 500.0, max_steps=500)
Esempio n. 3
0
                  "c")

    elif args.env_name == QCartPoleStabReal.name:
        env = QCartPoleStabReal(args.dt, args.max_steps)
        policy = QCartPoleSwingUpAndBalanceCtrl(env.spec)
        print_cbt("Set up controller for the QCartPoleStabReal environment.",
                  "c")

    elif args.env_name == QCartPoleSwingUpReal.name:
        env = QCartPoleSwingUpReal(args.dt, args.max_steps)
        policy = QCartPoleSwingUpAndBalanceCtrl(env.spec)
        print_cbt(
            "Set up controller for the QCartPoleSwingUpReal environment.", "c")

    elif args.env_name == QQubeSwingUpReal.name:
        env = QQubeSwingUpReal(args.dt, args.max_steps)
        policy = QQubeSwingUpAndBalanceCtrl(env.spec)
        print_cbt("Set up controller for the QQubeSwingUpReal environment.",
                  "c")

    else:
        raise pyrado.ValueErr(
            given=args.env_name,
            eq_constraint=
            f"{QBallBalancerReal.name}, {QCartPoleSwingUpReal.name}, "
            f"{QCartPoleStabReal.name}, or {QQubeSwingUpReal.name}",
        )

    # Run on device
    done = False
    print_cbt("Running predefined controller ...", "c", bright=True)
        policy.reset(**dict(domain_param=ml_domain_param))
    elif args.src_domain_param == "posterior":
        prefix_str = "" if args.iter == -1 and args.round == -1 else f"iter_{args.iter}_round_{args.round}"
        posterior = pyrado.load("posterior.pt", ex_dir, prefix=prefix_str)
    elif args.src_domain_param == "prior":
        prior = pyrado.load("prior.pt", ex_dir)
    elif args.src_domain_param == "nominal":
        policy.reset(**dict(domain_param=env_sim.get_nominal_domain_param()))

    # Detect the correct real-world counterpart and create it
    if isinstance(inner_env(env_sim), QBallBalancerSim):
        env_real = QBallBalancerReal(dt=args.dt, max_steps=args.max_steps)
    elif isinstance(inner_env(env_sim), QCartPoleSim):
        env_real = QCartPoleReal(dt=args.dt, max_steps=args.max_steps)
    elif isinstance(inner_env(env_sim), QQubeSim):
        env_real = QQubeSwingUpReal(dt=args.dt, max_steps=args.max_steps)
    else:
        raise pyrado.TypeErr(
            given=env_sim,
            expected_type=[QBallBalancerSim, QCartPoleSim, QQubeSim])

    # Wrap the real environment in the same way as done during training
    env_real = wrap_like_other_env(env_real, env_sim)

    # Run on device
    done, first_round = False, True
    print_cbt("Running loaded policy ...", "c", bright=True)
    while not done:
        # sample new domain parameter
        if (args.resample or first_round) and args.src_domain_param in [
                "posterior", "prior"
Esempio n. 5
0
    # Get the experiment's directory to load from if not given as command line argument
    ex_dir = ask_for_experiment(
        hparam_list=args.show_hparams) if args.dir is None else args.dir

    # Load the policy and the environment (for constructing the real-world counterpart)
    env_sim, policy, _ = load_experiment(ex_dir, args)
    if "argmax" in args.policy_name:
        policy = to.load(osp.join(ex_dir, "policy_argmax.pt"))
        print_cbt(f"Loaded {osp.join(ex_dir, 'policy_argmax.pt')}",
                  "g",
                  bright=True)

    # Create real-world counterpart
    max_steps = args.max_steps if args.max_steps < pyrado.inf else env_sim.max_steps
    dt = args.dt if args.dt is not None else env_sim.dt
    env_real = QQubeSwingUpReal(dt, max_steps)
    print_cbt(
        f"Set up the QQubeSwingUpReal environment with dt={env_real.dt} max_steps={env_real.max_steps}.",
        "c")

    # Finally wrap the env in the same as done during training
    env_real = wrap_like_other_env(env_real, env_sim)

    ex_ts = datetime.now().strftime(pyrado.timestamp_format)
    save_dir = osp.join(ex_dir, "evaluation")
    os.makedirs(save_dir, exist_ok=True)
    num_rollouts_per_config = args.num_rollouts_per_config if args.num_rollouts_per_config is not None else 5
    est_ret = BayRn.eval_policy(save_dir,
                                env_real,
                                policy,
                                mc_estimator=True,