예제 #1
0
파일: executor.py 프로젝트: luvneries/tfx
  def _ReadSchema(self, data_format,
                  schema_path):
    """Returns a TFT schema for the input data.

    Args:
      data_format: name of the input data format.
      schema_path: path to schema file.

    Returns:
      A schema representing the provided set of columns.
    """

    if self._ShouldDecodeAsRawExample(data_format):
      return _RAW_EXAMPLE_SCHEMA
    schema = self._GetSchema(schema_path)
    # TODO(b/77351671): Remove this conversion to tf.Transform's internal
    # schema format.
    feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec
    return dataset_schema.from_feature_spec(feature_spec)
예제 #2
0
def _get_raw_feature_spec(schema):
  return schema_utils.schema_as_feature_spec(schema).feature_spec
예제 #3
0
def get_raw_feature_spec():
    return schema_utils.schema_as_feature_spec(
        RAW_DATA_METADATA.schema).feature_spec
예제 #4
0
def _get_raw_feature_spec(schema):
    return schema_utils.schema_as_feature_spec(schema).feature_spec
예제 #5
0
def to_instance_dicts(schema, fetches):
  """Maps the values fetched by `tf.Session.run` to the internal batch format.

  Args:
    schema: A `Schema` proto.
    fetches: A dict representing a batch of data, as returned by `Session.run`.

  Returns:
    A list of dicts where each dict is an in-memory representation of an
        instance.

  Raises:
    ValueError: If `schema` is invalid.
  """

  def decompose_sparse_batch(sparse_value):
    """Decomposes a sparse batch into a list of sparse instances.

    Args:
      sparse_value: A `SparseTensorValue` representing a batch of N sparse
        instances. The indices of the SparseTensorValue are expected to be
        sorted by row order.

    Returns:
      A tuple (instance_indices, instance_values) where the elements are lists
      of N lists representing the indices and values, respectively, of the
      instances in the batch.

    Raises:
      ValueError: If `sparse_value` contains out-of-order indices.
    """
    batch_indices, batch_values, batch_shape = sparse_value
    # Preallocate lists of length batch_size, initialized to empty ndarrays,
    # representing the indices and values of instances. We can reuse the return
    # value of _get_empty_array here because it is immutable.
    instance_indices = [_get_empty_array(batch_indices.dtype)] * batch_shape[0]
    instance_values = [_get_empty_array(batch_values.dtype)] * batch_shape[0]
    instance_rank = len(batch_shape[1:])

    # Iterate over the rows in the batch. At each row, consume all the elements
    # that belong to that row.
    current_offset = 0
    for current_row in range(batch_shape[0]):
      start_offset = current_offset

      # Scan forward until we reach an element that does not belong to the
      # current row.
      while current_offset < len(batch_indices):
        row = batch_indices[current_offset][0]
        if row == current_row:
          # This element belongs to the current row.
          current_offset += 1
        elif row > current_row:
          # We've reached the end of the current row.
          break
        else:
          raise ValueError('Encountered out-of-order sparse index: {}.'.format(
              batch_indices[current_offset]))

      if current_offset == start_offset:
        # If the current row is empty, leave the default value, which is an
        # empty array.
        pass
      else:
        instance_indices[current_row] = batch_indices[
            start_offset:current_offset, 1:]
        if instance_rank == 1:
          # In this case indices will have length 1, so for convenience we
          # reshape from [-1, 1] to [-1].
          instance_indices[current_row] = (
              instance_indices[current_row].reshape([-1]))
        instance_values[current_row] = batch_values[start_offset:current_offset]

    return instance_indices, instance_values

  batch_dict = {}
  batch_sizes = {}
  feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec
  for name, value in six.iteritems(fetches):
    spec = feature_spec[name]
    if isinstance(spec, tf.io.FixedLenFeature):
      batch_dict[name] = [value[i] for i in range(value.shape[0])]
      batch_sizes[name] = value.shape[0]

    elif isinstance(spec, tf.io.VarLenFeature):
      if not isinstance(value, tf.compat.v1.SparseTensorValue):
        raise ValueError(
            'Expected a SparseTensorValue, but got {}'.format(value))
      instance_indices, instance_values = decompose_sparse_batch(value)
      for indices in instance_indices:
        if len(indices.shape) > 1 or np.any(indices != np.arange(len(indices))):
          raise ValueError('Encountered a SparseTensorValue that cannot be '
                           'decoded by ListColumnRepresentation.\n'
                           '"{}" : {}'.format(name, value))
      batch_dict[name] = instance_values
      batch_sizes[name] = len(instance_values)

    elif isinstance(spec, tf.io.SparseFeature):
      if not isinstance(value, tf.compat.v1.SparseTensorValue):
        raise ValueError(
            'Expected a SparseTensorValue, but got {}'.format(value))
      # TODO(abrao): Add support for N-d SparseFeatures.
      instance_indices, instance_values = decompose_sparse_batch(value)
      batch_dict[spec.index_key] = instance_indices
      batch_dict[spec.value_key] = instance_values
      batch_sizes[name] = len(instance_values)

    else:
      raise ValueError('Invalid feature spec {}.'.format(spec))

  # Check batch size is the same for each output.  Note this assumes that
  # fetches is not empty.
  batch_size = next(six.itervalues(batch_sizes))
  for name, batch_size_for_name in six.iteritems(batch_sizes):
    if batch_size_for_name != batch_size:
      raise ValueError(
          'Inconsistent batch sizes: "{}" had batch dimension {}, "{}" had'
          ' batch dimension {}'.format(name, batch_size_for_name,
                                       next(six.iterkeys(batch_sizes)),
                                       batch_size))

  # The following is the simplest way to convert batch_dict from a dict of
  # iterables to a list of dicts.  It does this by first extracting the values
  # of batch_dict, and reversing the order of iteration, then recombining with
  # the keys of batch_dict to create a dict.
  return [dict(zip(six.iterkeys(batch_dict), instance_values))
          for instance_values in zip(*six.itervalues(batch_dict))]
