def _perform_combiner_packing_optimization(saved_model_future,
                                           cache_value_nodes):
    """Optimizes the graph by packing possible combine nodes."""
    # Inspect the graph to identify all the packable combines.
    inspect_combine_visitor = _InspectCombineVisitor()
    inspect_combine_traverser = nodes.Traverser(inspect_combine_visitor)
    _ = inspect_combine_traverser.visit_value_node(saved_model_future)
    if cache_value_nodes:
        for value_node in cache_value_nodes.values():
            _ = inspect_combine_traverser.visit_value_node(value_node)

    packable_combines = inspect_combine_visitor.packable_combines
    # Do not pack if we have only a single combine in the group.
    packable_combines = {
        label: group
        for label, group in packable_combines.items() if len(group) > 1
    }

    # Do another pass over the graph and pack the grouped combines.
    pack_combine_visitor = _PackCombineVisitor(packable_combines)
    pack_combine_traverser = nodes.Traverser(pack_combine_visitor)
    saved_model_future = pack_combine_traverser.visit_value_node(
        saved_model_future)
    # Replace cache nodes to point to the corresponding new nodes.
    if cache_value_nodes:
        cache_value_nodes = {
            key: pack_combine_traverser.visit_value_node(value_node)
            for key, value_node in cache_value_nodes.items()
        }
    return (saved_model_future, cache_value_nodes)
def get_analysis_dataset_keys(preprocessing_fn, specs, dataset_keys,
                              input_cache):
    """Computes the dataset keys that are required in order to perform analysis.

  Args:
    preprocessing_fn: A tf.transform preprocessing_fn.
    specs: A dict of feature name to feature specification or tf.TypeSpecs.
    dataset_keys: A set of strings which are dataset keys, they uniquely
      identify these datasets across analysis runs.
    input_cache: A cache dictionary.

  Returns:
    A pair of:
      - A set of dataset keys that are required for analysis.
      - A boolean indicating whether or not a flattened version of the entire
        dataset is required. See the `flat_data` input to
        `AnalyzeDatasetWithCache`.
  """
    transform_fn_future, _ = _build_analysis_graph_for_inspection(
        preprocessing_fn, specs, dataset_keys, input_cache)

    required_dataset_keys_result = set()
    inspect_visitor = _InspectVisitor(required_dataset_keys_result)
    inspect_traverser = nodes.Traverser(inspect_visitor)
    _ = inspect_traverser.visit_value_node(transform_fn_future)

    # If None is present this means that a flattened version of the entire dataset
    # is required, therefore this will be returning all of the given dataset_keys.
    flat_data_required = None in required_dataset_keys_result
    if flat_data_required:
        required_dataset_keys_result = dataset_keys
    return required_dataset_keys_result, flat_data_required
def get_analyze_input_columns(preprocessing_fn,
                              specs,
                              force_tf_compat_v1=False):
    """Return columns that are required inputs of `AnalyzeDataset`.

  Args:
    preprocessing_fn: A tf.transform preprocessing_fn.
    specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is
      True, this can also be feature specifications.
    force_tf_compat_v1: (Optional) If `True`, use Tensorflow in compat.v1 mode.
      Defaults to `False`.

  Returns:
    A list of columns that are required inputs of analyzers.
  """
    use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1)
    if not use_tf_compat_v1:
        assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs
    graph, structured_inputs, _ = (impl_helper.trace_preprocessing_function(
        preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1))

    tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)
    visitor = _SourcedTensorsVisitor()
    for tensor_sink in tensor_sinks:
        nodes.Traverser(visitor).visit_value_node(tensor_sink.future)

    analyze_input_tensors = graph_tools.get_dependent_inputs(
        graph, structured_inputs, visitor.sourced_tensors)
    return list(analyze_input_tensors.keys())
示例#4
0
    def testTraverserComplexGraphMultipleCalls(self):
        a = nodes.apply_operation(_Constant, value='a', label='Constant[a]')
        b = nodes.apply_operation(_Constant, value='b', label='Constant[b]')
        c = nodes.apply_operation(_Constant, value='c', label='Constant[c]')
        b_copy, a_copy = nodes.apply_multi_output_operation(_Swap,
                                                            a,
                                                            b,
                                                            label='Swap')
        b_a = nodes.apply_operation(_Concat, b_copy, a_copy, label='Concat[0]')
        b_a_c = nodes.apply_operation(_Concat, b_a, c, label='Concat[1]')

        mock_visitor = mock.MagicMock()
        mock_visitor.visit.side_effect = [('a', ), ('b', ), ('b', 'a'),
                                          ('ba', ), ('c', ), ('bac', )]

        traverser = nodes.Traverser(mock_visitor)
        traverser.visit_value_node(b_a)
        traverser.visit_value_node(b_a_c)

        mock_visitor.assert_has_calls([
            mock.call.visit(_Constant('a', 'Constant[a]'), ()),
            mock.call.validate_value('a'),
            mock.call.visit(_Constant('b', 'Constant[b]'), ()),
            mock.call.validate_value('b'),
            mock.call.visit(_Swap('Swap'), ('a', 'b')),
            mock.call.validate_value('b'),
            mock.call.validate_value('a'),
            mock.call.visit(_Concat('Concat[0]'), ('b', 'a')),
            mock.call.validate_value('ba'),
            mock.call.visit(_Constant('c', 'Constant[c]'), ()),
            mock.call.validate_value('c'),
            mock.call.visit(_Concat('Concat[1]'), ('ba', 'c')),
            mock.call.validate_value('bac'),
        ])
