예제 #1
0
    def test_get_analysis_dataset_keys(self, preprocessing_fn,
                                       full_dataset_keys, cached_dataset_keys,
                                       expected_dataset_keys,
                                       expected_flat_data_required):
        # We force all dataset keys with entries in the cache dict will have a cache
        # hit.
        mocked_cache_entry_key = b'M'
        input_cache = {
            key: {
                mocked_cache_entry_key: 'C'
            }
            for key in cached_dataset_keys
        }
        feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)}
        with mock.patch(
                'tensorflow_transform.beam.analysis_graph_builder.'
                'analyzer_cache.make_cache_entry_key',
                return_value=mocked_cache_entry_key):
            dataset_keys, flat_data_required = (
                analysis_graph_builder.get_analysis_dataset_keys(
                    preprocessing_fn, feature_spec, full_dataset_keys,
                    input_cache))

        dot_string = nodes.get_dot_graph(
            [analysis_graph_builder._ANALYSIS_GRAPH]).to_string()
        self.WriteRenderedDotFile(dot_string)

        self.assertCountEqual(expected_dataset_keys, dataset_keys)
        self.assertEqual(expected_flat_data_required, flat_data_required)
    def test_get_analysis_dataset_keys(self, preprocessing_fn,
                                       full_dataset_keys, cached_dataset_keys,
                                       expected_dataset_keys,
                                       use_tf_compat_v1):
        if not use_tf_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')
        full_dataset_keys = [
            analysis_graph_builder.analyzer_cache.DatasetKey(k)
            for k in full_dataset_keys
        ]
        # We force all dataset keys with entries in the cache dict will have a cache
        # hit.
        mocked_cache_entry_key = b'M'
        input_cache = {
            key: {
                mocked_cache_entry_key: 'C'
            }
            for key in cached_dataset_keys
        }
        feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)}
        specs = (feature_spec if use_tf_compat_v1 else
                 impl_helper.get_type_specs_from_feature_specs(feature_spec))
        with mock.patch(
                'tensorflow_transform.beam.analysis_graph_builder.'
                'analyzer_cache.make_cache_entry_key',
                return_value=mocked_cache_entry_key):
            dataset_keys = (analysis_graph_builder.get_analysis_dataset_keys(
                preprocessing_fn,
                specs,
                full_dataset_keys,
                input_cache,
                force_tf_compat_v1=use_tf_compat_v1))

        dot_string = nodes.get_dot_graph(
            [analysis_graph_builder._ANALYSIS_GRAPH]).to_string()
        self.WriteRenderedDotFile(dot_string)
        self.assertCountEqual(expected_dataset_keys, dataset_keys)