示例#1
0
    def setUp(self):
        super(ExecutorTest, self).setUp()

        # Create input_dict.
        self._input_data_dir = os.path.join(os.path.dirname(__file__),
                                            'testdata')
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(self._input_data_dir, 'example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(self._input_data_dir, 'schema_gen')
        self._input_dict = {
            standard_component_specs.EXAMPLES_KEY: [examples],
            standard_component_specs.SCHEMA_KEY: [schema_artifact],
        }

        # Create output_dict.
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR',
                           tempfile.mkdtemp(dir=flags.FLAGS.test_tmpdir)),
            self._testMethodName)
        self._transformed_output = standard_artifacts.TransformGraph()
        self._transformed_output.uri = os.path.join(output_data_dir,
                                                    'transformed_output')
        self._transformed_examples = standard_artifacts.Examples()
        self._transformed_examples.uri = output_data_dir
        self._transformed_examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        temp_path_output = _TempPath()
        temp_path_output.uri = tempfile.mkdtemp()
        self._output_dict = {
            standard_component_specs.TRANSFORM_GRAPH_KEY:
            [self._transformed_output],
            standard_component_specs.TRANSFORMED_EXAMPLES_KEY:
            [self._transformed_examples],
            tfx_executor.TEMP_PATH_KEY: [temp_path_output],
        }

        # Create exec properties.
        self._exec_properties = {
            'custom_config':
            json.dumps({'problem_statement_path': '/some/fake/path'})
        }
示例#2
0
 def testGetFromSplitsMultipleArtifacts(self):
     """Test split retrieval utility on a multiple list of split Artifacts."""
     artifacts = [
         standard_artifacts.Examples(),
         standard_artifacts.Examples()
     ]
     artifacts[0].uri = '/tmp1'
     artifacts[0].split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     artifacts[1].uri = '/tmp2'
     artifacts[1].split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     # When creating new splits, use 'Split-<split_name>' format.
     self.assertEqual(['/tmp1/Split-train', '/tmp2/Split-train'],
                      artifact_utils.get_split_uris(artifacts, 'train'))
     self.assertEqual(['/tmp1/Split-eval', '/tmp2/Split-eval'],
                      artifact_utils.get_split_uris(artifacts, 'eval'))
     # When reading artifacts without version.
     artifacts[0].mlmd_artifact.state = metadata_store_pb2.Artifact.LIVE
     artifacts[1].mlmd_artifact.state = metadata_store_pb2.Artifact.LIVE
     self.assertEqual(['/tmp1/train', '/tmp2/train'],
                      artifact_utils.get_split_uris(artifacts, 'train'))
     self.assertEqual(['/tmp1/eval', '/tmp2/eval'],
                      artifact_utils.get_split_uris(artifacts, 'eval'))
     # When reading artifacts with old version.
     artifacts[0].set_string_custom_property(
         artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY, '0.1')
     artifacts[1].set_string_custom_property(
         artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY, '0.1')
     self.assertEqual(['/tmp1/train', '/tmp2/train'],
                      artifact_utils.get_split_uris(artifacts, 'train'))
     self.assertEqual(['/tmp1/eval', '/tmp2/eval'],
                      artifact_utils.get_split_uris(artifacts, 'eval'))
     # When reading artifacts with new version.
     artifacts[0].set_string_custom_property(
         artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY,
         artifact_utils._ARTIFACT_VERSION_FOR_SPLIT_UPDATE)
     artifacts[1].set_string_custom_property(
         artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY,
         artifact_utils._ARTIFACT_VERSION_FOR_SPLIT_UPDATE)
     self.assertEqual(['/tmp1/Split-train', '/tmp2/Split-train'],
                      artifact_utils.get_split_uris(artifacts, 'train'))
     self.assertEqual(['/tmp1/Split-eval', '/tmp2/Split-eval'],
                      artifact_utils.get_split_uris(artifacts, 'eval'))
示例#3
0
  def testDoLegacySingleEvalSavedModelWFairness(self, exec_properties):
    source_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')
    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create input dict.
    examples = standard_artifacts.Examples()
    examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
    examples.split_names = artifact_utils.encode_split_names(['train', 'eval'])
    model = standard_artifacts.Model()
    model.uri = os.path.join(source_data_dir, 'trainer/current')
    input_dict = {
        constants.EXAMPLES_KEY: [examples],
        constants.MODEL_KEY: [model],
    }

    # Create output dict.
    eval_output = standard_artifacts.ModelEvaluation()
    eval_output.uri = os.path.join(output_data_dir, 'eval_output')
    blessing_output = standard_artifacts.ModelBlessing()
    blessing_output.uri = os.path.join(output_data_dir, 'blessing_output')
    output_dict = {
        constants.EVALUATION_KEY: [eval_output],
        constants.BLESSING_KEY: [blessing_output],
    }

    try:
      # Need to import the following module so that the fairness indicator
      # post-export metric is registered.  This may raise an ImportError if the
      # currently-installed version of TFMA does not support fairness
      # indicators.
      import tensorflow_model_analysis.addons.fairness.post_export_metrics.fairness_indicators  # pylint: disable=g-import-not-at-top, unused-variable
      exec_properties['fairness_indicator_thresholds'] = [
          0.1, 0.3, 0.5, 0.7, 0.9
      ]
    except ImportError:
      logging.warning(
          'Not testing fairness indicators because a compatible TFMA version '
          'is not installed.')

    # List needs to be serialized before being passed into Do function.
    exec_properties[constants.EXAMPLE_SPLITS_KEY] = json_utils.dumps(None)

    # Run executor.
    evaluator = executor.Executor()
    evaluator.Do(input_dict, output_dict, exec_properties)

    # Check evaluator outputs.
    self.assertTrue(
        fileio.exists(os.path.join(eval_output.uri, 'eval_config.json')))
    self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'metrics')))
    self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'plots')))
    self.assertFalse(
        fileio.exists(os.path.join(blessing_output.uri, 'BLESSED')))
