def test_update_q_dqn(self,
                          mock_parameters,
                          mock_replay_memory):
        self._setup_parameters(mock_parameters.return_value)
        self._setup_replay_memory(mock_replay_memory.return_value)

        action_space = spaces.Discrete(2)
        observation_space = spaces.Box(0, 1, (1,))
        sut = QLearning('', observation_space, action_space)

        sut._q.eval = \
            MagicMock(return_value=np.array([[[0.2, 0.1]]], np.float32))
        sut._target_q.eval = \
            MagicMock(return_value=np.array([[[0.3, 0.4]]], np.float32))
        sut._trainer = MagicMock()

        sut._update_q_periodically()

        np.testing.assert_array_equal(
            sut._trainer.train_minibatch.call_args[0][0][sut._input_variables],
            [np.array([0.1], np.float32)])
        # 10 (reward) + 0.9 (gamma) x 0.4 (max q_target) -> update action 0
        np.testing.assert_array_equal(
            sut._trainer.train_minibatch.call_args[0][0][sut._output_variables],
            [np.array([10.36, 0.1], np.float32)])
    def test_update_q_dqn_prioritized_replay(self,
                                             mock_parameters,
                                             mock_replay_memory):
        self._setup_parameters(mock_parameters.return_value)
        mock_parameters.return_value.use_prioritized_replay = True
        self._setup_prioritized_replay_memory(mock_replay_memory.return_value)

        action_space = spaces.Discrete(2)
        observation_space = spaces.Box(0, 1, (1,))
        sut = QLearning('', observation_space, action_space)

        def new_q_value(self):
            return np.array([[[0.2, 0.1]]], np.float32)
        sut._q.eval = MagicMock(side_effect=new_q_value)
        sut._target_q.eval = MagicMock(
            return_value=np.array([[[0.3, 0.4]]], np.float32))
        sut._trainer = MagicMock()

        sut._update_q_periodically()

        self.assertEqual(sut._trainer.train_minibatch.call_count, 1)
        np.testing.assert_array_equal(
            sut._trainer.train_minibatch.call_args[0][0][sut._input_variables],
            [
                np.array([0.1], np.float32),
                np.array([0.3], np.float32),
                np.array([0.1], np.float32)
            ])
        np.testing.assert_array_equal(
            sut._trainer.train_minibatch.call_args[0][0][sut._output_variables],
            [
                # 10 (reward) + 0.9 (gamma) x 0.4 (max q_target)
                np.array([10.36, 0.1], np.float32),
                # 11 (reward) + 0.9 (gamma) x 0.4 (max q_target)
                np.array([0.2, 11.36], np.float32),
                np.array([10.36, 0.1], np.float32)
            ])
        np.testing.assert_almost_equal(
            sut._trainer.train_minibatch.call_args[0][0][sut._weight_variables],
            [
                [0.16666667],
                [0.66666667],
                [0.16666667]
            ])
        self.assertAlmostEqual(
            sut._replay_memory.update_priority.call_args[0][0][3],
            105.2676)  # (10.16 + 0.1)^2
        self.assertAlmostEqual(
            sut._replay_memory.update_priority.call_args[0][0][4],
            129.0496,
            places=6)  # (11.26 + 0.1) ^ 2
    def test_replay_start_size(self, mock_parameters):
        self._setup_parameters(mock_parameters.return_value)
        # Set exploration rate to 0
        mock_parameters.return_value.initial_epsilon = 0
        mock_parameters.return_value.epsilon_decay_step_count = 100
        mock_parameters.return_value.epsilon_minimum = 0
        mock_parameters.return_value.replay_start_size = 3

        action_space = spaces.Discrete(2)
        observation_space = spaces.Box(0, 1, (1,))
        sut = QLearning('', observation_space, action_space)
        sut._trainer = MagicMock()
        sut._replay_memory = MagicMock()

        _, debug = sut.start(np.array([0.1], np.float32))
        self.assertEqual(sut.step_count, 0)
        self.assertEqual(sut._trainer.train_minibatch.call_count, 0)
        self.assertEqual(debug['action_behavior'], 'RANDOM')

        _, debug = sut.step(0.1, np.array([0.2], np.float32))
        self.assertEqual(sut.step_count, 1)
        self.assertEqual(sut._trainer.train_minibatch.call_count, 0)
        self.assertEqual(debug['action_behavior'], 'RANDOM')

        sut.end(0.2, np.array([0.3], np.float32))
        self.assertEqual(sut.step_count, 2)
        self.assertEqual(sut._trainer.train_minibatch.call_count, 0)

        _, debug = sut.start(np.array([0.4], np.float32))
        self.assertEqual(sut.step_count, 2)
        self.assertEqual(sut._trainer.train_minibatch.call_count, 0)
        self.assertEqual(debug['action_behavior'], 'RANDOM')

        a, debug = sut.step(0.3, np.array([0.5], np.float32))
        self.assertEqual(sut.step_count, 3)
        self.assertEqual(sut._trainer.train_minibatch.call_count, 0)
        self.assertEqual(debug['action_behavior'], 'GREEDY')

        a, debug = sut.start(np.array([0.6], np.float32))
        self.assertEqual(sut.step_count, 3)
        self.assertEqual(sut._trainer.train_minibatch.call_count, 0)
        self.assertEqual(debug['action_behavior'], 'GREEDY')

        a, debug = sut.step(0.4, np.array([0.7], np.float32))
        self.assertEqual(sut.step_count, 4)
        self.assertEqual(sut._trainer.train_minibatch.call_count, 1)
        self.assertEqual(debug['action_behavior'], 'GREEDY')