コード例 #1
0
 def testPolicy(self):
     agent = random_agent.RandomAgent(
         self._time_step_spec,
         self._action_spec,
     )
     observations = tf.constant([[1, 2]], dtype=tf.float32)
     time_steps = ts.restart(observations, batch_size=1)
     action_step = agent.policy.action(time_steps)
     self.evaluate(tf.compat.v1.global_variables_initializer())
     actions = self.evaluate(action_step.action)
     self.assertEqual(list(actions.shape), [1, 1])
コード例 #2
0
    def testTrain(self):
        # Define the train step counter.
        counter = common.create_variable('test_train_counter')
        agent = random_agent.RandomAgent(self._time_step_spec,
                                         self._action_spec,
                                         train_step_counter=counter,
                                         num_outer_dims=2)
        observations = tf.constant([
            [[1, 2], [3, 4], [5, 6]],
            [[1, 2], [3, 4], [5, 6]],
        ],
                                   dtype=tf.float32)

        time_steps = ts.TimeStep(step_type=tf.constant([[1] * 3] * 2,
                                                       dtype=tf.int32),
                                 reward=tf.constant([[1] * 3] * 2,
                                                    dtype=tf.float32),
                                 discount=tf.constant([[1] * 3] * 2,
                                                      dtype=tf.float32),
                                 observation=observations)
        actions = tf.constant([[[0], [1], [1]], [[0], [1], [1]]],
                              dtype=tf.float32)

        experience = trajectory.Trajectory(time_steps.step_type, observations,
                                           actions, (), time_steps.step_type,
                                           time_steps.reward,
                                           time_steps.discount)

        # Assert that counter starts out at zero.
        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.assertEqual(0, self.evaluate(counter))

        agent.train(experience)

        # Now we should have one iteration.
        self.assertEqual(1, self.evaluate(counter))
コード例 #3
0
 def testCreateAgent(self):
     agent = random_agent.RandomAgent(
         self._time_step_spec,
         self._action_spec,
     )
     agent.initialize()