def testPolicy(self): agent = random_agent.RandomAgent( self._time_step_spec, self._action_spec, ) observations = tf.constant([[1, 2]], dtype=tf.float32) time_steps = ts.restart(observations, batch_size=1) action_step = agent.policy.action(time_steps) self.evaluate(tf.compat.v1.global_variables_initializer()) actions = self.evaluate(action_step.action) self.assertEqual(list(actions.shape), [1, 1])
def testTrain(self): # Define the train step counter. counter = common.create_variable('test_train_counter') agent = random_agent.RandomAgent(self._time_step_spec, self._action_spec, train_step_counter=counter, num_outer_dims=2) observations = tf.constant([ [[1, 2], [3, 4], [5, 6]], [[1, 2], [3, 4], [5, 6]], ], dtype=tf.float32) time_steps = ts.TimeStep(step_type=tf.constant([[1] * 3] * 2, dtype=tf.int32), reward=tf.constant([[1] * 3] * 2, dtype=tf.float32), discount=tf.constant([[1] * 3] * 2, dtype=tf.float32), observation=observations) actions = tf.constant([[[0], [1], [1]], [[0], [1], [1]]], dtype=tf.float32) experience = trajectory.Trajectory(time_steps.step_type, observations, actions, (), time_steps.step_type, time_steps.reward, time_steps.discount) # Assert that counter starts out at zero. self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertEqual(0, self.evaluate(counter)) agent.train(experience) # Now we should have one iteration. self.assertEqual(1, self.evaluate(counter))
def testCreateAgent(self): agent = random_agent.RandomAgent( self._time_step_spec, self._action_spec, ) agent.initialize()