示例#1
0
    #     feats=FeatureStack(identity_feat, sign_feat, abs_feat, squared_feat,
    #                        MultFeat((2, 5)), MultFeat((3, 5)), MultFeat((4, 5)))
    # )
    # policy = LinearPolicy(spec=env.spec, **policy_hparam)
    policy_hparam = dict(energy_gain=0.587, ref_energy=0.827)
    policy = QQubeSwingUpAndBalanceCtrl(env.spec, **policy_hparam)

    # Algorithm
    algo_hparam = dict(
        max_iter=5,
        pop_size=50,
        num_init_states_per_domain=4,
        num_domains=20,
        num_is_samples=10,
        expl_std_init=2.0,
        expl_std_min=0.02,
        num_workers=20,
    )
    algo = PoWER(ex_dir, env, policy, **algo_hparam)

    # Save the hyper-parameters
    save_dicts_to_yaml(
        dict(env=env_hparams, seed=args.seed),
        dict(policy=policy_hparam),
        dict(algo=algo_hparam, algo_name=algo.name),
        save_dir=ex_dir,
    )

    # Jeeeha
    algo.train(seed=args.seed, snapshot_mode="best")
示例#2
0
    # Posterior (normalizing flow)
    posterior_hparam = dict(model="maf", hidden_features=50, num_transforms=10)

    # Policy optimization subroutine
    subrtn_policy_hparam = dict(
        max_iter=5,
        pop_size=50,
        num_init_states_per_domain=4,
        num_domains=num_eval_samples,
        num_is_samples=10,
        expl_std_init=2.0,
        expl_std_min=0.02,
        symm_sampling=False,
        num_workers=args.num_workers,
    )
    subrtn_policy = PoWER(ex_dir, env_sim, policy, **subrtn_policy_hparam)

    # Algorithm
    algo_hparam = dict(
        max_iter=5,
        num_real_rollouts=num_real_rollouts,
        num_sim_per_round=5000,
        num_sbi_rounds=3,
        simulation_batch_size=10,
        normalize_posterior=False,
        num_eval_samples=num_eval_samples,
        num_segments=args.num_segments,
        len_segments=args.len_segments,
        stop_on_done=False,
        use_rec_act=True,
        posterior_hparam=posterior_hparam,
                           clip_lo=0),
        UniformDomainParam(name='joint_damping',
                           mean=9.4057e-03,
                           halfspan=5.0000e-04,
                           clip_lo=1e-6),
    )
    env = DomainRandWrapperLive(env, randomizer)

    # Policy
    policy_hparam = hparams['policy']
    policy_hparam['rbf_hparam'].update({'scale': None})
    policy = DualRBFLinearPolicy(env.spec, **policy_hparam)
    policy.param_values = to.tensor(hparams['algo']['policy_param_init'])

    # Algorithm
    algo_hparam = hparams['subroutine']
    algo_hparam.update(
        {'num_workers':
         8})  # should be equivalent to the number of cores per job
    algo = PoWER(ex_dir, env, policy, **algo_hparam)

    # Save the hyper-parameters
    save_list_of_dicts_to_yaml([
        dict(env=env_hparams, seed=ex_dir.seed),
        dict(policy=policy_hparam),
        dict(algo=algo_hparam, algo_name=algo.name)
    ], ex_dir)

    # Jeeeha
    algo.train(seed=ex_dir.seed, snapshot_mode='latest')