def testStepTrain(self):
    """Test the functionality of agent.step() in train mode.

    Specifically, the action returned, and confirms training is happening.
    """
    agent = self._create_test_agent()
    agent.eval_mode = False
    base_observation = onp.ones(self.observation_shape + (1,))
    # We mock the replay buffer to verify how the agent interacts with it.
    agent._replay = test_utils.MockReplayBuffer(is_jax=True)
    # This will reset state and choose a first action.
    agent.begin_episode(base_observation)
    expected_state = self.zero_state
    num_steps = 10
    for step in range(1, num_steps + 1):
      # We make observation a multiple of step for testing purposes (to
      # uniquely identify each observation).
      observation = base_observation * step
      self.assertEqual(agent.step(reward=1, observation=observation), 0)
      stack_pos = step - num_steps - 1
      if stack_pos >= -self.stack_size:
        expected_state[:, :, :, stack_pos] = onp.full(
            (1,) + self.observation_shape, step)
    onp.array_equal(agent.state, expected_state)
    onp.array_equal(
        agent._last_observation,
        onp.full(self.observation_shape, num_steps - 1))
    onp.array_equal(agent._observation, observation[:, :, 0])
    # We expect one more than num_steps because of the call to begin_episode.
    self.assertEqual(agent.training_steps, num_steps + 1)
    self.assertEqual(agent._replay.add.call_count, num_steps)
    agent.end_episode(reward=1)
    self.assertEqual(agent._replay.add.call_count, num_steps + 1)
  def testStepEval(self):
    """Tests the functionality of agent.step() in eval mode.

    Specifically, the action returned, and confirms that no training happens.
    """
    agent = self._create_test_agent()
    base_observation = onp.ones(self.observation_shape + (1,))
    # This will reset state and choose a first action.
    agent.begin_episode(base_observation)
    # We mock the replay buffer to verify how the agent interacts with it.
    agent._replay = test_utils.MockReplayBuffer()

    expected_state = self.zero_state
    num_steps = 10
    for step in range(1, num_steps + 1):
      # We make observation a multiple of step for testing purposes (to
      # uniquely identify each observation).
      observation = base_observation * step
      self.assertEqual(agent.step(reward=1, observation=observation), 0)
      stack_pos = step - num_steps - 1
      if stack_pos >= -self.stack_size:
        expected_state[:, :, :, stack_pos] = onp.full(
            (1,) + self.observation_shape, step)
    onp.array_equal(agent.state, expected_state)
    onp.array_equal(
        agent._last_observation,
        onp.ones(self.observation_shape) * (num_steps - 1))
    onp.array_equal(agent._observation, observation[:, :, 0])
    # No training happens in eval mode.
    self.assertEqual(agent.training_steps, 0)
    # No transitions are added in eval mode.
    self.assertEqual(agent._replay.add.call_count, 0)
Beispiel #3
0
    def testStepTrain(self):
        """Test the functionality of agent.step() in train mode.

    Specifically, the action returned, and confirm training is happening.
    """
        with tf.Session() as sess:
            agent = self._create_test_agent(sess)
            agent.eval_mode = False
            base_observation = np.ones(
                [self.observation_shape, self.observation_shape, 1])
            # We mock the replay buffer to verify how the agent interacts with it.
            agent._replay = test_utils.MockReplayBuffer()
            self.evaluate(tf.global_variables_initializer())
            # This will reset state and choose a first action.
            agent.begin_episode(base_observation)
            observation = base_observation

            expected_state = self.zero_state
            num_steps = 10
            for step in range(1, num_steps + 1):
                # We make observation a multiple of step for testing purposes (to
                # uniquely identify each observation).
                last_observation = observation
                observation = base_observation * step
                self.assertEqual(agent.step(reward=1, observation=observation),
                                 0)
                stack_pos = step - num_steps - 1
                if stack_pos >= -self.stack_size:
                    expected_state[:, :, :, stack_pos] = np.full(
                        (1, self.observation_shape, self.observation_shape),
                        step)
                self.assertEqual(agent._replay.add.call_count, step)
                mock_args, _ = agent._replay.add.call_args
                self.assertAllEqual(last_observation[:, :, 0], mock_args[0])
                self.assertAllEqual(0, mock_args[1])  # Action selected.
                self.assertAllEqual(1, mock_args[2])  # Reward received.
                self.assertFalse(mock_args[3])  # is_terminal
            self.assertAllEqual(agent.state, expected_state)
            self.assertAllEqual(
                agent._last_observation,
                np.full((self.observation_shape, self.observation_shape),
                        num_steps - 1))
            self.assertAllEqual(agent._observation, observation[:, :, 0])
            # We expect one more than num_steps because of the call to begin_episode.
            self.assertEqual(agent.training_steps, num_steps + 1)
            self.assertEqual(agent._replay.add.call_count, num_steps)

            agent.end_episode(reward=1)
            self.assertEqual(agent._replay.add.call_count, num_steps + 1)
            mock_args, _ = agent._replay.add.call_args
            self.assertAllEqual(observation[:, :, 0], mock_args[0])
            self.assertAllEqual(0, mock_args[1])  # Action selected.
            self.assertAllEqual(1, mock_args[2])  # Reward received.
            self.assertTrue(mock_args[3])  # is_terminal