예제 #6
0
def to_instance_dicts(schema, fetches):
    """Converts fetches to the internal batch format.

  Maps the values fetched by `tf.Session.run` or returned by a tf.function to
  the internal batch format.

  Args:
    schema: A `Schema` proto.
    fetches: A dict representing a batch of data, either as returned by
      `Session.run` or eager tensors.

  Returns:
    A list of dicts where each dict is an in-memory representation of an
        instance.

  Raises:
    ValueError: If `schema` is invalid.
  """

    batch_dict = {}
    batch_sizes = {}
    feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec
    for name, tensor_or_value in six.iteritems(fetches):
        spec = feature_spec[name]
        if isinstance(spec, tf.io.FixedLenFeature):
            value = tensor_or_value.numpy() if isinstance(
                tensor_or_value, tf.Tensor) else tensor_or_value
            batch_dict[name] = [value[i] for i in range(value.shape[0])]
            batch_sizes[name] = value.shape[0]

        elif isinstance(spec, tf.io.VarLenFeature):
            instance_values = _handle_varlen_batch(tensor_or_value, name)
            batch_dict[name] = instance_values
            batch_sizes[name] = len(instance_values)

        elif isinstance(spec, tf.io.SparseFeature):
            batch_dict_update = _handle_sparse_batch(tensor_or_value, spec,
                                                     name)
            batch_dict.update(batch_dict_update)
            batch_sizes[name] = len(batch_dict_update[spec.value_key])

        else:
            raise ValueError('Invalid feature spec {}.'.format(spec))

    # Check batch size is the same for each output.  Note this assumes that
    # fetches is not empty.
    batch_size = next(six.itervalues(batch_sizes))
    for name, batch_size_for_name in six.iteritems(batch_sizes):
        if batch_size_for_name != batch_size:
            raise ValueError(
                'Inconsistent batch sizes: "{}" had batch dimension {}, "{}" had'
                ' batch dimension {}'.format(name, batch_size_for_name,
                                             next(six.iterkeys(batch_sizes)),
                                             batch_size))

    # The following is the simplest way to convert batch_dict from a dict of
    # iterables to a list of dicts.  It does this by first extracting the values
    # of batch_dict, and reversing the order of iteration, then recombining with
    # the keys of batch_dict to create a dict.
    return [
        dict(zip(six.iterkeys(batch_dict), instance_values))
        for instance_values in zip(*six.itervalues(batch_dict))
    ]
