Esempio n. 1
0
    def test_optimize_traversal(self, feature_spec, preprocessing_fn,
                                dataset_input_cache_dict,
                                expected_dot_graph_str):
        span_0_key, span_1_key = 'span-0', 'span-1'
        if dataset_input_cache_dict is not None:
            cache = {span_0_key: dataset_input_cache_dict}
        else:
            cache = {}

        with tf.compat.v1.name_scope('inputs'):
            input_signature = impl_helper.feature_spec_as_batched_placeholders(
                feature_spec)
        output_signature = preprocessing_fn(input_signature)
        transform_fn_future, cache_output_dict = analysis_graph_builder.build(
            tf.compat.v1.get_default_graph(), input_signature,
            output_signature, {span_0_key, span_1_key}, cache)

        leaf_nodes = [transform_fn_future] + sorted(cache_output_dict.values(),
                                                    key=str)
        dot_string = nodes.get_dot_graph(leaf_nodes).to_string()
        self.WriteRenderedDotFile(dot_string)

        self.assertSameElements(
            dot_string.split('\n'),
            expected_dot_graph_str.split('\n'),
            msg='Result dot graph is:\n{}'.format(dot_string))
Esempio n. 2
0
  def test_perform_combiner_packing_optimization(
      self, feature_spec, preprocessing_fn, num_phases,
      expected_dot_graph_str_before_packing,
      expected_dot_graph_str_after_packing):

    graph, structured_inputs, structured_outputs = (
        impl_helper.trace_preprocessing_function(
            preprocessing_fn, feature_spec, use_tf_compat_v1=True))

    def _side_effect_fn(saved_model_future, cache_value_nodes,
                        unused_num_phases):
      return (saved_model_future, cache_value_nodes)

    with mock.patch.object(
        combiner_packing_util,
        'perform_combiner_packing_optimization',
        side_effect=_side_effect_fn):
      transform_fn_future_before, unused_cache = analysis_graph_builder.build(
          graph, structured_inputs, structured_outputs)
    transform_fn_future_after, unused_cache = (
        combiner_packing_util.perform_combiner_packing_optimization(
            transform_fn_future_before, unused_cache, num_phases))
    dot_string_before = nodes.get_dot_graph(
        [transform_fn_future_before]).to_string()
    self.assertMultiLineEqual(
        msg='Result dot graph is:\n{}'.format(dot_string_before),
        first=dot_string_before,
        second=expected_dot_graph_str_before_packing)
    dot_string_after = nodes.get_dot_graph(
        [transform_fn_future_after]).to_string()
    self.WriteRenderedDotFile(dot_string_after)
    self.assertMultiLineEqual(
        msg='Result dot graph is:\n{}'.format(dot_string_after),
        first=dot_string_after,
        second=expected_dot_graph_str_after_packing)
    def test_perform_combiner_packing_optimization(
            self, feature_spec, preprocessing_fn, num_phases,
            expected_dot_graph_str_before_packing,
            expected_dot_graph_str_after_packing):
        with tf.compat.v1.Graph().as_default() as graph:
            with tf.compat.v1.name_scope('inputs'):
                input_signature = impl_helper.feature_spec_as_batched_placeholders(
                    feature_spec)
            output_signature = preprocessing_fn(input_signature)

            def _side_effect_fn(saved_model_future, cache_value_nodes,
                                unused_num_phases):
                return (saved_model_future, cache_value_nodes)

            with mock.patch.object(combiner_packing_util,
                                   'perform_combiner_packing_optimization',
                                   side_effect=_side_effect_fn):
                transform_fn_future_before, unused_cache = analysis_graph_builder.build(
                    graph, input_signature, output_signature)
            transform_fn_future_after, unused_cache = (
                combiner_packing_util.perform_combiner_packing_optimization(
                    transform_fn_future_before, unused_cache, num_phases))
        dot_string_before = nodes.get_dot_graph([transform_fn_future_before
                                                 ]).to_string()
        self.assertMultiLineEqual(
            msg='Result dot graph is:\n{}'.format(dot_string_before),
            first=dot_string_before,
            second=expected_dot_graph_str_before_packing)
        dot_string_after = nodes.get_dot_graph([transform_fn_future_after
                                                ]).to_string()
        self.WriteRenderedDotFile(dot_string_after)
        self.assertMultiLineEqual(
            msg='Result dot graph is:\n{}'.format(dot_string_after),
            first=dot_string_after,
            second=expected_dot_graph_str_after_packing)
