def test_basic(self): agent = agents.DuelingLSTMDQNNet(2, [OBS_DIM, OBS_DIM, 1], stack_size=4) batch_size = 16 initial_agent_state = agent.initial_state(batch_size) _, _ = agent(self._create_agent_input(batch_size, 80), initial_agent_state, unroll=True)
def test_stack_frames_done(self): zero_state = agents.DuelingLSTMDQNNet(2, [1], stack_size=4).initial_state( 1).frame_stacking_state # frames: [time=1, batch_size=1, channels=1]. # done: [time=1, batch_size=1]. output, state = stack_frames( frames=[[[1]]], done=[[False]], frame_stacking_state=zero_state, stack_size=4) # [time=1, batch_size=1, frame_stack=4] # 3 zero frames and last one coming from last inputs. self.assertAllEqual(output, [[[1, 0, 0, 0]]]) # Episode is done, stacking should be reset. output, state = stack_frames( frames=[[[2]]], done=[[True]], frame_stacking_state=state, stack_size=4) self.assertAllEqual(output, [[[2, 0, 0, 0]]]) # A longer unroll with done in the middle should be used correctly. # frames: [time=6, batch_size=1, channels=1]. output, state = stack_frames( frames=[[[3]], [[4]], [[5]], [[6]], [[7]], [[8]]], done=[[False], [False], [False], [False], [True], [False]], frame_stacking_state=state, stack_size=4) self.assertEqual(output.shape[0], 6) self.assertAllEqual(output[0], [[3, 2, 0, 0]]) self.assertAllEqual(output[5], [[8, 7, 0, 0]])
def test_stack_frames(self): zero_state = agents.DuelingLSTMDQNNet(2, [1], stack_size=4).initial_state( 1).frame_stacking_state # frames: [time=1, batch_size=1, channels=1]. # done: [time=1, batch_size=1]. output, state = stack_frames( frames=[[[1]]], done=[[False]], frame_stacking_state=zero_state, stack_size=4) # [time=1, batch_size=1, frame_stack=4] # 3 zero frames and last one coming from last inputs. self.assertAllEqual(output, [[[1, 0, 0, 0]]]) output, state = stack_frames( frames=[[[2]]], done=[[False]], frame_stacking_state=state, stack_size=4) # 2 zero frames and last 2 ones coming from the last two inputs. self.assertAllEqual(output, [[[2, 1, 0, 0]]]) # A longer unroll should be used correctly. # frames: [time=6, batch_size=1, channels=1]. output, state = stack_frames( frames=[[[3]], [[4]], [[5]], [[6]], [[7]], [[8]]], done=[[False]] * 6, frame_stacking_state=state, stack_size=4) self.assertEqual(output.shape[0], 6) # The first element of the output should be a stack with 1 blank frames and # 3 real frames. self.assertAllEqual(output[0], [[3, 2, 1, 0]]) # The last element of the output should contain the last 4 frames from the # last inputs. self.assertAllEqual(output[5], [[8, 7, 6, 5]])
def test_compute_loss_basic(self): """Basic test to exercise learner.compute_loss_and_priorities().""" batch_size = 32 num_actions = 3 unroll_length = 10 training_agent = agents.DuelingLSTMDQNNet(num_actions, OBS_SHAPE) prev_actions = tf.random.uniform( [unroll_length, batch_size], maxval=2, dtype=tf.int32) tf.function(learner.compute_loss_and_priorities)( training_agent, agents.DuelingLSTMDQNNet(num_actions, OBS_SHAPE), training_agent.initial_state(batch_size), prev_actions, self._create_env_output(batch_size, unroll_length), self._create_agent_outputs(batch_size, unroll_length, num_actions), 0.99, burn_in=5)
def test_basic_frame_stacking(self): agent = agents.DuelingLSTMDQNNet(2, [OBS_DIM, OBS_DIM, 1], stack_size=4) batch_size = 16 initial_agent_state = agent.initial_state(batch_size) with mock.patch.object(agent, '_torso', wraps=agent._torso): _, _ = agent(self._create_agent_input(batch_size, 80), initial_agent_state, unroll=True) self.assertEqual(agent._torso.call_args[0][1].observation.shape[-1], 4)
def check_core_input_shape(self): num_actions = 37 agent = agents.DuelingLSTMDQNNet(num_actions, [OBS_DIM, OBS_DIM, 1], stack_size=4) batch_size = 16 initial_agent_state = agent.initial_state(batch_size) with mock.patch.object(agent, '_core', wraps=agent._core): _, _ = agent(self._create_agent_input(batch_size, 80), initial_agent_state, unroll=True) # conv_output_dim + num_actions + reward. self.assertEqual(agent._core.call_args[0][0].shape[-1], 512 + num_actions + 1)
def create_agent(env_output_specs, num_actions): return agents.DuelingLSTMDQNNet(num_actions, env_output_specs.observation.shape, FLAGS.stack_size)