예제 #7
0
    def expand(self, dataset):
        """Analyze the dataset.

    Args:
      dataset: A dataset.

    Returns:
      A TransformFn containing the deferred transform function.

    Raises:
      ValueError: If preprocessing_fn has no outputs.
    """
        (flattened_pcoll, input_values_pcoll_dict, dataset_cache_dict,
         input_metadata) = dataset
        input_schema = input_metadata.schema

        input_values_pcoll_dict = input_values_pcoll_dict or dict()

        with tf.Graph().as_default() as graph:

            with tf.compat.v1.name_scope('inputs'):
                feature_spec = schema_utils.schema_as_feature_spec(
                    input_schema).feature_spec
                input_signature = impl_helper.feature_spec_as_batched_placeholders(
                    feature_spec)
                # In order to avoid a bug where import_graph_def fails when the
                # input_map and return_elements of an imported graph are the same
                # (b/34288791), we avoid using the placeholder of an input column as an
                # output of a graph. We do this by applying tf.identity to all inputs of
                # the preprocessing_fn.  Note this applies at the level of raw tensors.
                # TODO(b/34288791): Remove this workaround and use a shallow copy of
                # inputs instead.  A shallow copy is needed in case
                # self._preprocessing_fn mutates its input.
                copied_inputs = impl_helper.copy_tensors(input_signature)

            output_signature = self._preprocessing_fn(copied_inputs)

        # At this point we check that the preprocessing_fn has at least one
        # output. This is because if we allowed the output of preprocessing_fn to
        # be empty, we wouldn't be able to determine how many instances to
        # "unbatch" the output into.
        if not output_signature:
            raise ValueError(
                'The preprocessing function returned an empty dict')

        if graph.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES):
            raise ValueError(
                'The preprocessing function contained trainable variables '
                '{}'.format(
                    graph.get_collection_ref(
                        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)))

        pipeline = self.pipeline or (flattened_pcoll or next(
            v for v in input_values_pcoll_dict.values()
            if v is not None)).pipeline
        tf_config = common._DEFAULT_TENSORFLOW_CONFIG_BY_RUNNER.get(  # pylint: disable=protected-access
            pipeline.runner)
        extra_args = common.ConstructBeamPipelineVisitor.ExtraArgs(
            base_temp_dir=Context.create_base_temp_dir(),
            tf_config=tf_config,
            pipeline=pipeline,
            flat_pcollection=flattened_pcoll,
            pcollection_dict=input_values_pcoll_dict,
            graph=graph,
            input_signature=input_signature,
            input_schema=input_schema,
            cache_pcoll_dict=dataset_cache_dict)

        transform_fn_future, cache_value_nodes = analysis_graph_builder.build(
            graph,
            input_signature,
            output_signature,
            input_values_pcoll_dict.keys(),
            cache_dict=dataset_cache_dict)

        traverser = nodes.Traverser(
            common.ConstructBeamPipelineVisitor(extra_args))
        transform_fn_pcoll = traverser.visit_value_node(transform_fn_future)

        if cache_value_nodes is not None:
            output_cache_pcoll_dict = {}
            for (dataset_key,
                 cache_key), value_node in six.iteritems(cache_value_nodes):
                if dataset_key not in output_cache_pcoll_dict:
                    output_cache_pcoll_dict[dataset_key] = {}
                output_cache_pcoll_dict[dataset_key][cache_key] = (
                    traverser.visit_value_node(value_node))
        else:
            output_cache_pcoll_dict = None

        # Infer metadata.  We take the inferred metadata and apply overrides that
        # refer to values of tensors in the graph.  The override tensors must
        # be "constant" in that they don't depend on input data.  The tensors can
        # depend on analyzer outputs though.  This allows us to set metadata that
        # depends on analyzer outputs. _infer_metadata_from_saved_model will use the
        # analyzer outputs stored in `transform_fn` to compute the metadata in a
        # deferred manner, once the analyzer outputs are known.
        metadata = dataset_metadata.DatasetMetadata(
            schema=schema_inference.infer_feature_schema(
                output_signature, graph))

        deferred_metadata = (transform_fn_pcoll
                             | 'ComputeDeferredMetadata' >>
                             beam.Map(_infer_metadata_from_saved_model))

        full_metadata = beam_metadata_io.BeamDatasetMetadata(
            metadata, deferred_metadata)

        _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll)

        return (transform_fn_pcoll, full_metadata), output_cache_pcoll_dict
