Пример #1
0
    def test_do_with_cache(self):
        # First run that creates cache.
        self._exec_properties['module_file'] = self._module_file
        metrics = self._run_pipeline_get_metrics()

        # The test data has 10036 instances in the train dataset, and 4964 instances
        # in the eval dataset. Since the analysis dataset (train) is read twice when
        # no input cache is present (once for analysis and once for transform), the
        # expected value of the num_instances counter is: 10036 * 2 + 4964 = 25036.
        self.assertMetricsCounterEqual(metrics, 'num_instances', 24909)
        self._verify_transform_outputs(store_cache=True)

        # Second run from cache.
        self._output_data_dir = self._get_output_data_dir('2nd_run')
        analyzer_cache_artifact = standard_artifacts.TransformCache()
        analyzer_cache_artifact.uri = self._updated_analyzer_cache_artifact.uri

        self._make_base_do_params(self._source_data_dir, self._output_data_dir)

        self._input_dict[executor.ANALYZER_CACHE_KEY] = [
            analyzer_cache_artifact
        ]

        self._exec_properties['module_file'] = self._module_file
        metrics = self._run_pipeline_get_metrics()

        # Since input cache should now cover all analysis (train) paths, the train
        # and eval sets are each read exactly once for transform. Thus, the
        # expected value of the num_instances counter is: 10036 + 4964 = 15000.
        self.assertMetricsCounterEqual(metrics, 'num_instances', 15000)
        self._verify_transform_outputs(store_cache=True)
Пример #2
0
 def test_construct_with_cache_disabled_but_input_cache(self):
     with self.assertRaises(ValueError):
         _ = component.Transform(examples=self.examples,
                                 schema=self.schema,
                                 preprocessing_fn='my_preprocessing_fn',
                                 disable_analyzer_cache=True,
                                 analyzer_cache=channel_utils.as_channel(
                                     [standard_artifacts.TransformCache()]))
Пример #3
0
    def _make_base_do_params(self, source_data_dir, output_data_dir):
        # Create input dict.
        example1 = standard_artifacts.Examples()
        example1.uri = self._ARTIFACT1_URI
        example1.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        example2 = copy.deepcopy(example1)
        example2.uri = self._ARTIFACT2_URI

        self._example_artifacts = [example1, example2]

        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen')

        self._input_dict = {
            standard_component_specs.EXAMPLES_KEY: self._example_artifacts[:1],
            standard_component_specs.SCHEMA_KEY: [schema_artifact],
        }

        # Create output dict.
        self._transformed_output = standard_artifacts.TransformGraph()
        self._transformed_output.uri = os.path.join(output_data_dir,
                                                    'transformed_graph')
        transformed1 = standard_artifacts.Examples()
        transformed1.uri = os.path.join(output_data_dir,
                                        'transformed_examples', '0')
        transformed2 = standard_artifacts.Examples()
        transformed2.uri = os.path.join(output_data_dir,
                                        'transformed_examples', '1')

        self._transformed_example_artifacts = [transformed1, transformed2]

        temp_path_output = _TempPath()
        temp_path_output.uri = tempfile.mkdtemp()
        self._updated_analyzer_cache_artifact = standard_artifacts.TransformCache(
        )
        self._updated_analyzer_cache_artifact.uri = os.path.join(
            self._output_data_dir, 'CACHE')

        self._output_dict = {
            standard_component_specs.TRANSFORM_GRAPH_KEY:
            [self._transformed_output],
            standard_component_specs.TRANSFORMED_EXAMPLES_KEY:
            self._transformed_example_artifacts[:1],
            executor.TEMP_PATH_KEY: [temp_path_output],
            standard_component_specs.UPDATED_ANALYZER_CACHE_KEY:
            [self._updated_analyzer_cache_artifact],
        }

        # Create exec properties skeleton.
        self._exec_properties = {}
