def testTwoScopes(self): resource_tracker1 = tracking.ResourceTracker() with tracking.resource_tracker_scope(resource_tracker1): dummy_resource1 = _DummyResource("test1") resource_tracker2 = tracking.ResourceTracker() with tracking.resource_tracker_scope(resource_tracker2): dummy_resource2 = _DummyResource("test2") self.assertEqual(1, len(resource_tracker1.resources)) self.assertEqual("test1", resource_tracker1.resources[0].resource_handle) self.assertEqual(1, len(resource_tracker1.resources)) self.assertEqual("test2", resource_tracker2.resources[0].resource_handle)
def testBasic(self): resource_tracker = tracking.ResourceTracker() with tracking.resource_tracker_scope(resource_tracker): dummy_resource1 = _DummyResource("test1") dummy_resource2 = _DummyResource("test2") self.assertEqual(2, len(resource_tracker.resources)) self.assertEqual("test1", resource_tracker.resources[0].resource_handle) self.assertEqual("test2", resource_tracker.resources[1].resource_handle)
def write_v2_saved_model(tf_function: function.Function, name: str, saved_model_dir: str) -> function.ConcreteFunction: """Writes `tf_function` under attr `name` to `saved_model_dir`.""" module = tf.Module() resource_tracker = tracking.ResourceTracker() object_tracker = annotators.ObjectTracker() created_variables = [] def _variable_creator(next_creator, **kwargs): var = next_creator(**kwargs) created_variables.append(var) return var # TODO(b/164921571): Handle generic Trackable objects. # Trace `tf_function` to gather any resources in it using the # resource_tracker. These are then assigned to `module.resources` and tracked # before exporting to SavedModel. with tracking.resource_tracker_scope(resource_tracker), \ annotators.object_tracker_scope(object_tracker), \ tf.variable_creator_scope(_variable_creator): concrete_fn = tf_function.get_concrete_function() # Prior to 2020/10/08, saving a tf.function with a concrete function signature # would ensure that the function was not re-traced in a round-trip to a # SavedModel. Since this is no longer the case, we save the concrete function # directly. if tf.compat.forward_compatible(2020, 10, 8): pruned_function = optimize_concrete_function(concrete_fn) module.pruned_variables = pruned_function.variables setattr(module, name, pruned_function) else: setattr(module, name, tf_function) # Any variables created need to be explicitly tracked. module.created_variables = created_variables # Resources need to be explicitly tracked. module.resources = resource_tracker.resources module.trackable_objects = object_tracker.trackable_objects # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the # table should be sufficient. initializers = [] for resource in module.resources: if isinstance(resource, lookup_ops.InitializableLookupTableBase): initializers.append(resource._initializer) # pylint: disable=protected-access module.initializers = initializers module.assets = [ common_types.Asset(asset_filepath) for asset_filepath in concrete_fn.graph.get_collection( tf.compat.v1.GraphKeys.ASSET_FILEPATHS) ] tf.saved_model.save(module, saved_model_dir) return concrete_fn
def _create_test_saved_model(export_in_tf1, input_specs, foo, export_path_suffix=None): if not export_path_suffix: export_path = os.path.join(tempfile.mkdtemp(), 'export') else: export_path = os.path.join(tempfile.mkdtemp(), export_path_suffix) if export_in_tf1: with tf.compat.v1.Graph().as_default(): with tf.compat.v1.Session().as_default() as session: inputs = {} for key in six.iterkeys(input_specs): tensor_spec = input_specs[key] if isinstance(tensor_spec, tf.TensorSpec): inputs[key] = tf.compat.v1.placeholder( tensor_spec.dtype, shape=tensor_spec.shape) elif isinstance(tensor_spec, tf.SparseTensorSpec): inputs[key] = tf.compat.v1.sparse_placeholder( tensor_spec.dtype, shape=tensor_spec.shape) elif isinstance(tensor_spec, tf.RaggedTensorSpec): inputs[key] = tf.compat.v1.ragged.placeholder( tensor_spec._dtype, tensor_spec._ragged_rank, []) else: raise ValueError( 'TypeSpecs specified should be one of `tf.TensorSpec`, ' '`tf.SparseTensorSpec`, `tf.RaggedTensorSpec`') outputs = foo(inputs) # show that unrelated & unmapped placeholders do not interfere tf.compat.v1.placeholder(tf.int64) saved_transform_io.write_saved_transform_from_session( session, inputs, outputs, export_path) else: module = tf.Module() module.transform_fn = tf.function(foo, input_signature=[input_specs]) resource_tracker = tracking.ResourceTracker() with tracking.resource_tracker_scope(resource_tracker): _ = module.transform_fn.get_concrete_function() module.resources = resource_tracker.resources # TODO(b/158011374) - Stop explicitly tracking initializers once tables # track their initializers. initializers = [] for resource in module.resources: if isinstance(resource, lookup_ops.InitializableLookupTableBase): initializers.append(resource._initializer) module.initializers = initializers tf.saved_model.save(module, export_path) return export_path
def trace_and_write_v2_saved_model(saved_model_dir, preprocessing_fn, input_signature, base_temp_dir, tensor_replacement_map, output_keys_to_name_map): """Writes out a SavedModelV2 with preprocessing_fn traced using tf.function. The SavedModel written contains a method called `transform_fn` that represents the traced `preprocessing_fn`. Additionally, if this is the final SavedModel being written out, it will contain a method called `metadata_fn` that provides deferred schema annotations. Args: saved_model_dir: Path to write SavedModel to. preprocessing_fn: A user defined python function to be traced. input_signature: TypeSpecs describing the inputs to the `preprocessing_fn`. base_temp_dir: Base path to write temporary artifacts to. tensor_replacement_map: A map from placeholder tensor names to their evaluated replacement tensors. output_keys_to_name_map: A map from output dictionary keys to the names of the tensors that they represent. Returns: A tuple containing a pair of `tf.ConcreteFunction`s: 1. The traced preprocessing_fn. 2. A metadata_fn that returns a dictionary containing the deferred annotations added to the graph when invoked with any valid input. """ module = tf.Module() transform_fn = get_traced_transform_fn( preprocessing_fn, input_signature, base_temp_dir, tensor_replacement_map=tensor_replacement_map, output_keys_to_name_map=output_keys_to_name_map) metadata_fn = None resource_tracker = tracking.ResourceTracker() created_variables = [] def _variable_creator(next_creator, **kwargs): var = next_creator(**kwargs) created_variables.append(var) return var # TODO(b/164921571): Handle generic Trackable objects. # Trace the `transform_fn` and `metadata_fn` to gather any resources in it # using the resource_tracker. These are then assigned to `module.resources` # and tracked before exporting to SavedModel. with tracking.resource_tracker_scope( resource_tracker), tf.variable_creator_scope(_variable_creator): concrete_transform_fn = transform_fn.get_concrete_function() concrete_metadata_fn = None # If the `TENSOR_REPLACEMENTS` graph collection is empty, all TFT analyzers # in the `preprocessing_fn` have already been evaluated. if not concrete_transform_fn.graph.get_collection( analyzer_nodes.TENSOR_REPLACEMENTS): metadata_fn = schema_inference.get_traced_metadata_fn( tensor_replacement_map, preprocessing_fn, input_signature, base_temp_dir, evaluate_schema_overrides=True) concrete_metadata_fn = metadata_fn.get_concrete_function() # Save ConcreteFunction when possible since the above workaround won't work if # the tf.function is retraced. if tf.compat.forward_compatible(2020, 10, 8): module.transform_fn = concrete_transform_fn module.metadata_fn = concrete_metadata_fn else: module.transform_fn = transform_fn module.metadata_fn = metadata_fn # Any variables created need to be explicitly tracked. module.created_variables = created_variables # Resources need to be explicitly tracked. module.resources = resource_tracker.resources # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the # table should be sufficient. initializers = [] for resource in module.resources: if isinstance(resource, lookup_ops.InitializableLookupTableBase): initializers.append(resource._initializer) # pylint: disable=protected-access module.initializers = initializers module.assets = [ common_types.Asset(asset_filepath) for asset_filepath in concrete_transform_fn.graph.get_collection( tf.compat.v1.GraphKeys.ASSET_FILEPATHS) ] tf.saved_model.save(module, saved_model_dir) return concrete_transform_fn, concrete_metadata_fn