예제 #8
0
def make_feed_list(column_names, schema, instances):
  """Creates a feed list for passing data to the graph.

  Converts a list of instances in the in-memory representation to a batch
  suitable for passing to `tf.Session.run`.

  Args:
    column_names: A list of column names.
    schema: A `Schema` proto.
    instances: A list of instances, each of which is a map from column name to a
      python primitive, list, or ndarray.

  Returns:
    A list of batches in the format required by a tf `Callable`.

  Raises:
    ValueError: If `schema` is invalid.
  """
  def make_batch_indices(instance_indices):
    """Converts a list of instance indices to the corresponding batch indices.

    Given a list of iterables representing the indices of N sparse tensors,
    creates a single list of indices representing the result of concatenating
    the sparse tensors along the 0'th dimension into a batch of size N.

    Args:
      instance_indices: A list of N iterables, each containing the sparse tensor
        indices for an instance.

    Returns:
      A list of indices with a batch dimension prepended.
    """
    batch_indices = list(itertools.chain.from_iterable([
        [(row_number, index) for index in indices]
        for row_number, indices in enumerate(instance_indices)
    ]))
    # Indices must have shape (?, 2). Therefore if we encounter an empty
    # batch, we return an empty ndarray with shape (0, 2).
    return batch_indices if batch_indices else np.empty([0, 2], dtype=np.int64)

  def make_sparse_batch(instance_indices, instance_values, max_index):
    """Converts a list of sparse instances into a sparse batch.

    Takes lists representing the indices and values of N sparse instances and
    concatenates them along the 0'th dimension into a sparse batch of size N.

    Args:
      instance_indices: A list of N iterables, each containing the sparse tensor
        indices for an instance.
      instance_values: A list of N iterables, each containing the sparse tensor
        values for an instance.
      max_index: An int representing the maximum index in `instance_indices`.

    Returns:
      A `SparseTensorValue` representing a batch of N sparse instances.
    """
    batch_indices = make_batch_indices(instance_indices)
    batch_values = list(itertools.chain.from_iterable(instance_values))
    batch_shape = (len(instance_indices), max_index)
    return tf.compat.v1.SparseTensorValue(batch_indices, batch_values,
                                          batch_shape)

  result = []
  feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec
  for name in column_names:
    spec = feature_spec[name]
    # TODO(abrao): Validate dtypes, shapes etc.
    if isinstance(spec, tf.io.FixedLenFeature):
      feed_value = [instance[name] for instance in instances]

    elif isinstance(spec, tf.io.VarLenFeature):
      values = [[] if instance[name] is None else instance[name]
                for instance in instances]
      indices = [range(len(value)) for value in values]
      max_index = max([len(value) for value in values])
      feed_value = make_sparse_batch(indices, values, max_index)

    elif isinstance(spec, tf.io.SparseFeature):
      # TODO(KesterTong): Add support for N-d SparseFeatures.
      max_index = spec.size
      indices, values = [], []
      for instance in instances:
        instance_indices = instance[spec.index_key]
        instance_values = instance[spec.value_key]
        check_valid_sparse_tensor(
            instance_indices, instance_values, max_index, name)
        indices.append(instance_indices)
        values.append(instance_values)
      feed_value = make_sparse_batch(indices, values, max_index)

    else:
      raise ValueError('Invalid feature spec {}.'.format(spec))
    result.append(feed_value)

  return result
예제 #9
0
def _get_raw_feature_spec(schema):
    # Tf.Transform considers these features as "raw"
    return schema_utils.schema_as_feature_spec(schema).feature_spec
