Exemplo n.º 1
0
def test_take_action_returns_nones_on_missing_values():
    test_seed = 3
    policy = TFPolicy(test_seed, basic_mock_brain(), basic_params())
    policy.evaluate = MagicMock(return_value={})
    brain_info_with_agents = BrainInfo([], [], [], agents=["an-agent-id"])
    result = policy.get_action(brain_info_with_agents)
    assert result == ActionInfo(None, None, None, None, {})
Exemplo n.º 2
0
def test_take_action_returns_nones_on_missing_values():
    test_seed = 3
    policy = TFPolicy(test_seed, basic_mock_brain(), basic_params())
    policy.evaluate = MagicMock(return_value={})
    policy.save_memories = MagicMock()
    step_with_agents = BatchedStepResult(
        [],
        np.array([], dtype=np.float32),
        np.array([False], dtype=np.bool),
        np.array([], dtype=np.bool),
        np.array([0]),
        None,
    )
    result = policy.get_action(step_with_agents, worker_id=0)
    assert result == ActionInfo(None, None, {}, [0])
Exemplo n.º 3
0
def test_take_action_returns_action_info_when_available():
    test_seed = 3
    policy = TFPolicy(test_seed, basic_mock_brain(), basic_params())
    policy_eval_out = {
        "action": np.array([1.0], dtype=np.float32),
        "memory_out": np.array([[2.5]], dtype=np.float32),
        "value": np.array([1.1], dtype=np.float32),
    }
    policy.evaluate = MagicMock(return_value=policy_eval_out)
    brain_info_with_agents = BrainInfo([], [], [],
                                       agents=["an-agent-id"],
                                       local_done=[False])
    result = policy.get_action(brain_info_with_agents)
    expected = ActionInfo(policy_eval_out["action"], policy_eval_out["value"],
                          policy_eval_out)
    assert result == expected
Exemplo n.º 4
0
def test_take_action_returns_action_info_when_available():
    test_seed = 3
    policy = TFPolicy(test_seed, basic_mock_brain(), basic_params())
    policy_eval_out = {
        "action": np.array([1.0], dtype=np.float32),
        "memory_out": np.array([[2.5]], dtype=np.float32),
        "value": np.array([1.1], dtype=np.float32),
    }
    policy.evaluate = MagicMock(return_value=policy_eval_out)
    step_with_agents = BatchedStepResult(
        [],
        np.array([], dtype=np.float32),
        np.array([False], dtype=np.bool),
        np.array([], dtype=np.bool),
        np.array([0]),
        None,
    )
    result = policy.get_action(step_with_agents)
    expected = ActionInfo(policy_eval_out["action"], policy_eval_out["value"],
                          policy_eval_out, [0])
    assert result == expected