Exemplo n.º 1
0
    def testRolloutDictSpace(self):
        register_env("nested", lambda _: NestedDictEnv())
        agent = PGTrainer(env="nested")
        agent.train()
        path = agent.save()
        agent.stop()

        # Test train works on restore
        agent2 = PGTrainer(env="nested")
        agent2.restore(path)
        agent2.train()

        # Test rollout works on restore
        rollout(agent2, "nested", 100)
Exemplo n.º 2
0
    def test_rollout_dict_space(self):
        register_env("nested", lambda _: NestedDictEnv())
        agent = PGTrainer(env="nested", config={"framework": "tf"})
        agent.train()
        path = agent.save()
        agent.stop()

        # Test train works on restore
        agent2 = PGTrainer(env="nested", config={"framework": "tf"})
        agent2.restore(path)
        agent2.train()

        # Test rollout works on restore
        rollout(agent2, "nested", 100)
Exemplo n.º 3
0
MARWIL_agent = MARWILTrainer(config=marwil_config, env=SSA_Tasker_Env)
MARWIL_agent.restore(marwil_checkpoint)
MARWIL_agent.get_policy().config['explore'] = False

pg_config = PG_CONFIG.copy()
pg_config['batch_mode'] = 'complete_episodes'
pg_config['train_batch_size'] = 2000
pg_config['lr'] = 0.0001
pg_config['evaluation_interval'] = None
pg_config['postprocess_inputs'] = True
pg_config['env_config'] = env_config
pg_config['explore'] = False

PGR_agent = PGTrainer(config=pg_config, env=SSA_Tasker_Env)
PGR_agent.restore(pgr_checkpoint)
PGR_agent.get_policy().config['explore'] = False

PGRE_agent = PGTrainer(config=pg_config, env=SSA_Tasker_Env)
PGRE_agent.restore(pgre_checkpoint)
PGRE_agent.get_policy().config['explore'] = False

OLR_agent = PGTrainer(config=pg_config, env=SSA_Tasker_Env)
OLR_agent.restore(olr_checkpoint)
OLR_agent.get_policy().config['explore'] = False


def ppo_agent(obs, env):
    return PPO_agent.compute_action(obs)