예제 #10
0
    def __init__(self,
                 column_names,
                 schema,
                 delimiter=',',
                 secondary_delimiter=None,
                 multivalent_columns=None):
        """Initializes CsvCoder.

    Args:
      column_names: Tuple of strings. Order must match the order in the file.
      schema: A `Schema` proto.
      delimiter: A one-character string used to separate fields.
      secondary_delimiter: A one-character string used to separate values within
        the same field.
      multivalent_columns: A list of names for multivalent columns that need to
        be split based on secondary delimiter.

    Raises:
      ValueError: If `schema` is invalid.
    """
        self._column_names = column_names
        self._schema = schema
        self._delimiter = delimiter
        self._secondary_delimiter = secondary_delimiter
        self._encoder = self._WriterWrapper(delimiter)

        if multivalent_columns is None:
            multivalent_columns = []
        self._multivalent_columns = multivalent_columns

        if secondary_delimiter:
            secondary_encoder = self._WriterWrapper(secondary_delimiter)
        elif multivalent_columns:
            raise ValueError(
                'secondary_delimiter unspecified for multivalent columns "{}"'.
                format(multivalent_columns))
        secondary_encoder_by_name = {
            name: secondary_encoder
            for name in multivalent_columns
        }
        indices_by_name = {
            name: index
            for index, name in enumerate(self._column_names)
        }

        def index(name):
            index = indices_by_name.get(name)
            if index is None:
                raise ValueError('Column not found: "{}"'.format(name))
            else:
                return index

        self._feature_handlers = []
        for name, feature_spec in schema_utils.schema_as_feature_spec(
                schema).feature_spec.items():
            if isinstance(feature_spec, tf.io.FixedLenFeature):
                self._feature_handlers.append(
                    _FixedLenFeatureHandler(
                        name, feature_spec, index(name),
                        secondary_encoder_by_name.get(name)))
            elif isinstance(feature_spec, tf.io.VarLenFeature):
                self._feature_handlers.append(
                    _VarLenFeatureHandler(name, feature_spec.dtype,
                                          index(name),
                                          secondary_encoder_by_name.get(name)))
            elif isinstance(feature_spec, tf.io.SparseFeature):
                index_keys = (feature_spec.index_key if isinstance(
                    feature_spec.index_key, list) else
                              [feature_spec.index_key])
                for key in index_keys:
                    self._feature_handlers.append(
                        _VarLenFeatureHandler(
                            key, tf.int64, index(key),
                            secondary_encoder_by_name.get(name)))
                self._feature_handlers.append(
                    _VarLenFeatureHandler(feature_spec.value_key,
                                          feature_spec.dtype,
                                          index(feature_spec.value_key),
                                          secondary_encoder_by_name.get(name)))
            else:
                raise ValueError(
                    'feature_spec should be one of tf.FixedLenFeature, '
                    'tf.VarLenFeature or tf.SparseFeature: {!r} was {!r}'.
                    format(name, type(feature_spec)))
예제 #11
0
def _get_raw_feature_spec(schema: schema_pb2.Schema):
    """Get the feature spec from the schema."""
    return schema_utils.schema_as_feature_spec(schema).feature_spec