示例#4
0
    def __init__(
            self,
            input: types.Channel = None,  # pylint: disable=redefined-builtin
            input_config: Optional[Union[example_gen_pb2.Input,
                                         Dict[Text, Any]]] = None,
            output_config: Optional[Union[example_gen_pb2.Output,
                                          Dict[Text, Any]]] = None,
            custom_config: Optional[Union[example_gen_pb2.CustomConfig,
                                          Dict[Text, Any]]] = None,
            example_artifacts: Optional[types.Channel] = None,
            custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
            input_base: Optional[types.Channel] = None,
            instance_name: Optional[Text] = None):
        """Construct a FileBasedExampleGen component.

    Args:
      input: A Channel of type `standard_artifacts.ExternalArtifact`, which
        includes one artifact whose uri is an external directory containing
        the data files. _required_
      input_config: An
        [`example_gen_pb2.Input`](https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto)
          instance, providing input configuration. If unset, the files under
          input_base will be treated as a single dataset.
      output_config: An example_gen_pb2.Output instance, providing the output
        configuration. If unset, default splits will be 'train' and
        'eval' with size 2:1.
      custom_config: An optional example_gen_pb2.CustomConfig instance,
        providing custom configuration for executor.
      example_artifacts: Channel of 'ExamplesPath' for output train and eval
        examples.
      custom_executor_spec: Optional custom executor spec overriding the default
        executor spec specified in the component attribute.
      input_base: Backwards compatibility alias for the 'input' argument.
      instance_name: Optional unique instance name. Required only if multiple
        ExampleGen components are declared in the same pipeline.  Either
        `input_base` or `input` must be present in the input arguments.
    """
        input = input or input_base
        # Configure inputs and outputs.
        input_config = input_config or utils.make_default_input_config()
        output_config = output_config or utils.make_default_output_config(
            input_config)
        if not example_artifacts:
            artifact = standard_artifacts.Examples()
            artifact.split_names = artifact_utils.encode_split_names(
                utils.generate_output_split_names(input_config, output_config))
            example_artifacts = channel_utils.as_channel([artifact])
        spec = FileBasedExampleGenSpec(input_base=input,
                                       input_config=input_config,
                                       output_config=output_config,
                                       custom_config=custom_config,
                                       examples=example_artifacts)
        super(FileBasedExampleGen,
              self).__init__(spec=spec,
                             custom_executor_spec=custom_executor_spec,
                             instance_name=instance_name)
示例#5
0
  def testDoValidation(self, exec_properties, blessed, has_baseline):
    source_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')
    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create input dict.
    examples = standard_artifacts.Examples()
    examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
    examples.split_names = artifact_utils.encode_split_names(['train', 'eval'])
    model = standard_artifacts.Model()
    baseline_model = standard_artifacts.Model()
    model.uri = os.path.join(source_data_dir, 'trainer/current')
    baseline_model.uri = os.path.join(source_data_dir, 'trainer/previous/')
    blessing_output = standard_artifacts.ModelBlessing()
    blessing_output.uri = os.path.join(output_data_dir, 'blessing_output')
    schema = standard_artifacts.Schema()
    schema.uri = os.path.join(source_data_dir, 'schema_gen')
    input_dict = {
        constants.EXAMPLES_KEY: [examples],
        constants.MODEL_KEY: [model],
        constants.SCHEMA_KEY: [schema],
    }
    if has_baseline:
      input_dict[constants.BASELINE_MODEL_KEY] = [baseline_model]

    # Create output dict.
    eval_output = standard_artifacts.ModelEvaluation()
    eval_output.uri = os.path.join(output_data_dir, 'eval_output')
    blessing_output = standard_artifacts.ModelBlessing()
    blessing_output.uri = os.path.join(output_data_dir, 'blessing_output')
    output_dict = {
        constants.EVALUATION_KEY: [eval_output],
        constants.BLESSING_KEY: [blessing_output],
    }

    # List needs to be serialized before being passed into Do function.
    exec_properties[constants.EXAMPLE_SPLITS_KEY] = json_utils.dumps(None)

    # Run executor.
    evaluator = executor.Executor()
    evaluator.Do(input_dict, output_dict, exec_properties)

    # Check evaluator outputs.
    self.assertTrue(
        fileio.exists(os.path.join(eval_output.uri, 'eval_config.json')))
    self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'metrics')))
    self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'plots')))
    self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'validations')))
    if blessed:
      self.assertTrue(
          fileio.exists(os.path.join(blessing_output.uri, 'BLESSED')))
    else:
      self.assertTrue(
          fileio.exists(os.path.join(blessing_output.uri, 'NOT_BLESSED')))
