def testStepTrain(self): """Test the functionality of agent.step() in train mode. Specifically, the action returned, and confirms training is happening. """ agent = self._create_test_agent() agent.eval_mode = False base_observation = onp.ones(self.observation_shape + (1,)) # We mock the replay buffer to verify how the agent interacts with it. agent._replay = test_utils.MockReplayBuffer(is_jax=True) # This will reset state and choose a first action. agent.begin_episode(base_observation) expected_state = self.zero_state num_steps = 10 for step in range(1, num_steps + 1): # We make observation a multiple of step for testing purposes (to # uniquely identify each observation). observation = base_observation * step self.assertEqual(agent.step(reward=1, observation=observation), 0) stack_pos = step - num_steps - 1 if stack_pos >= -self.stack_size: expected_state[:, :, :, stack_pos] = onp.full( (1,) + self.observation_shape, step) onp.array_equal(agent.state, expected_state) onp.array_equal( agent._last_observation, onp.full(self.observation_shape, num_steps - 1)) onp.array_equal(agent._observation, observation[:, :, 0]) # We expect one more than num_steps because of the call to begin_episode. self.assertEqual(agent.training_steps, num_steps + 1) self.assertEqual(agent._replay.add.call_count, num_steps) agent.end_episode(reward=1) self.assertEqual(agent._replay.add.call_count, num_steps + 1)
def testStepEval(self): """Tests the functionality of agent.step() in eval mode. Specifically, the action returned, and confirms that no training happens. """ agent = self._create_test_agent() base_observation = onp.ones(self.observation_shape + (1,)) # This will reset state and choose a first action. agent.begin_episode(base_observation) # We mock the replay buffer to verify how the agent interacts with it. agent._replay = test_utils.MockReplayBuffer() expected_state = self.zero_state num_steps = 10 for step in range(1, num_steps + 1): # We make observation a multiple of step for testing purposes (to # uniquely identify each observation). observation = base_observation * step self.assertEqual(agent.step(reward=1, observation=observation), 0) stack_pos = step - num_steps - 1 if stack_pos >= -self.stack_size: expected_state[:, :, :, stack_pos] = onp.full( (1,) + self.observation_shape, step) onp.array_equal(agent.state, expected_state) onp.array_equal( agent._last_observation, onp.ones(self.observation_shape) * (num_steps - 1)) onp.array_equal(agent._observation, observation[:, :, 0]) # No training happens in eval mode. self.assertEqual(agent.training_steps, 0) # No transitions are added in eval mode. self.assertEqual(agent._replay.add.call_count, 0)
def testStepTrain(self): """Test the functionality of agent.step() in train mode. Specifically, the action returned, and confirm training is happening. """ with tf.Session() as sess: agent = self._create_test_agent(sess) agent.eval_mode = False base_observation = np.ones( [self.observation_shape, self.observation_shape, 1]) # We mock the replay buffer to verify how the agent interacts with it. agent._replay = test_utils.MockReplayBuffer() self.evaluate(tf.global_variables_initializer()) # This will reset state and choose a first action. agent.begin_episode(base_observation) observation = base_observation expected_state = self.zero_state num_steps = 10 for step in range(1, num_steps + 1): # We make observation a multiple of step for testing purposes (to # uniquely identify each observation). last_observation = observation observation = base_observation * step self.assertEqual(agent.step(reward=1, observation=observation), 0) stack_pos = step - num_steps - 1 if stack_pos >= -self.stack_size: expected_state[:, :, :, stack_pos] = np.full( (1, self.observation_shape, self.observation_shape), step) self.assertEqual(agent._replay.add.call_count, step) mock_args, _ = agent._replay.add.call_args self.assertAllEqual(last_observation[:, :, 0], mock_args[0]) self.assertAllEqual(0, mock_args[1]) # Action selected. self.assertAllEqual(1, mock_args[2]) # Reward received. self.assertFalse(mock_args[3]) # is_terminal self.assertAllEqual(agent.state, expected_state) self.assertAllEqual( agent._last_observation, np.full((self.observation_shape, self.observation_shape), num_steps - 1)) self.assertAllEqual(agent._observation, observation[:, :, 0]) # We expect one more than num_steps because of the call to begin_episode. self.assertEqual(agent.training_steps, num_steps + 1) self.assertEqual(agent._replay.add.call_count, num_steps) agent.end_episode(reward=1) self.assertEqual(agent._replay.add.call_count, num_steps + 1) mock_args, _ = agent._replay.add.call_args self.assertAllEqual(observation[:, :, 0], mock_args[0]) self.assertAllEqual(0, mock_args[1]) # Action selected. self.assertAllEqual(1, mock_args[2]) # Reward received. self.assertTrue(mock_args[3]) # is_terminal
def _custom_shapes_test(self, shape, dtype, stack_size): self.observation_shape = shape self.observation_dtype = dtype self.stack_size = stack_size self.zero_state = onp.zeros(shape + (stack_size, )) agent = self._create_test_agent() agent.eval_mode = False base_observation = onp.ones(self.observation_shape + (1, )) # We mock the replay buffer to verify how the agent interacts with it. agent._replay = test_utils.MockReplayBuffer(is_jax=True) # This will reset state and choose a first action. agent.begin_episode(base_observation) observation = base_observation expected_state = self.zero_state num_steps = 10 for step in range(1, num_steps + 1): # We make observation a multiple of step for testing purposes (to # uniquely identify each observation). last_observation = observation observation = base_observation * step self.assertEqual(agent.step(reward=1, observation=observation), 0) stack_pos = step - num_steps - 1 if stack_pos >= -self.stack_size: expected_state[..., stack_pos] = onp.full(self.observation_shape, step) self.assertEqual(agent._replay.add.call_count, step) mock_args, _ = agent._replay.add.call_args self.assertTrue( onp.array_equal(last_observation[..., 0], mock_args[0])) self.assertEqual(0, mock_args[1]) # Action selected. self.assertEqual(1, mock_args[2]) # Reward received. self.assertFalse(mock_args[3]) # is_terminal self.assertTrue(onp.array_equal(agent.state, expected_state)) self.assertTrue( onp.array_equal(agent._last_observation, onp.full(self.observation_shape, num_steps - 1))) self.assertTrue( onp.array_equal(agent._observation, observation[..., 0])) # We expect one more than num_steps because of the call to begin_episode. self.assertEqual(agent.training_steps, num_steps + 1) self.assertEqual(agent._replay.add.call_count, num_steps) agent.end_episode(reward=1) self.assertEqual(agent._replay.add.call_count, num_steps + 1) mock_args, _ = agent._replay.add.call_args self.assertTrue(onp.array_equal(observation[..., 0], mock_args[0])) self.assertEqual(0, mock_args[1]) # Action selected. self.assertEqual(1, mock_args[2]) # Reward received. self.assertTrue(mock_args[3]) # is_terminal
def testStepEval(self): """Test the functionality of agent.step() in eval mode. Specifically, the action returned, and confirm no training is happening. """ with tf.Session() as sess: agent = self._create_test_agent(sess) base_observation = np.ones( [self.observation_shape, self.observation_shape, 1]) # This will reset state and choose a first action. agent.begin_episode(base_observation) # We mock the replay buffer to verify how the agent interacts with it. agent._replay = test_utils.MockReplayBuffer() self.evaluate(tf.global_variables_initializer()) expected_state = self.zero_state num_steps = 10 for step in range(1, num_steps + 1): # We make observation a multiple of step for testing purposes (to # uniquely identify each observation). observation = base_observation * step self.assertEqual(agent.step(reward=1, observation=observation), 0) stack_pos = step - num_steps - 1 if stack_pos >= -self.stack_size: expected_state[:, :, :, stack_pos] = np.full( (1, self.observation_shape, self.observation_shape), step) self.assertAllEqual(agent.state, expected_state) self.assertAllEqual( agent._last_observation, np.ones([self.observation_shape, self.observation_shape]) * (num_steps - 1)) self.assertAllEqual(agent._observation, observation[:, :, 0]) # No training happens in eval mode. self.assertEqual(agent.training_steps, 0) # No transitions are added in eval mode. self.assertEqual(agent._replay.add.call_count, 0)