예제 #1
0
def test_save_and_laod_yaml():
    ex_dir = setup_experiment("testenv",
                              "testalgo",
                              "testinfo",
                              base_dir=TEMP_DIR)

    # Save test data to YAML-file
    save_dicts_to_yaml(
        dict(a=1),
        dict(b=2.0),
        dict(c=np.array([1.0, 2.0])),
        dict(d=to.tensor([3.0, 4.0])),
        dict(e="string"),
        dict(f=[5, "f"]),
        dict(g=(6, "g")),
        save_dir=ex_dir,
        file_name="testfile",
    )

    data = load_dict_from_yaml(osp.join(ex_dir, "testfile.yaml"))
    assert isinstance(data, dict)
    assert data["a"] == 1
    assert data["b"] == 2
    assert data["c"] == [1.0, 2.0]  # now a list
    assert data["d"] == [3.0, 4.0]  # now a list
    assert data["e"] == "string"
    assert data["f"] == [5, "f"]
    assert data["g"] == (6, "g")

    # Delete the created folder recursively
    shutil.rmtree(osp.join(TEMP_DIR, "testenv"),
                  ignore_errors=True)  # also deletes read-only files
예제 #2
0
    policy = ADNPolicy(spec=EnvSpec(act_space=InfBoxSpace(shape=1),
                                    obs_space=InfBoxSpace(shape=1)),
                       **policy_hparam)

    # Algorithm
    algo_hparam = dict(
        max_iter=1000,
        windowed=False,
        cascaded=True,
        optim_class=optim.Adam,
        optim_hparam=dict(lr=1e-1, eps=1e-8,
                          weight_decay=1e-4),  # momentum=0.7
        loss_fcn=nn.MSELoss(),
        lr_scheduler=lr_scheduler.ExponentialLR,
        lr_scheduler_hparam=dict(gamma=0.995),
    )
    algo = TSPred(ex_dir, dataset, policy, **algo_hparam)

    # Save the hyper-parameters
    save_dicts_to_yaml(
        dict(data_set=data_set_hparam,
             data_set_name=data_set_name,
             seed=args.seed),
        dict(policy=policy_hparam),
        dict(algo=algo_hparam, algo_name=algo.name),
        save_dir=ex_dir,
    )

    # Jeeeha
    algo.train()
예제 #3
0
        acq_fc="EI",
        # acq_param=dict(beta=0.2),
        acq_restarts=500,
        acq_samples=1000,
        num_init_cand=10,
        warmstart=False,
        # policy_param_init=policy_init.param_values.data,
        # valuefcn_param_init=valuefcn_init.param_values.data,
    )

    # Save the environments and the hyper-parameters
    save_dicts_to_yaml(
        dict(env_sim=env_sim_hparams,
             env_real=env_real_hparams,
             seed=args.seed),
        dict(policy=policy_hparam),
        dict(critic=critic_hparam, vfcn=vfcn_hparam),
        dict(subrtn=subrtn_hparam, subrtn_name=PPO.name),
        dict(algo=bayrn_hparam, algo_name=BayRn.name, dp_map=dp_map),
        save_dir=ex_dir,
    )

    algo = BayRn(ex_dir,
                 env_sim,
                 env_real,
                 subrtn,
                 ddp_space=ddp_space,
                 **bayrn_hparam)

    # Jeeeha
    algo.train(
        snapshot_mode="latest",
예제 #4
0
            show_train_summary=False,  # default: False
            # max_num_epochs=5,  # only use for debugging
        ),
        subrtn_sbi_sampling_hparam=dict(sample_with_mcmc=False),
        num_workers=20,
    )
    algo = NPDR(
        ex_dir,
        env_sim,
        env_real,
        policy,
        dp_mapping,
        prior,
        embedding,
        subrtn_sbi_class=SNPE_C,
        **algo_hparam,
    )

    # Save the hyper-parameters
    save_dicts_to_yaml(
        dict(env=env_sim_hparams, seed=args.seed),
        dict(prior=prior_hparam),
        dict(posterior_nn=posterior_hparam),
        dict(embedding=embedding_hparam, embedding_name=embedding.name),
        dict(algo=algo_hparam, algo_name=algo.name),
        save_dir=ex_dir,
    )

    # Jeeeha
    algo.train(snapshot_mode="latest", seed=args.seed)
