Beispiel #1
0
    def test_stack_frames_done(self):
        zero_state = networks.DuelingLSTMDQNNet(
            2, [1], stack_size=4).initial_state(1).frame_stacking_state
        # frames: [time=1, batch_size=1, channels=1].
        # done: [time=1, batch_size=1].
        output, state = stack_frames(frames=[[[1]]],
                                     done=[[False]],
                                     frame_stacking_state=zero_state,
                                     stack_size=4)

        # [time=1, batch_size=1, frame_stack=4]
        # 3 zero frames and last one coming from last inputs.
        self.assertAllEqual(output, [[[1, 0, 0, 0]]])

        # Episode is done, stacking should be reset.
        output, state = stack_frames(frames=[[[2]]],
                                     done=[[True]],
                                     frame_stacking_state=state,
                                     stack_size=4)
        self.assertAllEqual(output, [[[2, 0, 0, 0]]])

        # A longer unroll with done in the middle should be used correctly.
        # frames: [time=6, batch_size=1, channels=1].
        output, state = stack_frames(frames=[[[3]], [[4]], [[5]], [[6]], [[7]],
                                             [[8]]],
                                     done=[[False], [False], [False], [False],
                                           [True], [False]],
                                     frame_stacking_state=state,
                                     stack_size=4)

        self.assertEqual(output.shape[0], 6)
        self.assertAllEqual(output[0], [[3, 2, 0, 0]])
        self.assertAllEqual(output[5], [[8, 7, 0, 0]])
Beispiel #2
0
 def test_compute_loss_basic(self):
   """Basic test to exercise learner.compute_loss_and_priorities()."""
   batch_size = 32
   num_actions = 3
   unroll_length = 10
   training_agent = networks.DuelingLSTMDQNNet(num_actions, OBS_SHAPE)
   prev_actions = tf.random.uniform(
       [unroll_length, batch_size], maxval=2, dtype=tf.int32)
   tf.function(learner.compute_loss_and_priorities)(
       training_agent,
       networks.DuelingLSTMDQNNet(num_actions, OBS_SHAPE),
       training_agent.initial_state(batch_size),
       prev_actions,
       self._create_env_output(batch_size, unroll_length),
       self._create_agent_outputs(batch_size, unroll_length, num_actions),
       0.99,
       burn_in=5)
Beispiel #3
0
 def test_basic(self):
     agent = networks.DuelingLSTMDQNNet(2, [OBS_DIM, OBS_DIM, 1],
                                        stack_size=4)
     batch_size = 16
     initial_agent_state = agent.initial_state(batch_size)
     _, _ = agent(self._create_agent_input(batch_size, 80),
                  initial_agent_state,
                  unroll=True)
Beispiel #4
0
 def test_basic_frame_stacking(self):
     agent = networks.DuelingLSTMDQNNet(2, [OBS_DIM, OBS_DIM, 1],
                                        stack_size=4)
     batch_size = 16
     initial_agent_state = agent.initial_state(batch_size)
     with mock.patch.object(agent, '_torso', wraps=agent._torso):
         _, _ = agent(self._create_agent_input(batch_size, 80),
                      initial_agent_state,
                      unroll=True)
         self.assertEqual(
             agent._torso.call_args[0][1].observation.shape[-1], 4)
Beispiel #5
0
 def check_core_input_shape(self):
     num_actions = 37
     agent = networks.DuelingLSTMDQNNet(num_actions, [OBS_DIM, OBS_DIM, 1],
                                        stack_size=4)
     batch_size = 16
     initial_agent_state = agent.initial_state(batch_size)
     with mock.patch.object(agent, '_core', wraps=agent._core):
         _, _ = agent(self._create_agent_input(batch_size, 80),
                      initial_agent_state,
                      unroll=True)
         # conv_output_dim + num_actions + reward.
         self.assertEqual(agent._core.call_args[0][0].shape[-1],
                          512 + num_actions + 1)
Beispiel #6
0
    def test_stack_frames(self):
        zero_state = networks.DuelingLSTMDQNNet(
            2, [1], stack_size=4).initial_state(1).frame_stacking_state
        # frames: [time=1, batch_size=1, channels=1].
        # done: [time=1, batch_size=1].
        output, state = stack_frames(frames=[[[1]]],
                                     done=[[False]],
                                     frame_stacking_state=zero_state,
                                     stack_size=4)

        # [time=1, batch_size=1, frame_stack=4]
        # 3 zero frames and last one coming from last inputs.
        self.assertAllEqual(output, [[[1, 0, 0, 0]]])

        output, state = stack_frames(frames=[[[2]]],
                                     done=[[False]],
                                     frame_stacking_state=state,
                                     stack_size=4)
        # 2 zero frames and last 2 ones coming from the last two inputs.
        self.assertAllEqual(output, [[[2, 1, 0, 0]]])

        # A longer unroll should be used correctly.
        # frames: [time=6, batch_size=1, channels=1].
        output, state = stack_frames(frames=[[[3]], [[4]], [[5]], [[6]], [[7]],
                                             [[8]]],
                                     done=[[False]] * 6,
                                     frame_stacking_state=state,
                                     stack_size=4)

        self.assertEqual(output.shape[0], 6)
        # The first element of the output should be a stack with 1 blank frames and
        # 3 real frames.
        self.assertAllEqual(output[0], [[3, 2, 1, 0]])
        # The last element of the output should contain the last 4 frames from the
        # last inputs.
        self.assertAllEqual(output[5], [[8, 7, 6, 5]])
Beispiel #7
0
def create_agent(env_output_specs, num_actions):
    return networks.DuelingLSTMDQNNet(num_actions,
                                      env_output_specs.observation.shape,
                                      FLAGS.stack_size)