示例#1
0
 def test_basic(self):
   agent = agents.DuelingLSTMDQNNet(2, [OBS_DIM, OBS_DIM, 1], stack_size=4)
   batch_size = 16
   initial_agent_state = agent.initial_state(batch_size)
   _, _ = agent(self._create_agent_input(batch_size, 80),
                initial_agent_state,
                unroll=True)
示例#2
0
  def test_stack_frames_done(self):
    zero_state = agents.DuelingLSTMDQNNet(2, [1], stack_size=4).initial_state(
        1).frame_stacking_state
    # frames: [time=1, batch_size=1, channels=1].
    # done: [time=1, batch_size=1].
    output, state = stack_frames(
        frames=[[[1]]], done=[[False]], frame_stacking_state=zero_state,
        stack_size=4)

    # [time=1, batch_size=1, frame_stack=4]
    # 3 zero frames and last one coming from last inputs.
    self.assertAllEqual(output, [[[1, 0, 0, 0]]])

    # Episode is done, stacking should be reset.
    output, state = stack_frames(
        frames=[[[2]]], done=[[True]], frame_stacking_state=state,
        stack_size=4)
    self.assertAllEqual(output, [[[2, 0, 0, 0]]])

    # A longer unroll with done in the middle should be used correctly.
    # frames: [time=6, batch_size=1, channels=1].
    output, state = stack_frames(
        frames=[[[3]], [[4]], [[5]], [[6]], [[7]], [[8]]],
        done=[[False], [False], [False], [False], [True], [False]],
        frame_stacking_state=state, stack_size=4)

    self.assertEqual(output.shape[0], 6)
    self.assertAllEqual(output[0], [[3, 2, 0, 0]])
    self.assertAllEqual(output[5], [[8, 7, 0, 0]])
示例#3
0
  def test_stack_frames(self):
    zero_state = agents.DuelingLSTMDQNNet(2, [1], stack_size=4).initial_state(
        1).frame_stacking_state
    # frames: [time=1, batch_size=1, channels=1].
    # done: [time=1, batch_size=1].
    output, state = stack_frames(
        frames=[[[1]]], done=[[False]], frame_stacking_state=zero_state,
        stack_size=4)

    # [time=1, batch_size=1, frame_stack=4]
    # 3 zero frames and last one coming from last inputs.
    self.assertAllEqual(output, [[[1, 0, 0, 0]]])

    output, state = stack_frames(
        frames=[[[2]]], done=[[False]], frame_stacking_state=state,
        stack_size=4)
    # 2 zero frames and last 2 ones coming from the last two inputs.
    self.assertAllEqual(output, [[[2, 1, 0, 0]]])

    # A longer unroll should be used correctly.
    # frames: [time=6, batch_size=1, channels=1].
    output, state = stack_frames(
        frames=[[[3]], [[4]], [[5]], [[6]], [[7]], [[8]]], done=[[False]] * 6,
        frame_stacking_state=state, stack_size=4)

    self.assertEqual(output.shape[0], 6)
     # The first element of the output should be a stack with 1 blank frames and
     # 3 real frames.
    self.assertAllEqual(output[0], [[3, 2, 1, 0]])
    # The last element of the output should contain the last 4 frames from the
    # last inputs.
    self.assertAllEqual(output[5], [[8, 7, 6, 5]])
示例#4
0
 def test_compute_loss_basic(self):
   """Basic test to exercise learner.compute_loss_and_priorities()."""
   batch_size = 32
   num_actions = 3
   unroll_length = 10
   training_agent = agents.DuelingLSTMDQNNet(num_actions, OBS_SHAPE)
   prev_actions = tf.random.uniform(
       [unroll_length, batch_size], maxval=2, dtype=tf.int32)
   tf.function(learner.compute_loss_and_priorities)(
       training_agent,
       agents.DuelingLSTMDQNNet(num_actions, OBS_SHAPE),
       training_agent.initial_state(batch_size),
       prev_actions,
       self._create_env_output(batch_size, unroll_length),
       self._create_agent_outputs(batch_size, unroll_length, num_actions),
       0.99,
       burn_in=5)
示例#5
0
 def test_basic_frame_stacking(self):
   agent = agents.DuelingLSTMDQNNet(2, [OBS_DIM, OBS_DIM, 1], stack_size=4)
   batch_size = 16
   initial_agent_state = agent.initial_state(batch_size)
   with mock.patch.object(agent, '_torso', wraps=agent._torso):
     _, _ = agent(self._create_agent_input(batch_size, 80),
                  initial_agent_state,
                  unroll=True)
     self.assertEqual(agent._torso.call_args[0][1].observation.shape[-1], 4)
示例#6
0
 def check_core_input_shape(self):
   num_actions = 37
   agent = agents.DuelingLSTMDQNNet(num_actions, [OBS_DIM, OBS_DIM, 1],
                                    stack_size=4)
   batch_size = 16
   initial_agent_state = agent.initial_state(batch_size)
   with mock.patch.object(agent, '_core', wraps=agent._core):
     _, _ = agent(self._create_agent_input(batch_size, 80),
                  initial_agent_state,
                  unroll=True)
     # conv_output_dim + num_actions + reward.
     self.assertEqual(agent._core.call_args[0][0].shape[-1],
                      512 + num_actions + 1)
示例#7
0
def create_agent(env_output_specs, num_actions):
    return agents.DuelingLSTMDQNNet(num_actions,
                                    env_output_specs.observation.shape,
                                    FLAGS.stack_size)