Esempio n. 4
0
  def test_build(self, feature_spec, preprocessing_fn, expected_dot_graph_str):
    graph, structured_inputs, structured_outputs = (
        impl_helper.trace_preprocessing_function(
            preprocessing_fn, feature_spec, use_tf_compat_v1=True))
    transform_fn_future, unused_cache = analysis_graph_builder.build(
        graph, structured_inputs, structured_outputs)

    dot_string = nodes.get_dot_graph([transform_fn_future]).to_string()
    self.WriteRenderedDotFile(dot_string)
    self.assertMultiLineEqual(
        msg='Result dot graph is:\n{}'.format(dot_string),
        first=dot_string,
        second=expected_dot_graph_str)
Esempio n. 5
0
  def test_build(self, feature_spec, preprocessing_fn, expected_dot_graph_str):
    with tf.name_scope('inputs'):
      input_signature = impl_helper.feature_spec_as_batched_placeholders(
          feature_spec)
    output_signature = preprocessing_fn(input_signature)
    transform_fn_future = analysis_graph_builder.build(
        tf.get_default_graph(), input_signature, output_signature)

    dot_string = nodes.get_dot_graph([transform_fn_future]).to_string()
    self.WriteRenderedDotFile(dot_string)

    self.assertMultiLineEqual(
        msg='Result dot graph is:\n{}'.format(dot_string),
        first=dot_string,
        second=expected_dot_graph_str)
Esempio n. 6
0
  def test_optimize_traversal(self, feature_spec, preprocessing_fn,
                              write_cache_fn, expected_dot_graph_str):
    cache_location = self._make_cache_location()
    span_0_key, span_1_key = 'span-0', 'span-1'
    if write_cache_fn is not None:
      write_cache_fn(cache_location.input_cache_dir, [span_0_key, span_1_key])

    with tf.name_scope('inputs'):
      input_signature = impl_helper.feature_spec_as_batched_placeholders(
          feature_spec)
    output_signature = preprocessing_fn(input_signature)
    transform_fn_future = analysis_graph_builder.build(
        tf.get_default_graph(), input_signature, output_signature,
        {span_0_key, span_1_key}, cache_location)

    dot_string = nodes.get_dot_graph([transform_fn_future]).to_string()
    self.WriteRenderedDotFile(dot_string)

    self.assertSameElements(
        dot_string.split('\n'),
        expected_dot_graph_str.split('\n'),
        msg='Result dot graph is:\n{}'.format(dot_string))
    def test_build(self, feature_spec, preprocessing_fn,
                   expected_dot_graph_str, expected_dot_graph_str_tf2,
                   use_tf_compat_v1):
        if not use_tf_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')
        specs = (feature_spec if use_tf_compat_v1 else
                 impl_helper.get_type_specs_from_feature_specs(feature_spec))
        graph, structured_inputs, structured_outputs = (
            impl_helper.trace_preprocessing_function(
                preprocessing_fn,
                specs,
                use_tf_compat_v1=use_tf_compat_v1,
                base_temp_dir=os.path.join(self.get_temp_dir(),
                                           self._testMethodName)))
        transform_fn_future, unused_cache = analysis_graph_builder.build(
            graph, structured_inputs, structured_outputs)

        dot_string = nodes.get_dot_graph([transform_fn_future]).to_string()
        self.WriteRenderedDotFile(dot_string)
        self.assertMultiLineEqual(
            msg='Result dot graph is:\n{}'.format(dot_string),
            first=dot_string,
            second=(expected_dot_graph_str
                    if use_tf_compat_v1 else expected_dot_graph_str_tf2))
