def test_env_trajectory(self, env_fn): env = env_fn(max_episode_len=10) agent = null_agent(action_space=env.action_space) assert len(list(trajectory(pi=agent, env=env))) == 10 goal = env.observation_space["desired_goal"].high assert len(list(trajectory(pi=agent, env=env, goal=goal))) == 10
def test_class_instantiation(): env = ToyLab() a1 = PPOAgent(env=env) a2 = HERSACAgent(env=env) a3 = GoalGANAgent(env=env, agent=a1) for a in [a1, a2, a3]: consume(trajectory(pi=a, env=env)) cb = EvaluateCallback(agent=a1, eval_env=env)
def _log_specific_goal_performance(self): for step in trajectory(pi=self._agent, env=self._eval_env, goal=self._specific_goal): pass info = step[-1] log = f"{self.num_timesteps},{int(info['is_success'])}\n" with open(self._log_fname, "a") as file: file.write(log)
def main(task: str, use_gan: bool, do_train: bool, perform_eval: bool): env_params = {"visualize": not do_train} env_fn = PandaEnv if task == Task.REACH else PandaPickAndPlace env = env_fn(**env_params) agent_params = {"env": env} if use_gan: agent_params["experiment_name"] = "goalgan-her-sac" agent = HERSACAgent(**agent_params) if use_gan: agent = GoalGANAgent(env=env, agent=agent) if do_train: cbs = [EvaluateCallback(agent=agent, eval_env=env_fn(**env_params))] if perform_eval else [] agent.train(timesteps=50000, callbacks=cbs) else: while True: consume(trajectory(agent, env))