コード例 #1
0
 def _build_saver(
     self, policy: tf_policy.TFPolicy
 ) -> Union[policy_saver.PolicySaver, async_policy_saver.AsyncPolicySaver]:
   saver = policy_saver.PolicySaver(
       policy, train_step=self._train_step, metadata=self._metadata)
   if self._async_saving:
     saver = async_policy_saver.AsyncPolicySaver(saver)
   return saver
コード例 #2
0
    def testRegisterFunction(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x. Step is required in TF1.x')

        time_step_spec = ts.TimeStep(
            step_type=tensor_spec.BoundedTensorSpec(dtype=tf.int32,
                                                    shape=(),
                                                    name='st',
                                                    minimum=0,
                                                    maximum=2),
            reward=tensor_spec.BoundedTensorSpec(dtype=tf.float32,
                                                 shape=(),
                                                 name='reward',
                                                 minimum=0.0,
                                                 maximum=5.0),
            discount=tensor_spec.BoundedTensorSpec(dtype=tf.float32,
                                                   shape=(),
                                                   name='discount',
                                                   minimum=0.0,
                                                   maximum=1.0),
            observation=tensor_spec.BoundedTensorSpec(dtype=tf.float32,
                                                      shape=(4, ),
                                                      name='obs',
                                                      minimum=-10.0,
                                                      maximum=10.0))
        action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32,
                                                    shape=(),
                                                    minimum=0,
                                                    maximum=10,
                                                    name='act_0')

        network = q_network.QNetwork(
            input_tensor_spec=time_step_spec.observation,
            action_spec=action_spec)

        policy = q_policy.QPolicy(time_step_spec=time_step_spec,
                                  action_spec=action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        async_saver = async_policy_saver.AsyncPolicySaver(saver)
        async_saver.register_function('q_network', network,
                                      time_step_spec.observation)

        path = os.path.join(self.get_temp_dir(), 'save_model')
        async_saver.save(path)
        async_saver.flush()
        async_saver.close()
        self.assertFalse(async_saver._save_thread.is_alive())
        reloaded = tf.compat.v2.saved_model.load(path)

        sample_input = self.evaluate(
            tensor_spec.sample_spec_nest(time_step_spec.observation,
                                         outer_dims=(3, )))
        expected_output, _ = network(sample_input)
        reloaded_output, _ = reloaded.q_network(sample_input)

        self.assertAllClose(expected_output, reloaded_output)
コード例 #3
0
    def testSave(self):
        saver = mock.create_autospec(policy_saver.PolicySaver, instance=True)
        async_saver = async_policy_saver.AsyncPolicySaver(saver)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        save_path = os.path.join(self.get_temp_dir(), 'policy')
        async_saver.save(save_path)
        async_saver.flush()

        saver.save.assert_called_once_with(save_path)
コード例 #4
0
    def testBlockingSave(self):
        saver = mock.create_autospec(policy_saver.PolicySaver, instance=True)
        async_saver = async_policy_saver.AsyncPolicySaver(saver)
        path1 = os.path.join(self.get_temp_dir(), 'save_model')
        path2 = os.path.join(self.get_temp_dir(), 'save_model2')

        self.evaluate(tf.compat.v1.global_variables_initializer())
        async_saver.save(path1)
        async_saver.save(path2, blocking=True)

        saver.save.assert_has_calls([mock.call(path1), mock.call(path2)])
コード例 #5
0
  def testSave(self):
    saver = mock.create_autospec(policy_saver.PolicySaver, instance=True)
    async_saver = async_policy_saver.AsyncPolicySaver(saver)

    self.evaluate(tf.compat.v1.global_variables_initializer())
    save_path = os.path.join(self.get_temp_dir(), 'policy')
    async_saver.save(save_path)
    async_saver.flush()

    saver.save.assert_called_once_with(save_path)
    # Have to close the saver to avoid hanging threads that will prevent OSS
    # tests from finishing.
    async_saver.close()
コード例 #6
0
 def _build_saver(
     self,
     policy: tf_policy.TFPolicy,
     batch_size: Optional[int] = None
 ) -> Union[policy_saver.PolicySaver, async_policy_saver.AsyncPolicySaver]:
   saver = policy_saver.PolicySaver(
       policy,
       batch_size=batch_size,
       train_step=self._train_step,
       metadata=self._metadata)
   if self._async_saving:
     saver = async_policy_saver.AsyncPolicySaver(saver)
   return saver
コード例 #7
0
    def testCheckpointSave(self):
        saver = mock.create_autospec(policy_saver.PolicySaver, instance=True)
        async_saver = async_policy_saver.AsyncPolicySaver(saver)
        path = os.path.join(self.get_temp_dir(), 'save_model')

        self.evaluate(tf.compat.v1.global_variables_initializer())
        async_saver.save(path)
        async_saver.flush()
        checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
        async_saver.save_checkpoint(checkpoint_path)
        async_saver.flush()

        saver.save_checkpoint.assert_called_once_with(checkpoint_path)
コード例 #8
0
  def testBlockingCheckpointSave(self):
    saver = mock.create_autospec(policy_saver.PolicySaver, instance=True)
    async_saver = async_policy_saver.AsyncPolicySaver(saver)
    path1 = os.path.join(self.get_temp_dir(), 'save_model')
    path2 = os.path.join(self.get_temp_dir(), 'save_model2')

    self.evaluate(tf.compat.v1.global_variables_initializer())
    async_saver.save_checkpoint(path1)
    async_saver.save_checkpoint(path2, blocking=True)

    saver.save_checkpoint.assert_has_calls([mock.call(path1), mock.call(path2)])
    # Have to close the saver to avoid hanging threads that will prevent OSS
    # tests from finishing.
    async_saver.close()
コード例 #9
0
 def _build_saver(
     self,
     policy: tf_policy.TFPolicy,
     batch_size: Optional[int] = None,
     use_nest_path_signatures: bool = True,
 ) -> Union[policy_saver.PolicySaver, async_policy_saver.AsyncPolicySaver]:
     saver = policy_saver.PolicySaver(
         policy,
         batch_size=batch_size,
         train_step=self._train_step,
         metadata=self._metadata,
         use_nest_path_signatures=use_nest_path_signatures,
     )
     if self._async_saving:
         saver = async_policy_saver.AsyncPolicySaver(saver)
     return saver
コード例 #10
0
  def testClose(self):
    saver = mock.create_autospec(policy_saver.PolicySaver, instance=True)
    async_saver = async_policy_saver.AsyncPolicySaver(saver)
    path = os.path.join(self.get_temp_dir(), 'save_model')

    self.evaluate(tf.compat.v1.global_variables_initializer())
    async_saver.save(path)
    self.assertTrue(async_saver._save_thread.is_alive())

    async_saver.close()
    saver.save.assert_called_once()

    self.assertFalse(async_saver._save_thread.is_alive())

    with self.assertRaises(ValueError):
      async_saver.save(path)