示例#6
0
  def setUp(self):
    super(ExecutorTest, self).setUp()
    self._source_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')
    self._output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create input dict.
    examples = standard_artifacts.Examples()
    examples.uri = os.path.join(self._source_data_dir,
                                'transform/transformed_examples')
    examples.split_names = artifact_utils.encode_split_names(['train', 'eval'])
    transform_output = standard_artifacts.TransformGraph()
    transform_output.uri = os.path.join(self._source_data_dir,
                                        'transform/transform_output')
    schema = standard_artifacts.Schema()
    schema.uri = os.path.join(self._source_data_dir, 'schema_gen')
    previous_model = standard_artifacts.Model()
    previous_model.uri = os.path.join(self._source_data_dir, 'trainer/previous')

    self._input_dict = {
        executor.EXAMPLES_KEY: [examples],
        executor.TRANSFORM_GRAPH_KEY: [transform_output],
        executor.SCHEMA_KEY: [schema],
        executor.BASE_MODEL_KEY: [previous_model]
    }

    # Create output dict.
    self._model_exports = standard_artifacts.Model()
    self._model_exports.uri = os.path.join(self._output_data_dir,
                                           'model_export_path')
    self._output_dict = {executor.OUTPUT_MODEL_KEY: [self._model_exports]}

    # Create exec properties skeleton.
    self._exec_properties = {
        'train_args':
            json_format.MessageToJson(
                trainer_pb2.TrainArgs(num_steps=1000),
                preserving_proto_field_name=True),
        'eval_args':
            json_format.MessageToJson(
                trainer_pb2.EvalArgs(num_steps=500),
                preserving_proto_field_name=True),
        'warm_starting':
            False,
    }

    self._module_file = os.path.join(self._source_data_dir, 'module_file',
                                     'trainer_module.py')
    self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__,
                                  trainer_module.trainer_fn.__name__)

    # Executors for test.
    self._trainer_executor = executor.Executor()
    self._generic_trainer_executor = executor.GenericExecutor()
示例#7
0
 def testEnableCache(self):
     examples = standard_artifacts.Examples()
     examples.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     statistics_gen_1 = component.StatisticsGen(
         examples=channel_utils.as_channel([examples]))
     self.assertEqual(None, statistics_gen_1.enable_cache)
     statistics_gen_2 = component.StatisticsGen(
         examples=channel_utils.as_channel([examples]), enable_cache=True)
     self.assertEqual(True, statistics_gen_2.enable_cache)
示例#8
0
    def testDoWithTwoSchemas(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        tf.io.gfile.makedirs(output_data_dir)

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])

        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(source_data_dir, 'schema_gen')

        input_dict = {
            executor.EXAMPLES_KEY: [examples],
            executor.SCHEMA_KEY: [schema]
        }

        exec_properties = {
            executor.STATS_OPTIONS_JSON_KEY:
            tfdv.StatsOptions(label_feature='company',
                              schema=schema_pb2.Schema()).to_json(),
        }

        # Create output dict.
        stats = standard_artifacts.ExampleStatistics()
        stats.uri = output_data_dir
        stats.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        output_dict = {
            executor.STATISTICS_KEY: [stats],
        }

        # Run executor.
        stats_gen_executor = executor.Executor()
        with self.assertRaises(ValueError):
            stats_gen_executor.Do(input_dict,
                                  output_dict,
                                  exec_properties=exec_properties)
示例#9
0
 def testConstruct(self):
     statistics_artifact = standard_artifacts.ExampleStatistics()
     statistics_artifact.split_names = artifact_utils.encode_split_names(
         ['eval'])
     example_validator = component.ExampleValidator(
         statistics=channel_utils.as_channel([statistics_artifact]),
         schema=channel_utils.as_channel([standard_artifacts.Schema()]),
     )
     self.assertEqual(standard_artifacts.ExampleAnomalies.TYPE_NAME,
                      example_validator.outputs['anomalies'].type_name)
示例#10
0
 def testConstruct(self):
     statistics_artifact = standard_artifacts.ExampleStatistics()
     statistics_artifact.split_names = artifact_utils.encode_split_names(
         ['train'])
     schema_gen = component.SchemaGen(
         statistics=channel_utils.as_channel([statistics_artifact]))
     self.assertEqual(standard_artifacts.Schema.TYPE_NAME,
                      schema_gen.outputs['schema'].type_name)
     self.assertFalse(
         schema_gen.spec.exec_properties['infer_feature_shape'])