Esempio n. 8
0
  def expand(self, dataset):
    """Analyze the dataset.

    Args:
      dataset: A dataset.

    Returns:
      A TransformFn containing the deferred transform function.

    Raises:
      ValueError: If preprocessing_fn has no outputs.
    """
    flattened_pcoll, input_values_pcoll_dict, input_metadata = dataset
    input_schema = input_metadata.schema

    input_values_pcoll_dict = input_values_pcoll_dict or dict()

    analyzer_cache.validate_dataset_keys(input_values_pcoll_dict.keys())

    with tf.Graph().as_default() as graph:

      with tf.name_scope('inputs'):
        feature_spec = input_schema.as_feature_spec()
        input_signature = impl_helper.feature_spec_as_batched_placeholders(
            feature_spec)
        # In order to avoid a bug where import_graph_def fails when the
        # input_map and return_elements of an imported graph are the same
        # (b/34288791), we avoid using the placeholder of an input column as an
        # output of a graph. We do this by applying tf.identity to all inputs of
        # the preprocessing_fn.  Note this applies at the level of raw tensors.
        # TODO(b/34288791): Remove this workaround and use a shallow copy of
        # inputs instead.  A shallow copy is needed in case
        # self._preprocessing_fn mutates its input.
        copied_inputs = impl_helper.copy_tensors(input_signature)

      output_signature = self._preprocessing_fn(copied_inputs)

    # At this point we check that the preprocessing_fn has at least one
    # output. This is because if we allowed the output of preprocessing_fn to
    # be empty, we wouldn't be able to determine how many instances to
    # "unbatch" the output into.
    if not output_signature:
      raise ValueError('The preprocessing function returned an empty dict')

    if graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
      raise ValueError(
          'The preprocessing function contained trainable variables '
          '{}'.format(
              graph.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)))

    pipeline = flattened_pcoll.pipeline
    serialized_tf_config = common._DEFAULT_TENSORFLOW_CONFIG_BY_RUNNER.get(  # pylint: disable=protected-access
        pipeline.runner)
    extra_args = common.ConstructBeamPipelineVisitor.ExtraArgs(
        base_temp_dir=Context.create_base_temp_dir(),
        serialized_tf_config=serialized_tf_config,
        pipeline=pipeline,
        flat_pcollection=flattened_pcoll,
        pcollection_dict=input_values_pcoll_dict,
        graph=graph,
        input_signature=input_signature,
        input_schema=input_schema,
        cache_location=self._cache_location)

    transform_fn_future = analysis_graph_builder.build(
        graph, input_signature, output_signature,
        input_values_pcoll_dict.keys(), self._cache_location)

    transform_fn_pcoll = nodes.Traverser(
        common.ConstructBeamPipelineVisitor(extra_args)).visit_value_node(
            transform_fn_future)

    # Infer metadata.  We take the inferred metadata and apply overrides that
    # refer to values of tensors in the graph.  The override tensors must
    # be "constant" in that they don't depend on input data.  The tensors can
    # depend on analyzer outputs though.  This allows us to set metadata that
    # depends on analyzer outputs. _augment_metadata will use the analyzer
    # outputs stored in `transform_fn` to compute the metadata in a
    # deferred manner, once the analyzer outputs are known.
    metadata = dataset_metadata.DatasetMetadata(
        schema=schema_inference.infer_feature_schema(output_signature, graph))

    deferred_metadata = (
        transform_fn_pcoll
        |
        'ComputeDeferredMetadata' >> beam.Map(_infer_metadata_from_saved_model))

    full_metadata = beam_metadata_io.BeamDatasetMetadata(
        metadata, deferred_metadata)

    _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll)

    return transform_fn_pcoll, full_metadata