def get_analysis_dataset_keys(preprocessing_fn, specs, dataset_keys,
                              input_cache, force_tf_compat_v1):
    """Computes the dataset keys that are required in order to perform analysis.

  Args:
    preprocessing_fn: A tf.transform preprocessing_fn.
    specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is
      True, this can also be feature specifications.
    dataset_keys: A set of strings which are dataset keys, they uniquely
      identify these datasets across analysis runs.
    input_cache: A cache dictionary.
    force_tf_compat_v1: If `True`, use Tensorflow in compat.v1 mode.

  Returns:
    A set of dataset keys that are required for analysis.
  """
    transform_fn_future, _ = _build_analysis_graph_for_inspection(
        preprocessing_fn, specs, dataset_keys, input_cache, force_tf_compat_v1)

    result = set()
    inspect_visitor = _InspectVisitor(result)
    inspect_traverser = nodes.Traverser(inspect_visitor)
    _ = inspect_traverser.visit_value_node(transform_fn_future)

    # If None is present this means that a flattened version of the entire dataset
    # is required, therefore this will be returning all of the given dataset_keys.
    if any(k.is_flattened_dataset_key() for k in result):
        result = dataset_keys
    return result
示例#6
0
    def testTraverserComplexGraph(self):
        a = nodes.apply_operation(_Constant, value='a')
        b = nodes.apply_operation(_Constant, value='b')
        c = nodes.apply_operation(_Constant, value='c')
        b_copy, a_copy = nodes.apply_multi_output_operation(_Swap, a, b)
        b_a = nodes.apply_operation(_Concat, b_copy, a_copy)
        b_a_c = nodes.apply_operation(_Concat, b_a, c)

        mock_visitor = mock.MagicMock()
        mock_visitor.visit.side_effect = [('a', ), ('b', ), ('b', 'a'),
                                          ('ba', ), ('c', ), ('bac', )]

        nodes.Traverser(mock_visitor).visit_value_node(b_a_c)

        mock_visitor.assert_has_calls([
            mock.call.visit(_Constant('a'), ()),
            mock.call.validate_value('a'),
            mock.call.visit(_Constant('b'), ()),
            mock.call.validate_value('b'),
            mock.call.visit(_Swap(), ('a', 'b')),
            mock.call.validate_value('b'),
            mock.call.validate_value('a'),
            mock.call.visit(_Concat(), ('b', 'a')),
            mock.call.validate_value('ba'),
            mock.call.visit(_Constant('c'), ()),
            mock.call.validate_value('c'),
            mock.call.visit(_Concat(), ('ba', 'c')),
            mock.call.validate_value('bac'),
        ])
示例#7
0
    def testTraverserBadNumOutputs(self):
        a = nodes.apply_operation(_Constant, value='a', label='Constant[a]')
        mock_visitor = mock.MagicMock()
        mock_visitor.visit.side_effect = [('a', 'b')]

        with self.assertRaisesRegexp(
                ValueError, 'has 1 outputs but visitor returned 2 values: '):
            nodes.Traverser(mock_visitor).visit_value_node(a)
示例#8
0
 def testTraverserSimpleGraph(self):
     a = nodes.apply_operation(_Constant, value='a', label='Constant[a]')
     mock_visitor = mock.MagicMock()
     mock_visitor.visit.side_effect = [('a', )]
     nodes.Traverser(mock_visitor).visit_value_node(a)
     mock_visitor.assert_has_calls([
         mock.call.visit(_Constant('a', 'Constant[a]'), ()),
         mock.call.validate_value('a'),
     ])
示例#9
0
    def testTraverserOutputsNotATuple(self):
        a = nodes.apply_operation(_Constant, value='a', label='Constant[a]')

        mock_visitor = mock.MagicMock()
        mock_visitor.visit.side_effect = ['not a tuple']

        with self.assertRaisesRegexp(
                ValueError, r'expected visitor to return a tuple, got'):
            nodes.Traverser(mock_visitor).visit_value_node(a)
示例#10
0
def _perform_cache_optimization(saved_model_future, dataset_keys, cache_dict):
    """Performs cache optimization on the given graph."""
    cache_output_nodes = {}
    optimize_visitor = _OptimizeVisitor(dataset_keys or {}, cache_dict,
                                        cache_output_nodes)
    optimize_traverser = nodes.Traverser(optimize_visitor)
    optimized = optimize_traverser.visit_value_node(
        saved_model_future).flattened_view

    if cache_dict is None:
        assert not cache_output_nodes
        cache_output_nodes = None

    return optimized, cache_output_nodes
