def test_final_state_from_segment():
    env = make_vec_env(SerialVecEnv, make_gym_env, 'CartPole-v1', 3, 0)
    env_spec = EnvSpec(env)

    with pytest.raises(AssertionError):
        final_state_from_segment([1, 2, 3])

    D = BatchSegment(env_spec, 4)
    D.obs = np.random.randn(*D.obs.shape)

    D.done[0, -1] = True
    D.info[0] = [{}, {}, {}, {'terminal_observation': [0.1, 0.2, 0.3, 0.4]}]

    D.done[1, 2] = True
    D.info[1] = [{}, {}, {'terminal_observation': [1, 2, 3, 4]}, {}]

    D.done[2, -1] = True
    D.info[2] = [{}, {}, {}, {'terminal_observation': [10, 20, 30, 40]}]

    final_states = final_state_from_segment(D)
    assert final_states.shape == (3, ) + env_spec.observation_space.shape
    assert np.allclose(final_states[0], [0.1, 0.2, 0.3, 0.4])
    assert np.allclose(final_states[1], D.numpy_observations[1, -1, ...])
    assert not np.allclose(final_states[1], [1, 2, 3, 4])
    assert np.allclose(final_states[2], [10, 20, 30, 40])

    with pytest.raises(AssertionError):
def test_bootstrapped_returns_from_segment():
    env = make_vec_env(SerialVecEnv, make_gym_env, 'CartPole-v1', 3, 0)
    env_spec = EnvSpec(env)

    D = BatchSegment(env_spec, 5)
    D.r[0] = [1, 2, 3, 4, 5]
    D.done[0] = [False, False, False, False, False]
    D.r[1] = [1, 2, 3, 4, 5]
    D.done[1] = [False, False, True, False, False]
    D.r[2] = [1, 2, 3, 4, 5]
    D.done[2] = [True, False, False, False, True]

    last_Vs = torch.tensor([10, 20, 30]).unsqueeze(1)

    out = bootstrapped_returns_from_segment(D, last_Vs, 1.0)
    assert out.shape == (3, 5)
    assert np.allclose(out[0], [25, 24, 22, 19, 15])
    assert np.allclose(out[1], [6, 5, 3, 29, 25])
    assert np.allclose(out[2], [1, 14, 12, 9, 5])
    del out

    out = bootstrapped_returns_from_segment(D, last_Vs, 0.1)
    assert out.shape == (3, 5)
    assert np.allclose(out[0], [1.2346, 2.346, 3.46, 4.6, 6])
    assert np.allclose(out[1], [1.23, 2.3, 3, 4.7, 7])
    assert np.allclose(out[2], [1, 2.345, 3.45, 4.5, 5])

    with pytest.raises(AssertionError):
        bootstrapped_returns_from_episode(D, last_Vs, 0.1)
def test_returns_from_segment():
    env = make_vec_env(SerialVecEnv, make_gym_env, 'CartPole-v1', 3, 0)
    env_spec = EnvSpec(env)

    D = BatchSegment(env_spec, 5)
    D.r[0] = [1, 2, 3, 4, 5]
    D.done[0] = [False, False, False, False, False]
    D.r[1] = [1, 2, 3, 4, 5]
    D.done[1] = [False, False, True, False, False]
    D.r[2] = [1, 2, 3, 4, 5]
    D.done[2] = [True, False, False, False, True]

    out = returns_from_segment(D, 1.0)
    assert out.shape == (3, 5)
    assert np.allclose(out[0], [15, 14, 12, 9, 5])
    assert np.allclose(out[1], [6, 5, 3, 9, 5])
    assert np.allclose(out[2], [1, 14, 12, 9, 5])
    del out

    out = returns_from_segment(D, 0.1)
    assert out.shape == (3, 5)
    assert np.allclose(out[0], [1.2345, 2.345, 3.45, 4.5, 5])
    assert np.allclose(out[1], [1.23, 2.3, 3, 4.5, 5])
    assert np.allclose(out[2], [1, 2.345, 3.45, 4.5, 5])

    with pytest.raises(AssertionError):
        returns_from_episode(D, 0.1)
