コード例 #1
0
    def test_flexibility(self):
        """
        Test that the model is flexible enough to
        represent our expected result
        """

        ant_states = [AntState(position) for position in range(10)]
        ant_states = StateList(ant_states)
        # Finding Home Targets
        targets0 = np.array([
            [7, 8, 0, 10, 9, 8, 7, 6, 5, 4],  # Left
            [9, 10, 0, 8, 7, 6, 5, 4, 3, 2]
        ]).T  # Right
        # Finding Food Targets
        targets1 = np.array([
            [1, 2, 3, 4, 5, 6, 7, 8, 0, 10],  # Left
            [3, 4, 5, 6, 7, 8, 9, 10, 0, 9]
        ]).T  # Right
        targets = [targets0, targets1]

        value_function = AntActionValueFunction()
        value_function.vectorized_fit(ant_states, targets, epochs=200)

        score = value_function.evaluate(ant_states, targets)
        self.assertLess(score[0], 1.0)
コード例 #2
0
    def test___call__(self):
        state = AntState()
        int_ext_state = IntExtState(0, state)
        value_function = AntActionValueFunction()

        value = value_function(int_ext_state)

        self.assertIsInstance(value, np.ndarray)
        self.assertEqual(2, len(value))
コード例 #3
0
    def calculate_action_values(self):

        result = []
        for internal_state in (0, 1):
            action_values = []
            for _position in range(10):
                _external_state = AntState(position=_position)
                _state = IntExtState(internal_state, _external_state)
                avs = self.action_value_function(_state)
                action_values.append(avs)
            result.append(np.r_[action_values])
        return result
コード例 #4
0
def calculate_greedy_actions(ant_world):
    rows = []
    null_cells = [(0, 2), (1, 8)]
    for internal_state in (0, 1):
        row = []
        for _position in range(10):
            _external_state = AntState(position=_position)
            _state = IntExtState(internal_state, _external_state)
            _action = ant_world.choose_action(_state)
            if (internal_state, _position) in null_cells:
                row.append('x')
            else:
                row.append(ACTION_STRINGS[_action])
        rows.append(''.join(row))

    row0 = 'FINDING HOME : %s' % rows[0]
    row1 = 'FINDING FOOD : %s' % rows[1]
    return row0 + '\n' + row1
コード例 #5
0
    def test_something(self):

        world = AntWorld()

        learner = build_learner(world,
                                calculator_type='modelbasedstatemachine')

        states_list = StateList(
            [AntState(position=position) for position in range(10)])

        for _ in range(1000):
            learner.learn(states_list, epochs=1)

        greedy_actions = calculate_greedy_actions(world)

        print(greedy_actions)
        expected_greedy_actions = r"""
FINDING HOME : >>x<<<<<<<
FINDING FOOD : >>>>>>>>x<
""".strip()
        self.assertEqual(expected_greedy_actions, greedy_actions)

        action_values = world.calculate_action_values()
        print(action_values)

        expected_action_values = [
            np.array([[8, 9], [8, 10], [np.nan, np.nan], [10, 8], [9, 7],
                      [8, 6], [7, 5], [6, 5], [5, 4], [4, 3]]),
            np.array([[6, 8], [7, 9], [8, 10], [9, 11], [10, 12], [11, 13],
                      [12, 14], [13, 15], [np.nan, np.nan], [15, 14]]),
        ]

        e0 = expected_action_values[0]
        r0 = action_values[0]
        r0[np.isnan(e0)] = np.nan
        np.testing.assert_array_almost_equal(e0, r0, decimal=0)

        e1 = expected_action_values[1]
        r1 = action_values[1]
        r1[np.isnan(e1)] = np.nan
        np.testing.assert_array_almost_equal(e1, r1, decimal=0)
コード例 #6
0
from rl.core.learning.learner import build_learner
from rl.core.state import StateList
from rl.environments.line_world.rl_system import AntWorld
from rl.environments.line_world.rl_system import calculate_greedy_actions
from rl.environments.line_world.state import AntState
from rl.environments.line_world.constants import HOME_POSITION, FOOD_POSITION, FINDING_FOOD, FINDING_HOME

world = AntWorld()
learner = build_learner(world, calculator_type='modelbasedstatemachine')

states = [AntState(position=position) for position in range(10)]
states_list = StateList(states)

initial_greedy_actions = calculate_greedy_actions(world)

for _ in range(500):
    learner.learn(states_list, epochs=1)

greedy_actions = calculate_greedy_actions(world)

print('Initial Greedy Actions (should be random):')
print(initial_greedy_actions)

print(
    'Optimised Greedy Actions (should point at home(%s) and food(%s) positions):'
    % (HOME_POSITION, FOOD_POSITION))
print(greedy_actions)

action_values = world.calculate_action_values()

print('Home is at position %s' % HOME_POSITION)
コード例 #7
0
        self.model = model

    def evaluate(self, states, targets, **kwargs):
        return self.model.evaluate(states.as_array(), targets, **kwargs)

    def vectorized_fit(self, states, targets, **kwargs):
        x = states.as_array()
        return self.model.fit(x, targets, **kwargs)

    def scalar_fit(self, states, actions, rewards, **kwargs):
        pass


if __name__ == '__main__':
    from rl.environments.line_world.state import AntState
    from rl.core.state import IntExtState
    value_function = AntActionValueFunction()

    state = IntExtState(0, AntState(position=1))

    print(value_function(state))

    state = IntExtState(1, AntState(position=1))

    print(value_function(state))

    print(value_function.model.predict(state.external_state.as_array().reshape((1, 11))))



コード例 #8
0
            new_state.external_state.num_homecomings += 1

        return new_state

    def apply_movement(self, state, action):

        new_state = IntExtState(state.internal_state,
                                state.external_state.copy())
        external_state = state.external_state

        if action == MOVE_RIGHT:
            new_state.external_state.position = min(
                9, external_state.position + 1)
        elif action == MOVE_LEFT:
            new_state.external_state.position = max(
                0, external_state.position - 1)

        return new_state

    def is_terminal(self, state):
        return state.external_state.num_homecomings >= self.max_homecomings


if __name__ == '__main__':
    from rl.environments.line_world.state import AntState

    model = AntModel()
    ant_state = AntState()

    print(ant_state)