예제 #5
0
    args = get_argparser().parse_args()

    if args.dir is None:
        ex_dir = setup_experiment("hyperparams", TSPred.name, f"{TSPred.name}_{ADNPolicy.name}")
        study_dir = osp.join(pyrado.TEMP_DIR, ex_dir)
        print_cbt(f"Starting a new Optuna study.", "c", bright=True)
    else:
        study_dir = args.dir
        if not osp.isdir(study_dir):
            raise pyrado.PathErr(given=study_dir)
        print_cbt(f"Continuing an existing Optuna study.", "c", bright=True)

    name = f"{TSPred.name}_{TSPred.name}_{ADNPolicy.name}"
    study = optuna.create_study(
        study_name=name,
        storage=f"sqlite:////{osp.join(study_dir, f'{name}.db')}",
        direction="maximize",
        load_if_exists=True,
    )

    # Start optimizing
    study.optimize(functools.partial(train_and_eval, study_dir=study_dir, seed=args.seed), n_trials=100, n_jobs=16)

    # Save the best hyper-parameters
    save_dicts_to_yaml(
        study.best_params,
        dict(seed=args.seed),
        save_dir=study_dir,
        file_name="best_hyperparams",
    )
예제 #6
0
algo_hparam = dict(
    max_iter=8,
    pop_size=20,
    num_init_states_per_domain=10,
    expl_factor=1.1,
    expl_std_init=1.0,
    num_workers=4,
)
algo = HCNormal(ex_dir, env, policy, **algo_hparam)
"""
Save the hyper-parameters before staring the training in a YAML-file. This step is not strictly necessary, but it helps
you to later see which hyper-parameters you used, i.e. which setting leads to a successfully trained policy.
"""
save_dicts_to_yaml(
    dict(env=env_hparams, seed=0),
    dict(policy=policy_hparam),
    dict(algo=algo_hparam, algo_name=algo.name),
    save_dir=ex_dir,
)
"""
Finally, start the training. The `train()` function is the same for all algorithms inheriting from the `Algorithm`
base class. It repetitively calls the algorithm's custom `step()` and `update()` functions.
You can load and continue a previous experiment using the Algorithm's `load()` method. The `snapshot_mode()` method
determines when to save the current training state, e.g. 'latest' saves after every step of the algorithm, and 'best'
only saves if the average return is a new highscore.
Moreover, you can set the random number generator's seed. This second option for setting the seed comes in handy when
you want to continue from a previous experiment multiple times. 
"""
algo.train(snapshot_mode="latest", seed=None)

