示例#1
0
    def test_update_q(self, mock_parameters, mock_replay_memory):
        """Test if _update_q_periodically() can finish successfully."""
        self._setup_parameters(mock_parameters.return_value)
        self._setup_replay_memory(mock_replay_memory.return_value)

        action_space = spaces.Discrete(2)
        observation_space = spaces.Box(0, 1, (1, ))
        sut = QLearning('', observation_space, action_space)
        sut._trainer.train_minibatch = MagicMock()
        sut._choose_action = MagicMock(side_effect=[
            (1, 'GREEDY'),
            (0, 'GREEDY'),
            (1, 'RANDOM'),
        ])

        action, debug_info = sut.start(np.array([0.1], np.float32))
        self.assertEqual(action, 1)
        self.assertEqual(debug_info['action_behavior'], 'GREEDY')
        self.assertEqual(sut.episode_count, 1)
        self.assertEqual(sut.step_count, 0)
        self.assertEqual(sut._epsilon, 0.1)
        self.assertEqual(sut._trainer.parameter_learners[0].learning_rate(),
                         0.1)
        self.assertEqual(sut._last_state, np.array([0.1], np.float32))
        self.assertEqual(sut._last_action, 1)

        action, debug_info = sut.step(1, np.array([0.2], np.float32))
        self.assertEqual(action, 0)
        self.assertEqual(debug_info['action_behavior'], 'GREEDY')
        self.assertEqual(sut.episode_count, 1)
        self.assertEqual(sut.step_count, 1)
        self.assertEqual(sut._epsilon, 0.09)
        # learning rate remains 0.1 as Q is not updated during this time step.
        self.assertEqual(sut._trainer.parameter_learners[0].learning_rate(),
                         0.1)
        self.assertEqual(sut._last_state, np.array([0.2], np.float32))
        self.assertEqual(sut._last_action, 0)

        action, debug_info = sut.step(2, np.array([0.3], np.float32))
        self.assertEqual(action, 1)
        self.assertEqual(debug_info['action_behavior'], 'RANDOM')
        self.assertEqual(sut.episode_count, 1)
        self.assertEqual(sut.step_count, 2)
        self.assertEqual(sut._epsilon, 0.08)
        self.assertEqual(sut._trainer.parameter_learners[0].learning_rate(),
                         0.08)
        self.assertEqual(sut._last_state, np.array([0.3], np.float32))
        self.assertEqual(sut._last_action, 1)

        sut.end(3, np.array([0.4], np.float32))
        self.assertEqual(sut.episode_count, 1)
        self.assertEqual(sut.step_count, 3)
        self.assertEqual(sut._epsilon, 0.08)
        # learning rate remains 0.08 as Q is not updated during this time step.
        self.assertEqual(sut._trainer.parameter_learners[0].learning_rate(),
                         0.08)
示例#2
0
    def test_update_q(self,
                      mock_parameters,
                      mock_replay_memory):
        """Test if _update_q_periodically() can finish successfully."""
        self._setup_parameters(mock_parameters.return_value)
        self._setup_replay_memory(mock_replay_memory.return_value)

        action_space = spaces.Discrete(2)
        observation_space = spaces.Box(0, 1, (1,))
        sut = QLearning('', observation_space, action_space)
        sut._trainer.train_minibatch = MagicMock()
        sut._choose_action = MagicMock(side_effect=[
            (1, 'GREEDY'),
            (0, 'GREEDY'),
            (1, 'RANDOM'),
        ])

        action, debug_info = sut.start(np.array([0.1], np.float32))
        self.assertEqual(action, 1)
        self.assertEqual(debug_info['action_behavior'], 'GREEDY')
        self.assertEqual(sut.episode_count, 1)
        self.assertEqual(sut.step_count, 0)
        self.assertEqual(sut._epsilon, 0.1)
        self.assertEqual(sut._trainer.parameter_learners[0].learning_rate(), 0.1)
        self.assertEqual(sut._last_state, np.array([0.1], np.float32))
        self.assertEqual(sut._last_action, 1)

        action, debug_info = sut.step(1, np.array([0.2], np.float32))
        self.assertEqual(action, 0)
        self.assertEqual(debug_info['action_behavior'], 'GREEDY')
        self.assertEqual(sut.episode_count, 1)
        self.assertEqual(sut.step_count, 1)
        self.assertEqual(sut._epsilon, 0.09)
        # learning rate remains 0.1 as Q is not updated during this time step.
        self.assertEqual(sut._trainer.parameter_learners[0].learning_rate(), 0.1)
        self.assertEqual(sut._last_state, np.array([0.2], np.float32))
        self.assertEqual(sut._last_action, 0)

        action, debug_info = sut.step(2, np.array([0.3], np.float32))
        self.assertEqual(action, 1)
        self.assertEqual(debug_info['action_behavior'], 'RANDOM')
        self.assertEqual(sut.episode_count, 1)
        self.assertEqual(sut.step_count, 2)
        self.assertEqual(sut._epsilon, 0.08)
        self.assertEqual(sut._trainer.parameter_learners[0].learning_rate(), 0.08)
        self.assertEqual(sut._last_state, np.array([0.3], np.float32))
        self.assertEqual(sut._last_action, 1)

        sut.end(3, np.array([0.4], np.float32))
        self.assertEqual(sut.episode_count, 1)
        self.assertEqual(sut.step_count, 3)
        self.assertEqual(sut._epsilon, 0.08)
        # learning rate remains 0.08 as Q is not updated during this time step.
        self.assertEqual(sut._trainer.parameter_learners[0].learning_rate(), 0.08)
    def test_populate_replay_memory(self, mock_parameters):
        self._setup_parameters(mock_parameters.return_value)
        mock_parameters.return_value.preprocessing = \
            'cntk.contrib.deeprl.agent.shared.preprocessing.SlidingWindow'
        mock_parameters.return_value.preprocessing_args = '(2, )'

        action_space = spaces.Discrete(2)
        observation_space = spaces.Box(0, 1, (1,))
        sut = QLearning('', observation_space, action_space)

        sut._compute_priority = Mock(side_effect=[1, 2, 3])
        sut._choose_action = Mock(
            side_effect=[(0, ''), (0, ''), (1, ''), (1, '')])
        sut._replay_memory = MagicMock()
        sut._update_q_periodically = MagicMock()

        sut.start(np.array([0.1], np.float32))
        sut.step(0.1, np.array([0.2], np.float32))
        sut.step(0.2, np.array([0.3], np.float32))
        sut.end(0.3, np.array([0.4], np.float32))

        self.assertEqual(sut._replay_memory.store.call_count, 3)

        call_args = sut._replay_memory.store.call_args_list[0]
        np.testing.assert_array_equal(
            call_args[0][0],
            np.array([[0], [0.1]], np.float32))
        self.assertEqual(call_args[0][1], 0)
        self.assertEqual(call_args[0][2], 0.1)
        np.testing.assert_array_equal(
            call_args[0][3],
            np.array([[0.1], [0.2]], np.float32))
        self.assertEqual(call_args[0][4], 1)

        call_args = sut._replay_memory.store.call_args_list[2]
        np.testing.assert_array_equal(
            call_args[0][0],
            np.array([[0.2], [0.3]], np.float32))
        self.assertEqual(call_args[0][1], 1)
        self.assertEqual(call_args[0][2], 0.3)
        self.assertIsNone(call_args[0][3])
        self.assertEqual(call_args[0][4], 3)