コード例 #1
0
    def test_env_trajectory(self, env_fn):
        env = env_fn(max_episode_len=10)
        agent = null_agent(action_space=env.action_space)
        assert len(list(trajectory(pi=agent, env=env))) == 10

        goal = env.observation_space["desired_goal"].high
        assert len(list(trajectory(pi=agent, env=env, goal=goal))) == 10
コード例 #2
0
def test_class_instantiation():
    env = ToyLab()
    a1 = PPOAgent(env=env)
    a2 = HERSACAgent(env=env)
    a3 = GoalGANAgent(env=env, agent=a1)
    for a in [a1, a2, a3]:
        consume(trajectory(pi=a, env=env))

    cb = EvaluateCallback(agent=a1, eval_env=env)
コード例 #3
0
ファイル: agents.py プロジェクト: Hongkuan-Zhou/goalGAN
 def _log_specific_goal_performance(self):
     for step in trajectory(pi=self._agent,
                            env=self._eval_env,
                            goal=self._specific_goal):
         pass
     info = step[-1]
     log = f"{self.num_timesteps},{int(info['is_success'])}\n"
     with open(self._log_fname, "a") as file:
         file.write(log)
コード例 #4
0
def main(task: str, use_gan: bool, do_train: bool, perform_eval: bool):
    env_params = {"visualize": not do_train}
    env_fn = PandaEnv if task == Task.REACH else PandaPickAndPlace
    env = env_fn(**env_params)

    agent_params = {"env": env}
    if use_gan:
        agent_params["experiment_name"] = "goalgan-her-sac"
    agent = HERSACAgent(**agent_params)
    if use_gan:
        agent = GoalGANAgent(env=env, agent=agent)

    if do_train:
        cbs = [EvaluateCallback(agent=agent, eval_env=env_fn(**env_params))] if perform_eval else []
        agent.train(timesteps=50000, callbacks=cbs)
    else:
        while True:
            consume(trajectory(agent, env))