예제 #1
0
    def test_supply_missing_tensor_inputs(self, batch_size, dtype):
        test_case.skip_if_not_tf2('Tensorflow 2.x required.')

        @tf.function(input_signature=[{
            'x_1':
            tf.TensorSpec([None], dtype=tf.int32),
            'x_2':
            tf.TensorSpec([None], dtype=dtype),
        }])
        def foo(inputs):
            return inputs

        conc_fn = foo.get_concrete_function()
        # structured_input_signature is a tuple of (args, kwargs). [0][0] retrieves
        # the structure of the first arg, which for `foo` is `inputs`.
        structured_inputs = tf.nest.pack_sequence_as(
            conc_fn.structured_input_signature[0][0],
            conc_fn.inputs,
            expand_composites=True)
        missing_keys = ['x_2']
        result = tf2_utils.supply_missing_inputs(structured_inputs, batch_size,
                                                 missing_keys)

        self.assertCountEqual(missing_keys, result.keys())
        self.assertIsInstance(result['x_2'], tf.Tensor)
        self.assertEqual((batch_size, ), result['x_2'].shape)
        self.assertEqual(dtype, result['x_2'].dtype)
예제 #2
0
 def _get_missing_inputs(self, unfed_input_keys, batch_size):
     """Supplies inputs for `unfed_input_keys`."""
     result = {}
     if unfed_input_keys:
         result = (tf2_utils.supply_missing_inputs(self._structured_inputs,
                                                   batch_size,
                                                   unfed_input_keys))
     return result
예제 #3
0
 def _get_missing_inputs(
     self, unfed_input_keys: Iterable[str],
     batch_size: int) -> Dict[str, common_types.TensorType]:
   """Supplies inputs for `unfed_input_keys`."""
   result = {}
   if unfed_input_keys:
     result = (
         tf2_utils.supply_missing_inputs(self._structured_inputs, batch_size,
                                         unfed_input_keys))
   return result
예제 #4
0
def infer_feature_schema_v2(features, concrete_metadata_fn,
                            evaluate_schema_overrides):
    """Given a dict of tensors, creates a `Schema`.

  Infers a schema, in the format of a tf.Transform `Schema`, for the given
  dictionary of tensors.

  If there is an override specified, we override the inferred schema for the
  given feature's tensor.  An override has the meaning that we should set
  is_categorical=True.  If evaluate_schema_overrides is False then we just set
  is_categorical=True, and if evaluate_schema_overrides is True then we also
  compute values of the tensors representing the min and max values and set them
  in the schema.

  If annotations have been specified, they are added to the output schema.

  Args:
    features: A dict mapping column names to `Tensor` or `SparseTensor`s. The
      `Tensor` or `SparseTensor`s should have a 0'th dimension which is
      interpreted as the batch dimension.
    concrete_metadata_fn: A `tf.ConcreteFunction` that returns a dictionary
      containing the deferred annotations added to the graph when invoked with
      any valid input.
    evaluate_schema_overrides: A Boolean used to compute schema overrides. If
      `False`, schema overrides will not be computed.

  Returns:
    A `Schema` proto.
  """
    structured_inputs = tf2_utils.get_structured_inputs_from_func_graph(
        concrete_metadata_fn.graph)
    # Invoke concrete_metadata_fn with some dummy data.
    inputs = tf2_utils.supply_missing_inputs(structured_inputs, batch_size=1)
    flattened_inputs = tf.nest.flatten(inputs, expand_composites=True)
    metadata = collections.defaultdict(list,
                                       concrete_metadata_fn(*flattened_inputs))

    if not evaluate_schema_overrides:
        tensor_ranges = {
            tensor.numpy().decode(): (None, None)
            for tensor in metadata[_TF_METADATA_TENSOR_COLLECTION]
        }
        tensor_annotations = {}
        global_annotations = []
    else:
        tensor_ranges = _get_tensor_ranges_v2(metadata)
        tensor_annotations, global_annotations = _get_schema_annotations_v2(
            metadata)
    return _infer_feature_schema_common(features, tensor_ranges,
                                        tensor_annotations, global_annotations)
  def _apply_v2_transform_model(self, logical_input_map):
    """Applies a V2 transform graph to `Tensor`s.

    This method applies the transformation graph to the `logical_input_map` to
    return only outputs that can be computed from the keys provided in
    `logical_input_map`.

    Args:
      logical_input_map: a dict of logical name to Tensor.  The logical names
        must be a subset of those in the input signature of the transform graph,
        and the corresponding Tensors must have the expected types and shapes.

    Returns:
      A dict of logical name to Tensor, as provided by the output signature of
      the transform graph.
    """
    # TODO(b/160550490): Remove local import.
    from tensorflow_transform import tf2_utils  # pylint: disable=g-import-not-at-top

    feeds = object_identity.ObjectIdentitySet(self._func_graph.inputs)
    unfed_input_keys = (
        set(six.iterkeys(self._structured_inputs)) -
        set(six.iterkeys(logical_input_map)))
    for input_key in unfed_input_keys:
      unfed_input_components = self._get_component_tensors(
          self._structured_inputs[input_key])
      feeds = feeds.difference(unfed_input_components)

    modified_inputs = copy.copy(logical_input_map)
    if unfed_input_keys:
      batch_size = 1
      if logical_input_map:
        an_input = next(six.itervalues(logical_input_map))
        if tf.shape(an_input)[0] is not None:
          batch_size = tf.shape(an_input)[0]
      missing_inputs = (
          tf2_utils.supply_missing_inputs(self._structured_inputs, batch_size,
                                          unfed_input_keys))
      modified_inputs.update(missing_inputs)

    fetches = self._get_fetches(feeds)
    transformed_features = self._wrapped(modified_inputs)
    return {key: transformed_features[key] for key in fetches.keys()}
