Ejemplo n.º 1
0
def test_continuous_action_prediction(behavior_spec: BehaviorSpec, seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    curiosity_settings = CuriositySettings(32, 0.1)
    curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings)
    buffer = create_agent_buffer(behavior_spec, 5)
    for _ in range(200):
        curiosity_rp.update(buffer)
    prediction = curiosity_rp._network.predict_action(buffer)[0]
    target = torch.tensor(buffer["actions"][0])
    error = torch.mean((prediction - target) ** 2).item()
    assert error < 0.001
Ejemplo n.º 2
0
def test_next_state_prediction(behavior_spec: BehaviorSpec, seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    curiosity_settings = CuriositySettings(32, 0.1)
    curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings)
    buffer = create_agent_buffer(behavior_spec, 5)
    for _ in range(100):
        curiosity_rp.update(buffer)
    prediction = curiosity_rp._network.predict_next_state(buffer)[0]
    target = curiosity_rp._network.get_next_state(buffer)[0]
    error = float(ModelUtils.to_numpy(torch.mean((prediction - target) ** 2)))
    assert error < 0.001
Ejemplo n.º 3
0
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    curiosity_settings = CuriositySettings(32, 0.01)
    curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings)
    buffer = create_agent_buffer(behavior_spec, 5)
    curiosity_rp.update(buffer)
    reward_old = curiosity_rp.evaluate(buffer)[0]
    for _ in range(20):
        curiosity_rp.update(buffer)
        reward_new = curiosity_rp.evaluate(buffer)[0]
    assert reward_new < reward_old
Ejemplo n.º 4
0
def test_construction(behavior_spec: BehaviorSpec) -> None:
    curiosity_settings = CuriositySettings(32, 0.01)
    curiosity_settings.strength = 0.1
    curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings)
    assert curiosity_rp.strength == 0.1
    assert curiosity_rp.name == "Curiosity"