示例#11
0
    def testTraverserCycle(self):
        a = nodes.apply_operation(_Constant, value='a', label='Constant[a]')
        x_0 = nodes.apply_operation(_Identity, a, label='Identity[0]')
        x_1 = nodes.apply_operation(_Identity, x_0, label='Identity[1]')
        x_2 = nodes.apply_operation(_Identity, x_1, label='Identity[2]')
        x_0.parent_operation._inputs = (x_2, )

        mock_visitor = mock.MagicMock()
        mock_visitor.visit.return_value = ('x', )

        with self.assertRaisesWithLiteralMatch(
                AssertionError,
                'Cycle detected: [Identity[2], Identity[1], Identity[0], Identity[2]]'
        ):
            nodes.Traverser(mock_visitor).visit_value_node(x_2)
示例#12
0
    def testTraverserCycle(self):
        x = nodes.apply_operation(_Constant, value='x')
        x_0 = nodes.apply_operation(_Identity, x, name='x_0')
        x_1 = nodes.apply_operation(_Identity, x_0, name='x_1')
        x_2 = nodes.apply_operation(_Identity, x_1, name='x_2')
        x_0.parent_operation._inputs = (x_2, )

        mock_visitor = mock.MagicMock()
        mock_visitor.visit.return_value = ('x', )

        with self.assertRaisesWithLiteralMatch(
                AssertionError,
                'Cycle detected: [_Identity(name=\'x_2\'), _Identity(name=\'x_1\'), '
                '_Identity(name=\'x_0\'), _Identity(name=\'x_2\')]'):
            nodes.Traverser(mock_visitor).visit_value_node(x_2)
def _perform_cache_optimization(saved_model_future, dataset_keys,
                                tensor_keys_to_paths, cache_dict, num_phases):
    """Performs cache optimization on the given graph."""
    cache_output_nodes = {}
    optimize_visitor = _OptimizeVisitor(dataset_keys or {}, cache_dict,
                                        tensor_keys_to_paths,
                                        cache_output_nodes, num_phases)
    optimize_traverser = nodes.Traverser(optimize_visitor)
    optimized = optimize_traverser.visit_value_node(
        saved_model_future).flattened_view

    if cache_dict is None:
        assert not cache_output_nodes
        cache_output_nodes = None

    return (optimized, cache_output_nodes,
            optimize_visitor.get_detached_sideeffect_leafs())
def get_analyze_input_columns(
        preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
                                   Mapping[str, common_types.TensorType]],
        specs: Mapping[str, Union[common_types.FeatureSpecType, tf.TypeSpec]],
        force_tf_compat_v1: bool = False) -> List[str]:
    """Return columns that are required inputs of `AnalyzeDataset`.

  Args:
    preprocessing_fn: A tf.transform preprocessing_fn.
    specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is
      True, this can also be feature specifications.
    force_tf_compat_v1: (Optional) If `True`, use Tensorflow in compat.v1 mode.
      Defaults to `False`.

  Returns:
    A list of columns that are required inputs of analyzers.
  """
    use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1)
    if not use_tf_compat_v1:
        assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs
    graph, structured_inputs, structured_outputs = (
        impl_helper.trace_preprocessing_function(
            preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1))

    tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)
    visitor = graph_tools.SourcedTensorsVisitor()
    for tensor_sink in tensor_sinks:
        nodes.Traverser(visitor).visit_value_node(tensor_sink.future)

    if use_tf_compat_v1:
        control_dependency_ops = []
    else:
        # If traced in TF2 as a tf.function, inputs that end up in control
        # dependencies are required for the function to execute. Return such inputs
        # as required inputs of analyzers as well.
        _, control_dependency_ops = (
            tf2_utils.strip_and_get_tensors_and_control_dependencies(
                tf.nest.flatten(structured_outputs, expand_composites=True)))

    output_tensors = list(
        itertools.chain(visitor.sourced_tensors, control_dependency_ops))
    analyze_input_tensors = graph_tools.get_dependent_inputs(
        graph, structured_inputs, output_tensors)
    return list(analyze_input_tensors.keys())
示例#15
0
def get_analysis_dataset_keys(preprocessing_fn, feature_spec, dataset_keys,
                              input_cache):
  """Computes the dataset keys that are required in order to perform analysis.

  Args:
    preprocessing_fn: A tf.transform preprocessing_fn.
    feature_spec: A dict of feature name to feature specification.
    dataset_keys: A set of strings which are dataset keys, they uniquely
      identify these datasets across analysis runs.
    input_cache: A cache dictionary.

  Returns:
    A pair of:
      - A set of dataset keys that are required for analysis.
      - A boolean indicating whether or not a flattened version of the entire
        dataset is required. See the `flat_data` input to
        `AnalyzeDatasetWithCache`.
  """
  with tf.Graph().as_default() as graph:
    with tf.compat.v1.name_scope('inputs'):
      input_signature = impl_helper.feature_spec_as_batched_placeholders(
          feature_spec)
      # TODO(b/34288791): This needs to be exactly the same as in impl.py
      copied_inputs = impl_helper.copy_tensors(input_signature)

    output_signature = preprocessing_fn(copied_inputs)
  transform_fn_future, _ = build(
      graph,
      input_signature,
      output_signature,
      dataset_keys=dataset_keys,
      cache_dict=input_cache)

  required_dataset_keys_result = set()
  inspect_visitor = _InspectVisitor(required_dataset_keys_result)
  inspect_traverser = nodes.Traverser(inspect_visitor)
  _ = inspect_traverser.visit_value_node(transform_fn_future)

  # If None is present this means that a flattened version of the entire dataset
  # is required, therefore this will be returning all of the given dataset_keys.
  flat_data_required = None in required_dataset_keys_result
  if flat_data_required:
    required_dataset_keys_result = dataset_keys
  return required_dataset_keys_result, flat_data_required