예제 #12
0
파일: executor.py 프로젝트: rummens/tfx
    def _RunBeamImpl(self, inputs: Mapping[Text, Any],
                     outputs: Mapping[Text, Any], preprocessing_fn: Any,
                     input_dataset_metadata: dataset_metadata.DatasetMetadata,
                     raw_examples_data_format: Text,
                     transform_output_path: Text, compute_statistics: bool,
                     materialize_output_paths: Sequence[Text]) -> _Status:
        """Perform data preprocessing with FlumeC++ runner.

    Args:
      inputs: A dictionary of labelled input values.
      outputs: A dictionary of labelled output values.
      preprocessing_fn: The tf.Transform preprocessing_fn.
      input_dataset_metadata: A DatasetMetadata object for the input data.
      raw_examples_data_format: A string describing the raw data format.
      transform_output_path: An absolute path to write the output to.
      compute_statistics: A bool indicating whether or not compute statistics.
      materialize_output_paths: Paths to materialized outputs.

    Raises:
      RuntimeError: If reset() is not being invoked between two run().
      ValueError: If the schema is empty.

    Returns:
      Status of the execution.
    """
        raw_examples_file_format = common.GetSoleValue(
            inputs, labels.EXAMPLES_FILE_FORMAT_LABEL, strict=False)
        analyze_and_transform_data_paths = common.GetValues(
            inputs, labels.ANALYZE_AND_TRANSFORM_DATA_PATHS_LABEL)
        transform_only_data_paths = common.GetValues(
            inputs, labels.TRANSFORM_ONLY_DATA_PATHS_LABEL)
        stats_use_tfdv = common.GetSoleValue(
            inputs, labels.TFT_STATISTICS_USE_TFDV_LABEL)
        per_set_stats_output_paths = common.GetValues(
            outputs, labels.PER_SET_STATS_OUTPUT_PATHS_LABEL)
        temp_path = common.GetSoleValue(outputs, labels.TEMP_OUTPUT_LABEL)

        tf.logging.info('Analyze and transform data patterns: %s',
                        list(enumerate(analyze_and_transform_data_paths)))
        tf.logging.info('Transform data patterns: %s',
                        list(enumerate(transform_only_data_paths)))
        tf.logging.info('Transform materialization output paths: %s',
                        list(enumerate(materialize_output_paths)))
        tf.logging.info('Transform output path: %s', transform_output_path)

        feature_spec = schema_utils.schema_as_feature_spec(
            _GetSchemaProto(input_dataset_metadata)).feature_spec
        try:
            analyze_input_columns = tft.get_analyze_input_columns(
                preprocessing_fn, feature_spec)
            transform_input_columns = (tft.get_transform_input_columns(
                preprocessing_fn, feature_spec))
        except AttributeError:
            # If using TFT 1.12, fall back to assuming all features are used.
            analyze_input_columns = feature_spec.keys()
            transform_input_columns = feature_spec.keys()
        # Use the same dataset (same columns) for AnalyzeDataset and computing
        # pre-transform stats so that the data will only be read once for these
        # two operations.
        if compute_statistics:
            analyze_input_columns = list(
                set(
                    list(analyze_input_columns) +
                    list(transform_input_columns)))
        if input_dataset_metadata.schema is _RAW_EXAMPLE_SCHEMA:
            analyze_input_dataset_metadata = input_dataset_metadata
            transform_input_dataset_metadata = input_dataset_metadata
        else:
            analyze_input_dataset_metadata = dataset_metadata.DatasetMetadata(
                dataset_schema.from_feature_spec({
                    feature: feature_spec[feature]
                    for feature in analyze_input_columns
                }))
            transform_input_dataset_metadata = dataset_metadata.DatasetMetadata(
                dataset_schema.from_feature_spec({
                    feature: feature_spec[feature]
                    for feature in transform_input_columns
                }))

        can_process_jointly = not bool(per_set_stats_output_paths
                                       or materialize_output_paths)
        analyze_data_list = self._MakeDatasetList(
            analyze_and_transform_data_paths, raw_examples_file_format,
            raw_examples_data_format, analyze_input_dataset_metadata,
            can_process_jointly)
        transform_data_list = self._MakeDatasetList(
            list(analyze_and_transform_data_paths) +
            list(transform_only_data_paths), raw_examples_file_format,
            raw_examples_data_format, transform_input_dataset_metadata,
            can_process_jointly)

        desired_batch_size = self._GetDesiredBatchSize(
            raw_examples_data_format)

        with self._CreatePipeline(outputs) as p:
            with tft_beam.Context(
                    temp_dir=temp_path,
                    desired_batch_size=desired_batch_size,
                    passthrough_keys={_TRANSFORM_INTERNAL_FEATURE_FOR_KEY},
                    use_deep_copy_optimization=True):
                # pylint: disable=expression-not-assigned
                # pylint: disable=no-value-for-parameter

                _ = (p | self._IncrementColumnUsageCounter(
                    len(feature_spec.keys()), len(analyze_input_columns),
                    len(transform_input_columns)))

                analyze_decode_fn = (self._GetDecodeFunction(
                    raw_examples_data_format,
                    analyze_input_dataset_metadata.schema))

                for (idx, dataset) in enumerate(analyze_data_list):
                    dataset.encoded = (p
                                       | 'ReadAnalysisDataset[{}]'.format(idx)
                                       >> self._ReadExamples(dataset))
                    dataset.decoded = (
                        dataset.encoded
                        | 'DecodeAnalysisDataset[{}]'.format(idx) >>
                        self._DecodeInputs(analyze_decode_fn))

                input_analysis_data = (
                    [dataset.decoded for dataset in analyze_data_list]
                    | 'FlattenAnalysisDatasets' >> beam.Flatten())
                transform_fn = ((input_analysis_data, input_dataset_metadata)
                                | 'AnalyzeDataset' >>
                                tft_beam.AnalyzeDataset(preprocessing_fn))
                # Write the raw/input metadata.
                (input_dataset_metadata
                 | 'WriteMetadata' >> tft_beam.WriteMetadata(
                     os.path.join(transform_output_path,
                                  tft.TFTransformOutput.RAW_METADATA_DIR), p))

                # WriteTransformFn writes transform_fn and metadata to subdirectories
                # tensorflow_transform.SAVED_MODEL_DIR and
                # tensorflow_transform.TRANSFORMED_METADATA_DIR respectively.
                (transform_fn | 'WriteTransformFn' >>
                 tft_beam.WriteTransformFn(transform_output_path))

                if compute_statistics or materialize_output_paths:
                    # Do not compute pre-transform stats if the input format is raw proto,
                    # as StatsGen would treat any input as tf.Example.
                    if (compute_statistics and not self._IsDataFormatProto(
                            raw_examples_data_format)):
                        # Aggregated feature stats before transformation.
                        pre_transform_feature_stats_path = os.path.join(
                            transform_output_path, tft.TFTransformOutput.
                            PRE_TRANSFORM_FEATURE_STATS_PATH)

                        schema_proto = _GetSchemaProto(
                            analyze_input_dataset_metadata)
                        ([
                            dataset.decoded
                            if stats_use_tfdv else dataset.encoded
                            for dataset in analyze_data_list
                        ]
                         | 'FlattenPreTransformAnalysisDatasets' >>
                         beam.Flatten()
                         | 'GenerateAggregatePreTransformAnalysisStats' >>
                         self._GenerateStats(pre_transform_feature_stats_path,
                                             schema_proto,
                                             use_deep_copy_optimization=True,
                                             use_tfdv=stats_use_tfdv))

                    transform_decode_fn = (self._GetDecodeFunction(
                        raw_examples_data_format,
                        transform_input_dataset_metadata.schema))
                    # transform_data_list is a superset of analyze_data_list, we pay the
                    # cost to read the same dataset (analyze_data_list) again here to
                    # prevent certain beam runner from doing large temp materialization.
                    for (idx, dataset) in enumerate(transform_data_list):
                        dataset.encoded = (
                            p
                            | 'ReadTransformDataset[{}]'.format(idx) >>
                            self._ReadExamples(dataset))
                        dataset.decoded = (
                            dataset.encoded
                            | 'DecodeTransformDataset[{}]'.format(idx) >>
                            self._DecodeInputs(transform_decode_fn))
                        (dataset.transformed, metadata) = (
                            ((dataset.decoded,
                              transform_input_dataset_metadata), transform_fn)
                            | 'TransformDataset[{}]'.format(idx) >>
                            tft_beam.TransformDataset())

                        if materialize_output_paths or not stats_use_tfdv:
                            dataset.transformed_and_encoded = (
                                dataset.transformed
                                | 'EncodeTransformedDataset[{}]'.format(idx) >>
                                beam.ParDo(self._EncodeAsExamples(), metadata))

                    if compute_statistics:
                        # Aggregated feature stats after transformation.
                        _, metadata = transform_fn
                        post_transform_feature_stats_path = os.path.join(
                            transform_output_path, tft.TFTransformOutput.
                            POST_TRANSFORM_FEATURE_STATS_PATH)

                        # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in
                        # schema. Currently input dataset schema only contains dtypes,
                        # and other metadata is dropped due to roundtrip to tensors.
                        transformed_schema_proto = _GetSchemaProto(metadata)

                        ([(dataset.transformed if stats_use_tfdv else
                           dataset.transformed_and_encoded)
                          for dataset in transform_data_list]
                         | 'FlattenPostTransformAnalysisDatasets' >>
                         beam.Flatten()
                         | 'GenerateAggregatePostTransformAnalysisStats' >>
                         self._GenerateStats(post_transform_feature_stats_path,
                                             transformed_schema_proto,
                                             use_tfdv=stats_use_tfdv))

                        if per_set_stats_output_paths:
                            assert len(transform_data_list) == len(
                                per_set_stats_output_paths)
                            # TODO(b/67632871): Remove duplicate stats gen compute that is
                            # done both on a flattened view of the data, and on each span
                            # below.
                            bundles = zip(transform_data_list,
                                          per_set_stats_output_paths)
                            for (idx, (dataset,
                                       output_path)) in enumerate(bundles):
                                if stats_use_tfdv:
                                    data = dataset.transformed
                                else:
                                    data = dataset.transformed_and_encoded
                                (data
                                 | 'GeneratePostTransformStats[{}]'.format(idx)
                                 >> self._GenerateStats(
                                     output_path,
                                     transformed_schema_proto,
                                     use_tfdv=stats_use_tfdv))

                    if materialize_output_paths:
                        assert len(transform_data_list) == len(
                            materialize_output_paths)
                        bundles = zip(transform_data_list,
                                      materialize_output_paths)
                        for (idx, (dataset,
                                   output_path)) in enumerate(bundles):
                            (dataset.transformed_and_encoded
                             | 'Materialize[{}]'.format(idx) >>
                             self._WriteExamples(raw_examples_file_format,
                                                 output_path))

        return _Status.OK()
