Esempio n. 1
0
def main(max_episodes: int, max_episode_len, env: gym.Env, gym_dir: str):

    seed = 0
    np.random.seed(seed)
    tf.set_random_seed(seed)
    env.seed(seed)

    render_env = True
    use_gym_monitor = True

    agent = TensorFlowDDPGAgent(
        state_dim=env.observation_space.shape[0],
        action_space=env.action_space,
    )

    if use_gym_monitor:
        env = wrappers.Monitor(env, gym_dir, force=True)

    Experiment(
        agent=agent,
        env=env,
        render_env=render_env,
        num_episodes=max_episodes,
        max_episode_len=max_episode_len,
    ).run(seed=seed)

    if use_gym_monitor:
        env.close()
def test_tf_ddpg_agent_trigger_train(cleanup):
    replay_buffer = InMemoryReplayBuffer(lower_size_limit=2, batch_size=2)
    agent = TensorFlowDDPGAgent(state_dim=STATE_DIM, action_space=ACTION_SPACE,
                                replay_buffer=replay_buffer)
    with mock.patch.object(agent, "_train", wraps=agent._train) as train_spy:
        agent.observe_episode(EPISODE)
        agent.act(current_state=STATE_SPACE.sample())
        train_spy.assert_called()
def test_tf_ddpg_agent_reject_invalid_seed(flower: TensorFlowDDPGAgent):
    for invalid_seed in [None, "one"]:
        with pytest.raises(TypeError):
            flower.set_seed(invalid_seed)
def test_tf_ddpg_agent_set_seed(flower: TensorFlowDDPGAgent):
    flower.set_seed(1)
def test_tf_ddpg_agent_rejects_invalid_episodes(flower: TensorFlowDDPGAgent):
    for invalid_episode in [None, "one"]:
        with pytest.raises(TypeError):
            flower.observe_episode(invalid_episode)
def test_tf_ddpg_agent_observe_episode(flower: TensorFlowDDPGAgent, cleanup):
    """
    Observing an episode may trigger model saving, so we need to remove
    the created folder.
    """
    flower.observe_episode(EPISODE)
def test_tf_ddpg_agent_act(flower: TensorFlowDDPGAgent):
    action = flower.act(current_state=STATE_SPACE.sample())
    assert ACTION_SPACE.contains(action)
def flower():
    return TensorFlowDDPGAgent(state_dim=STATE_DIM, action_space=ACTION_SPACE)