示例#11
0
    def testEvalution(self, exec_properties, model_agnostic=False):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        baseline_model = standard_artifacts.Model()
        baseline_model.uri = os.path.join(source_data_dir, 'trainer/previous/')
        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(source_data_dir, 'schema_gen')
        input_dict = {
            EXAMPLES_KEY: [examples],
            SCHEMA_KEY: [schema],
        }
        if not model_agnostic:
            model = standard_artifacts.Model()
            model.uri = os.path.join(source_data_dir, 'trainer/current')
            input_dict[MODEL_KEY] = [model]

        # Create output dict.
        eval_output = standard_artifacts.ModelEvaluation()
        eval_output.uri = os.path.join(output_data_dir, 'eval_output')
        blessing_output = standard_artifacts.ModelBlessing()
        blessing_output.uri = os.path.join(output_data_dir, 'blessing_output')
        output_dict = {
            EVALUATION_KEY: [eval_output],
            BLESSING_KEY: [blessing_output],
        }

        # Test multiple splits.
        exec_properties[EXAMPLE_SPLITS_KEY] = json_utils.dumps(
            ['train', 'eval'])

        if MODULE_FILE_KEY in exec_properties:
            exec_properties[MODULE_FILE_KEY] = os.path.join(
                source_data_dir, 'module_file', 'evaluator_module.py')

        # Run executor.
        evaluator = executor.Executor()
        evaluator.Do(input_dict, output_dict, exec_properties)

        # Check evaluator outputs.
        self.assertTrue(
            fileio.exists(os.path.join(eval_output.uri, 'eval_config.json')))
        self.assertTrue(fileio.exists(os.path.join(eval_output.uri,
                                                   'metrics')))
        self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'plots')))
        self.assertFalse(
            fileio.exists(os.path.join(blessing_output.uri, 'BLESSED')))
示例#12
0
  def __init__(self,
               statistics: types.Channel = None,
               schema: types.Channel = None,
               exclude_splits: Optional[List[Text]] = None,
               output: Optional[types.Channel] = None,
               stats: Optional[types.Channel] = None,
               instance_name: Optional[Text] = None):
    """Construct an ExampleValidator component.

    Args:
      statistics: A Channel of type `standard_artifacts.ExampleStatistics`. This
        should contain at least 'eval' split. Other splits are currently
        ignored.
      schema: A Channel of type `standard_artifacts.Schema`. _required_
      exclude_splits: Names of splits that the example validator should not
        validate. Default behavior (when exclude_splits is set to None)
        is excluding no splits.
      output: Output channel of type `standard_artifacts.ExampleAnomalies`.
      stats: Backwards compatibility alias for the 'statistics' argument.
      instance_name: Optional name assigned to this specific instance of
        ExampleValidator. Required only if multiple ExampleValidator components
        are declared in the same pipeline.  Either `stats` or `statistics` must
        be present in the arguments.
    """
    if stats:
      logging.warning(
          'The "stats" argument to the StatisticsGen component has '
          'been renamed to "statistics" and is deprecated. Please update your '
          'usage as support for this argument will be removed soon.')
      statistics = stats
    if exclude_splits is None:
      exclude_splits = []
      logging.info('Excluding no splits because exclude_splits is not set.')
    anomalies = output
    if not anomalies:
      anomalies_artifact = standard_artifacts.ExampleAnomalies()
      statistics_split_names = artifact_utils.decode_split_names(
          artifact_utils.get_single_instance(list(
              statistics.get())).split_names)
      split_names = [
          split for split in statistics_split_names
          if split not in exclude_splits
      ]
      anomalies_artifact.split_names = artifact_utils.encode_split_names(
          split_names)
      anomalies = types.Channel(
          type=standard_artifacts.ExampleAnomalies,
          artifacts=[anomalies_artifact])
    spec = ExampleValidatorSpec(
        statistics=statistics,
        schema=schema,
        exclude_splits=json_utils.dumps(exclude_splits),
        anomalies=anomalies)
    super(ExampleValidator, self).__init__(
        spec=spec, instance_name=instance_name)
示例#13
0
    def _import_artifacts(self, source_uri: List[Text], reimport: bool,
                          destination_channel: types.Channel,
                          split_names: List[Text]) -> List[types.Artifact]:
        """Imports external resource in MLMD."""
        results = []
        for uri, s in zip(source_uri, split_names):
            absl.logging.info('Processing source uri: %s, split: %s' %
                              (uri, s or 'NO_SPLIT'))

            result = destination_channel.type()

            # TODO(ccy): refactor importer to treat split name just like any other
            # property.
            unfiltered_previous_artifacts = self._metadata_handler.get_artifacts_by_uri(
                uri)
            # Filter by split name.
            desired_split_names = artifact_utils.encode_split_names([s or ''])
            previous_artifacts = []
            for previous_artifact in unfiltered_previous_artifacts:
                # TODO(ccy): refactor importer to treat split name just like any other
                # property.
                if result.PROPERTIES and SPLIT_KEY in result.PROPERTIES:
                    # Consider the previous artifact only if the split_names match.
                    split_names = previous_artifact.properties.get(
                        'split_names', None)
                    if split_names and split_names.string_value == desired_split_names:
                        previous_artifacts.append(previous_artifact)
                else:
                    # Unconditionally add the previous artifact for consideration.
                    previous_artifacts.append(previous_artifact)

            # TODO(ccy): refactor importer to treat split name just like any other
            # property.
            if SPLIT_KEY in result.artifact_type.properties:
                result.split_names = desired_split_names
            result.uri = uri

            # If any registered artifact with the same uri also has the same
            # fingerprint and user does not ask for re-import, just reuse the latest.
            # Otherwise, register the external resource into MLMD using the type info
            # in the destination channel.
            if bool(previous_artifacts) and not reimport:
                absl.logging.info('Reusing existing artifact')
                result.set_mlmd_artifact(
                    max(previous_artifacts, key=lambda m: m.id))
            else:
                [registered_artifact
                 ] = self._metadata_handler.publish_artifacts([result])
                absl.logging.info('Registered new artifact: %s' %
                                  registered_artifact)
                result.set_mlmd_artifact(registered_artifact)

            results.append(result)

        return results
