def _build_saver( self, policy: tf_policy.TFPolicy ) -> Union[policy_saver.PolicySaver, async_policy_saver.AsyncPolicySaver]: saver = policy_saver.PolicySaver( policy, train_step=self._train_step, metadata=self._metadata) if self._async_saving: saver = async_policy_saver.AsyncPolicySaver(saver) return saver
def testRegisterFunction(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in TF2.x. Step is required in TF1.x') time_step_spec = ts.TimeStep( step_type=tensor_spec.BoundedTensorSpec(dtype=tf.int32, shape=(), name='st', minimum=0, maximum=2), reward=tensor_spec.BoundedTensorSpec(dtype=tf.float32, shape=(), name='reward', minimum=0.0, maximum=5.0), discount=tensor_spec.BoundedTensorSpec(dtype=tf.float32, shape=(), name='discount', minimum=0.0, maximum=1.0), observation=tensor_spec.BoundedTensorSpec(dtype=tf.float32, shape=(4, ), name='obs', minimum=-10.0, maximum=10.0)) action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32, shape=(), minimum=0, maximum=10, name='act_0') network = q_network.QNetwork( input_tensor_spec=time_step_spec.observation, action_spec=action_spec) policy = q_policy.QPolicy(time_step_spec=time_step_spec, action_spec=action_spec, q_network=network) saver = policy_saver.PolicySaver(policy, batch_size=None) async_saver = async_policy_saver.AsyncPolicySaver(saver) async_saver.register_function('q_network', network, time_step_spec.observation) path = os.path.join(self.get_temp_dir(), 'save_model') async_saver.save(path) async_saver.flush() async_saver.close() self.assertFalse(async_saver._save_thread.is_alive()) reloaded = tf.compat.v2.saved_model.load(path) sample_input = self.evaluate( tensor_spec.sample_spec_nest(time_step_spec.observation, outer_dims=(3, ))) expected_output, _ = network(sample_input) reloaded_output, _ = reloaded.q_network(sample_input) self.assertAllClose(expected_output, reloaded_output)
def testSave(self): saver = mock.create_autospec(policy_saver.PolicySaver, instance=True) async_saver = async_policy_saver.AsyncPolicySaver(saver) self.evaluate(tf.compat.v1.global_variables_initializer()) save_path = os.path.join(self.get_temp_dir(), 'policy') async_saver.save(save_path) async_saver.flush() saver.save.assert_called_once_with(save_path)
def testBlockingSave(self): saver = mock.create_autospec(policy_saver.PolicySaver, instance=True) async_saver = async_policy_saver.AsyncPolicySaver(saver) path1 = os.path.join(self.get_temp_dir(), 'save_model') path2 = os.path.join(self.get_temp_dir(), 'save_model2') self.evaluate(tf.compat.v1.global_variables_initializer()) async_saver.save(path1) async_saver.save(path2, blocking=True) saver.save.assert_has_calls([mock.call(path1), mock.call(path2)])
def testSave(self): saver = mock.create_autospec(policy_saver.PolicySaver, instance=True) async_saver = async_policy_saver.AsyncPolicySaver(saver) self.evaluate(tf.compat.v1.global_variables_initializer()) save_path = os.path.join(self.get_temp_dir(), 'policy') async_saver.save(save_path) async_saver.flush() saver.save.assert_called_once_with(save_path) # Have to close the saver to avoid hanging threads that will prevent OSS # tests from finishing. async_saver.close()
def _build_saver( self, policy: tf_policy.TFPolicy, batch_size: Optional[int] = None ) -> Union[policy_saver.PolicySaver, async_policy_saver.AsyncPolicySaver]: saver = policy_saver.PolicySaver( policy, batch_size=batch_size, train_step=self._train_step, metadata=self._metadata) if self._async_saving: saver = async_policy_saver.AsyncPolicySaver(saver) return saver
def testCheckpointSave(self): saver = mock.create_autospec(policy_saver.PolicySaver, instance=True) async_saver = async_policy_saver.AsyncPolicySaver(saver) path = os.path.join(self.get_temp_dir(), 'save_model') self.evaluate(tf.compat.v1.global_variables_initializer()) async_saver.save(path) async_saver.flush() checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint') async_saver.save_checkpoint(checkpoint_path) async_saver.flush() saver.save_checkpoint.assert_called_once_with(checkpoint_path)
def testBlockingCheckpointSave(self): saver = mock.create_autospec(policy_saver.PolicySaver, instance=True) async_saver = async_policy_saver.AsyncPolicySaver(saver) path1 = os.path.join(self.get_temp_dir(), 'save_model') path2 = os.path.join(self.get_temp_dir(), 'save_model2') self.evaluate(tf.compat.v1.global_variables_initializer()) async_saver.save_checkpoint(path1) async_saver.save_checkpoint(path2, blocking=True) saver.save_checkpoint.assert_has_calls([mock.call(path1), mock.call(path2)]) # Have to close the saver to avoid hanging threads that will prevent OSS # tests from finishing. async_saver.close()
def _build_saver( self, policy: tf_policy.TFPolicy, batch_size: Optional[int] = None, use_nest_path_signatures: bool = True, ) -> Union[policy_saver.PolicySaver, async_policy_saver.AsyncPolicySaver]: saver = policy_saver.PolicySaver( policy, batch_size=batch_size, train_step=self._train_step, metadata=self._metadata, use_nest_path_signatures=use_nest_path_signatures, ) if self._async_saving: saver = async_policy_saver.AsyncPolicySaver(saver) return saver
def testClose(self): saver = mock.create_autospec(policy_saver.PolicySaver, instance=True) async_saver = async_policy_saver.AsyncPolicySaver(saver) path = os.path.join(self.get_temp_dir(), 'save_model') self.evaluate(tf.compat.v1.global_variables_initializer()) async_saver.save(path) self.assertTrue(async_saver._save_thread.is_alive()) async_saver.close() saver.save.assert_called_once() self.assertFalse(async_saver._save_thread.is_alive()) with self.assertRaises(ValueError): async_saver.save(path)