Exemplo n.º 1
0
  def testDo(self):
    source_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')
    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)
    tf.io.gfile.makedirs(output_data_dir)

    # Create input dict.
    train_examples = standard_artifacts.Examples(split='train')
    train_examples.uri = os.path.join(source_data_dir, 'csv_example_gen/train/')
    eval_examples = standard_artifacts.Examples(split='eval')
    eval_examples.uri = os.path.join(source_data_dir, 'csv_example_gen/eval/')

    train_stats = standard_artifacts.ExampleStatistics(split='train')
    train_stats.uri = os.path.join(output_data_dir, 'train', '')
    eval_stats = standard_artifacts.ExampleStatistics(split='eval')
    eval_stats.uri = os.path.join(output_data_dir, 'eval', '')
    input_dict = {
        'input_data': [train_examples, eval_examples],
    }

    output_dict = {
        'output': [train_stats, eval_stats],
    }

    # Run executor.
    evaluator = executor.Executor()
    evaluator.Do(input_dict, output_dict, exec_properties={})

    # Check statistics_gen outputs.
    self._validate_stats_output(os.path.join(train_stats.uri, 'stats_tfrecord'))
    self._validate_stats_output(os.path.join(eval_stats.uri, 'stats_tfrecord'))
Exemplo n.º 2
0
 def testConstruct(self):
   schema_gen = component.SchemaGen(
       stats=channel_utils.as_channel(
           [standard_artifacts.ExampleStatistics(split='train')]),
       infer_feature_shape=True)
   self.assertEqual('SchemaPath', schema_gen.outputs.output.type_name)
   self.assertTrue(schema_gen.spec.exec_properties['infer_feature_shape'])
Exemplo n.º 3
0
 def testConstruct(self):
     schema_gen = component.SchemaGen(statistics=channel_utils.as_channel(
         [standard_artifacts.ExampleStatistics(split='train')]))
     self.assertEqual(standard_artifacts.Schema.TYPE_NAME,
                      schema_gen.outputs['schema'].type_name)
     self.assertFalse(
         schema_gen.spec.exec_properties['infer_feature_shape'])
Exemplo n.º 4
0
    def testDo(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        tf.io.gfile.makedirs(output_data_dir)

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])

        stats = standard_artifacts.ExampleStatistics()
        stats.uri = output_data_dir
        stats.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        input_dict = {
            executor.EXAMPLES_KEY: [examples],
        }

        output_dict = {
            executor.STATISTICS_KEY: [stats],
        }

        # Run executor.
        stats_gen_executor = executor.Executor()
        stats_gen_executor.Do(input_dict, output_dict, exec_properties={})

        # Check statistics_gen outputs.
        self._validate_stats_output(
            os.path.join(stats.uri, 'train', 'stats_tfrecord'))
        self._validate_stats_output(
            os.path.join(stats.uri, 'eval', 'stats_tfrecord'))