input("\nFinished training. Hit enter to simulate the policy.\n")
"""
예제 #7
0
        lr=1e-3,
        standardize_adv=False,
        max_grad_norm=5.0,
    )
    particle_hparam = dict(actor=actor_hparam,
                           vfcn=vfcn_hparam,
                           critic=critic_hparam)

    # Algorithm
    algo_hparam = dict(
        max_iter=200,
        min_steps=30 * env.max_steps,
        num_particles=3,
        temperature=1,
        lr=1e-3,
        std_init=1.0,
        horizon=50,
        num_workers=12,
    )
    algo = SVPG(ex_dir, env, particle_hparam, **algo_hparam)

    # Save the hyper-parameters
    save_dicts_to_yaml(
        dict(env=env_hparams, seed=args.seed),
        dict(algo=algo_hparam, algo_name=algo.name),
        save_dir=ex_dir,
    )

    # Jeeeha
    algo.train(seed=args.seed)
예제 #8
0
    # Set the hyper-parameters of SysIdViaEpisodicRL
    num_eval_rollouts = 5
    algo_hparam = dict(
        metric=None,
        std_obs_filt=5,
        obs_dim_weight=[1, 1, 1, 1, 10, 10],
        num_rollouts_per_distr=len(dp_map) * 10,  # former 50
        num_workers=subrtn_hparam["num_workers"],
    )

    # Save the environments and the hyper-parameters
    save_dicts_to_yaml(
        dict(env=env_hparams),
        dict(subrtn=subrtn_hparam, subrtn_name=subrtn.name),
        dict(algo=algo_hparam,
             algo_name=SysIdViaEpisodicRL.name,
             dp_map=dp_map),
        save_dir=ex_dir,
    )

    algo = SysIdViaEpisodicRL(subrtn, behavior_policy, **algo_hparam)

    # Jeeeha
    while algo.curr_iter < algo.max_iter and not algo.stopping_criterion_met():
        algo.logger.add_value(algo.iteration_key, algo.curr_iter)

        # Create fake real-world data
        ro_real = []
        for _ in range(num_eval_rollouts):
            ro_real.append(rollout(env_real, behavior_policy, eval=True))
예제 #9
0
            stop_after_epochs=20,  # default: 20
            retrain_from_scratch_each_round=False,  # default: False
            show_train_summary=False,  # default: False
            # max_num_epochs=5,  # only use for debugging
        ),
        num_workers=20,
    )
    algo = BayesSim(
        save_dir=ex_dir,
        env_sim=env_sim,
        env_real=env_real,
        policy=policy,
        dp_mapping=dp_mapping,
        prior=prior,
        embedding=embedding,
        **algo_hparam,
    )

    # Save the hyper-parameters
    save_dicts_to_yaml(
        dict(env=env_sim_hparams, seed=args.seed),
        dict(dp_mapping=dp_mapping),
        dict(policy=policy_hparam, policy_name=policy.name),
        dict(prior=prior_hparam),
        dict(embedding=embedding_hparam, embedding_name=embedding.name),
        dict(algo=algo_hparam, algo_name=algo.name),
        save_dir=ex_dir,
    )

    algo.train(seed=args.seed)
예제 #10
0
        f"\nVolt disturbance: {6} volts for {steps_disturb} steps",
        "c",
    )

    # Center cart and reset velocity filters and wait until the user or the conroller has put pole upright
    env_real.reset()
    print_cbt("Ready", "g")

    ros = []
    for r in range(args.num_runs):
        if args.mode == "wo":
            ro = experiment_wo_distruber(env_real, env_sim)
        elif args.mode == "w":
            ro = experiment_w_distruber(env_real, env_sim)
        else:
            raise pyrado.ValueErr(given=args.mode, eq_constraint="without (wo), or with (w) disturber")
        ros.append(ro)

    env_real.close()

    # Print and save results
    avg_return = np.mean([ro.undiscounted_return() for ro in ros])
    print_cbt(f"Average return: {avg_return}", "g", bright=True)
    save_dir = setup_experiment("evaluation", "qcp-st_experiment", ex_tag, base_dir=pyrado.TEMP_DIR)
    joblib.dump(ros, osp.join(save_dir, "experiment_rollouts.pkl"))
    save_dicts_to_yaml(
        dict(ex_dir=ex_dir, avg_return=avg_return, num_runs=len(ros), steps_disturb=steps_disturb),
        save_dir=save_dir,
        file_name="experiment_summary",
    )
예제 #11
0
            rollout(env_real,
                    policy,
                    eval=True,
                    max_steps=args.max_steps,
                    render_mode=RenderMode()))

    # Print and save results
    avg_return = np.mean([ro.undiscounted_return() for ro in ros])
    print_cbt(f"Average return: {avg_return}", "g", bright=True)
    save_dir = setup_experiment("evaluation",
                                "qbb_experiment",
                                ex_tag,
                                base_dir=pyrado.TEMP_DIR)
    joblib.dump(ros, osp.join(save_dir, "experiment_rollouts.pkl"))
    save_dicts_to_yaml(
        dict(ex_dir=ex_dir,
             init_state=init_state,
             avg_return=avg_return,
             num_runs=len(ros)),
        save_dir=save_dir,
        file_name="experiment_summary",
    )

    # Stabilize at the end
    pdctrl.reset(state_des=np.zeros(2))
    rollout(env_real,
            pdctrl,
            eval=True,
            max_steps=1000,
            render_mode=RenderMode(text=True))
예제 #12
0
def evaluate_policy(args, ex_dir):
    """Helper function to evaluate the policy from an experiment in the associated environment."""
    env, policy, _ = load_experiment(ex_dir, args)

    # Create multi-dim evaluation grid
    param_spec = dict()
    param_spec_dim = None

    if isinstance(inner_env(env), BallOnPlateSim):
        param_spec["ball_radius"] = np.linspace(0.02, 0.08, num=2, endpoint=True)
        param_spec["ball_rolling_friction_coefficient"] = np.linspace(0.0295, 0.9, num=2, endpoint=True)

    elif isinstance(inner_env(env), QQubeSwingUpSim):
        eval_num = 200
        # Use nominal values for all other parameters.
        for param, nominal_value in env.get_nominal_domain_param().items():
            param_spec[param] = nominal_value
        # param_spec["gravity_const"] = np.linspace(5.0, 15.0, num=eval_num, endpoint=True)
        param_spec["damping_pend_pole"] = np.linspace(0.0, 0.0001, num=eval_num, endpoint=True)
        param_spec["damping_rot_pole"] = np.linspace(0.0, 0.0006, num=eval_num, endpoint=True)
        param_spec_dim = 2

    elif isinstance(inner_env(env), QBallBalancerSim):
        # param_spec["gravity_const"] = np.linspace(7.91, 11.91, num=11, endpoint=True)
        # param_spec["ball_mass"] = np.linspace(0.003, 0.3, num=11, endpoint=True)
        # param_spec["ball_radius"] = np.linspace(0.01, 0.1, num=11, endpoint=True)
        param_spec["plate_length"] = np.linspace(0.275, 0.275, num=11, endpoint=True)
        param_spec["arm_radius"] = np.linspace(0.0254, 0.0254, num=11, endpoint=True)
        # param_spec["load_inertia"] = np.linspace(5.2822e-5*0.5, 5.2822e-5*1.5, num=11, endpoint=True)
        # param_spec["motor_inertia"] = np.linspace(4.6063e-7*0.5, 4.6063e-7*1.5, num=11, endpoint=True)
        # param_spec["gear_ratio"] = np.linspace(60, 80, num=11, endpoint=True)
        # param_spec["gear_efficiency"] = np.linspace(0.6, 1.0, num=11, endpoint=True)
        # param_spec["motor_efficiency"] = np.linspace(0.49, 0.89, num=11, endpoint=True)
        # param_spec["motor_back_emf"] = np.linspace(0.006, 0.066, num=11, endpoint=True)
        # param_spec["motor_resistance"] = np.linspace(2.6*0.5, 2.6*1.5, num=11, endpoint=True)
        # param_spec["combined_damping"] = np.linspace(0.0, 0.05, num=11, endpoint=True)
        # param_spec["friction_coeff"] = np.linspace(0, 0.015, num=11, endpoint=True)
        # param_spec["voltage_thold_x_pos"] = np.linspace(0.0, 1.0, num=11, endpoint=True)
        # param_spec["voltage_thold_x_neg"] = np.linspace(-1., 0.0, num=11, endpoint=True)
        # param_spec["voltage_thold_y_pos"] = np.linspace(0.0, 1.0, num=11, endpoint=True)
        # param_spec["voltage_thold_y_neg"] = np.linspace(-1.0, 0, num=11, endpoint=True)
        # param_spec["offset_th_x"] = np.linspace(-5/180*np.pi, 5/180*np.pi, num=11, endpoint=True)
        # param_spec["offset_th_y"] = np.linspace(-5/180*np.pi, 5/180*np.pi, num=11, endpoint=True)

    else:
        raise NotImplementedError

    # Always add an action delay wrapper (with 0 delay by default)
    if typed_env(env, ActDelayWrapper) is None:
        env = ActDelayWrapper(env)
    # param_spec['act_delay'] = np.linspace(0, 30, num=11, endpoint=True, dtype=int)

    add_info = "-".join(param_spec.keys())

    # Create multidimensional results grid and ensure right number of rollouts
    param_list = param_grid(param_spec)
    param_list *= args.num_rollouts_per_config

    # Fix initial state (set to None if it should not be fixed)
    init_state = np.array([0.0, 0.0, 0.0, 0.0])

    # Create sampler
    pool = SamplerPool(args.num_workers)
    if args.seed is not None:
        pool.set_seed(args.seed)
        print_cbt(f"Set the random number generators' seed to {args.seed}.", "w")
    else:
        print_cbt("No seed was set", "y")

    # Sample rollouts
    ros = eval_domain_params(pool, env, policy, param_list, init_state)

    # Compute metrics
    lod = []
    for ro in ros:
        d = dict(**ro.rollout_info["domain_param"], ret=ro.undiscounted_return(), len=ro.length)
        # Simply remove the observation noise from the domain parameters
        try:
            d.pop("obs_noise_mean")
            d.pop("obs_noise_std")
        except KeyError:
            pass
        lod.append(d)

    df = pd.DataFrame(lod)
    metrics = dict(
        avg_len=df["len"].mean(),
        avg_ret=df["ret"].mean(),
        median_ret=df["ret"].median(),
        min_ret=df["ret"].min(),
        max_ret=df["ret"].max(),
        std_ret=df["ret"].std(),
    )
    pprint(metrics, indent=4)

    # Create subfolder and save
    timestamp = datetime.datetime.now()
    add_info = timestamp.strftime(pyrado.timestamp_format) + "--" + add_info
    save_dir = osp.join(ex_dir, "eval_domain_grid", add_info)
    os.makedirs(save_dir, exist_ok=True)

    save_dicts_to_yaml(
        {"ex_dir": str(ex_dir)},
        {"varied_params": list(param_spec.keys())},
        {"num_rpp": args.num_rollouts_per_config, "seed": args.seed},
        {"metrics": dict_arraylike_to_float(metrics)},
        save_dir=save_dir,
        file_name="summary",
    )
    pyrado.save(df, f"df_sp_grid_{len(param_spec) if param_spec_dim is None else param_spec_dim}d.pkl", save_dir)
예제 #13
0
            # max_num_epochs=5,  # only use for debugging
        ),
        subrtn_policy_snapshot_mode="best",
        train_initial_policy=True,
        num_workers=args.num_workers,
    )
    algo = BayesSim(
        save_dir=ex_dir,
        env_sim=env_sim,
        env_real=env_real,
        policy=policy,
        dp_mapping=dp_mapping,
        prior=prior,
        embedding=embedding,
        subrtn_policy=subrtn_policy,
        **algo_hparam,
    )

    # Save the hyper-parameters
    save_dicts_to_yaml(
        dict(env=env_sim_hparams, seed=args.seed),
        dict(prior=prior_hparam),
        dict(embedding=embedding_hparam, embedding_name=embedding.name),
        dict(subrtn_policy=subrtn_policy_hparam,
             subrtn_policy_name=subrtn_policy.name),
        dict(algo=algo_hparam, algo_name=algo.name),
        save_dir=ex_dir,
    )

    algo.train(seed=args.seed)
            dict(policy=ex_labels[i], ret=rets, len=lengths)),
                       ignore_index=True)

    metrics = dict(
        avg_len=df.groupby("policy").mean()["len"].to_dict(),
        avg_ret=df.groupby("policy").mean()["ret"].to_dict(),
        median_ret=df.groupby("policy").median()["ret"].to_dict(),
        min_ret=df.groupby("policy").min()["ret"].to_dict(),
        max_ret=df.groupby("policy").max()["ret"].to_dict(),
        std_ret=df.groupby("policy").std()["ret"].to_dict(),
    )
    pprint(metrics, indent=4)

    # Create sub-folder and save
    save_dir = setup_experiment("multiple_policies",
                                args.env_name,
                                "nominal",
                                base_dir=pyrado.EVAL_DIR)

    save_dicts_to_yaml(
        {"ex_dirs": ex_dirs},
        {
            "num_rpp": args.num_rollouts_per_config,
            "seed": args.seed
        },
        {"metrics": dict_arraylike_to_float(metrics)},
        save_dir=save_dir,
        file_name="summary",
    )
    df.to_pickle(osp.join(save_dir, "df_nom_mp.pkl"))
예제 #15
0
                                      **subrtn_distr_hparam)

    # Algorithm
    algo_hparam = dict(
        max_iter=15,
        num_eval_rollouts=5,
        warmstart=True,
        thold_succ_subrtn=100,
        subrtn_snapshot_mode="latest",
    )
    algo = SimOpt(ex_dir, env_sim, env_real, subrtn_policy, subrtn_distr,
                  **algo_hparam)

    # Save the environments and the hyper-parameters
    save_dicts_to_yaml(
        dict(env=env_hparams, seed=args.seed),
        dict(behav_policy=behav_policy_hparam),
        dict(ddp_policy=ddp_policy_hparam, subrtn_distr_name=ddp_policy.name),
        dict(subrtn_distr=subrtn_distr_hparam,
             subrtn_distr_name=subrtn_distr.name),
        dict(subsubrtn_distr=subsubrtn_distr_hparam,
             subsubrtn_distr_name=subsubrtn_distr.name),
        dict(subrtn_policy=subrtn_policy_hparam,
             subrtn_policy_name=subrtn_policy.name),
        dict(algo=algo_hparam, algo_name=algo.name),
        save_dir=ex_dir,
    )

    # Jeeeha
    algo.train(seed=ex_dir.seed)
예제 #16
0
    # Policy
    policy = to.load(osp.join(ref_ex_dir, "policy.pt"))
    policy.init_param()

    # Critic
    vfcn = to.load(osp.join(ref_ex_dir, "valuefcn.pt"))
    vfcn.init_param()
    critic = GAE(vfcn, **hparams["critic"])

    # Algorithm
    algo_hparam = hparams["subrtn"]
    algo_hparam.update(
        {"num_workers":
         1})  # should be equivalent to the number of cores per job
    # algo_hparam.update({'max_iter': 300})
    # algo_hparam.update({'max_iter': 600})
    # algo_hparam.update({'min_steps': 3*algo_hparam['min_steps']})
    algo = PPO(ex_dir, env, policy, critic, **algo_hparam)

    # Save the hyper-parameters
    save_dicts_to_yaml(
        dict(env=env_hparams, seed=args.seed),
        dict(policy=hparams["policy"]),
        dict(critic=hparams["critic"]),
        dict(algo=algo_hparam, algo_name=algo.name),
        save_dir=ex_dir,
    )

    # Jeeeha
    algo.train(seed=args.seed, snapshot_mode="latest")
예제 #17
0
    )
    critic = GAE(vfcn, **critic_hparam)

    # Subroutine
    algo_hparam = dict(
        max_iter=200 if policy.name == FNNPolicy.name else 75,
        eps_clip=0.12648736789309026,
        min_steps=30 * env.max_steps,
        num_epoch=7,
        batch_size=500,
        std_init=0.7573286998997557,
        lr=6.999956625305722e-04,
        max_grad_norm=1.0,
        num_workers=8,
        lr_scheduler=lr_scheduler.ExponentialLR,
        lr_scheduler_hparam=dict(gamma=0.999),
    )
    algo = PPO(ex_dir, env, policy, critic, **algo_hparam)

    # Save the hyper-parameters
    save_dicts_to_yaml(
        dict(env=env_hparams, seed=args.seed),
        dict(policy=policy_hparam),
        dict(critic=critic_hparam, vfcn=vfcn_hparam),
        dict(algo=algo_hparam, algo_name=algo.name),
        save_dir=ex_dir,
    )

    # Jeeeha
    algo.train(snapshot_mode="latest", seed=args.seed)
예제 #18
0
        alpha=0.05,
        beta=0.1,
        nG=20,
        nJ=120,
        ntau=5,
        nc_init=5,
        nr_init=1,
        sequence_cand=sequence_add_init,
        sequence_refs=sequence_const,
        warmstart_cand=True,
        warmstart_refs=True,
        cand_policy_param_init=init_policy_param_values,
        num_bs_reps=1000,
        studentized_ci=False,
    )
    algo = SPOTA(ex_dir, env, sr_cand, sr_refs, **algo_hparam)

    # Save the environments and the hyper-parameters
    save_dicts_to_yaml(
        dict(env=env_hparams, seed=args.seed),
        dict(policy=policy_hparam),
        dict(subrtn_name=sr_cand.name,
             subrtn_cand=subrtn_hparam_cand,
             subrtn_refs=subrtn_hparam_refs),
        dict(algo=algo_hparam, algo_name=algo.name),
        save_dir=ex_dir,
    )

    # Jeeeha
    algo.train(seed=args.seed)
예제 #19
0
    env = DomainRandWrapperLive(
        env,
        randomizer=DomainRandomizer(
            *[SelfPacedDomainParam(**p) for p in env_sprl_params]))

    sprl_hparam = dict(
        kl_constraints_ub=8000,
        performance_lower_bound=500,
        std_lower_bound=0.4,
        kl_threshold=200,
        max_iter=args.sprl_iterations,
        optimize_mean=not args.cov_only,
    )
    algo = SPRL(env, PPO(ex_dir, env, policy, critic, **algo_hparam),
                **sprl_hparam)

    # Save the hyper-parameters
    save_dicts_to_yaml(
        dict(env=env_hparams, seed=args.seed),
        dict(policy=policy_hparam),
        dict(critic=critic_hparam, vfcn=vfcn_hparam),
        dict(subrtn=algo_hparam, subrtn_name=PPO.name),
        dict(algo=sprl_hparam,
             algo_name=algo.name,
             env_sprl_params=env_sprl_params),
        save_dir=ex_dir,
    )

    # Jeeeha
    algo.train(snapshot_mode="latest", seed=args.seed)
예제 #20
0
        subrtn_sbi_sampling_hparam=dict(sample_with_mcmc=False),
        num_workers=args.num_workers,
    )
    algo = NPDR(
        ex_dir,
        env_sim,
        env_real,
        policy,
        dp_mapping,
        prior,
        embedding,
        subrtn_sbi_class=SNPE_C,
        subrtn_policy=subrtn_policy,
        **algo_hparam,
    )

    # Save the hyper-parameters
    save_dicts_to_yaml(
        dict(env=env_hparams, seed=args.seed),
        dict(prior=prior_hparam),
        dict(posterior_nn=posterior_hparam),
        dict(policy=policy_hparam),
        dict(subrtn_policy=subrtn_policy_hparam,
             subrtn_policy_name=subrtn_policy.name),
        dict(algo=algo_hparam, algo_name=algo.name),
        save_dir=ex_dir,
    )

    # Jeeeha
    algo.train(seed=args.seed)