Ejemplo n.º 1
0
  def testTwoScopes(self):
    resource_tracker1 = tracking.ResourceTracker()
    with tracking.resource_tracker_scope(resource_tracker1):
      dummy_resource1 = _DummyResource("test1")

    resource_tracker2 = tracking.ResourceTracker()
    with tracking.resource_tracker_scope(resource_tracker2):
      dummy_resource2 = _DummyResource("test2")

    self.assertEqual(1, len(resource_tracker1.resources))
    self.assertEqual("test1", resource_tracker1.resources[0].resource_handle)
    self.assertEqual(1, len(resource_tracker1.resources))
    self.assertEqual("test2", resource_tracker2.resources[0].resource_handle)
Ejemplo n.º 2
0
  def testBasic(self):
    resource_tracker = tracking.ResourceTracker()
    with tracking.resource_tracker_scope(resource_tracker):
      dummy_resource1 = _DummyResource("test1")
      dummy_resource2 = _DummyResource("test2")

    self.assertEqual(2, len(resource_tracker.resources))
    self.assertEqual("test1", resource_tracker.resources[0].resource_handle)
    self.assertEqual("test2", resource_tracker.resources[1].resource_handle)
Ejemplo n.º 3
0
def write_v2_saved_model(tf_function: function.Function, name: str,
                         saved_model_dir: str) -> function.ConcreteFunction:
    """Writes `tf_function` under attr `name` to `saved_model_dir`."""
    module = tf.Module()

    resource_tracker = tracking.ResourceTracker()
    object_tracker = annotators.ObjectTracker()
    created_variables = []

    def _variable_creator(next_creator, **kwargs):
        var = next_creator(**kwargs)
        created_variables.append(var)
        return var

    # TODO(b/164921571): Handle generic Trackable objects.
    # Trace `tf_function` to gather any resources in it using the
    # resource_tracker. These are then assigned to `module.resources` and tracked
    # before exporting to SavedModel.
    with tracking.resource_tracker_scope(resource_tracker), \
         annotators.object_tracker_scope(object_tracker), \
         tf.variable_creator_scope(_variable_creator):
        concrete_fn = tf_function.get_concrete_function()

    # Prior to 2020/10/08, saving a tf.function with a concrete function signature
    # would ensure that the function was not re-traced in a round-trip to a
    # SavedModel. Since this is no longer the case, we save the concrete function
    # directly.
    if tf.compat.forward_compatible(2020, 10, 8):
        pruned_function = optimize_concrete_function(concrete_fn)
        module.pruned_variables = pruned_function.variables
        setattr(module, name, pruned_function)
    else:
        setattr(module, name, tf_function)

    # Any variables created need to be explicitly tracked.
    module.created_variables = created_variables
    # Resources need to be explicitly tracked.
    module.resources = resource_tracker.resources
    module.trackable_objects = object_tracker.trackable_objects
    # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the
    # table should be sufficient.
    initializers = []
    for resource in module.resources:
        if isinstance(resource, lookup_ops.InitializableLookupTableBase):
            initializers.append(resource._initializer)  # pylint: disable=protected-access
    module.initializers = initializers
    module.assets = [
        common_types.Asset(asset_filepath)
        for asset_filepath in concrete_fn.graph.get_collection(
            tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
    ]
    tf.saved_model.save(module, saved_model_dir)
    return concrete_fn
Ejemplo n.º 4
0
def _create_test_saved_model(export_in_tf1,
                             input_specs,
                             foo,
                             export_path_suffix=None):
    if not export_path_suffix:
        export_path = os.path.join(tempfile.mkdtemp(), 'export')
    else:
        export_path = os.path.join(tempfile.mkdtemp(), export_path_suffix)
    if export_in_tf1:
        with tf.compat.v1.Graph().as_default():
            with tf.compat.v1.Session().as_default() as session:
                inputs = {}
                for key in six.iterkeys(input_specs):
                    tensor_spec = input_specs[key]
                    if isinstance(tensor_spec, tf.TensorSpec):
                        inputs[key] = tf.compat.v1.placeholder(
                            tensor_spec.dtype, shape=tensor_spec.shape)
                    elif isinstance(tensor_spec, tf.SparseTensorSpec):
                        inputs[key] = tf.compat.v1.sparse_placeholder(
                            tensor_spec.dtype, shape=tensor_spec.shape)
                    elif isinstance(tensor_spec, tf.RaggedTensorSpec):
                        inputs[key] = tf.compat.v1.ragged.placeholder(
                            tensor_spec._dtype, tensor_spec._ragged_rank, [])
                    else:
                        raise ValueError(
                            'TypeSpecs specified should be one of `tf.TensorSpec`, '
                            '`tf.SparseTensorSpec`, `tf.RaggedTensorSpec`')
                outputs = foo(inputs)
                # show that unrelated & unmapped placeholders do not interfere
                tf.compat.v1.placeholder(tf.int64)
                saved_transform_io.write_saved_transform_from_session(
                    session, inputs, outputs, export_path)
    else:
        module = tf.Module()
        module.transform_fn = tf.function(foo, input_signature=[input_specs])
        resource_tracker = tracking.ResourceTracker()
        with tracking.resource_tracker_scope(resource_tracker):
            _ = module.transform_fn.get_concrete_function()
        module.resources = resource_tracker.resources
        # TODO(b/158011374) - Stop explicitly tracking initializers once tables
        # track their initializers.
        initializers = []
        for resource in module.resources:
            if isinstance(resource, lookup_ops.InitializableLookupTableBase):
                initializers.append(resource._initializer)
        module.initializers = initializers
        tf.saved_model.save(module, export_path)
    return export_path
Ejemplo n.º 5
0
def trace_and_write_v2_saved_model(saved_model_dir, preprocessing_fn,
                                   input_signature, base_temp_dir,
                                   tensor_replacement_map,
                                   output_keys_to_name_map):
    """Writes out a SavedModelV2 with preprocessing_fn traced using tf.function.

  The SavedModel written contains a method called `transform_fn` that
  represents the traced `preprocessing_fn`. Additionally, if this is the final
  SavedModel being written out, it will contain a method called `metadata_fn`
  that provides deferred schema annotations.

  Args:
    saved_model_dir: Path to write SavedModel to.
    preprocessing_fn: A user defined python function to be traced.
    input_signature: TypeSpecs describing the inputs to the `preprocessing_fn`.
    base_temp_dir: Base path to write temporary artifacts to.
    tensor_replacement_map: A map from placeholder tensor names to their
      evaluated replacement tensors.
    output_keys_to_name_map: A map from output dictionary keys to the names of
      the tensors that they represent.

  Returns:
    A tuple containing a pair of `tf.ConcreteFunction`s:
      1. The traced preprocessing_fn.
      2. A metadata_fn that returns a dictionary containing the deferred
      annotations added to the graph when invoked with any valid input.
  """

    module = tf.Module()
    transform_fn = get_traced_transform_fn(
        preprocessing_fn,
        input_signature,
        base_temp_dir,
        tensor_replacement_map=tensor_replacement_map,
        output_keys_to_name_map=output_keys_to_name_map)
    metadata_fn = None

    resource_tracker = tracking.ResourceTracker()
    created_variables = []

    def _variable_creator(next_creator, **kwargs):
        var = next_creator(**kwargs)
        created_variables.append(var)
        return var

    # TODO(b/164921571): Handle generic Trackable objects.
    # Trace the `transform_fn` and `metadata_fn` to gather any resources in it
    # using the resource_tracker. These are then assigned to `module.resources`
    # and tracked before exporting to SavedModel.
    with tracking.resource_tracker_scope(
            resource_tracker), tf.variable_creator_scope(_variable_creator):
        concrete_transform_fn = transform_fn.get_concrete_function()
        concrete_metadata_fn = None
        # If the `TENSOR_REPLACEMENTS` graph collection is empty, all TFT analyzers
        # in the `preprocessing_fn` have already been evaluated.
        if not concrete_transform_fn.graph.get_collection(
                analyzer_nodes.TENSOR_REPLACEMENTS):
            metadata_fn = schema_inference.get_traced_metadata_fn(
                tensor_replacement_map,
                preprocessing_fn,
                input_signature,
                base_temp_dir,
                evaluate_schema_overrides=True)
            concrete_metadata_fn = metadata_fn.get_concrete_function()

    # Save ConcreteFunction when possible since the above workaround won't work if
    # the tf.function is retraced.
    if tf.compat.forward_compatible(2020, 10, 8):
        module.transform_fn = concrete_transform_fn
        module.metadata_fn = concrete_metadata_fn
    else:
        module.transform_fn = transform_fn
        module.metadata_fn = metadata_fn

    # Any variables created need to be explicitly tracked.
    module.created_variables = created_variables
    # Resources need to be explicitly tracked.
    module.resources = resource_tracker.resources
    # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the
    # table should be sufficient.
    initializers = []
    for resource in module.resources:
        if isinstance(resource, lookup_ops.InitializableLookupTableBase):
            initializers.append(resource._initializer)  # pylint: disable=protected-access
    module.initializers = initializers
    module.assets = [
        common_types.Asset(asset_filepath)
        for asset_filepath in concrete_transform_fn.graph.get_collection(
            tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
    ]
    tf.saved_model.save(module, saved_model_dir)
    return concrete_transform_fn, concrete_metadata_fn