コード例 #1
0
  def test_validate_dataset_keys(self):
    analyzer_cache.validate_dataset_keys(
        {'foo', 'Foo', 'A1', 'A_1', 'A.1', 'A-1'})

    for key in {'foo 1', 'foo@1', 'foo*', 'foo[]', 'foo/goo'}:
      with self.assertRaisesRegexp(
          ValueError, 'Dataset key .* does not match allowed pattern:'):
        analyzer_cache.validate_dataset_keys({key})
コード例 #2
0
ファイル: impl.py プロジェクト: jiwidi/transform
 def expand(self, dataset):
   input_values_pcoll_dict = dataset[1] or dict()
   analyzer_cache.validate_dataset_keys(input_values_pcoll_dict.keys())
   return super(AnalyzeDatasetWithCache, self).expand(dataset)
コード例 #3
0
    def expand(self, dataset):
        """Analyze the dataset.

    Args:
      dataset: A dataset.

    Returns:
      A TransformFn containing the deferred transform function.

    Raises:
      ValueError: If preprocessing_fn has no outputs.
    """
        (flattened_pcoll, input_values_pcoll_dict, dataset_cache_dict,
         input_metadata) = dataset
        input_schema = input_metadata.schema

        input_values_pcoll_dict = input_values_pcoll_dict or dict()

        analyzer_cache.validate_dataset_keys(input_values_pcoll_dict.keys())

        with tf.Graph().as_default() as graph:

            with tf.compat.v1.name_scope('inputs'):
                feature_spec = input_schema.as_feature_spec()
                input_signature = impl_helper.feature_spec_as_batched_placeholders(
                    feature_spec)
                # In order to avoid a bug where import_graph_def fails when the
                # input_map and return_elements of an imported graph are the same
                # (b/34288791), we avoid using the placeholder of an input column as an
                # output of a graph. We do this by applying tf.identity to all inputs of
                # the preprocessing_fn.  Note this applies at the level of raw tensors.
                # TODO(b/34288791): Remove this workaround and use a shallow copy of
                # inputs instead.  A shallow copy is needed in case
                # self._preprocessing_fn mutates its input.
                copied_inputs = impl_helper.copy_tensors(input_signature)

            output_signature = self._preprocessing_fn(copied_inputs)

        # At this point we check that the preprocessing_fn has at least one
        # output. This is because if we allowed the output of preprocessing_fn to
        # be empty, we wouldn't be able to determine how many instances to
        # "unbatch" the output into.
        if not output_signature:
            raise ValueError(
                'The preprocessing function returned an empty dict')

        if graph.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES):
            raise ValueError(
                'The preprocessing function contained trainable variables '
                '{}'.format(
                    graph.get_collection_ref(
                        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)))

        pipeline = flattened_pcoll.pipeline
        tf_config = common._DEFAULT_TENSORFLOW_CONFIG_BY_RUNNER.get(  # pylint: disable=protected-access
            pipeline.runner)
        extra_args = common.ConstructBeamPipelineVisitor.ExtraArgs(
            base_temp_dir=Context.create_base_temp_dir(),
            tf_config=tf_config,
            pipeline=pipeline,
            flat_pcollection=flattened_pcoll,
            pcollection_dict=input_values_pcoll_dict,
            graph=graph,
            input_signature=input_signature,
            input_schema=input_schema,
            cache_pcoll_dict=dataset_cache_dict)

        transform_fn_future, cache_value_nodes = analysis_graph_builder.build(
            graph,
            input_signature,
            output_signature,
            input_values_pcoll_dict.keys(),
            cache_dict=dataset_cache_dict)

        traverser = nodes.Traverser(
            common.ConstructBeamPipelineVisitor(extra_args))
        transform_fn_pcoll = traverser.visit_value_node(transform_fn_future)

        if cache_value_nodes is not None:
            output_cache_pcoll_dict = {}
            for (dataset_key,
                 cache_key), value_node in six.iteritems(cache_value_nodes):
                if dataset_key not in output_cache_pcoll_dict:
                    output_cache_pcoll_dict[dataset_key] = {}
                output_cache_pcoll_dict[dataset_key][cache_key] = (
                    traverser.visit_value_node(value_node))
        else:
            output_cache_pcoll_dict = None

        # Infer metadata.  We take the inferred metadata and apply overrides that
        # refer to values of tensors in the graph.  The override tensors must
        # be "constant" in that they don't depend on input data.  The tensors can
        # depend on analyzer outputs though.  This allows us to set metadata that
        # depends on analyzer outputs. _infer_metadata_from_saved_model will use the
        # analyzer outputs stored in `transform_fn` to compute the metadata in a
        # deferred manner, once the analyzer outputs are known.
        metadata = dataset_metadata.DatasetMetadata(
            schema=schema_inference.infer_feature_schema(
                output_signature, graph))

        deferred_metadata = (transform_fn_pcoll
                             | 'ComputeDeferredMetadata' >>
                             beam.Map(_infer_metadata_from_saved_model))

        full_metadata = beam_metadata_io.BeamDatasetMetadata(
            metadata, deferred_metadata)

        _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll)

        return (transform_fn_pcoll, full_metadata), output_cache_pcoll_dict