示例#14
0
  def testDo(self, mock_client):
    # Mock query result schema for _BigQueryConverter.
    mock_client.return_value.query.return_value.result.return_value.schema = self._schema

    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create output dict.
    examples = standard_artifacts.Examples()
    examples.uri = output_data_dir
    output_dict = {'examples': [examples]}

    # Create exe properties.
    exec_properties = {
        'input_config':
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, b, f, s FROM `fake`'),
                ])),
        'output_config':
            proto_utils.proto_to_json(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(
                            name='train', hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(
                            name='eval', hash_buckets=1)
                    ])))
    }

    # Run executor.
    big_query_example_gen = executor.Executor(
        base_beam_executor.BaseBeamExecutor.Context(
            beam_pipeline_args=['--project=test-project']))
    big_query_example_gen.Do({}, output_dict, exec_properties)

    mock_client.assert_called_with(project='test-project')

    self.assertEqual(
        artifact_utils.encode_split_names(['train', 'eval']),
        examples.split_names)

    # Check BigQuery example gen outputs.
    train_output_file = os.path.join(examples.uri, 'Split-train',
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(examples.uri, 'Split-eval',
                                    'data_tfrecord-00000-of-00001.gz')
    self.assertTrue(fileio.exists(train_output_file))
    self.assertTrue(fileio.exists(eval_output_file))
    self.assertGreater(
        fileio.open(train_output_file).size(),
        fileio.open(eval_output_file).size())
示例#15
0
 def testEnableCache(self):
   statistics_artifact = standard_artifacts.ExampleStatistics()
   statistics_artifact.split_names = artifact_utils.encode_split_names(
       ['train'])
   schema_gen_1 = component.SchemaGen(
       statistics=channel_utils.as_channel([statistics_artifact]))
   schema_gen_2 = component.SchemaGen(
       statistics=channel_utils.as_channel([statistics_artifact]),
       enable_cache=True)
   self.assertEqual(None, schema_gen_1.enable_cache)
   self.assertEqual(True, schema_gen_2.enable_cache)
示例#16
0
 def testConstruct(self):
     examples = standard_artifacts.Examples()
     examples.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     exclude_splits = ['eval']
     statistics_gen = component.StatisticsGen(
         examples=channel_utils.as_channel([examples]),
         exclude_splits=exclude_splits)
     self.assertEqual(standard_artifacts.ExampleStatistics.TYPE_NAME,
                      statistics_gen.outputs['statistics'].type_name)
     self.assertEqual(statistics_gen.spec.exec_properties['exclude_splits'],
                      '["eval"]')
示例#17
0
    def __init__(self,
                 input_config: Union[example_gen_pb2.Input, Dict[Text, Any]],
                 output_config: Optional[Union[example_gen_pb2.Output,
                                               Dict[Text, Any]]] = None,
                 custom_config: Optional[Union[example_gen_pb2.CustomConfig,
                                               Dict[Text, Any]]] = None,
                 example_artifacts: Optional[types.Channel] = None,
                 instance_name: Optional[Text] = None,
                 enable_cache: Optional[bool] = None):
        """Construct an QueryBasedExampleGen component.

    Args:
      input_config: An
        [example_gen_pb2.Input](https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto)
          instance, providing input configuration. If any field is provided as a
        RuntimeParameter, input_config should be constructed as a dict with the
        same field names as Input proto message. _required_
      output_config: An
        [example_gen_pb2.Output](https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto)
          instance, providing output configuration. If unset, the default splits
        will be labeled as 'train' and 'eval' with a distribution ratio of 2:1.
        If any field is provided as a RuntimeParameter, output_config should be
        constructed as a dict with the same field names as Output proto message.
      custom_config: An
        [example_gen_pb2.CustomConfig](https://github.com/tensorflow/tfx/blob/master/tfx/proto/example_gen.proto)
          instance, providing custom configuration for ExampleGen. If any field
          is provided as a RuntimeParameter, output_config should be
          constructed as a dict.
      example_artifacts: Channel of `standard_artifacts.Examples` for output
        train and eval examples.
      instance_name: Optional unique instance name. Required only if multiple
        ExampleGen components are declared in the same pipeline.
      enable_cache: Optional boolean to indicate if cache is enabled for the
        QueryBasedExampleGen component. If not specified, defaults to the value
        specified for pipeline's enable_cache parameter.
    """
        # Configure outputs.
        output_config = output_config or utils.make_default_output_config(
            input_config)
        if not example_artifacts:
            artifact = standard_artifacts.Examples()
            artifact.split_names = artifact_utils.encode_split_names(
                utils.generate_output_split_names(input_config, output_config))
            example_artifacts = channel_utils.as_channel([artifact])
        spec = QueryBasedExampleGenSpec(input_config=input_config,
                                        output_config=output_config,
                                        custom_config=custom_config,
                                        examples=example_artifacts)
        super(_QueryBasedExampleGen,
              self).__init__(spec=spec,
                             instance_name=instance_name,
                             enable_cache=enable_cache)
示例#18
0
  def testDo(self, mock_client):
    # Mock query result schema for _BigQueryConverter.
    mock_client.return_value.query.return_value.result.return_value.schema = self._schema

    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create output dict.
    examples = standard_artifacts.Examples()
    examples.uri = output_data_dir
    output_dict = {'examples': [examples]}

    # Create exe properties.
    exec_properties = {
        'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, b, f, s FROM `fake`'),
                ]),
                preserving_proto_field_name=True),
        'output_config':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(
                            name='train', hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(
                            name='eval', hash_buckets=1)
                    ])),
                preserving_proto_field_name=True)
    }

    # Run executor.
    big_query_example_gen = executor.Executor()
    big_query_example_gen.Do({}, output_dict, exec_properties)

    self.assertEqual(
        artifact_utils.encode_split_names(['train', 'eval']),
        examples.split_names)

    # Check BigQuery example gen outputs.
    train_output_file = os.path.join(examples.uri, 'train',
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(examples.uri, 'eval',
                                    'data_tfrecord-00000-of-00001.gz')
    self.assertTrue(tf.io.gfile.exists(train_output_file))
    self.assertTrue(tf.io.gfile.exists(eval_output_file))
    self.assertGreater(
        tf.io.gfile.GFile(train_output_file).size(),
        tf.io.gfile.GFile(eval_output_file).size())
示例#19
0
 def testConstruct(self):
   statistics_artifact = standard_artifacts.ExampleStatistics()
   statistics_artifact.split_names = artifact_utils.encode_split_names(
       ['train', 'eval'])
   exclude_splits = ['eval']
   example_validator = component.ExampleValidator(
       statistics=channel_utils.as_channel([statistics_artifact]),
       schema=channel_utils.as_channel([standard_artifacts.Schema()]),
       exclude_splits=exclude_splits)
   self.assertEqual(standard_artifacts.ExampleAnomalies.TYPE_NAME,
                    example_validator.outputs['anomalies'].type_name)
   self.assertEqual(example_validator.spec.exec_properties['exclude_splits'],
                    '["eval"]')
示例#20
0
 def testGetSplitUriDeprecated(self):
     with mock.patch.object(tf_logging, 'warning'):
         warn_mock = mock.MagicMock()
         tf_logging.warning = warn_mock
         my_artifact = artifact.Artifact('TestType')
         my_artifact.uri = '123'
         my_artifact.split_names = artifact_utils.encode_split_names(
             ['train'])
         self.assertEqual('123/train',
                          types.get_split_uri([my_artifact], 'train'))
         warn_mock.assert_called_once()
         self.assertIn('tfx.utils.types.get_split_uri has been renamed to',
                       warn_mock.call_args[0][5])
示例#21
0
 def testConstructWithParameter(self):
   statistics_artifact = standard_artifacts.ExampleStatistics()
   statistics_artifact.split_names = artifact_utils.encode_split_names(
       ['train'])
   infer_shape = data_types.RuntimeParameter(name='infer-shape', ptype=bool)
   schema_gen = component.SchemaGen(
       statistics=channel_utils.as_channel([statistics_artifact]),
       infer_feature_shape=infer_shape)
   self.assertEqual(standard_artifacts.Schema.TYPE_NAME,
                    schema_gen.outputs['schema'].type_name)
   self.assertJsonEqual(
       str(schema_gen.spec.exec_properties['infer_feature_shape']),
       str(infer_shape))
示例#22
0
    def _make_base_do_params(self, source_data_dir, output_data_dir):
        # Create input dict.
        example1 = standard_artifacts.Examples()
        example1.uri = self._ARTIFACT1_URI
        example1.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        example2 = copy.deepcopy(example1)
        example2.uri = self._ARTIFACT2_URI

        self._example_artifacts = [example1, example2]

        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen')

        self._input_dict = {
            standard_component_specs.EXAMPLES_KEY: self._example_artifacts[:1],
            standard_component_specs.SCHEMA_KEY: [schema_artifact],
        }

        # Create output dict.
        self._transformed_output = standard_artifacts.TransformGraph()
        self._transformed_output.uri = os.path.join(output_data_dir,
                                                    'transformed_graph')
        transformed1 = standard_artifacts.Examples()
        transformed1.uri = os.path.join(output_data_dir,
                                        'transformed_examples', '0')
        transformed2 = standard_artifacts.Examples()
        transformed2.uri = os.path.join(output_data_dir,
                                        'transformed_examples', '1')

        self._transformed_example_artifacts = [transformed1, transformed2]

        temp_path_output = _TempPath()
        temp_path_output.uri = tempfile.mkdtemp()
        self._updated_analyzer_cache_artifact = standard_artifacts.TransformCache(
        )
        self._updated_analyzer_cache_artifact.uri = os.path.join(
            self._output_data_dir, 'CACHE')

        self._output_dict = {
            standard_component_specs.TRANSFORM_GRAPH_KEY:
            [self._transformed_output],
            standard_component_specs.TRANSFORMED_EXAMPLES_KEY:
            self._transformed_example_artifacts[:1],
            executor.TEMP_PATH_KEY: [temp_path_output],
            standard_component_specs.UPDATED_ANALYZER_CACHE_KEY:
            [self._updated_analyzer_cache_artifact],
        }

        # Create exec properties skeleton.
        self._exec_properties = {}
示例#23
0
 def setUp(self):
     super(PlaceholderUtilsTest, self).setUp()
     examples = [standard_artifacts.Examples()]
     examples[0].uri = "/tmp"
     examples[0].split_names = artifact_utils.encode_split_names(
         ["train", "eval"])
     self._serving_spec = infra_validator_pb2.ServingSpec()
     self._serving_spec.tensorflow_serving.tags.extend(
         ["latest", "1.15.0-gpu"])
     self._resolution_context = placeholder_utils.ResolutionContext(
         exec_info=data_types.ExecutionInfo(
             input_dict={
                 "model": [standard_artifacts.Model()],
                 "examples": examples,
             },
             output_dict={"blessing": [standard_artifacts.ModelBlessing()]},
             exec_properties={
                 "proto_property":
                 json_format.MessageToJson(message=self._serving_spec,
                                           sort_keys=True,
                                           preserving_proto_field_name=True,
                                           indent=0)
             },
             execution_output_uri="test_executor_output_uri",
             stateful_working_dir="test_stateful_working_dir",
             pipeline_node=pipeline_pb2.PipelineNode(
                 node_info=pipeline_pb2.NodeInfo(
                     type=metadata_store_pb2.ExecutionType(
                         name="infra_validator"))),
             pipeline_info=pipeline_pb2.PipelineInfo(
                 id="test_pipeline_id")),
         executor_spec=executable_spec_pb2.PythonClassExecutableSpec(
             class_path="test_class_path"),
     )
     # Resolution context to simulate missing optional values.
     self._none_resolution_context = placeholder_utils.ResolutionContext(
         exec_info=data_types.ExecutionInfo(
             input_dict={
                 "model": [],
                 "examples": [],
             },
             output_dict={"blessing": []},
             exec_properties={},
             pipeline_node=pipeline_pb2.PipelineNode(
                 node_info=pipeline_pb2.NodeInfo(
                     type=metadata_store_pb2.ExecutionType(
                         name="infra_validator"))),
             pipeline_info=pipeline_pb2.PipelineInfo(
                 id="test_pipeline_id")),
         executor_spec=None,
         platform_config=None)
示例#24
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:

        source = exec_properties[StepKeys.SOURCE]
        args = exec_properties[StepKeys.ARGS]

        c = source_utils.load_source_path_class(source)
        tokenizer_step: BaseTokenizer = c(**args)

        tokenizer_location = artifact_utils.get_single_uri(
            output_dict["tokenizer"])

        split_uris, split_names, all_files = [], [], []
        for artifact in input_dict["examples"]:
            for split in artifact_utils.decode_split_names(
                    artifact.split_names):
                split_names.append(split)
                uri = os.path.join(artifact.uri, split)
                split_uris.append((split, uri))
                all_files += path_utils.list_dir(uri)

        # Get output split path
        output_examples = artifact_utils.get_single_instance(
            output_dict["output_examples"])
        output_examples.split_names = artifact_utils.encode_split_names(
            split_names)

        if not tokenizer_step.skip_training:
            tokenizer_step.train(files=all_files)

            tokenizer_step.save(output_dir=tokenizer_location)

        with self._make_beam_pipeline() as p:
            for split, uri in split_uris:
                input_uri = io_utils.all_files_pattern(uri)

                _ = (p
                     | 'ReadData.' + split >> beam.io.ReadFromTFRecord(
                            file_pattern=input_uri)
                     | "ParseTFExFromString." + split >> beam.Map(
                            tf.train.Example.FromString)
                     | "AddTokens." + split >> beam.Map(
                            append_tf_example,
                            tokenizer_step=tokenizer_step)
                     | 'Serialize.' + split >> beam.Map(
                            lambda x: x.SerializeToString())
                     | 'WriteSplit.' + split >> WriteSplit(
                            get_split_uri(
                                output_dict["output_examples"],
                                split)))
示例#25
0
  def testDo(self):
    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create output dict.
    examples = standard_artifacts.Examples()
    examples.uri = output_data_dir
    output_dict = {utils.EXAMPLES_KEY: [examples]}

    # Create exec proterties.
    exec_properties = {
        utils.INPUT_BASE_KEY:
            self._input_data_dir,
        utils.INPUT_CONFIG_KEY:
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='avro', pattern='avro/*.avro'),
                ]),
                preserving_proto_field_name=True),
        utils.OUTPUT_CONFIG_KEY:
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(
                            name='train', hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(
                            name='eval', hash_buckets=1)
                    ])),
                preserving_proto_field_name=True)
    }

    # Run executor.
    avro_example_gen = avro_executor.Executor()
    avro_example_gen.Do({}, output_dict, exec_properties)

    self.assertEqual(
        artifact_utils.encode_split_names(['train', 'eval']),
        examples.split_names)

    # Check Avro example gen outputs.
    train_output_file = os.path.join(examples.uri, 'train',
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(examples.uri, 'eval',
                                    'data_tfrecord-00000-of-00001.gz')
    self.assertTrue(fileio.exists(train_output_file))
    self.assertTrue(fileio.exists(eval_output_file))
    self.assertGreater(
        fileio.open(train_output_file).size(),
        fileio.open(eval_output_file).size())
示例#26
0
  def setUp(self):
    super(ExecutorTest, self).setUp()

    source_data_dir = os.path.dirname(os.path.dirname(__file__))
    input_data_dir = os.path.join(source_data_dir, 'testdata')

    statistics = standard_artifacts.ExampleStatistics()
    statistics.uri = os.path.join(input_data_dir,
                                  'StatisticsGen.train_mockdata_1',
                                  'statistics', '5')
    statistics.split_names = artifact_utils.encode_split_names(
        ['train', 'eval'])
    transformed_examples = standard_artifacts.Examples()
    transformed_examples.uri = os.path.join(input_data_dir,
                                            'Transform.train_mockdata_1',
                                            'transformed_examples', '10')
    transformed_examples.split_names = artifact_utils.encode_split_names(
        ['train', 'eval'])
    self._input_dict = {
        executor.EXAMPLES_KEY: [transformed_examples],
        executor.STATISTICS_KEY: [statistics],
    }

    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR',
                       tempfile.mkdtemp(dir=flags.FLAGS.test_tmpdir)),
        self._testMethodName)
    self._metafeatures = artifacts.MetaFeatures()
    self._metafeatures.uri = output_data_dir
    self._output_dict = {
        executor.METAFEATURES_KEY: [self._metafeatures],
    }

    self._exec_properties = {
        'custom_config': {
            'problem_statement_path': '/some/fake/path'
        }
    }