Exemplo n.º 5
0
    def testDo(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        statistics_artifact = standard_artifacts.ExampleStatistics()
        statistics_artifact.uri = os.path.join(source_data_dir,
                                               'statistics_gen')
        statistics_artifact.split_names = artifact_utils.encode_split_names(
            ['train'])

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        schema_output = standard_artifacts.Schema()
        schema_output.uri = os.path.join(output_data_dir, 'schema_output')

        input_dict = {
            'stats': [statistics_artifact],
        }
        output_dict = {
            'output': [schema_output],
        }

        exec_properties = {'infer_feature_shape': False}

        schema_gen_executor = executor.Executor()
        schema_gen_executor.Do(input_dict, output_dict, exec_properties)
        self.assertNotEqual(0, len(tf.io.gfile.listdir(schema_output.uri)))
Exemplo n.º 6
0
  def __init__(self,
               input_data: types.Channel = None,
               output: Optional[types.Channel] = None,
               examples: Optional[types.Channel] = None,
               name: Optional[Text] = None):
    """Construct a StatisticsGen component.

    Args:
      input_data: A Channel of 'ExamplesPath' type. This should contain two
        splits 'train' and 'eval' (required).
      output: Optional 'ExampleStatisticsPath' channel for statistics of each
        split provided in input examples.
      examples: Forwards compatibility alias for the 'input_data' argument.
      name: Optional unique name. Necessary iff multiple StatisticsGen
        components are declared in the same pipeline.
    """
    input_data = input_data or examples
    output = output or types.Channel(
        type=standard_artifacts.ExampleStatistics,
        artifacts=[
            standard_artifacts.ExampleStatistics(split=split)
            for split in artifact.DEFAULT_EXAMPLE_SPLITS
        ])
    spec = StatisticsGenSpec(
        input_data=input_data, output=output)
    super(StatisticsGen, self).__init__(spec=spec, name=name)
Exemplo n.º 7
0
    def __init__(self,
                 examples: types.Channel = None,
                 output: Optional[types.Channel] = None,
                 input_data: Optional[types.Channel] = None,
                 instance_name: Optional[Text] = None):
        """Construct a StatisticsGen component.

    Args:
      examples: A Channel of `ExamplesPath` type, likely generated by the
        [ExampleGen component](https://www.tensorflow.org/tfx/guide/examplegen).
        This needs to contain two splits labeled `train` and `eval`. _required_
      output: `ExampleStatisticsPath` channel for statistics of each split
        provided in the input examples.
      input_data: Backwards compatibility alias for the `examples` argument.
      instance_name: Optional name assigned to this specific instance of
        StatisticsGen.  Required only if multiple StatisticsGen components are
        declared in the same pipeline.
    """
        examples = examples or input_data
        if not output:
            statistics_artifact = standard_artifacts.ExampleStatistics()
            statistics_artifact.split_names = artifact_utils.encode_split_names(
                artifact.DEFAULT_EXAMPLE_SPLITS)
            output = types.Channel(type=standard_artifacts.ExampleStatistics,
                                   artifacts=[statistics_artifact])
        spec = StatisticsGenSpec(input_data=examples, output=output)
        super(StatisticsGen, self).__init__(spec=spec,
                                            instance_name=instance_name)
Exemplo n.º 8
0
    def setUp(self):
        super(ExecutorTest, self).setUp()

        self.source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        self.train_stats_artifact = standard_artifacts.ExampleStatistics(
            split='train')
        self.train_stats_artifact.uri = os.path.join(self.source_data_dir,
                                                     'statistics_gen/train/')

        self.output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        self.schema_output = standard_artifacts.Schema()
        self.schema_output.uri = os.path.join(self.output_data_dir,
                                              'schema_output')

        self.schema = standard_artifacts.Schema()
        self.schema.uri = os.path.join(self.source_data_dir, 'fixed_schema/')

        self.expected_schema = standard_artifacts.Schema()
        self.expected_schema.uri = os.path.join(self.source_data_dir,
                                                'schema_gen/')

        self.input_dict = {
            'stats': [self.train_stats_artifact],
            'schema': None
        }
        self.output_dict = {
            'output': [self.schema_output],
        }
        self.exec_properties = {'infer_feature_shape': False}
Exemplo n.º 9
0
  def testDo(self):
    source_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')

    statistics_artifact = standard_artifacts.ExampleStatistics()
    statistics_artifact.uri = os.path.join(source_data_dir, 'statistics_gen')
    statistics_artifact.split_names = artifact_utils.encode_split_names(
        ['train', 'eval', 'test'])

    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    schema_output = standard_artifacts.Schema()
    schema_output.uri = os.path.join(output_data_dir, 'schema_output')

    input_dict = {
        standard_component_specs.STATISTICS_KEY: [statistics_artifact],
    }

    exec_properties = {
        # List needs to be serialized before being passed into Do function.
        standard_component_specs.EXCLUDE_SPLITS_KEY:
            json_utils.dumps(['test'])
    }

    output_dict = {
        standard_component_specs.SCHEMA_KEY: [schema_output],
    }

    schema_gen_executor = executor.Executor()
    schema_gen_executor.Do(input_dict, output_dict, exec_properties)
    self.assertNotEqual(0, len(fileio.listdir(schema_output.uri)))
Exemplo n.º 10
0
    def testDo(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        eval_stats_artifact = standard_artifacts.ExampleStatistics()
        eval_stats_artifact.uri = os.path.join(source_data_dir,
                                               'statistics_gen')
        eval_stats_artifact.split_names = artifact_utils.encode_split_names(
            ['train', 'eval', 'test'])

        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        validation_output = standard_artifacts.ExampleAnomalies()
        validation_output.uri = os.path.join(output_data_dir, 'output')

        input_dict = {
            STATISTICS_KEY: [eval_stats_artifact],
            SCHEMA_KEY: [schema_artifact],
        }

        exec_properties = {
            # List needs to be serialized before being passed into Do function.
            EXCLUDE_SPLITS_KEY: json_utils.dumps(['test'])
        }

        output_dict = {
            ANOMALIES_KEY: [validation_output],
        }

        example_validator_executor = executor.Executor()
        example_validator_executor.Do(input_dict, output_dict, exec_properties)

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         validation_output.split_names)

        # Check example_validator outputs.
        train_anomalies_path = os.path.join(validation_output.uri,
                                            'Split-train', 'SchemaDiff.pb')
        eval_anomalies_path = os.path.join(validation_output.uri, 'Split-eval',
                                           'SchemaDiff.pb')
        self.assertTrue(fileio.exists(train_anomalies_path))
        self.assertTrue(fileio.exists(eval_anomalies_path))
        train_anomalies_bytes = io_utils.read_bytes_file(train_anomalies_path)
        train_anomalies = anomalies_pb2.Anomalies()
        train_anomalies.ParseFromString(train_anomalies_bytes)
        eval_anomalies_bytes = io_utils.read_bytes_file(eval_anomalies_path)
        eval_anomalies = anomalies_pb2.Anomalies()
        eval_anomalies.ParseFromString(eval_anomalies_bytes)
        self.assertEqual(0, len(train_anomalies.anomaly_info))
        self.assertEqual(0, len(eval_anomalies.anomaly_info))

        # Assert 'test' split is excluded.
        train_file_path = os.path.join(validation_output.uri, 'Split-test',
                                       'SchemaDiff.pb')
        self.assertFalse(fileio.exists(train_file_path))
Exemplo n.º 11
0
 def testConstruct(self):
   example_validator = component.ExampleValidator(
       statistics=channel_utils.as_channel(
           [standard_artifacts.ExampleStatistics(split='eval')]),
       schema=channel_utils.as_channel([standard_artifacts.Schema()]),
   )
   self.assertEqual(standard_artifacts.ExampleAnomalies.TYPE_NAME,
                    example_validator.outputs['anomalies'].type_name)
Exemplo n.º 12
0
 def testConstruct(self):
   example_validator = component.ExampleValidator(
       stats=channel_utils.as_channel(
           [standard_artifacts.ExampleStatistics(split='eval')]),
       schema=channel_utils.as_channel([standard_artifacts.Schema()]),
   )
   self.assertEqual('ExampleValidationPath',
                    example_validator.outputs['output'].type_name)
Exemplo n.º 13
0
 def testConstructWithParameter(self):
   infer_shape = data_types.RuntimeParameter(name='infer-shape', ptype=bool)
   schema_gen = component.SchemaGen(
       statistics=channel_utils.as_channel(
           [standard_artifacts.ExampleStatistics(split='train')]),
       infer_feature_shape=infer_shape)
   self.assertEqual('SchemaPath', schema_gen.outputs['schema'].type_name)
   self.assertJsonEqual(
       str(schema_gen.spec.exec_properties['infer_feature_shape']),
       str(infer_shape))
Exemplo n.º 14
0
 def testConstruct(self):
     statistics_artifact = standard_artifacts.ExampleStatistics()
     statistics_artifact.split_names = artifact_utils.encode_split_names(
         ['eval'])
     example_validator = component.ExampleValidator(
         statistics=channel_utils.as_channel([statistics_artifact]),
         schema=channel_utils.as_channel([standard_artifacts.Schema()]),
     )
     self.assertEqual(standard_artifacts.ExampleAnomalies.TYPE_NAME,
                      example_validator.outputs['anomalies'].type_name)
Exemplo n.º 15
0
  def testGetStatusOutputPathsEntriesMissingArtifact(self):
    pre_transform_stats = standard_artifacts.ExampleStatistics()
    pre_transform_stats.uri = '/pre_transform_stats'

    with self.assertRaisesRegex(
        ValueError, 'all stats_output_paths should be specified or none'):
      executor_utils.GetStatsOutputPathEntries(False, {
          standard_component_specs.PRE_TRANSFORM_STATS_KEY:
              [pre_transform_stats]
      })
Exemplo n.º 16
0
  def testGetStatusOutputPathsEntries(self):
    # disabled.
    self.assertEmpty(executor_utils.GetStatsOutputPathEntries(True, {}))

    # enabled.
    pre_transform_stats = standard_artifacts.ExampleStatistics()
    pre_transform_stats.uri = '/pre_transform_stats'
    pre_transform_schema = standard_artifacts.Schema()
    pre_transform_schema.uri = '/pre_transform_schema'
    post_transform_anomalies = standard_artifacts.ExampleAnomalies()
    post_transform_anomalies.uri = '/post_transform_anomalies'
    post_transform_stats = standard_artifacts.ExampleStatistics()
    post_transform_stats.uri = '/post_transform_stats'
    post_transform_schema = standard_artifacts.Schema()
    post_transform_schema.uri = '/post_transform_schema'

    result = executor_utils.GetStatsOutputPathEntries(
        False, {
            standard_component_specs.PRE_TRANSFORM_STATS_KEY:
                [pre_transform_stats],
            standard_component_specs.PRE_TRANSFORM_SCHEMA_KEY:
                [pre_transform_schema],
            standard_component_specs.POST_TRANSFORM_ANOMALIES_KEY:
                [post_transform_anomalies],
            standard_component_specs.POST_TRANSFORM_STATS_KEY:
                [post_transform_stats],
            standard_component_specs.POST_TRANSFORM_SCHEMA_KEY:
                [post_transform_schema],
        })
    self.assertEqual(
        {
            labels.PRE_TRANSFORM_OUTPUT_STATS_PATH_LABEL:
                '/pre_transform_stats',
            labels.PRE_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL:
                '/pre_transform_schema',
            labels.POST_TRANSFORM_OUTPUT_ANOMALIES_PATH_LABEL:
                '/post_transform_anomalies',
            labels.POST_TRANSFORM_OUTPUT_STATS_PATH_LABEL:
                '/post_transform_stats',
            labels.POST_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL:
                '/post_transform_schema',
        }, result)
Exemplo n.º 17
0
    def __init__(self,
                 examples: types.Channel = None,
                 schema: Optional[types.Channel] = None,
                 stats_options: Optional[tfdv.StatsOptions] = None,
                 exclude_splits: Optional[List[Text]] = None,
                 output: Optional[types.Channel] = None,
                 input_data: Optional[types.Channel] = None,
                 instance_name: Optional[Text] = None):
        """Construct a StatisticsGen component.

    Args:
      examples: A Channel of `ExamplesPath` type, likely generated by the
        [ExampleGen component](https://www.tensorflow.org/tfx/guide/examplegen).
        This needs to contain two splits labeled `train` and `eval`. _required_
      schema: A `Schema` channel to use for automatically configuring the value
        of stats options passed to TFDV.
      stats_options: The StatsOptions instance to configure optional TFDV
        behavior. When stats_options.schema is set, it will be used instead of
        the `schema` channel input. Due to the requirement that stats_options be
        serialized, the slicer functions and custom stats generators are dropped
        and are therefore not usable.
      exclude_splits: Names of splits where statistics and sample should not
        be generated. Default behavior (when exclude_splits is set to None)
        is excluding no splits.
      output: `ExampleStatisticsPath` channel for statistics of each split
        provided in the input examples.
      input_data: Backwards compatibility alias for the `examples` argument.
      instance_name: Optional name assigned to this specific instance of
        StatisticsGen.  Required only if multiple StatisticsGen components are
        declared in the same pipeline.
    """
        if input_data:
            logging.warning(
                'The "input_data" argument to the StatisticsGen component has '
                'been renamed to "examples" and is deprecated. Please update your '
                'usage as support for this argument will be removed soon.')
            examples = input_data
        if exclude_splits is None:
            exclude_splits = []
            logging.info(
                'Excluding no splits because exclude_splits is not set.')
        if not output:
            output = channel_utils.as_channel(
                [standard_artifacts.ExampleStatistics()])
        # TODO(b/150802589): Move jsonable interface to tfx_bsl and use json_utils.
        stats_options_json = stats_options.to_json() if stats_options else None
        spec = StatisticsGenSpec(
            examples=examples,
            schema=schema,
            stats_options_json=stats_options_json,
            exclude_splits=json_utils.dumps(exclude_splits),
            statistics=output)
        super(StatisticsGen, self).__init__(spec=spec,
                                            instance_name=instance_name)
Exemplo n.º 18
0
 def testEnableCache(self):
   statistics_artifact = standard_artifacts.ExampleStatistics()
   statistics_artifact.split_names = artifact_utils.encode_split_names(
       ['train'])
   schema_gen_1 = component.SchemaGen(
       statistics=channel_utils.as_channel([statistics_artifact]))
   schema_gen_2 = component.SchemaGen(
       statistics=channel_utils.as_channel([statistics_artifact]),
       enable_cache=True)
   self.assertEqual(None, schema_gen_1.enable_cache)
   self.assertEqual(True, schema_gen_2.enable_cache)
Exemplo n.º 19
0
    def __init__(self,
                 examples: types.Channel = None,
                 schema: Optional[types.Channel] = None,
                 stats_options: Optional[tfdv.StatsOptions] = None,
                 output: Optional[types.Channel] = None,
                 input_data: Optional[types.Channel] = None,
                 instance_name: Optional[Text] = None,
                 enable_cache: Optional[bool] = None):
        """Construct a StatisticsGen component.

    Args:
      examples: A Channel of `ExamplesPath` type, likely generated by the
        [ExampleGen component](https://www.tensorflow.org/tfx/guide/examplegen).
        This needs to contain two splits labeled `train` and `eval`. _required_
      schema: A `Schema` channel to use for automatically configuring the value
        of stats options passed to TFDV.
      stats_options: The StatsOptions instance to configure optional TFDV
        behavior. When stats_options.schema is set, it will be used instead of
        the `schema` channel input. Due to the requirement that stats_options be
        serialized, the slicer functions and custom stats generators are dropped
        and are therefore not usable.
      output: `ExampleStatisticsPath` channel for statistics of each split
        provided in the input examples.
      input_data: Backwards compatibility alias for the `examples` argument.
      instance_name: Optional name assigned to this specific instance of
        StatisticsGen.  Required only if multiple StatisticsGen components are
        declared in the same pipeline.
      enable_cache: Optional boolean to indicate if cache is enabled for the
        StatisticsGen component. If not specified, defaults to the value
        specified for pipeline's enable_cache parameter.
    """
        if input_data:
            absl.logging.warning(
                'The "input_data" argument to the StatisticsGen component has '
                'been renamed to "examples" and is deprecated. Please update your '
                'usage as support for this argument will be removed soon.')
            examples = input_data
        if not output:
            statistics_artifact = standard_artifacts.ExampleStatistics()
            statistics_artifact.split_names = artifact_utils.get_single_instance(
                list(examples.get())).split_names
            output = types.Channel(type=standard_artifacts.ExampleStatistics,
                                   artifacts=[statistics_artifact])
        # TODO(b/150802589): Move jsonable interface to tfx_bsl and use json_utils.
        stats_options_json = stats_options.to_json() if stats_options else None
        spec = StatisticsGenSpec(examples=examples,
                                 schema=schema,
                                 stats_options_json=stats_options_json,
                                 statistics=output)
        super(StatisticsGen, self).__init__(spec=spec,
                                            instance_name=instance_name,
                                            enable_cache=enable_cache)
Exemplo n.º 20
0
 def testConstruct(self):
   statistics_artifact = standard_artifacts.ExampleStatistics()
   statistics_artifact.split_names = artifact_utils.encode_split_names(
       ['train', 'eval'])
   exclude_splits = ['eval']
   example_validator = component.ExampleValidator(
       statistics=channel_utils.as_channel([statistics_artifact]),
       schema=channel_utils.as_channel([standard_artifacts.Schema()]),
       exclude_splits=exclude_splits)
   self.assertEqual(standard_artifacts.ExampleAnomalies.TYPE_NAME,
                    example_validator.outputs['anomalies'].type_name)
   self.assertEqual(example_validator.spec.exec_properties['exclude_splits'],
                    '["eval"]')
Exemplo n.º 21
0
 def testConstructWithParameter(self):
   statistics_artifact = standard_artifacts.ExampleStatistics()
   statistics_artifact.split_names = artifact_utils.encode_split_names(
       ['train'])
   infer_shape = data_types.RuntimeParameter(name='infer-shape', ptype=bool)
   schema_gen = component.SchemaGen(
       statistics=channel_utils.as_channel([statistics_artifact]),
       infer_feature_shape=infer_shape)
   self.assertEqual(standard_artifacts.Schema.TYPE_NAME,
                    schema_gen.outputs['schema'].type_name)
   self.assertJsonEqual(
       str(schema_gen.spec.exec_properties['infer_feature_shape']),
       str(infer_shape))
Exemplo n.º 22
0
  def setUp(self):
    super(ComponentTest, self).setUp()

    examples_artifact = standard_artifacts.Examples()
    examples_artifact.split_names = artifact_utils.encode_split_names(
        ['train', 'eval'])
    statistics_artifact = standard_artifacts.ExampleStatistics()
    statistics_artifact.split_names = artifact_utils.encode_split_names(
        ['train'])

    self.examples = channel_utils.as_channel([examples_artifact])
    self.statistics = channel_utils.as_channel([statistics_artifact])
    self.custom_config = {'some': 'thing', 'some other': 1, 'thing': 2}
Exemplo n.º 23
0
 def testConstruct(self):
     statistics_artifact = standard_artifacts.ExampleStatistics()
     statistics_artifact.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     exclude_splits = ['eval']
     schema_gen = component.SchemaGen(statistics=channel_utils.as_channel(
         [statistics_artifact]),
                                      exclude_splits=exclude_splits)
     self.assertEqual(standard_artifacts.Schema.TYPE_NAME,
                      schema_gen.outputs['schema'].type_name)
     self.assertTrue(schema_gen.spec.exec_properties['infer_feature_shape'])
     self.assertEqual(schema_gen.spec.exec_properties['exclude_splits'],
                      '["eval"]')
Exemplo n.º 24
0
 def testEnableCache(self):
     statistics_artifact = standard_artifacts.ExampleStatistics()
     statistics_artifact.split_names = artifact_utils.encode_split_names(
         ['eval'])
     example_validator_1 = component.ExampleValidator(
         statistics=channel_utils.as_channel([statistics_artifact]),
         schema=channel_utils.as_channel([standard_artifacts.Schema()]),
     )
     self.assertEqual(None, example_validator_1.enable_cache)
     example_validator_2 = component.ExampleValidator(
         statistics=channel_utils.as_channel([statistics_artifact]),
         schema=channel_utils.as_channel([standard_artifacts.Schema()]),
         enable_cache=True)
     self.assertEqual(True, example_validator_2.enable_cache)
Exemplo n.º 25
0
 def testConstruct(self):
     statistics_artifact = standard_artifacts.ExampleStatistics()
     statistics_artifact.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     exclude_splits = ['eval']
     schema_gen = component.SchemaGen(statistics=channel_utils.as_channel(
         [statistics_artifact]),
                                      exclude_splits=exclude_splits)
     self.assertEqual(
         standard_artifacts.Schema.TYPE_NAME,
         schema_gen.outputs[standard_component_specs.SCHEMA_KEY].type_name)
     self.assertTrue(schema_gen.spec.exec_properties[
         standard_component_specs.INFER_FEATURE_SHAPE_KEY])
     self.assertEqual(
         schema_gen.spec.exec_properties[
             standard_component_specs.EXCLUDE_SPLITS_KEY], '["eval"]')
Exemplo n.º 26
0
    def testDo(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        fileio.makedirs(output_data_dir)

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval', 'test'])

        input_dict = {
            standard_component_specs.EXAMPLES_KEY: [examples],
        }

        exec_properties = {
            # List needs to be serialized before being passed into Do function.
            standard_component_specs.EXCLUDE_SPLITS_KEY:
            json_utils.dumps(['test']),
        }

        # Create output dict.
        stats = standard_artifacts.ExampleStatistics()
        stats.uri = output_data_dir
        output_dict = {
            standard_component_specs.STATISTICS_KEY: [stats],
        }

        # Run executor.
        stats_gen_executor = executor.Executor()
        stats_gen_executor.Do(input_dict, output_dict, exec_properties)

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         stats.split_names)

        # Check statistics_gen outputs.
        self._validate_stats_output(
            os.path.join(stats.uri, 'train', 'stats_tfrecord'))
        self._validate_stats_output(
            os.path.join(stats.uri, 'eval', 'stats_tfrecord'))

        # Assert 'test' split is excluded.
        self.assertFalse(
            fileio.exists(os.path.join(stats.uri, 'test', 'stats_tfrecord')))
Exemplo n.º 27
0
    def testDoWithSchemaAndStatsOptions(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        tf.io.gfile.makedirs(output_data_dir)

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])

        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(source_data_dir, 'schema_gen')

        input_dict = {
            executor.EXAMPLES_KEY: [examples],
            executor.SCHEMA_KEY: [schema]
        }

        exec_properties = {
            executor.STATS_OPTIONS_JSON_KEY:
            tfdv.StatsOptions(label_feature='company').to_json(),
        }

        # Create output dict.
        stats = standard_artifacts.ExampleStatistics()
        stats.uri = output_data_dir
        stats.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        output_dict = {
            executor.STATISTICS_KEY: [stats],
        }

        # Run executor.
        stats_gen_executor = executor.Executor()
        stats_gen_executor.Do(input_dict,
                              output_dict,
                              exec_properties=exec_properties)

        # Check statistics_gen outputs.
        self._validate_stats_output(
            os.path.join(stats.uri, 'train', 'stats_tfrecord'))
        self._validate_stats_output(
            os.path.join(stats.uri, 'eval', 'stats_tfrecord'))
