def get_analyze_input_columns(preprocessing_fn,
                              specs,
                              force_tf_compat_v1=False):
    """Return columns that are required inputs of `AnalyzeDataset`.

  Args:
    preprocessing_fn: A tf.transform preprocessing_fn.
    specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is
      True, this can also be feature specifications.
    force_tf_compat_v1: (Optional) If `True`, use Tensorflow in compat.v1 mode.
      Defaults to `False`.

  Returns:
    A list of columns that are required inputs of analyzers.
  """
    use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1)
    if not use_tf_compat_v1:
        assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs
    graph, structured_inputs, _ = (impl_helper.trace_preprocessing_function(
        preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1))

    tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)
    visitor = _SourcedTensorsVisitor()
    for tensor_sink in tensor_sinks:
        nodes.Traverser(visitor).visit_value_node(tensor_sink.future)

    analyze_input_tensors = graph_tools.get_dependent_inputs(
        graph, structured_inputs, visitor.sourced_tensors)
    return list(analyze_input_tensors.keys())
def get_transform_input_columns(preprocessing_fn,
                                specs,
                                force_tf_compat_v1=False):
    """Return columns that are required inputs of `TransformDataset`.

  Args:
    preprocessing_fn: A tf.transform preprocessing_fn.
    specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is
      True, this can also be feature specifications.
    force_tf_compat_v1: (Optional) If `True`, use Tensorflow in compat.v1 mode.
      Defaults to `False`.

  Returns:
    A list of columns that are required inputs of the transform `tf.Graph`
    defined by `preprocessing_fn`.
  """
    use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1)
    if not use_tf_compat_v1:
        assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs
    graph, structured_inputs, structured_outputs = (
        impl_helper.trace_preprocessing_function(
            preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1))

    transform_input_tensors = graph_tools.get_dependent_inputs(
        graph, structured_inputs, structured_outputs)
    return list(transform_input_tensors.keys())
Exemple #3
0
  def test_perform_combiner_packing_optimization(
      self, feature_spec, preprocessing_fn, num_phases,
      expected_dot_graph_str_before_packing,
      expected_dot_graph_str_after_packing):

    graph, structured_inputs, structured_outputs = (
        impl_helper.trace_preprocessing_function(
            preprocessing_fn, feature_spec, use_tf_compat_v1=True))

    def _side_effect_fn(saved_model_future, cache_value_nodes,
                        unused_num_phases):
      return (saved_model_future, cache_value_nodes)

    with mock.patch.object(
        combiner_packing_util,
        'perform_combiner_packing_optimization',
        side_effect=_side_effect_fn):
      transform_fn_future_before, unused_cache = analysis_graph_builder.build(
          graph, structured_inputs, structured_outputs)
    transform_fn_future_after, unused_cache = (
        combiner_packing_util.perform_combiner_packing_optimization(
            transform_fn_future_before, unused_cache, num_phases))
    dot_string_before = nodes.get_dot_graph(
        [transform_fn_future_before]).to_string()
    self.assertMultiLineEqual(
        msg='Result dot graph is:\n{}'.format(dot_string_before),
        first=dot_string_before,
        second=expected_dot_graph_str_before_packing)
    dot_string_after = nodes.get_dot_graph(
        [transform_fn_future_after]).to_string()
    self.WriteRenderedDotFile(dot_string_after)
    self.assertMultiLineEqual(
        msg='Result dot graph is:\n{}'.format(dot_string_after),
        first=dot_string_after,
        second=expected_dot_graph_str_after_packing)
Exemple #4
0
  def test_build(self, feature_spec, preprocessing_fn, expected_dot_graph_str):
    graph, structured_inputs, structured_outputs = (
        impl_helper.trace_preprocessing_function(
            preprocessing_fn, feature_spec, use_tf_compat_v1=True))
    transform_fn_future, unused_cache = analysis_graph_builder.build(
        graph, structured_inputs, structured_outputs)

    dot_string = nodes.get_dot_graph([transform_fn_future]).to_string()
    self.WriteRenderedDotFile(dot_string)
    self.assertMultiLineEqual(
        msg='Result dot graph is:\n{}'.format(dot_string),
        first=dot_string,
        second=expected_dot_graph_str)