def get_analyze_input_columns(preprocessing_fn, feature_spec):
    """Return columns that are required inputs of `AnalyzeDataset`.

  Args:
    preprocessing_fn: A tf.transform preprocessing_fn.
    feature_spec: A dict of feature name to feature specification.

  Returns:
    A list of columns that are required inputs of analyzers.
  """
    with tf.compat.v1.Graph().as_default() as graph:
        input_signature = impl_helper.feature_spec_as_batched_placeholders(
            feature_spec)
        _ = preprocessing_fn(input_signature.copy())

        tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)
        visitor = _SourcedTensorsVisitor()
        for tensor_sink in tensor_sinks:
            nodes.Traverser(visitor).visit_value_node(tensor_sink.future)

        analyze_input_tensors = graph_tools.get_dependent_inputs(
            graph, input_signature, visitor.sourced_tensors)
        return analyze_input_tensors.keys()
示例#17
0
def get_analyzers_fingerprint(
    graph: tf.Graph, structured_inputs: Mapping[str, common_types.TensorType]
) -> Mapping[str, AnalyzersFingerprint]:
  """Computes fingerprints for all analyzers in `graph`.

  Args:
    graph: a TF Graph.
    structured_inputs: a dict from keys to batches of placeholder graph tensors.

  Returns:
    A mapping from analyzer name to a set of paths that define its fingerprint.
  """
  result = {}
  tensor_sinks = graph.get_collection(analyzer_nodes.ALL_REPLACEMENTS)
  # The value for the keys in this dictionary are unused and can be arbitrary.
  sink_tensors_ready = {
      tf_utils.hashable_tensor_or_op(tensor_sink.tensor): False
      for tensor_sink in tensor_sinks
  }
  graph_analyzer = InitializableGraphAnalyzer(
      graph, structured_inputs, list(sink_tensors_ready.items()),
      describe_path_as_analyzer_cache_hash)
  for tensor_sink in tensor_sinks:
    # Retrieve tensors that are inputs to the analyzer's value node.
    visitor = SourcedTensorsVisitor()
    nodes.Traverser(visitor).visit_value_node(tensor_sink.future)
    source_keys = _retrieve_source_keys(visitor.sourced_tensors,
                                        structured_inputs)
    paths = set()
    for tensor in visitor.sourced_tensors:
      # Obtain fingerprint for each tensor that is an input to the value node.
      path = graph_analyzer.get_unique_path(tensor)
      if path is not None:
        paths.add(path)
    result[str(tensor_sink.tensor.name)] = AnalyzersFingerprint(
        source_keys, paths)
  return result
示例#18
0
def _perform_cache_optimization(saved_model_future, cache_location,
                                dataset_keys):
  optimize_visitor = _OptimizeVisitor(dataset_keys or {}, cache_location)
  optimize_traverser = nodes.Traverser(optimize_visitor)
  return optimize_traverser.visit_value_node(saved_model_future).flattened_view
示例#19
0
  def expand(self, dataset):
    """Analyze the dataset.

    Args:
      dataset: A dataset.

    Returns:
      A TransformFn containing the deferred transform function.

    Raises:
      ValueError: If preprocessing_fn has no outputs.
    """
    flattened_pcoll, input_values_pcoll_dict, input_metadata = dataset
    input_schema = input_metadata.schema

    input_values_pcoll_dict = input_values_pcoll_dict or dict()

    analyzer_cache.validate_dataset_keys(input_values_pcoll_dict.keys())

    with tf.Graph().as_default() as graph:

      with tf.name_scope('inputs'):
        feature_spec = input_schema.as_feature_spec()
        input_signature = impl_helper.feature_spec_as_batched_placeholders(
            feature_spec)
        # In order to avoid a bug where import_graph_def fails when the
        # input_map and return_elements of an imported graph are the same
        # (b/34288791), we avoid using the placeholder of an input column as an
        # output of a graph. We do this by applying tf.identity to all inputs of
        # the preprocessing_fn.  Note this applies at the level of raw tensors.
        # TODO(b/34288791): Remove this workaround and use a shallow copy of
        # inputs instead.  A shallow copy is needed in case
        # self._preprocessing_fn mutates its input.
        copied_inputs = impl_helper.copy_tensors(input_signature)

      output_signature = self._preprocessing_fn(copied_inputs)

    # At this point we check that the preprocessing_fn has at least one
    # output. This is because if we allowed the output of preprocessing_fn to
    # be empty, we wouldn't be able to determine how many instances to
    # "unbatch" the output into.
    if not output_signature:
      raise ValueError('The preprocessing function returned an empty dict')

    if graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
      raise ValueError(
          'The preprocessing function contained trainable variables '
          '{}'.format(
              graph.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)))

    pipeline = flattened_pcoll.pipeline
    serialized_tf_config = common._DEFAULT_TENSORFLOW_CONFIG_BY_RUNNER.get(  # pylint: disable=protected-access
        pipeline.runner)
    extra_args = common.ConstructBeamPipelineVisitor.ExtraArgs(
        base_temp_dir=Context.create_base_temp_dir(),
        serialized_tf_config=serialized_tf_config,
        pipeline=pipeline,
        flat_pcollection=flattened_pcoll,
        pcollection_dict=input_values_pcoll_dict,
        graph=graph,
        input_signature=input_signature,
        input_schema=input_schema,
        cache_location=self._cache_location)

    transform_fn_future = analysis_graph_builder.build(
        graph, input_signature, output_signature,
        input_values_pcoll_dict.keys(), self._cache_location)

    transform_fn_pcoll = nodes.Traverser(
        common.ConstructBeamPipelineVisitor(extra_args)).visit_value_node(
            transform_fn_future)

    # Infer metadata.  We take the inferred metadata and apply overrides that
    # refer to values of tensors in the graph.  The override tensors must
    # be "constant" in that they don't depend on input data.  The tensors can
    # depend on analyzer outputs though.  This allows us to set metadata that
    # depends on analyzer outputs. _augment_metadata will use the analyzer
    # outputs stored in `transform_fn` to compute the metadata in a
    # deferred manner, once the analyzer outputs are known.
    metadata = dataset_metadata.DatasetMetadata(
        schema=schema_inference.infer_feature_schema(output_signature, graph))

    deferred_metadata = (
        transform_fn_pcoll
        |
        'ComputeDeferredMetadata' >> beam.Map(_infer_metadata_from_saved_model))

    full_metadata = beam_metadata_io.BeamDatasetMetadata(
        metadata, deferred_metadata)

    _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll)

    return transform_fn_pcoll, full_metadata
