def test_epsilon_greedy(self):
        self.action_value[self.mock_state1, self.mock_action1] = 1.
        self.action_value[self.mock_state1, self.mock_action2] = 2.
        self.action_value[self.mock_state1, self.mock_action3] = 3.

        self.assertEqual(
            self.policy.epsilon_greedy(self.mock_state1,
                                       MockActionSpace({
                                           self.mock_action1,
                                           self.mock_action2, self.mock_action3
                                       }),
                                       epsilon=0), MockAction([5, 6, 7]))

        self.action_value[MockState([1, 2, 3, 4, 5]), self.mock_action2] = 10.

        self.assertEqual(
            self.policy.epsilon_greedy(self.mock_state1,
                                       MockActionSpace({
                                           self.mock_action1,
                                           self.mock_action2, self.mock_action3
                                       }),
                                       epsilon=0), MockAction([2, 3, 4]))

        self.assertEqual(
            self.policy.epsilon_greedy(self.mock_state1,
                                       MockActionSpace({self.mock_action1}),
                                       epsilon=0), MockAction([1, 2, 3]))

        self.assertEqual(
            self.policy.epsilon_greedy(self.mock_state2,
                                       MockActionSpace({self.mock_action3}),
                                       epsilon=0), MockAction([5, 6, 7]))
Пример #2
0
    def setUp(self):
        self.action_value = LinearRegressionActionValue(
            init_weights_zeros=True)

        self.mock_state1 = MockState(
            np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
        self.mock_action1 = MockAction(1, 2)
        self.mock_action2 = MockAction(5, 6)
Пример #3
0
    def setUp(self):
        self.action_value = LazyTabularActionValue()
        self.mock_state1 = MockState(np.array([1, 2, 3, 4, 5]))
        self.mock_state2 = MockState(np.array([6, 5, 4, 3]))
        self.mock_action1 = MockAction(1, 2)
        self.mock_action2 = MockAction(5, 6)
        self.mock_action3 = MockAction(7, 8)

        self.mock_state1_copy = MockState(np.array([1, 2, 3, 4, 5]))
    def setUp(self):
        self.action_value = LazyTabularActionValue()
        self.policy = ActionValueDerivedPolicy(self.action_value)

        self.mock_state1 = MockState([1, 2, 3, 4, 5])
        self.mock_state2 = MockState([6, 5, 4, 3])
        self.mock_action1 = MockAction([1, 2, 3])
        self.mock_action2 = MockAction([2, 3, 4])
        self.mock_action3 = MockAction([5, 6, 7])
        self.mock_action4 = MockAction([5, 6, 8])

        self.mock_state1_copy = MockState([1, 2, 3, 4, 5])
Пример #5
0
        return graph

    def _repr_svg_(self):
        return self._get_graph()._repr_svg_()

    def view(self):
        return self._get_graph().view()


if __name__ == '__main__':
    # Model test

    m = StochasticModel()

    s1 = MockState([[-1, 0], [-1, -1]])
    a1 = MockAction([1, 0])
    s2 = MockState([[-1, 0], [0, -1]])
    s3 = MockState([[-1, 0], [-1, 0]])
    m[s1, a1] = s2
    m[s1, a1] = s3
    m[s1, a1] = s3
    m[s1, a1] = s3
    m[s1, a1] = s3

    a2 = MockAction([0, 0])
    s4 = MockState([[0, 0], [0, -1]])
    m[s2, a2] = s4

    a4 = MockAction([1, 1])
    s5 = MockState([[-1, 0], [0, 0]])
    s6 = MockState([[-1, 0], [-1, 0]])
Пример #6
0
                           str(hash(greedy_action)),
                           label=str(greedy_action_value),
                           color="red",
                           penwidth="2")
        return graph

    def _repr_svg_(self):
        return self._get_graph()._repr_svg_()

    def view(self):
        return self._get_graph().view()


if __name__ == '__main__':
    av = LazyTabularActionValue()

    s = MockState([[-1, -1], [-1, 1]])

    a1 = MockAction([0, 0])
    a2 = MockAction([0, 1])
    a3 = MockAction([1, 0])

    av[s, a1] = 6
    av[s, a2] = 2.9
    av[s, a3] = -10

    egp = ActionValueDerivedPolicy(av)

    print(egp)
    # egp.view()