예제 #13
0
파일: executor.py 프로젝트: rummens/tfx
    def Transform(self, inputs: Mapping[Text, Any],
                  outputs: Mapping[Text, Any], status_file: Text) -> None:
        """Executes on request.

    This is the implementation part of transform executor. This is intended for
    using or extending the executor without artifact dependency.

    Args:
      inputs: A dictionary of labelled input values, including:
        - labels.COMPUTE_STATISTICS_LABEL: Whether compute statistics.
        - labels.SCHEMA_PATH_LABEL: Path to schema file.
        - labels.EXAMPLES_FILE_FORMAT_LABEL: Example file format, optional.
        - labels.EXAMPLES_DATA_FORMAT_LABEL: Example data format.
        - labels.ANALYZE_AND_TRANSFORM_DATA_PATHS_LABEL: Paths or path patterns
          to analyze and transform data.
        - labels.TRANSFORM_DATA_PATHS_LABEL: Paths or path patterns to transform
          only data.
        - labels.TFT_STATISTICS_USE_TFDV_LABEL: Whether use tfdv to compute
          statistics.
        - labels.PREPROCESSING_FN: Path to a Python module that contains the
          preprocessing_fn, optional.
      outputs: A dictionary of labelled output values, including:
        - labels.PER_SET_STATS_OUTPUT_PATHS_LABEL: Paths to statistics output,
          optional.
        - labels.TRANSFORM_METADATA_OUTPUT_PATH_LABEL: A path to
          TFTransformOutput output.
        - labels.TRANSFORM_MATERIALIZE_OUTPUT_PATHS_LABEL: Paths to transform
          materialization.
        - labels.TEMP_OUTPUT_LABEL: A path to temporary directory.
      status_file: Where the status should be written (not yet implemented)
    """

        del status_file  # unused
        compute_statistics = common.GetSoleValue(
            inputs, labels.COMPUTE_STATISTICS_LABEL)
        transform_output_path = common.GetSoleValue(
            outputs, labels.TRANSFORM_METADATA_OUTPUT_PATH_LABEL)
        raw_examples_data_format = common.GetSoleValue(
            inputs, labels.EXAMPLES_DATA_FORMAT_LABEL)
        schema = common.GetSoleValue(inputs, labels.SCHEMA_PATH_LABEL)
        input_dataset_schema = self._ReadSchema(raw_examples_data_format,
                                                schema)
        input_dataset_metadata = dataset_metadata.DatasetMetadata(
            input_dataset_schema)

        tf.logging.info(
            'Inputs to executor.Transform function: {}'.format(inputs))
        tf.logging.info(
            'Outputs to executor.Transform function: {}'.format(outputs))

        feature_spec = schema_utils.schema_as_feature_spec(
            _GetSchemaProto(input_dataset_metadata)).feature_spec

        # NOTE: We disallow an empty schema, which we detect by testing the
        # number of columns.  While in principal an empty schema is valid, in
        # practice this is a sign of a user error, and this is a convenient
        # place to catch that error.
        if (not feature_spec and
                not self._ShouldDecodeAsRawExample(raw_examples_data_format)):
            raise ValueError(messages.SCHEMA_EMPTY)

        preprocessing_fn = self._GetPreprocessingFn(inputs, outputs)

        materialize_output_paths = common.GetValues(
            outputs, labels.TRANSFORM_MATERIALIZE_OUTPUT_PATHS_LABEL)

        # Inspecting the preprocessing_fn even if we know we need a full pass in
        # order to fail faster if it fails.
        try:
            analyze_input_columns = tft.get_analyze_input_columns(
                preprocessing_fn, feature_spec)
        except AttributeError:
            # If using TFT 1.12, fall back to assuming all features are used.
            analyze_input_columns = feature_spec.keys()

        if not compute_statistics and not materialize_output_paths:
            if analyze_input_columns:
                tf.logging.warning(
                    'Not using the in-place Transform because the following features '
                    'require analyzing: {}'.format(
                        tuple(c for c in analyze_input_columns)))
            else:
                tf.logging.warning(
                    'Using the in-place Transform since compute_statistics=False, '
                    'it does not materialize transformed data, and the configured '
                    'preprocessing_fn appears to not require analyzing the data.'
                )
                self._RunInPlaceImpl(preprocessing_fn, input_dataset_metadata,
                                     transform_output_path)
                # TODO(b/122478841): Writes status to status file.
                return
        self._RunBeamImpl(inputs, outputs, preprocessing_fn,
                          input_dataset_metadata, raw_examples_data_format,
                          transform_output_path, compute_statistics,
                          materialize_output_paths)