def build(graph,
          input_signature,
          output_signature,
          dataset_keys=None,
          cache_dict=None):
    """Returns a list of `Phase`s describing how to execute the pipeline.

  The default graph is assumed to contain some `Analyzer`s which must be
  executed by doing a full pass over the dataset, and passing the inputs for
  that analyzer into some implementation, then taking the results and replacing
  the `Analyzer`s outputs with constants in the graph containing these results.

  The execution plan is described by a list of `Phase`s.  Each phase contains
  a list of `Analyzer`s, which are the `Analyzer`s which are ready to run in
  that phase, together with a list of ops, which are the table initializers that
  are ready to run in that phase.

  An `Analyzer` or op is ready to run when all its dependencies in the graph
  have been computed.  Thus if the graph is constructed by

  def preprocessing_fn(input)
    x = inputs['x']
    scaled_0 = x - tft.min(x)
    scaled_0_1 = scaled_0 / tft.max(scaled_0)

  Then the first phase will contain the analyzer corresponding to the call to
  `min`, because `x` is an input and so is ready to compute in the first phase,
  while the second phase will contain the analyzer corresponding to the call to
  `max` since `scaled_1` depends on the result of the call to `tft.min` which
  is computed in the first phase.

  More generally, we define a level for each op and each `Analyzer` by walking
  the graph, assigning to each operation the max level of its inputs, to each
  `Tensor` the level of its operation, unless it's the output of an `Analyzer`
  in which case we assign the level of its `Analyzer` plus one.

  Args:
    graph: A `tf.Graph`.
    input_signature: A dict whose keys are strings and values are `Tensor`s or
      `SparseTensor`s.
    output_signature: A dict whose keys are strings and values are `Tensor`s or
      `SparseTensor`s.
    dataset_keys: (Optional) A set of strings which are dataset keys, they
      uniquely identify these datasets across analysis runs.
    cache_dict: (Optional): A cache dictionary.

  Returns:
    A pair of:
      * list of `Phase`s
      * A dictionary of output cache `ValueNode`s.

  Raises:
    ValueError: if the graph cannot be analyzed.
  """
    tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)
    graph.clear_collection(analyzer_nodes.TENSOR_REPLACEMENTS)
    phase = 0
    tensor_bindings = []
    sink_tensors_ready = {
        tf_utils.hashable_tensor_or_op(tensor_sink.tensor): False
        for tensor_sink in tensor_sinks
    }
    translate_visitor = _TranslateVisitor()
    translate_traverser = nodes.Traverser(translate_visitor)

    analyzers_input_signature = {}
    graph_analyzer = None

    extracted_input_node = nodes.apply_operation(
        beam_nodes.ExtractInputForSavedModel,
        dataset_key=analyzer_cache._make_flattened_dataset_key(),  # pylint: disable=protected-access
        label='ExtractInputForSavedModel[FlattenedDataset]')

    while not all(sink_tensors_ready.values()):
        infix = 'Phase{}'.format(phase)
        # Determine which table init ops are ready to run in this phase
        # Determine which keys of pending_tensor_replacements are ready to run
        # in this phase, based in whether their dependencies are ready.
        graph_analyzer = graph_tools.InitializableGraphAnalyzer(
            graph, input_signature, list(sink_tensors_ready.items()),
            graph_tools.describe_path_as_analyzer_cache_hash)
        ready_traverser = nodes.Traverser(_ReadyVisitor(graph_analyzer))

        # Now create and apply a SavedModel with all tensors in tensor_bindings
        # bound, which outputs all the tensors in the required tensor tuples.
        intermediate_output_signature = collections.OrderedDict()
        saved_model_future = nodes.apply_operation(
            beam_nodes.CreateSavedModel,
            *tensor_bindings,
            table_initializers=tuple(graph_analyzer.ready_table_initializers),
            output_signature=intermediate_output_signature,
            label='CreateSavedModelForAnalyzerInputs[{}]'.format(infix))

        extracted_values_dict = nodes.apply_operation(
            beam_nodes.ApplySavedModel,
            saved_model_future,
            extracted_input_node,
            phase=phase,
            label='ApplySavedModel[{}]'.format(infix))

        translate_visitor.phase = phase
        translate_visitor.intermediate_output_signature = (
            intermediate_output_signature)
        translate_visitor.extracted_values_dict = extracted_values_dict
        for tensor, value_node, is_asset_filepath in tensor_sinks:
            hashable_tensor = tf_utils.hashable_tensor_or_op(tensor)
            # Don't compute a binding/sink/replacement that's already been computed
            if sink_tensors_ready[hashable_tensor]:
                continue

            if not ready_traverser.visit_value_node(value_node):
                continue

            translated_value_node = translate_traverser.visit_value_node(
                value_node)

            name = _tensor_name(tensor)
            tensor_bindings.append(
                nodes.apply_operation(
                    beam_nodes.CreateTensorBinding,
                    translated_value_node,
                    tensor_name=str(tensor.name),
                    dtype_enum=tensor.dtype.as_datatype_enum,
                    is_asset_filepath=is_asset_filepath,
                    label=analyzer_nodes.sanitize_label(
                        'CreateTensorBinding[{}]'.format(name))))
            sink_tensors_ready[hashable_tensor] = True

        analyzers_input_signature.update(intermediate_output_signature)
        phase += 1

    # We need to make sure that the representation of this output_signature is
    # deterministic.
    output_signature = collections.OrderedDict(
        sorted(output_signature.items(), key=lambda t: t[0]))

    # TODO(KesterTong): check all table initializers are ready, check all output
    # tensors are ready.
    saved_model_future = nodes.apply_operation(
        beam_nodes.CreateSavedModel,
        *tensor_bindings,
        table_initializers=tuple(
            graph.get_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)),
        output_signature=output_signature,
        label='CreateSavedModel')

    tensor_keys_to_paths = {
        tensor_key:
        graph_analyzer.get_unique_path(analyzers_input_signature[tensor_key])
        for tensor_key in analyzers_input_signature
    }
    (optimized_saved_model_future, output_cache_value_nodes,
     detached_sideeffect_leafs) = _perform_cache_optimization(
         saved_model_future, dataset_keys, tensor_keys_to_paths, cache_dict,
         phase)

    (optimized_saved_model_future, output_cache_value_nodes) = (
        combiner_packing_util.perform_combiner_packing_optimization(
            optimized_saved_model_future, output_cache_value_nodes, phase))

    global _ANALYSIS_GRAPH
    _ANALYSIS_GRAPH = optimized_saved_model_future
    return (optimized_saved_model_future, output_cache_value_nodes,
            detached_sideeffect_leafs)