Пример #4
0
    def test_do_with_cache(self, provide_first_input_cache):
        # First run that creates cache.
        self._exec_properties[
            standard_component_specs.MODULE_FILE_KEY] = self._module_file
        if provide_first_input_cache:
            self._input_dict[standard_component_specs.ANALYZER_CACHE_KEY] = []
        metrics = self._run_pipeline_get_metrics()

        # The test data has 9909 instances in the train dataset, and 5091 instances
        # in the eval dataset. Since the analysis dataset (train) is read twice when
        # no input cache is present (once for analysis and once for transform), the
        # expected value of the num_instances counter is: 9909 * 2 + 5091 = 24909.
        self.assertMetricsCounterEqual(metrics, 'num_instances', 24909,
                                       ['tfx.Transform'])

        # Additionally we have 24909 instances due to generating statistics.
        self.assertMetricsCounterEqual(metrics, 'num_instances', 24909,
                                       ['tfx.DataValidation'])
        self._verify_transform_outputs(store_cache=True)

        # Second run from cache.
        self._output_data_dir = self._get_output_data_dir('2nd_run')
        analyzer_cache_artifact = standard_artifacts.TransformCache()
        analyzer_cache_artifact.uri = self._updated_analyzer_cache_artifact.uri

        self._make_base_do_params(self._SOURCE_DATA_DIR, self._output_data_dir)

        self._input_dict[standard_component_specs.ANALYZER_CACHE_KEY] = [
            analyzer_cache_artifact
        ]

        self._exec_properties[
            standard_component_specs.MODULE_FILE_KEY] = self._module_file
        metrics = self._run_pipeline_get_metrics()

        # Since input cache should now cover all analysis (train) paths, the train
        # and eval sets are each read exactly once for transform. Thus, the
        # expected value of the num_instances counter is: 9909 + 5091 = 15000.
        self.assertMetricsCounterEqual(metrics, 'num_instances', 15000,
                                       ['tfx.Transform'])

        # Additionally we have 24909 instances due to generating statistics.
        self.assertMetricsCounterEqual(metrics, 'num_instances', 24909,
                                       ['tfx.DataValidation'])
        self._verify_transform_outputs(store_cache=True)
Пример #5
0
  def testGetCachePathEntry(self):
    # Empty case.
    self.assertEmpty(
        executor_utils.GetCachePathEntry(
            standard_component_specs.ANALYZER_CACHE_KEY, {}))

    cache_artifact = standard_artifacts.TransformCache()
    cache_artifact.uri = '/dummy'
    # input
    result = executor_utils.GetCachePathEntry(
        standard_component_specs.ANALYZER_CACHE_KEY,
        {standard_component_specs.ANALYZER_CACHE_KEY: [cache_artifact]})
    self.assertEqual({labels.CACHE_INPUT_PATH_LABEL: '/dummy'}, result)

    # output
    result = executor_utils.GetCachePathEntry(
        standard_component_specs.UPDATED_ANALYZER_CACHE_KEY,
        {standard_component_specs.UPDATED_ANALYZER_CACHE_KEY: [cache_artifact]})
    self.assertEqual({labels.CACHE_OUTPUT_PATH_LABEL: '/dummy'}, result)
Пример #6
0
    def _make_base_do_params(self, source_data_dir, output_data_dir):
        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen')

        self._input_dict = {
            executor.EXAMPLES_KEY: [examples],
            executor.SCHEMA_KEY: [schema_artifact],
        }

        # Create output dict.
        self._transformed_output = standard_artifacts.TransformGraph()
        self._transformed_output.uri = os.path.join(output_data_dir,
                                                    'transformed_graph')
        self._transformed_examples = standard_artifacts.Examples()
        self._transformed_examples.uri = os.path.join(output_data_dir,
                                                      'transformed_examples')
        temp_path_output = _TempPath()
        temp_path_output.uri = tempfile.mkdtemp()
        self._updated_analyzer_cache_artifact = standard_artifacts.TransformCache(
        )
        self._updated_analyzer_cache_artifact.uri = os.path.join(
            self._output_data_dir, 'CACHE')

        self._output_dict = {
            executor.TRANSFORM_GRAPH_KEY: [self._transformed_output],
            executor.TRANSFORMED_EXAMPLES_KEY: [self._transformed_examples],
            executor.TEMP_PATH_KEY: [temp_path_output],
            executor.UPDATED_ANALYZER_CACHE_KEY:
            [self._updated_analyzer_cache_artifact],
        }

        # Create exec properties skeleton.
        self._exec_properties = {}
