def test_get_analysis_cache_entry_keys(self, use_tf_compat_v1):
        if not use_tf_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')
        full_dataset_keys = ['a', 'b']

        def preprocessing_fn(inputs):
            return {'x': tft.scale_to_0_1(inputs['x'])}

        mocked_cache_entry_key = 'A'

        def mocked_make_cache_entry_key(_):
            return mocked_cache_entry_key

        feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)}
        specs = (feature_spec if use_tf_compat_v1 else
                 impl_helper.get_type_specs_from_feature_specs(feature_spec))
        with mock.patch(
                'tensorflow_transform.beam.analysis_graph_builder.'
                'analyzer_cache.make_cache_entry_key',
                side_effect=mocked_make_cache_entry_key):
            cache_entry_keys = (
                analysis_graph_builder.get_analysis_cache_entry_keys(
                    preprocessing_fn,
                    specs,
                    full_dataset_keys,
                    force_tf_compat_v1=use_tf_compat_v1))

        dot_string = nodes.get_dot_graph(
            [analysis_graph_builder._ANALYSIS_GRAPH]).to_string()
        self.WriteRenderedDotFile(dot_string)
        self.assertCountEqual(cache_entry_keys, [mocked_cache_entry_key])
Example #2
0
  def test_get_analysis_cache_entry_keys(self):
    full_dataset_keys = ['a', 'b']
    def preprocessing_fn(inputs):
      return {'x': tft.scale_to_0_1(inputs['x'])}
    mocked_cache_entry_key = 'A'
    def mocked_make_cache_entry_key(_):
      return mocked_cache_entry_key
    feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)}
    with mock.patch(
        'tensorflow_transform.beam.analysis_graph_builder.'
        'analyzer_cache.make_cache_entry_key',
        side_effect=mocked_make_cache_entry_key):
      cache_entry_keys = (
          analysis_graph_builder.get_analysis_cache_entry_keys(
              preprocessing_fn, feature_spec, full_dataset_keys))

    dot_string = nodes.get_dot_graph([analysis_graph_builder._ANALYSIS_GRAPH
                                     ]).to_string()
    self.WriteRenderedDotFile(dot_string)
    self.assertCountEqual(cache_entry_keys, [mocked_cache_entry_key])