示例#21
0
def perform_combiner_packing_optimization(saved_model_future,
                                          cache_value_nodes, num_phases):
  """Optimizes the graph by packing possible combine nodes."""
  # Inspect the graph to identify all the packable combines.
  inspect_acc_combine_visitor = _InspectAccumulateCombineVisitor()
  inspect_acc_combine_traverser = nodes.Traverser(inspect_acc_combine_visitor)
  _ = inspect_acc_combine_traverser.visit_value_node(saved_model_future)

  packable_combines = inspect_acc_combine_visitor.packable_combines
  # Do not pack if we have only a single combine in the group.
  packable_combines = {
      label: group for label, group in packable_combines.items()
      if len(group) > 1
  }

  pack_acc_combine_visitor = _PackAccumulateCombineVisitor(packable_combines)
  pack_acc_combine_traverser = nodes.Traverser(pack_acc_combine_visitor)
  saved_model_future = pack_acc_combine_traverser.visit_value_node(
      saved_model_future)

  # Replace cache nodes to point to the corresponding new nodes.
  cache_value_nodes = _update_cache_value_node_references(
      cache_value_nodes, pack_acc_combine_traverser)

  # TODO(b/134414978): Consider also packing the merges even when we have
  # multiple phases.
  if num_phases > 1:
    return (saved_model_future, cache_value_nodes)

  # Identify the merge combines that can be packed together.
  inspect_merge_combine_visitor = _InspectMergeCombineVisitor()
  inspect_merge_combine_traverser = nodes.Traverser(
      inspect_merge_combine_visitor)
  _ = inspect_merge_combine_traverser.visit_value_node(saved_model_future)

  # Only pack if we have more than one merge combines.
  if len(inspect_merge_combine_visitor.packable_combine_extract_outputs) <= 1:
    return (saved_model_future, cache_value_nodes)

  # Add flatten and packed merge nodes.
  pack_merge_combine_visitor = _PackMergeCombineVisitor(
      packable_combine_extract_outputs=
      inspect_merge_combine_visitor.packable_combine_extract_outputs)
  pack_merge_combine_traverser = nodes.Traverser(pack_merge_combine_visitor)
  saved_model_future = pack_merge_combine_traverser.visit_value_node(
      saved_model_future)
  # Replace cache nodes to point to the corresponding new nodes.
  cache_value_nodes = _update_cache_value_node_references(
      cache_value_nodes, pack_merge_combine_traverser)

  # Remove redundant flatten and packed merge nodes.
  remove_redundant_visitor = _RemoveRedundantPackedMergeCombineVisitor(
      final_packed_merge_combine_label=
      pack_merge_combine_visitor.final_packed_merge_combine_label)
  remove_redundant_traverser = nodes.Traverser(remove_redundant_visitor)
  saved_model_future = remove_redundant_traverser.visit_value_node(
      saved_model_future)
  # Replace cache nodes to point to the corresponding new nodes.
  cache_value_nodes = _update_cache_value_node_references(
      cache_value_nodes, remove_redundant_traverser)

  return (saved_model_future, cache_value_nodes)