def test_gae_from_segment():
    env = make_vec_env(SerialVecEnv, make_gym_env, 'CartPole-v1', 3, 0)
    env_spec = EnvSpec(env)

    D = BatchSegment(env_spec, 5)
    D.r[0] = [1, 2, 3, 4, 5]
    D.done[0] = [False, False, False, False, False]
    D.r[1] = [1, 2, 3, 4, 5]
    D.done[1] = [False, False, True, False, False]
    D.r[2] = [1, 2, 3, 4, 5]
    D.done[2] = [True, False, False, False, True]

    all_Vs = [
        torch.tensor([[0.1], [0.5], [1.0]]),
        torch.tensor([[1.1], [1.5], [2.0]]),
        torch.tensor([[2.1], [2.5], [3.0]]),
        torch.tensor([[3.1], [3.5], [4.0]]),
        torch.tensor([[4.1], [4.5], [5.0]])
    last_Vs = torch.tensor([10, 20, 30]).unsqueeze(1)

    all_Vs = torch.stack(all_Vs, 1)

    out = gae_from_segment(D, all_Vs, last_Vs, 1.0, 0.5)
    assert out.shape == (3, 5)
    assert np.allclose(out[0], [5.80625, 7.6125, 9.225, 10.45, 10.9])
    assert np.allclose(out[1], [3.625, 3.25, 0.5, 15.25, 20.5])
    assert np.allclose(out[2], [0, 6.25, 6.5, 5, 0])
    del out

    out = gae_from_segment(D, all_Vs, last_Vs, 0.1, 0.2)
    assert out.shape == (3, 5)
    assert np.allclose(out[0], [1.03269478, 1.1347393, 1.23696, 1.348, 1.9])
    assert np.allclose(out[1], [0.6652, 0.76, 0.5, 1, 2.5])
    assert np.allclose(out[2], [0, 0.3082, 0.41, 0.5, 0])

    with pytest.raises(AssertionError):
        gae_from_episode(D, all_Vs, last_Vs, 0.1, 0.2)
def test_td0_error_from_segment():
    env = make_vec_env(SerialVecEnv, make_gym_env, 'CartPole-v1', 3, 0)
    env_spec = EnvSpec(env)

    D = BatchSegment(env_spec, 5)
    D.r[0] = [1, 2, 3, 4, 5]
    D.done[0] = [False, False, False, False, False]
    D.r[1] = [1, 2, 3, 4, 5]
    D.done[1] = [False, False, True, False, False]
    D.r[2] = [1, 2, 3, 4, 5]
    D.done[2] = [True, False, False, False, True]

    all_Vs = [
        torch.tensor([[0.1], [0.5], [1.0]]),
        torch.tensor([[1.1], [1.5], [2.0]]),
        torch.tensor([[2.1], [2.5], [3.0]]),
        torch.tensor([[3.1], [3.5], [4.0]]),
        torch.tensor([[4.1], [4.5], [5.0]])
    last_Vs = torch.tensor([10, 20, 30]).unsqueeze(1)

    all_Vs = torch.stack(all_Vs, 1)

    out = td0_error_from_segment(D, all_Vs, last_Vs, 1.0)
    assert out.shape == (3, 5)
    assert np.allclose(out[0], [2.0, 3, 4, 5, 10.9])
    assert np.allclose(out[1], [2, 3, 0.5, 5, 20.5])
    assert np.allclose(out[2], [0, 3, 4, 5, 0])
    del out

    out = td0_error_from_segment(D, all_Vs, last_Vs, 0.1)
    assert out.shape == (3, 5)
    assert np.allclose(out[0], [1.01, 1.11, 1.21, 1.31, 1.9])
    assert np.allclose(out[1], [0.65, 0.75, 0.5, 0.95, 2.5])
    assert np.allclose(out[2], [0, 0.3, 0.4, 0.5, 0])

    with pytest.raises(AssertionError):
        td0_error_from_episode(D, all_Vs, last_Vs, 0.1)
def test_batch_segment(vec_env, env_id):
    env = make_vec_env(vec_env, make_gym_env, env_id, 3, 0)
    env_spec = EnvSpec(env)

    T = 30

    D = BatchSegment(env_spec, T)

    if env_id == 'CartPole-v1':
        sticky_action = 1
        action_shape = ()
        action_dtype = np.int32
    elif env_id == 'Pendulum-v0':
        sticky_action = [0.1]
        action_shape = env_spec.action_space.shape
        action_dtype = np.float32

    obs = env.reset()
    D.add_observation(0, obs)
    for t in range(T):
        action = [sticky_action] * env.num_env
        obs, reward, done, info = env.step(action)
        D.add_observation(t + 1, obs)
        D.add_action(t, action)
        D.add_reward(t, reward)
        D.add_done(t, done)
        D.add_batch_info({'V': [0.1 * (t + 1), (t + 1), 10 * (t + 1)]})

    assert D.N == 3
    assert D.T == T
    assert all([
        isinstance(x, np.ndarray) for x in [
            D.numpy_observations, D.numpy_actions, D.numpy_rewards,
            D.numpy_dones, D.numpy_masks
    assert all([
        x.dtype == np.float32
        for x in [D.numpy_observations, D.numpy_rewards, D.numpy_masks]
    assert D.numpy_actions.dtype == action_dtype
    assert D.numpy_dones.dtype == np.bool
    assert D.numpy_observations.shape[:2] == (3, T + 1)
    assert D.numpy_actions.shape == (3, T) + action_shape
    assert all([
        x.shape == (3, T)
        for x in [D.numpy_rewards, D.numpy_dones, D.numpy_masks]
    assert isinstance(D.batch_infos, list) and len(D.batch_infos) == T
    assert np.allclose([0.1 * (x + 1) for x in range(T)],
                       [info['V'][0] for info in D.batch_infos])
    assert np.allclose([1 * (x + 1) for x in range(T)],
                       [info['V'][1] for info in D.batch_infos])
    assert np.allclose([10 * (x + 1) for x in range(T)],
                       [info['V'][2] for info in D.batch_infos])

    seeder = Seeder(0)
    seed1, seed2, seed3 = seeder(3)
    env1 = make_gym_env(env_id, seed1)
    env2 = make_gym_env(env_id, seed2)
    env3 = make_gym_env(env_id, seed3)

    for n, ev in enumerate([env1, env2, env3]):
        obs = ev.reset()
        assert np.allclose(obs, D.numpy_observations[n, 0, ...])
        for t in range(T):
            obs, reward, done, info = ev.step(sticky_action)
            if done:
                info['terminal_observation'] = obs
                obs = ev.reset()

            assert np.allclose(obs, D.numpy_observations[n, t + 1, ...])
            assert np.allclose(sticky_action, D.numpy_actions[n, t, ...])
            assert np.allclose(reward, D.numpy_rewards[n, t])
            assert done == D.numpy_dones[n, t]
            assert int(not done) == D.numpy_masks[n, t]

            if done:
                assert np.allclose(info['terminal_observation'],
    def __call__(self, T, reset=False):
        D = BatchSegment(self.env_spec, T)

        if self.obs_buffer is None or reset:
            obs = self.env.reset()
            # reset agent: e.g. RNN states because initial observation
            obs = self.obs_buffer
        D.add_observation(0, obs)

        for t in range(T):
            info = {}
            out_agent = self.agent.choose_action(obs, info=info)

            action = out_agent.pop('action')
            if torch.is_tensor(action):
                raw_action = list(action.detach().cpu().numpy())
                raw_action = action
            D.add_action(t, raw_action)

            obs, reward, done, info = self.env.step(raw_action)
            D.add_observation(t + 1, obs)
            D.add_reward(t, reward)
            D.add_done(t, done)

            # Record other information: e.g. log-probability of action, policy entropy

        self.obs_buffer = obs
        self.done_buffer = done

        return D