Esempio n. 9
0
  def expand(self, dataset):
    """Analyze the dataset.

    Args:
      dataset: A dataset.

    Returns:
      A TransformFn containing the deferred transform function.

    Raises:
      ValueError: If preprocessing_fn has no outputs.
    """
    (flattened_pcoll, input_values_pcoll_dict, dataset_cache_dict,
     input_metadata) = dataset
    if self._use_tfxio:
      input_schema = None
      input_tensor_adapter_config = input_metadata
    else:
      input_schema = input_metadata.schema
      input_tensor_adapter_config = None

    input_values_pcoll_dict = input_values_pcoll_dict or dict()

    with tf.compat.v1.Graph().as_default() as graph:

      with tf.compat.v1.name_scope('inputs'):
        if self._use_tfxio:
          specs = TensorAdapter(input_tensor_adapter_config).OriginalTypeSpecs()
        else:
          specs = schema_utils.schema_as_feature_spec(input_schema).feature_spec
        input_signature = impl_helper.batched_placeholders_from_specs(specs)
        # In order to avoid a bug where import_graph_def fails when the
        # input_map and return_elements of an imported graph are the same
        # (b/34288791), we avoid using the placeholder of an input column as an
        # output of a graph. We do this by applying tf.identity to all inputs of
        # the preprocessing_fn.  Note this applies at the level of raw tensors.
        # TODO(b/34288791): Remove this workaround and use a shallow copy of
        # inputs instead.  A shallow copy is needed in case
        # self._preprocessing_fn mutates its input.
        copied_inputs = impl_helper.copy_tensors(input_signature)

      output_signature = self._preprocessing_fn(copied_inputs)

    # At this point we check that the preprocessing_fn has at least one
    # output. This is because if we allowed the output of preprocessing_fn to
    # be empty, we wouldn't be able to determine how many instances to
    # "unbatch" the output into.
    if not output_signature:
      raise ValueError('The preprocessing function returned an empty dict')

    if graph.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES):
      raise ValueError(
          'The preprocessing function contained trainable variables '
          '{}'.format(
              graph.get_collection_ref(
                  tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)))

    pipeline = self.pipeline or (flattened_pcoll or next(
        v for v in input_values_pcoll_dict.values() if v is not None)).pipeline

    # Add a stage that inspects graph collections for API use counts and logs
    # them as a beam metric.
    _ = (pipeline | 'InstrumentAPI' >> _InstrumentAPI(graph))

    tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get(
        type(pipeline.runner))
    extra_args = beam_common.ConstructBeamPipelineVisitor.ExtraArgs(
        base_temp_dir=Context.create_base_temp_dir(),
        tf_config=tf_config,
        pipeline=pipeline,
        flat_pcollection=flattened_pcoll,
        pcollection_dict=input_values_pcoll_dict,
        graph=graph,
        input_signature=input_signature,
        input_schema=input_schema,
        input_tensor_adapter_config=input_tensor_adapter_config,
        use_tfxio=self._use_tfxio,
        cache_pcoll_dict=dataset_cache_dict)

    transform_fn_future, cache_value_nodes = analysis_graph_builder.build(
        graph,
        input_signature,
        output_signature,
        input_values_pcoll_dict.keys(),
        cache_dict=dataset_cache_dict)
    traverser = nodes.Traverser(
        beam_common.ConstructBeamPipelineVisitor(extra_args))
    transform_fn_pcoll = traverser.visit_value_node(transform_fn_future)

    if cache_value_nodes is not None:
      output_cache_pcoll_dict = {}
      for (dataset_key,
           cache_key), value_node in six.iteritems(cache_value_nodes):
        if dataset_key not in output_cache_pcoll_dict:
          output_cache_pcoll_dict[dataset_key] = {}
        output_cache_pcoll_dict[dataset_key][cache_key] = (
            traverser.visit_value_node(value_node))
    else:
      output_cache_pcoll_dict = None

    # Infer metadata.  We take the inferred metadata and apply overrides that
    # refer to values of tensors in the graph.  The override tensors must
    # be "constant" in that they don't depend on input data.  The tensors can
    # depend on analyzer outputs though.  This allows us to set metadata that
    # depends on analyzer outputs. _infer_metadata_from_saved_model will use the
    # analyzer outputs stored in `transform_fn` to compute the metadata in a
    # deferred manner, once the analyzer outputs are known.
    metadata = dataset_metadata.DatasetMetadata(
        schema=schema_inference.infer_feature_schema(output_signature, graph))

    deferred_metadata = (
        transform_fn_pcoll
        |
        'ComputeDeferredMetadata' >> beam.Map(_infer_metadata_from_saved_model))

    full_metadata = beam_metadata_io.BeamDatasetMetadata(
        metadata, deferred_metadata)

    _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll)

    return (transform_fn_pcoll, full_metadata), output_cache_pcoll_dict