示例#22
0
def build(graph, input_signature, output_signature):
    """Returns a list of `Phase`s describing how to execute the pipeline.

  The default graph is assumed to contain some `Analyzer`s which must be
  executed by doing a full pass over the dataset, and passing the inputs for
  that analyzer into some implementation, then taking the results and replacing
  the `Analyzer`s outputs with constants in the graph containing these results.

  The execution plan is described by a list of `Phase`s.  Each phase contains
  a list of `Analyzer`s, which are the `Analyzer`s which are ready to run in
  that phase, together with a list of ops, which are the table initializers that
  are ready to run in that phase.

  An `Analyzer` or op is ready to run when all its dependencies in the graph
  have been computed.  Thus if the graph is constructed by

  def preprocessing_fn(input)
    x = inputs['x']
    scaled_0 = x - tft.min(x)
    scaled_0_1 = scaled_0 / tft.max(scaled_0)

  Then the first phase will contain the analyzer corresponding to the call to
  `min`, because `x` is an input and so is ready to compute in the first phase,
  while the second phase will contain the analyzer corresponding to the call to
  `max` since `scaled_1` depends on the result of the call to `tft.min` which
  is computed in the first phase.

  More generally, we define a level for each op and each `Analyzer` by walking
  the graph, assigning to each operation the max level of its inputs, to each
  `Tensor` the level of its operation, unless it's the output of an `Analyzer`
  in which case we assign the level of its `Analyzer` plus one.

  Args:
    graph: A `tf.Graph`.
    input_signature: A dict whose keys are strings and values are `Tensor`s or
        `SparseTensor`s.
    output_signature: A dict whose keys are strings and values are `Tensor`s or
        `SparseTensor`s.

  Returns:
    A list of `Phase`s.

  Raises:
    ValueError: if the graph cannot be analyzed.
  """
    tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)
    graph.clear_collection(analyzer_nodes.TENSOR_REPLACEMENTS)
    phase = 0
    tensor_bindings = []
    sink_tensors_ready = {
        tensor_sink.tensor: False
        for tensor_sink in tensor_sinks
    }
    translate_visitor = _TranslateVisitor()
    translate_traverser = nodes.Traverser(translate_visitor)

    while not all(sink_tensors_ready.values()):
        # Determine which table init ops are ready to run in this phase
        # Determine which keys of pending_tensor_replacements are ready to run
        # in this phase, based in whether their dependencies are ready.
        graph_analyzer = graph_tools.InitializableGraphAnalyzer(
            graph, input_signature.values(), sink_tensors_ready)
        ready_traverser = nodes.Traverser(_ReadyVisitor(graph_analyzer))

        # Now create and apply a SavedModel with all tensors in tensor_bindings
        # bound, which outputs all the tensors in the required tensor tuples.
        intermediate_output_signature = collections.OrderedDict()
        saved_model_future = nodes.apply_operation(
            beam_nodes.CreateSavedModel,
            *tensor_bindings,
            table_initializers=tuple(graph_analyzer.ready_table_initializers),
            output_signature=intermediate_output_signature,
            label='CreateSavedModelForAnalyzerInputs[{}]'.format(phase))
        extracted_values_dict = nodes.apply_operation(
            beam_nodes.ApplySavedModel,
            saved_model_future,
            phase=phase,
            label='ApplySavedModel[{}]'.format(phase))

        translate_visitor.phase = phase
        translate_visitor.intermediate_output_signature = (
            intermediate_output_signature)
        translate_visitor.extracted_values_dict = extracted_values_dict
        for tensor, value_node, is_asset_filepath in tensor_sinks:
            # Don't compute a binding/sink/replacement that's already been computed
            if sink_tensors_ready[tensor]:
                continue

            if not ready_traverser.visit_value_node(value_node):
                continue

            translated_value_node = translate_traverser.visit_value_node(
                value_node)

            name = _tensor_name(tensor)
            tensor_bindings.append(
                nodes.apply_operation(
                    beam_nodes.CreateTensorBinding,
                    translated_value_node,
                    tensor=str(tensor.name),
                    is_asset_filepath=is_asset_filepath,
                    label='CreateTensorBinding[{}]'.format(name)))
            sink_tensors_ready[tensor] = True

        phase += 1

    # We need to make sure that the representation of this output_signature is
    # deterministic.
    output_signature = collections.OrderedDict(
        sorted(output_signature.items(), key=lambda t: t[0]))

    return nodes.apply_operation(beam_nodes.CreateSavedModel,
                                 *tensor_bindings,
                                 table_initializers=tuple(
                                     graph.get_collection(
                                         tf.GraphKeys.TABLE_INITIALIZERS)),
                                 output_signature=output_signature,
                                 label='CreateSavedModel')