Пример #7
0
  def __init__(
      self,
      examples: types.Channel = None,
      schema: types.Channel = None,
      module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None,
      preprocessing_fn: Optional[Union[Text,
                                       data_types.RuntimeParameter]] = None,
      transform_graph: Optional[types.Channel] = None,
      transformed_examples: Optional[types.Channel] = None,
      input_data: Optional[types.Channel] = None,
      analyzer_cache: Optional[types.Channel] = None,
      instance_name: Optional[Text] = None,
      materialize: bool = True,
      disable_analyzer_cache: bool = False,
      custom_config: Optional[Dict[Text, Any]] = None):
    """Construct a Transform component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples` (required).
        This should contain the two splits 'train' and 'eval'.
      schema: A Channel of type `standard_artifacts.Schema`. This should
        contain a single schema artifact.
      module_file: The file path to a python module file, from which the
        'preprocessing_fn' function will be loaded.
        Exactly one of 'module_file' or 'preprocessing_fn' must be supplied.

        The function needs to have the following signature:
        ```
        def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]:
          ...
        ```
        where the values of input and returned Dict are either tf.Tensor or
        tf.SparseTensor.

        If additional inputs are needed for preprocessing_fn, they can be passed
        in custom_config:

        ```
        def preprocessing_fn(inputs: Dict[Text, Any], custom_config:
                             Dict[Text, Any]) -> Dict[Text, Any]:
          ...
        ```
      preprocessing_fn: The path to python function that implements a
        'preprocessing_fn'. See 'module_file' for expected signature of the
        function. Exactly one of 'module_file' or 'preprocessing_fn' must be
        supplied.
      transform_graph: Optional output 'TransformPath' channel for output of
        'tf.Transform', which includes an exported Tensorflow graph suitable for
        both training and serving;
      transformed_examples: Optional output 'ExamplesPath' channel for
        materialized transformed examples, which includes both 'train' and
        'eval' splits.
      input_data: Backwards compatibility alias for the 'examples' argument.
      analyzer_cache: Optional input 'TransformCache' channel containing
        cached information from previous Transform runs. When provided,
        Transform will try use the cached calculation if possible.
      instance_name: Optional unique instance name. Necessary iff multiple
        transform components are declared in the same pipeline.
      materialize: If True, write transformed examples as an output. If False,
        `transformed_examples` must not be provided.
      disable_analyzer_cache: If False, Transform will use input cache if
        provided and write cache output. If True, `analyzer_cache` must not be
        provided.
      custom_config: A dict which contains additional parameters that will be
        passed to preprocessing_fn.

    Raises:
      ValueError: When both or neither of 'module_file' and 'preprocessing_fn'
        is supplied.
    """
    if input_data:
      absl.logging.warning(
          'The "input_data" argument to the Transform component has '
          'been renamed to "examples" and is deprecated. Please update your '
          'usage as support for this argument will be removed soon.')
      examples = input_data
    if bool(module_file) == bool(preprocessing_fn):
      raise ValueError(
          "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied."
      )

    transform_graph = transform_graph or types.Channel(
        type=standard_artifacts.TransformGraph,
        artifacts=[standard_artifacts.TransformGraph()])

    if materialize and transformed_examples is None:
      transformed_examples = types.Channel(
          type=standard_artifacts.Examples,
          # TODO(b/161548528): remove the hardcode artifact.
          artifacts=[standard_artifacts.Examples()],
          matching_channel_name='examples')
    elif not materialize and transformed_examples is not None:
      raise ValueError(
          'Must not specify transformed_examples when materialize is False.')

    if disable_analyzer_cache:
      updated_analyzer_cache = None
      if analyzer_cache:
        raise ValueError(
            '`analyzer_cache` is set when disable_analyzer_cache is True.')
    else:
      updated_analyzer_cache = types.Channel(
          type=standard_artifacts.TransformCache,
          artifacts=[standard_artifacts.TransformCache()])

    spec = TransformSpec(
        examples=examples,
        schema=schema,
        module_file=module_file,
        preprocessing_fn=preprocessing_fn,
        transform_graph=transform_graph,
        transformed_examples=transformed_examples,
        analyzer_cache=analyzer_cache,
        updated_analyzer_cache=updated_analyzer_cache,
        custom_config=json.dumps(custom_config))
    super(Transform, self).__init__(spec=spec, instance_name=instance_name)