Beispiel #4
0
    def _custom_shapes_test(self, shape, dtype, stack_size):
        self.observation_shape = shape
        self.observation_dtype = dtype
        self.stack_size = stack_size
        self.zero_state = onp.zeros(shape + (stack_size, ))
        agent = self._create_test_agent()
        agent.eval_mode = False
        base_observation = onp.ones(self.observation_shape + (1, ))
        # We mock the replay buffer to verify how the agent interacts with it.
        agent._replay = test_utils.MockReplayBuffer(is_jax=True)
        # This will reset state and choose a first action.
        agent.begin_episode(base_observation)
        observation = base_observation

        expected_state = self.zero_state
        num_steps = 10
        for step in range(1, num_steps + 1):
            # We make observation a multiple of step for testing purposes (to
            # uniquely identify each observation).
            last_observation = observation
            observation = base_observation * step
            self.assertEqual(agent.step(reward=1, observation=observation), 0)
            stack_pos = step - num_steps - 1
            if stack_pos >= -self.stack_size:
                expected_state[...,
                               stack_pos] = onp.full(self.observation_shape,
                                                     step)
            self.assertEqual(agent._replay.add.call_count, step)
            mock_args, _ = agent._replay.add.call_args
            self.assertTrue(
                onp.array_equal(last_observation[..., 0], mock_args[0]))
            self.assertEqual(0, mock_args[1])  # Action selected.
            self.assertEqual(1, mock_args[2])  # Reward received.
            self.assertFalse(mock_args[3])  # is_terminal
        self.assertTrue(onp.array_equal(agent.state, expected_state))
        self.assertTrue(
            onp.array_equal(agent._last_observation,
                            onp.full(self.observation_shape, num_steps - 1)))
        self.assertTrue(
            onp.array_equal(agent._observation, observation[..., 0]))
        # We expect one more than num_steps because of the call to begin_episode.
        self.assertEqual(agent.training_steps, num_steps + 1)
        self.assertEqual(agent._replay.add.call_count, num_steps)

        agent.end_episode(reward=1)
        self.assertEqual(agent._replay.add.call_count, num_steps + 1)
        mock_args, _ = agent._replay.add.call_args
        self.assertTrue(onp.array_equal(observation[..., 0], mock_args[0]))
        self.assertEqual(0, mock_args[1])  # Action selected.
        self.assertEqual(1, mock_args[2])  # Reward received.
        self.assertTrue(mock_args[3])  # is_terminal
Beispiel #5
0
    def testStepEval(self):
        """Test the functionality of agent.step() in eval mode.

    Specifically, the action returned, and confirm no training is happening.
    """
        with tf.Session() as sess:
            agent = self._create_test_agent(sess)
            base_observation = np.ones(
                [self.observation_shape, self.observation_shape, 1])
            # This will reset state and choose a first action.
            agent.begin_episode(base_observation)
            # We mock the replay buffer to verify how the agent interacts with it.
            agent._replay = test_utils.MockReplayBuffer()
            self.evaluate(tf.global_variables_initializer())

            expected_state = self.zero_state
            num_steps = 10
            for step in range(1, num_steps + 1):
                # We make observation a multiple of step for testing purposes (to
                # uniquely identify each observation).
                observation = base_observation * step
                self.assertEqual(agent.step(reward=1, observation=observation),
                                 0)
                stack_pos = step - num_steps - 1
                if stack_pos >= -self.stack_size:
                    expected_state[:, :, :, stack_pos] = np.full(
                        (1, self.observation_shape, self.observation_shape),
                        step)
            self.assertAllEqual(agent.state, expected_state)
            self.assertAllEqual(
                agent._last_observation,
                np.ones([self.observation_shape, self.observation_shape]) *
                (num_steps - 1))
            self.assertAllEqual(agent._observation, observation[:, :, 0])
            # No training happens in eval mode.
            self.assertEqual(agent.training_steps, 0)
            # No transitions are added in eval mode.
            self.assertEqual(agent._replay.add.call_count, 0)