示例#23
0
  def expand(self, dataset):
    """Analyze the dataset.

    Args:
      dataset: A dataset.

    Returns:
      A TransformFn containing the deferred transform function.

    Raises:
      ValueError: If preprocessing_fn has no outputs.
    """
    (flattened_pcoll, input_values_pcoll_dict, dataset_cache_dict,
     input_metadata) = dataset
    if self._use_tfxio:
      input_schema = None
      input_tensor_adapter_config = input_metadata
    else:
      input_schema = input_metadata.schema
      input_tensor_adapter_config = None

    input_values_pcoll_dict = input_values_pcoll_dict or dict()

    with tf.compat.v1.Graph().as_default() as graph:

      with tf.compat.v1.name_scope('inputs'):
        if self._use_tfxio:
          specs = TensorAdapter(input_tensor_adapter_config).OriginalTypeSpecs()
        else:
          specs = schema_utils.schema_as_feature_spec(input_schema).feature_spec
        input_signature = impl_helper.batched_placeholders_from_specs(specs)
        # In order to avoid a bug where import_graph_def fails when the
        # input_map and return_elements of an imported graph are the same
        # (b/34288791), we avoid using the placeholder of an input column as an
        # output of a graph. We do this by applying tf.identity to all inputs of
        # the preprocessing_fn.  Note this applies at the level of raw tensors.
        # TODO(b/34288791): Remove this workaround and use a shallow copy of
        # inputs instead.  A shallow copy is needed in case
        # self._preprocessing_fn mutates its input.
        copied_inputs = impl_helper.copy_tensors(input_signature)

      output_signature = self._preprocessing_fn(copied_inputs)

    # At this point we check that the preprocessing_fn has at least one
    # output. This is because if we allowed the output of preprocessing_fn to
    # be empty, we wouldn't be able to determine how many instances to
    # "unbatch" the output into.
    if not output_signature:
      raise ValueError('The preprocessing function returned an empty dict')

    if graph.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES):
      raise ValueError(
          'The preprocessing function contained trainable variables '
          '{}'.format(
              graph.get_collection_ref(
                  tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)))

    pipeline = self.pipeline or (flattened_pcoll or next(
        v for v in input_values_pcoll_dict.values() if v is not None)).pipeline

    # Add a stage that inspects graph collections for API use counts and logs
    # them as a beam metric.
    _ = (pipeline | 'InstrumentAPI' >> _InstrumentAPI(graph))

    tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get(
        type(pipeline.runner))
    extra_args = beam_common.ConstructBeamPipelineVisitor.ExtraArgs(
        base_temp_dir=Context.create_base_temp_dir(),
        tf_config=tf_config,
        pipeline=pipeline,
        flat_pcollection=flattened_pcoll,
        pcollection_dict=input_values_pcoll_dict,
        graph=graph,
        input_signature=input_signature,
        input_schema=input_schema,
        input_tensor_adapter_config=input_tensor_adapter_config,
        use_tfxio=self._use_tfxio,
        cache_pcoll_dict=dataset_cache_dict)

    transform_fn_future, cache_value_nodes = analysis_graph_builder.build(
        graph,
        input_signature,
        output_signature,
        input_values_pcoll_dict.keys(),
        cache_dict=dataset_cache_dict)
    traverser = nodes.Traverser(
        beam_common.ConstructBeamPipelineVisitor(extra_args))
    transform_fn_pcoll = traverser.visit_value_node(transform_fn_future)

    if cache_value_nodes is not None:
      output_cache_pcoll_dict = {}
      for (dataset_key,
           cache_key), value_node in six.iteritems(cache_value_nodes):
        if dataset_key not in output_cache_pcoll_dict:
          output_cache_pcoll_dict[dataset_key] = {}
        output_cache_pcoll_dict[dataset_key][cache_key] = (
            traverser.visit_value_node(value_node))
    else:
      output_cache_pcoll_dict = None

    # Infer metadata.  We take the inferred metadata and apply overrides that
    # refer to values of tensors in the graph.  The override tensors must
    # be "constant" in that they don't depend on input data.  The tensors can
    # depend on analyzer outputs though.  This allows us to set metadata that
    # depends on analyzer outputs. _infer_metadata_from_saved_model will use the
    # analyzer outputs stored in `transform_fn` to compute the metadata in a
    # deferred manner, once the analyzer outputs are known.
    metadata = dataset_metadata.DatasetMetadata(
        schema=schema_inference.infer_feature_schema(output_signature, graph))

    deferred_metadata = (
        transform_fn_pcoll
        |
        'ComputeDeferredMetadata' >> beam.Map(_infer_metadata_from_saved_model))

    full_metadata = beam_metadata_io.BeamDatasetMetadata(
        metadata, deferred_metadata)

    _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll)

    return (transform_fn_pcoll, full_metadata), output_cache_pcoll_dict