def test_update_policy_and_value_function(self):
        action_space = spaces.Discrete(2)
        observation_space = spaces.Box(0, 1, (1, ))
        sut = ActorCritic('', observation_space, action_space)

        # Set up.
        self._setup_trajectory(sut)
        sut._process_accumulated_trajectory(True)
        sut._trainer = MagicMock()
        sut._adjust_learning_rate = MagicMock()

        # Call test method.
        sut._update_networks()

        # Verify value network behavior.
        self.assertEqual(sut._trainer.train_minibatch.call_count, 1)
        call_args = sut._trainer.train_minibatch.call_args
        np.testing.assert_array_equal(
            call_args[0][0][sut._input_variables],
            [np.array([0.1], np.float32),
             np.array([0.2], np.float32)])
        np.testing.assert_array_almost_equal(
            call_args[0][0][sut._value_network_output_variables],
            [[2.9975], [3.05]])
        np.testing.assert_array_equal(
            call_args[0][0][sut._policy_network_output_variables],
            [np.array([1, 0], np.float32),
             np.array([0, 1], np.float32)])
        np.testing.assert_array_almost_equal(
            call_args[0][0][sut._policy_network_weight_variables],
            [[0.9975], [2.05]])

        # Verify data buffer size.
        self.assertEqual(len(sut._input_buffer), 0)
Exemplo n.º 2
0
    def test_update_policy_and_value_function(self):
        action_space = spaces.Discrete(2)
        observation_space = spaces.Box(0, 1, (1,))
        sut = ActorCritic('', observation_space, action_space)

        # Set up.
        self._setup_trajectory(sut)
        sut._process_accumulated_trajectory(True)
        sut._trainer = MagicMock()
        sut._adjust_learning_rate = MagicMock()

        # Call test method.
        sut._update_networks()

        # Verify value network behavior.
        self.assertEqual(
            sut._trainer.train_minibatch.call_count, 1)
        call_args = sut._trainer.train_minibatch.call_args
        np.testing.assert_array_equal(
            call_args[0][0][sut._input_variables],
            [np.array([0.1], np.float32), np.array([0.2], np.float32)])
        np.testing.assert_array_almost_equal(
            call_args[0][0][sut._value_network_output_variables],
            [[2.9975], [3.05]])
        np.testing.assert_array_equal(
            call_args[0][0][sut._policy_network_output_variables],
            [np.array([1, 0], np.float32), np.array([0, 1], np.float32)])
        np.testing.assert_array_almost_equal(
            call_args[0][0][sut._policy_network_weight_variables],
            [[0.9975], [2.05]])

        # Verify data buffer size.
        self.assertEqual(len(sut._input_buffer), 0)
    def test_rollout_with_update(self, mock_parameters):
        self._setup_parameters(mock_parameters.return_value)
        mock_parameters.return_value.update_frequency = 2

        action_space = spaces.Discrete(2)
        observation_space = spaces.Box(0, 1, (1, ))
        sut = ActorCritic('', observation_space, action_space)
        sut._update_networks = MagicMock()

        sut._choose_action = Mock(
            side_effect=[(0, ''), (1, ''), (1, ''), (0, ''), (1, ''), (0, '')])

        sut.start(np.array([0.1], np.float32))
        sut.step(0.1, np.array([0.2], np.float32))
        self.assertEqual(sut._trajectory_rewards, [0.1])
        self.assertEqual(sut._trajectory_actions, [0, 1])
        self.assertEqual(sut._trajectory_states, [0.1, 0.2])
        self.assertEqual(sut._update_networks.call_count, 0)

        sut.step(0.2, np.array([0.3], np.float32))
        self.assertEqual(sut._trajectory_rewards, [])
        self.assertEqual(sut._trajectory_actions, [1])
        self.assertEqual(sut._trajectory_states, [0.3])
        self.assertEqual(sut._update_networks.call_count, 1)

        sut.step(0.3, np.array([0.4], np.float32))
        self.assertEqual(sut._trajectory_rewards, [0.3])
        self.assertEqual(sut._trajectory_actions, [1, 0])
        self.assertEqual(sut._trajectory_states, [0.3, 0.4])
        self.assertEqual(sut._update_networks.call_count, 1)

        sut.start(np.array([0.5], np.float32))
        self.assertEqual(sut._trajectory_rewards, [])
        self.assertEqual(sut._trajectory_actions, [1])
        self.assertEqual(sut._trajectory_states, [0.5])
        self.assertEqual(sut._update_networks.call_count, 1)

        sut.step(0.4, np.array([0.6], np.float32))
        self.assertEqual(sut._trajectory_rewards, [])
        self.assertEqual(sut._trajectory_actions, [0])
        self.assertEqual(sut._trajectory_states, [0.6])
        self.assertEqual(sut._update_networks.call_count, 2)

        sut.end(0.5, np.array([0.7], np.float32))
        self.assertEqual(sut._trajectory_rewards, [0.5])
        self.assertEqual(sut._trajectory_actions, [0])
        self.assertEqual(sut._trajectory_states, [0.6])
        self.assertEqual(sut._update_networks.call_count, 2)
Exemplo n.º 4
0
    def test_rollout_with_update(self, mock_parameters):
        self._setup_parameters(mock_parameters.return_value)
        mock_parameters.return_value.update_frequency = 2

        action_space = spaces.Discrete(2)
        observation_space = spaces.Box(0, 1, (1,))
        sut = ActorCritic('', observation_space, action_space)
        sut._update_networks = MagicMock()

        sut._choose_action = Mock(side_effect=[
            (0, ''), (1, ''), (1, ''), (0, ''), (1, ''), (0, '')])

        sut.start(np.array([0.1], np.float32))
        sut.step(0.1, np.array([0.2], np.float32))
        self.assertEqual(sut._trajectory_rewards, [0.1])
        self.assertEqual(sut._trajectory_actions, [0, 1])
        self.assertEqual(sut._trajectory_states, [0.1, 0.2])
        self.assertEqual(sut._update_networks.call_count, 0)

        sut.step(0.2, np.array([0.3], np.float32))
        self.assertEqual(sut._trajectory_rewards, [])
        self.assertEqual(sut._trajectory_actions, [1])
        self.assertEqual(sut._trajectory_states, [0.3])
        self.assertEqual(sut._update_networks.call_count, 1)

        sut.step(0.3, np.array([0.4], np.float32))
        self.assertEqual(sut._trajectory_rewards, [0.3])
        self.assertEqual(sut._trajectory_actions, [1, 0])
        self.assertEqual(sut._trajectory_states, [0.3, 0.4])
        self.assertEqual(sut._update_networks.call_count, 1)

        sut.start(np.array([0.5], np.float32))
        self.assertEqual(sut._trajectory_rewards, [])
        self.assertEqual(sut._trajectory_actions, [1])
        self.assertEqual(sut._trajectory_states, [0.5])
        self.assertEqual(sut._update_networks.call_count, 1)

        sut.step(0.4, np.array([0.6], np.float32))
        self.assertEqual(sut._trajectory_rewards, [])
        self.assertEqual(sut._trajectory_actions, [0])
        self.assertEqual(sut._trajectory_states, [0.6])
        self.assertEqual(sut._update_networks.call_count, 2)

        sut.end(0.5, np.array([0.7], np.float32))
        self.assertEqual(sut._trajectory_rewards, [0.5])
        self.assertEqual(sut._trajectory_actions, [0])
        self.assertEqual(sut._trajectory_states, [0.6])
        self.assertEqual(sut._update_networks.call_count, 2)