def _build_analysis_graph_for_inspection(preprocessing_fn, specs, dataset_keys,
                                         input_cache, force_tf_compat_v1):
    """Builds the analysis graph for inspection."""
    if not force_tf_compat_v1:
        assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs
    graph, structured_inputs, structured_outputs = (
        impl_helper.trace_preprocessing_function(
            preprocessing_fn,
            specs,
            use_tf_compat_v1=tf2_utils.use_tf_compat_v1(force_tf_compat_v1)))

    transform_fn_future, cache_dict, _ = build(graph,
                                               structured_inputs,
                                               structured_outputs,
                                               dataset_keys=dataset_keys,
                                               cache_dict=input_cache)
    return transform_fn_future, cache_dict
def get_analyze_input_columns(
        preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
                                   Mapping[str, common_types.TensorType]],
        specs: Mapping[str, Union[common_types.FeatureSpecType, tf.TypeSpec]],
        force_tf_compat_v1: bool = False) -> List[str]:
    """Return columns that are required inputs of `AnalyzeDataset`.

  Args:
    preprocessing_fn: A tf.transform preprocessing_fn.
    specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is
      True, this can also be feature specifications.
    force_tf_compat_v1: (Optional) If `True`, use Tensorflow in compat.v1 mode.
      Defaults to `False`.

  Returns:
    A list of columns that are required inputs of analyzers.
  """
    use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1)
    if not use_tf_compat_v1:
        assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs
    graph, structured_inputs, structured_outputs = (
        impl_helper.trace_preprocessing_function(
            preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1))

    tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)
    visitor = graph_tools.SourcedTensorsVisitor()
    for tensor_sink in tensor_sinks:
        nodes.Traverser(visitor).visit_value_node(tensor_sink.future)

    if use_tf_compat_v1:
        control_dependency_ops = []
    else:
        # If traced in TF2 as a tf.function, inputs that end up in control
        # dependencies are required for the function to execute. Return such inputs
        # as required inputs of analyzers as well.
        _, control_dependency_ops = (
            tf2_utils.strip_and_get_tensors_and_control_dependencies(
                tf.nest.flatten(structured_outputs, expand_composites=True)))

    output_tensors = list(
        itertools.chain(visitor.sourced_tensors, control_dependency_ops))
    analyze_input_tensors = graph_tools.get_dependent_inputs(
        graph, structured_inputs, output_tensors)
    return list(analyze_input_tensors.keys())
    def test_build(self, feature_spec, preprocessing_fn,
                   expected_dot_graph_str, expected_dot_graph_str_tf2,
                   use_tf_compat_v1):
        if not use_tf_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')
        specs = (feature_spec if use_tf_compat_v1 else
                 impl_helper.get_type_specs_from_feature_specs(feature_spec))
        graph, structured_inputs, structured_outputs = (
            impl_helper.trace_preprocessing_function(
                preprocessing_fn,
                specs,
                use_tf_compat_v1=use_tf_compat_v1,
                base_temp_dir=os.path.join(self.get_temp_dir(),
                                           self._testMethodName)))
        transform_fn_future, unused_cache = analysis_graph_builder.build(
            graph, structured_inputs, structured_outputs)

        dot_string = nodes.get_dot_graph([transform_fn_future]).to_string()
        self.WriteRenderedDotFile(dot_string)
        self.assertMultiLineEqual(
            msg='Result dot graph is:\n{}'.format(dot_string),
            first=dot_string,
            second=(expected_dot_graph_str
                    if use_tf_compat_v1 else expected_dot_graph_str_tf2))