コード例 #1
0
    def test_seed(self):
        sampler = SerialSampler(
            EnvCls=gym_make,
            env_kwargs={"id": "MountainCarContinuous-v0"},
            batch_T=1,
            batch_B=1,
        )

        agent = SacAgent(pretrain_std=0.0)
        agent.give_min_itr_learn(10000)

        set_seed(0)
        sampler.initialize(agent, seed=0)
        samples_1 = sampler.obtain_samples(0)

        set_seed(0)
        sampler.initialize(agent, seed=0)
        samples_2 = sampler.obtain_samples(0)

        # Dirty hack to compare objects containing tensors.
        self.assertEqual(str(samples_1), str(samples_2))

        samples_3 = sampler.obtain_samples(0)
        self.assertNotEqual(samples_1, samples_3)
コード例 #2
0
 def estimateForState(s):
     cpus = list(range(C.N_PARALLEL))
     affinity = dict(cuda_idx=C.CUDA_IDX, workers_cpus=cpus)
     agent_ = CategoricalPgAgent(
         AcrobotNet, initial_model_state_dict=agent.state_dict())
     sampler = SerialSampler(
         EnvCls=rlpyt_make,
         env_kwargs=dict(id=C.ENV,
                         reward=rewardFn,
                         internalStateFn=C.INTERNAL_STATE_FN,
                         s0=s),
         batch_T=C.HORIZON,
         batch_B=C.BATCH_B,
         max_decorrelation_steps=0,
     )
     sampler.initialize(agent=agent_, affinity=affinity, seed=C.SEED)
     _, traj_info = sampler.obtain_samples(0)
     returns = [t['DiscountedReturn'] for t in traj_info]
     return np.mean(returns)
コード例 #3
0
def log_diagnostics(itr, algo, agent, sampler):
    mp: ProMP = agent.model.promp
    mu, cov = mp.mu_and_cov_w
    std = cov.diagonal(dim1=-2, dim2=-1).sqrt().detach().numpy()
    # for i in range(std.shape[0]):
    #     record_tabular('agent/std{}'.format(i), std[i])
    record_tabular_misc_stat('AgentCov', std)
    record_tabular_misc_stat('AgentMu', mu.detach().numpy())


runner = MinibatchRlWithLog(algo=args.get_ppo_from_options(options),
                            agent=agent,
                            sampler=sampler,
                            log_traj_window=32,
                            n_steps=int(500000 * sampler.batch_size),
                            log_interval_steps=int(horizon * 32),
                            log_diagnostics_fun=log_diagnostics)

if not args.is_evaluation(options):
    with args.get_default_context(options):
        runner.train()
else:
    runner.startup()
    while True:
        sampler.obtain_samples(0)
        print(np.sum(sampler.samples_np.env.reward, 0))
        EnvProMP.plot_states(sampler.samples_np.env.observation,
                             sampler.samples_np.agent.action,
                             max_obs=64)