Exemplo n.º 28
0
    def testDoWithTwoSchemas(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        tf.io.gfile.makedirs(output_data_dir)

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])

        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(source_data_dir, 'schema_gen')

        input_dict = {
            executor.EXAMPLES_KEY: [examples],
            executor.SCHEMA_KEY: [schema]
        }

        exec_properties = {
            executor.STATS_OPTIONS_JSON_KEY:
            tfdv.StatsOptions(label_feature='company',
                              schema=schema_pb2.Schema()).to_json(),
            executor.EXCLUDE_SPLITS_KEY:
            json_utils.dumps([])
        }

        # Create output dict.
        stats = standard_artifacts.ExampleStatistics()
        stats.uri = output_data_dir
        stats.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        output_dict = {
            executor.STATISTICS_KEY: [stats],
        }

        # Run executor.
        stats_gen_executor = executor.Executor()
        with self.assertRaises(ValueError):
            stats_gen_executor.Do(input_dict, output_dict, exec_properties)
Exemplo n.º 29
0
    def build(self, context: Context) -> BaseNode:
        from tfx.components import StatisticsGen

        statistics_artifact = standard_artifacts.ExampleStatistics()
        statistics_artifact.split_names = artifact_utils.encode_split_names(
            splits_or_example_defaults(self._config.params.split_names))

        output = Channel(type=standard_artifacts.ExampleStatistics,
                         artifacts=[statistics_artifact])

        examples = context.get(self._config.inputs.examples)
        component = StatisticsGen(
            examples=examples,
            stats_options=None,
            output=output,
            instance_name=context.abs_current_url_friendly)

        put_outputs_to_context(context, self._config.outputs, component)
        return component