示例#27
0
 def testConstruct(self):
     examples = standard_artifacts.Examples()
     examples.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     examples.span = 1
     binder = binder_component.DataViewBinder(
         input_examples=channel_utils.as_channel([examples]),
         data_view=channel_utils.as_channel([standard_artifacts.DataView()
                                             ]))
     output_examples = binder.outputs['output_examples']
     self.assertIsNotNone(output_examples)
     output_examples = output_examples.get()
     self.assertLen(output_examples, 1)
     self._assert_example_artifact_equal(output_examples[0], examples)
示例#28
0
    def setUp(self):
        super().setUp()
        self._testdata_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        self._output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        self._context = executor.Executor.Context(
            tmp_dir=self._output_data_dir, unique_id='1')

        # Create input dict.
        e1 = standard_artifacts.Examples()
        e1.uri = os.path.join(self._testdata_dir, 'penguin', 'data')
        e1.split_names = artifact_utils.encode_split_names(['train', 'eval'])

        e2 = copy.deepcopy(e1)

        self._single_artifact = [e1]
        self._multiple_artifacts = [e1, e2]

        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(self._testdata_dir, 'penguin', 'schema')

        base_model = standard_artifacts.Model()
        base_model.uri = os.path.join(self._testdata_dir, 'trainer/previous')

        self._input_dict = {
            standard_component_specs.EXAMPLES_KEY: self._single_artifact,
            standard_component_specs.SCHEMA_KEY: [schema],
            standard_component_specs.BASE_MODEL_KEY: [base_model]
        }

        # Create output dict.
        self._best_hparams = standard_artifacts.Model()
        self._best_hparams.uri = os.path.join(self._output_data_dir,
                                              'best_hparams')

        self._output_dict = {
            standard_component_specs.BEST_HYPERPARAMETERS_KEY:
            [self._best_hparams],
        }

        # Create exec properties.
        self._exec_properties = {
            standard_component_specs.TRAIN_ARGS_KEY:
            proto_utils.proto_to_json(trainer_pb2.TrainArgs(num_steps=100)),
            standard_component_specs.EVAL_ARGS_KEY:
            proto_utils.proto_to_json(trainer_pb2.EvalArgs(num_steps=50)),
        }
