def run_exp(inputs):
    seed, use_goalgan = inputs
    env = PandaEnv(seed=seed)
    name = f"{'goalgan-' if use_goalgan else ''}her-sac"
    agent = HERSACAgent(env=env, rank=seed, experiment_name=name)
    if use_goalgan:
        agent = GoalGANAgent(env=env, agent=agent)
    cb = EvaluateCallback(agent=agent, eval_env=PandaEnv(seed=seed), rank=seed)
    agent.train(timesteps=int(300000), callbacks=[cb])
Exemple #2
0
def continuous_viz():
    env = PandaEnv(visualize=True)
    agent = HERSACAgent(env=env, experiment_name="goalgan-her-sac")
    obs = env.reset()
    while True:
        action = agent(obs)
        obs, _, done, info = env.step(action)
        if done:
            obs = env.reset(reset_agent_pos=not info["is_success"])
Exemple #3
0
def main(task: str, use_gan: bool, do_train: bool, perform_eval: bool):
    env_params = {"visualize": not do_train}
    env_fn = PandaEnv if task == Task.REACH else PandaPickAndPlace
    env = env_fn(**env_params)

    agent_params = {"env": env}
    if use_gan:
        agent_params["experiment_name"] = "goalgan-her-sac"
    agent = HERSACAgent(**agent_params)
    if use_gan:
        agent = GoalGANAgent(env=env, agent=agent)

    if do_train:
        cbs = [EvaluateCallback(agent=agent, eval_env=env_fn(**env_params))] if perform_eval else []
        agent.train(timesteps=50000, callbacks=cbs)
    else:
        while True:
            consume(trajectory(agent, env))
Exemple #4
0
def test_class_instantiation():
    env = ToyLab()
    a1 = PPOAgent(env=env)
    a2 = HERSACAgent(env=env)
    a3 = GoalGANAgent(env=env, agent=a1)
    for a in [a1, a2, a3]:
        consume(trajectory(pi=a, env=env))

    cb = EvaluateCallback(agent=a1, eval_env=env)
Exemple #5
0
def main(seed: int):
    env = ToyLab(seed=seed)
    # π     = PPOAgent(env=env, experiment_name="goalgan-ppo-seed-{}".format(seed), rank=seed)
    π = HERSACAgent(env=env,
                    experiment_name="goalgan-her-sac-{}".format(seed),
                    rank=seed)
    agent = GoalGANAgent(env=env, agent=π)

    callback = EvaluateCallback(agent=agent,
                                eval_env=ToyLab(seed=seed),
                                rank=seed)
    agent.train(timesteps=int(1e6), callbacks=[callback])