예제 #6
0
    def metadata_fn():
        graph = ops.get_default_graph()
        inputs = tf2_utils.supply_missing_inputs(structured_inputs,
                                                 batch_size=1)
        with graph_context.TFGraphContext(
                temp_dir=base_temp_dir,
                evaluated_replacements=tensor_replacement_map):
            transformed_features = preprocessing_fn(inputs)

        # Get a map from tensor value names to feature keys.
        reversed_features = _get_tensor_value_to_key_map(transformed_features)

        result = collections.defaultdict(list)
        if not evaluate_schema_overrides:
            schema_override_tensors = graph.get_collection(
                _TF_METADATA_TENSOR_COLLECTION)
            for tensor in schema_override_tensors:
                if tensor.name in reversed_features:
                    result[_TF_METADATA_TENSOR_COLLECTION].append(
                        reversed_features[tensor.name])
        else:
            # Obtain schema overrides for feature tensor ranges.
            result.update(
                _get_schema_overrides(graph, reversed_features,
                                      _TF_METADATA_TENSOR_COLLECTION, [
                                          _TF_METADATA_TENSOR_MIN_COLLECTION,
                                          _TF_METADATA_TENSOR_MAX_COLLECTION
                                      ]))
            # Obtain schema overrides for feature protos. If no feature tensor is in
            # the `_TF_METADATA_EXTRA_ANNOTATION` collection for a specified
            # annotation, `_TF_METADATA_EXTRA_ANNOTATION_GLOBAL` is used as the
            # feature name to indicate that this annotation should be added to the
            # global schema.
            result.update(
                _get_schema_overrides(
                    graph, reversed_features, _TF_METADATA_EXTRA_ANNOTATION, [
                        _TF_METADATA_EXTRA_ANNOTATION_TYPE_URL,
                        _TF_METADATA_EXTRA_ANNOTATION_PROTO
                    ], _TF_METADATA_EXTRA_ANNOTATION_GLOBAL))
        return result