Exemplo n.º 1
0
def mock_bstep():
    """Mock batch stepper class with fixed run_episode_batch return."""
    bstep_cls = mock.create_autospec(batch_steppers.LocalBatchStepper)
    bstep_cls.return_value.run_episode_batch.return_value = (
        testing.construct_episodes(
            actions=[
                [0, 2], [0, 2], [0, 2],  # Three first episodes action 0
                [1, 2], [1, 2], [1, 2],  # Three last episodes action 1
            ],
            rewards=[
                [0, 1], [0, 1], [0, 1],  # Higher mean return, action 0
                [0, 0], [0, 0], [0, 2],  # Higher max return, action 1
            ], truncated=True))
    return bstep_cls
Exemplo n.º 2
0
def test_bootstrap_return_with_quality_estimator(truncated, x_return):
    # Set up
    episode = testing.construct_episodes(actions=[
        [0, 1, 2, 3],
    ],
                                         rewards=[[0, 2, 0, -1]],
                                         truncated=truncated)[0]
    logits = (np.array([[7, 3, 4]]), None)

    # Run
    bootstrap_return = testing.run_with_constant_network_prediction(
        shooting.bootstrap_return_with_quality(episode), logits=logits)

    # Test
    assert bootstrap_return == x_return
Exemplo n.º 3
0
def test_bootstrap_return_with_value_estimator(truncated, x_return):
    # Set up
    episode = testing.construct_episodes(
        actions=[
            [0, 1, 2, 3],
        ],
        rewards=[
            [0, 2, 0, -1]
        ],
        truncated=truncated
    )[0]
    logits = (np.array([[7]]), None)

    # Run
    bootstrap_return = testing.run_with_constant_network_prediction(
        mc_simulation.bootstrap_return_with_value([episode]),
        logits=logits
    )[0]

    # Test
    assert bootstrap_return == x_return