示例#1
0
def test_higher_level_step():
    """Tests environment for higher level steps correctly"""
    hiro_agent = HIRO(config)
    ll_env = hiro_agent.lower_level_agent.environment
    h_env = hiro_agent.higher_level_agent.environment
    h_env.reset()
    # HIRO.goal_transition = lambda x, y, z: y
    state_before = hiro_agent.higher_level_state
    assert hiro_agent.higher_level_next_state is None
    next_state, reward, done, _ = h_env.step(np.array([-1.0, 2.0, 3.0]))

    assert np.allclose(
        hiro_agent.goal,
        HIRO.goal_transition(state_before, np.array([-1.0, 2.0, 3.0]),
                             next_state))

    assert all(hiro_agent.higher_level_state == next_state)
    assert all(hiro_agent.higher_level_next_state == next_state)
    assert hiro_agent.higher_level_reward == reward
    assert hiro_agent.higher_level_done == done

    assert next_state.shape[0] == 3
    assert isinstance(reward, float)
    assert not done

    for _ in range(200):
        next_state, reward, done, _ = h_env.step(np.array([-1.0, 2.0, 3.0]))
        assert all(hiro_agent.higher_level_next_state == next_state)
        assert all(hiro_agent.higher_level_next_state == next_state)
        assert hiro_agent.higher_level_reward == reward
        assert hiro_agent.higher_level_done == done
示例#2
0
def test_goal_transition():
    """Tests environment does goal transitions properly"""
    hiro_agent.higher_level_state = 2
    hiro_agent.goal = 9
    next_state = 3
    assert HIRO.goal_transition(hiro_agent.higher_level_state, hiro_agent.goal,
                                next_state) == 8

    hiro_agent.higher_level_state = 2
    hiro_agent.goal = 9
    next_state = 3
    ll_env.update_goal(next_state)
    assert hiro_agent.goal == 8

    h_env.reset()
    hiro_agent.goal = np.array([2.0, 4.0, -3.0])
    hiro_agent.higher_level_reward = 0
    ll_env.reset()
    state = hiro_agent.higher_level_state
    next_state, reward, done, _ = ll_env.step(np.array([random.random()]))
    assert all(hiro_agent.goal == state + np.array([2.0, 4.0, -3.0]) -
               next_state[0:3])