示例#29
0
 def testEnableCache(self):
     statistics_artifact = standard_artifacts.ExampleStatistics()
     statistics_artifact.split_names = artifact_utils.encode_split_names(
         ['eval'])
     example_validator_1 = component.ExampleValidator(
         statistics=channel_utils.as_channel([statistics_artifact]),
         schema=channel_utils.as_channel([standard_artifacts.Schema()]),
     )
     self.assertEqual(None, example_validator_1.enable_cache)
     example_validator_2 = component.ExampleValidator(
         statistics=channel_utils.as_channel([statistics_artifact]),
         schema=channel_utils.as_channel([standard_artifacts.Schema()]),
         enable_cache=True)
     self.assertEqual(True, example_validator_2.enable_cache)
示例#30
0
 def testConstructWithSchemaAndStatsOptions(self):
     examples = standard_artifacts.Examples()
     examples.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     schema = standard_artifacts.Schema()
     stats_options = tfdv.StatsOptions(weight_feature='weight')
     statistics_gen = component.StatisticsGen(
         examples=channel_utils.as_channel([examples]),
         schema=channel_utils.as_channel([schema]),
         stats_options=stats_options)
     self.assertEqual(
         standard_artifacts.ExampleStatistics.TYPE_NAME,
         statistics_gen.outputs[
             standard_component_specs.STATISTICS_KEY].type_name)