Exemplo n.º 30
0
  def setUp(self):
    super(ExecutorTest, self).setUp()

    source_data_dir = os.path.dirname(os.path.dirname(__file__))
    input_data_dir = os.path.join(source_data_dir, 'testdata')

    statistics = standard_artifacts.ExampleStatistics()
    statistics.uri = os.path.join(input_data_dir,
                                  'StatisticsGen.train_mockdata_1',
                                  'statistics', '5')
    statistics.split_names = artifact_utils.encode_split_names(
        ['train', 'eval'])
    transformed_examples = standard_artifacts.Examples()
    transformed_examples.uri = os.path.join(input_data_dir,
                                            'Transform.train_mockdata_1',
                                            'transformed_examples', '10')
    transformed_examples.split_names = artifact_utils.encode_split_names(
        ['train', 'eval'])
    self._input_dict = {
        executor.EXAMPLES_KEY: [transformed_examples],
        executor.STATISTICS_KEY: [statistics],
    }

    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR',
                       tempfile.mkdtemp(dir=flags.FLAGS.test_tmpdir)),
        self._testMethodName)
    self._metafeatures = artifacts.MetaFeatures()
    self._metafeatures.uri = output_data_dir
    self._output_dict = {
        executor.METAFEATURES_KEY: [self._metafeatures],
    }

    self._exec_properties = {
        'custom_config': {
            'problem_statement_path': '/some/fake/path'
        }
    }