示例#1
0
 def testConstructSubclassQueryBased(self):
     example_gen = TestQueryBasedExampleGenComponent(
         input_config=example_gen_pb2.Input(splits=[
             example_gen_pb2.Input.Split(name='single', pattern='query'),
         ]))
     self.assertEqual({}, example_gen.inputs.get_all())
     self.assertEqual(base_driver.BaseDriver, example_gen.driver_class)
     self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                      example_gen.outputs['examples'].type_name)
     self.assertIsNone(example_gen.exec_properties.get('custom_config'))
     artifact_collection = example_gen.outputs['examples'].get()
     self.assertEqual(1, len(artifact_collection))
     self.assertEqual(['train', 'eval'],
                      artifact_utils.decode_split_names(
                          artifact_collection[0].split_names))
  def testDoInputSplit(self):
    # Create exec proterties for input split.
    self._exec_properties = {
        utils.INPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='train', pattern='train/*'),
                    example_gen_pb2.Input.Split(name='eval', pattern='eval/*')
                ])),
        utils.OUTPUT_CONFIG_KEY:
            proto_utils.proto_to_json(example_gen_pb2.Output())
    }

    self._testDo()
    def testDoInputSplit(self):
        # Create exec proterties for input split.
        self._exec_properties = {
            utils.INPUT_CONFIG_KEY:
            json_format.MessageToJson(example_gen_pb2.Input(splits=[
                example_gen_pb2.Input.Split(name='train', pattern='train/*'),
                example_gen_pb2.Input.Split(name='eval', pattern='eval/*')
            ]),
                                      preserving_proto_field_name=True),
            utils.OUTPUT_CONFIG_KEY:
            json_format.MessageToJson(example_gen_pb2.Output(),
                                      preserving_proto_field_name=True)
        }

        self._testDo()
示例#4
0
 def test_construct_with_input_config(self):
     input_base = types.TfxArtifact(type_name='ExternalPath')
     example_gen = component._FileBasedExampleGen(
         input_base=channel.as_channel([input_base]),
         input_config=example_gen_pb2.Input(splits=[
             example_gen_pb2.Input.Split(name='train', pattern='train/*'),
             example_gen_pb2.Input.Split(name='eval', pattern='eval/*'),
             example_gen_pb2.Input.Split(name='test', pattern='test/*')
         ]))
     self.assertEqual('ExamplesPath',
                      example_gen.outputs.examples.type_name)
     artifact_collection = example_gen.outputs.examples.get()
     self.assertEqual('train', artifact_collection[0].split)
     self.assertEqual('eval', artifact_collection[1].split)
     self.assertEqual('test', artifact_collection[2].split)
示例#5
0
 def testConstructWithInputConfig(self):
     input_base = standard_artifacts.ExternalArtifact()
     example_gen = TestFileBasedExampleGenComponent(
         input=channel_utils.as_channel([input_base]),
         input_config=example_gen_pb2.Input(splits=[
             example_gen_pb2.Input.Split(name='train', pattern='train/*'),
             example_gen_pb2.Input.Split(name='eval', pattern='eval/*'),
             example_gen_pb2.Input.Split(name='test', pattern='test/*')
         ]))
     self.assertEqual('ExamplesPath',
                      example_gen.outputs['examples'].type_name)
     artifact_collection = example_gen.outputs['examples'].get()
     self.assertEqual('train', artifact_collection[0].split)
     self.assertEqual('eval', artifact_collection[1].split)
     self.assertEqual('test', artifact_collection[2].split)
示例#6
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        examples = standard_artifacts.Examples()
        examples.uri = output_data_dir
        output_dict = {utils.EXAMPLES_KEY: [examples]}

        # Create exec proterties.
        exec_properties = {
            utils.INPUT_BASE_KEY:
            self._input_data_dir,
            utils.INPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='avro',
                                                pattern='avro/*.avro'),
                ])),
            utils.OUTPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ])))
        }

        # Run executor.
        avro_example_gen = avro_executor.Executor()
        avro_example_gen.Do({}, output_dict, exec_properties)

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         examples.split_names)

        # Check Avro example gen outputs.
        train_output_file = os.path.join(examples.uri, 'train',
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(examples.uri, 'eval',
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(fileio.exists(train_output_file))
        self.assertTrue(fileio.exists(eval_output_file))
        self.assertGreater(
            fileio.open(train_output_file).size(),
            fileio.open(eval_output_file).size())
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, serving_model_dir: Text) -> pipeline.Pipeline:
    examples = external_input(data_root)
    input_split = example_gen_pb2.Input(splits=[
        example_gen_pb2.Input.Split(name='train', pattern='iris_training.csv'),
        example_gen_pb2.Input.Split(name='eval', pattern='iris_test.csv')
    ])
    example_gen = CsvExampleGen(input_base=examples, input_config=input_split)
    statistics_gen = StatisticsGen(input_data=example_gen.outputs.examples)
    infer_schema = SchemaGen(stats=statistics_gen.outputs.output)
    validate_stats = ExampleValidator(
        stats=statistics_gen.outputs.output,
        schema=infer_schema.outputs.output)
    transform = Transform(
        input_data=example_gen.outputs.examples,
        schema=infer_schema.outputs.output,
        module_file=module_file)
    trainer = Trainer(
        module_file=module_file,
        examples=transform.outputs.transformed_examples,
        schema=infer_schema.outputs.output,
        transform_output=transform.outputs.transform_output,
        train_args=trainer_pb2.TrainArgs(num_steps=1000),
        eval_args=trainer_pb2.EvalArgs(num_steps=500))
    model_analyzer = Evaluator(
        examples=example_gen.outputs.examples,
        model_exports=trainer.outputs.output,
        feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[
            evaluator_pb2.SingleSlicingSpec()
        ]))
    model_validator = ModelValidator(
        examples=example_gen.outputs.examples, model=trainer.outputs.output)
    pusher = Pusher(
        model_export=trainer.outputs.output,
        model_blessing=model_validator.outputs.blessing,
        push_destination=pusher_pb2.PushDestination(
            filesystem=pusher_pb2.PushDestination.Filesystem(
                base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen, statistics_gen, infer_schema, validate_stats, transform,
            trainer, model_analyzer, model_validator, pusher
        ],
        log_root='/var/tmp/tfx/logs',
    )
示例#8
0
    def testComponentspecBasic(self):
        proto = example_gen_pb2.Input()
        proto.splits.extend([
            example_gen_pb2.Input.Split(name='name1', pattern='pattern1'),
            example_gen_pb2.Input.Split(name='name2', pattern='pattern2'),
            example_gen_pb2.Input.Split(name='name3', pattern='pattern3'),
        ])
        input_channel = Channel(type_name='InputType')
        output_channel = Channel(type_name='OutputType')
        spec = _BasicComponentSpec(folds=10,
                                   proto=proto,
                                   input=input_channel,
                                   output=output_channel)
        # Verify proto property.
        self.assertIsInstance(spec.exec_properties['proto'], str)
        decoded_proto = json.loads(spec.exec_properties['proto'])
        self.assertCountEqual(['splits'], decoded_proto.keys())
        self.assertEqual(3, len(decoded_proto['splits']))
        self.assertCountEqual(['name1', 'name2', 'name3'],
                              list(s['name'] for s in decoded_proto['splits']))
        self.assertCountEqual(['pattern1', 'pattern2', 'pattern3'],
                              list(s['pattern']
                                   for s in decoded_proto['splits']))

        # Verify other properties.
        self.assertEqual(10, spec.exec_properties['folds'])
        self.assertIs(spec.inputs.input, input_channel)
        self.assertIs(spec.outputs.output, output_channel)

        with self.assertRaisesRegexp(
                TypeError,
                "Expected type <(class|type) 'int'> for parameter u?'folds' but got "
                'string.'):
            spec = _BasicComponentSpec(folds='string',
                                       input=input_channel,
                                       output=output_channel)

        with self.assertRaisesRegexp(TypeError,
                                     'Expected InputType but found WrongType'):
            spec = _BasicComponentSpec(folds=10,
                                       input=Channel(type_name='WrongType'),
                                       output=output_channel)

        with self.assertRaisesRegexp(
                TypeError, 'Expected OutputType but found WrongType'):
            spec = _BasicComponentSpec(folds=10,
                                       input=input_channel,
                                       output=Channel(type_name='WrongType'))
示例#9
0
  def testDo(self):
    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create output dict.
    train_examples = standard_artifacts.Examples(split='train')
    train_examples.uri = os.path.join(output_data_dir, 'train')
    eval_examples = standard_artifacts.Examples(split='eval')
    eval_examples.uri = os.path.join(output_data_dir, 'eval')
    output_dict = {'examples': [train_examples, eval_examples]}

    # Create exe properties.
    exec_properties = {
        'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, f, s FROM `fake`'),
                ]),
                preserving_proto_field_name=True),
        'custom_config':
            json_format.MessageToJson(example_gen_pb2.CustomConfig()),
        'output_config':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(
                            name='train', hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(
                            name='eval', hash_buckets=1)
                    ]))),
    }

    # Run executor.
    presto_example_gen = executor.Executor()
    presto_example_gen.Do({}, output_dict, exec_properties)

    # Check Presto example gen outputs.
    train_output_file = os.path.join(train_examples.uri,
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(eval_examples.uri,
                                    'data_tfrecord-00000-of-00001.gz')
    self.assertTrue(tf.io.gfile.exists(train_output_file))
    self.assertTrue(tf.io.gfile.exists(eval_output_file))
    self.assertGreater(
        tf.io.gfile.GFile(train_output_file).size(),
        tf.io.gfile.GFile(eval_output_file).size())
示例#10
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        examples = standard_artifacts.Examples()
        examples.uri = output_data_dir
        output_dict = {utils.EXAMPLES_KEY: [examples]}

        # Create exec proterties.
        exec_properties = {
            utils.INPUT_BASE_KEY:
            self._input_data_dir,
            utils.INPUT_CONFIG_KEY:
            json_format.MessageToJson(example_gen_pb2.Input(splits=[
                example_gen_pb2.Input.Split(name='parquet',
                                            pattern='parquet/*'),
            ]),
                                      preserving_proto_field_name=True),
            utils.OUTPUT_CONFIG_KEY:
            json_format.MessageToJson(example_gen_pb2.Output(
                split_config=example_gen_pb2.SplitConfig(splits=[
                    example_gen_pb2.SplitConfig.Split(name='train',
                                                      hash_buckets=2),
                    example_gen_pb2.SplitConfig.Split(name='eval',
                                                      hash_buckets=1)
                ])),
                                      preserving_proto_field_name=True)
        }

        # Run executor.
        parquet_example_gen = parquet_executor.Executor()
        parquet_example_gen.Do({}, output_dict, exec_properties)

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         examples.split_names)

        # Check Parquet example gen outputs.
        train_output_file = os.path.join(examples.uri, 'train',
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(examples.uri, 'eval',
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.io.gfile.exists(train_output_file))
        self.assertTrue(tf.io.gfile.exists(eval_output_file))
        self.assertGreater(
            tf.io.gfile.GFile(train_output_file).size(),
            tf.io.gfile.GFile(eval_output_file).size())
示例#11
0
    def testDo(self, mock_client):
        # Mock query result schema for _BigQueryConverter.
        mock_client.return_value.query.return_value.result.return_value.schema = self._schema

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        train_examples = standard_artifacts.Examples(split='train')
        train_examples.uri = os.path.join(output_data_dir, 'train')
        eval_examples = standard_artifacts.Examples(split='eval')
        eval_examples.uri = os.path.join(output_data_dir, 'eval')
        output_dict = {'examples': [train_examples, eval_examples]}

        # Create exe properties.
        exec_properties = {
            'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, f, s FROM `fake`'),
                ])),
            'output_config':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ])))
        }

        # Run executor.
        big_query_example_gen = executor.Executor()
        big_query_example_gen.Do({}, output_dict, exec_properties)

        # Check BigQuery example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())
示例#12
0
  def testComponentspecBasic(self):
    proto = example_gen_pb2.Input()
    proto.splits.extend([
        example_gen_pb2.Input.Split(name='name1', pattern='pattern1'),
        example_gen_pb2.Input.Split(name='name2', pattern='pattern2'),
        example_gen_pb2.Input.Split(name='name3', pattern='pattern3'),
    ])
    input_channel = Channel(type=_InputArtifact)
    output_channel = Channel(type=_OutputArtifact)
    spec = _BasicComponentSpec(
        folds=10, proto=proto, input=input_channel, output=output_channel)
    # Verify proto property.
    self.assertIsInstance(spec.exec_properties['proto'], str)
    decoded_proto = json.loads(spec.exec_properties['proto'])
    self.assertCountEqual(['splits'], decoded_proto.keys())
    self.assertEqual(3, len(decoded_proto['splits']))
    self.assertCountEqual(['name1', 'name2', 'name3'],
                          list(s['name'] for s in decoded_proto['splits']))
    self.assertCountEqual(['pattern1', 'pattern2', 'pattern3'],
                          list(s['pattern'] for s in decoded_proto['splits']))

    # Verify other properties.
    self.assertEqual(10, spec.exec_properties['folds'])
    self.assertIs(spec.inputs['input'], input_channel)
    self.assertIs(spec.outputs['output'], output_channel)

    # Verify compatibility aliasing behavior.
    self.assertIs(spec.inputs['future_input_name'], spec.inputs['input'])
    self.assertIs(spec.outputs['future_output_name'], spec.outputs['output'])

    with self.assertRaisesRegexp(
        TypeError,
        "Expected type <(class|type) 'int'> for parameter u?'folds' but got "
        'string.'):
      spec = _BasicComponentSpec(
          folds='string', input=input_channel, output=output_channel)

    with self.assertRaisesRegexp(
        TypeError,
        '.*should be a Channel of .*InputArtifact.*got (.|\\s)*Examples.*'):
      spec = _BasicComponentSpec(
          folds=10, input=Channel(type=Examples), output=output_channel)

    with self.assertRaisesRegexp(
        TypeError,
        '.*should be a Channel of .*OutputArtifact.*got (.|\\s)*Examples.*'):
      spec = _BasicComponentSpec(
          folds=10, input=input_channel, output=Channel(type=Examples))
示例#13
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        examples = standard_artifacts.Examples()
        examples.uri = output_data_dir
        output_dict = {'examples': [examples]}

        # Create exe properties.
        exec_properties = {
            'input_config':
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, f, s FROM `fake`'),
                ])),
            'custom_config':
            proto_utils.proto_to_json(example_gen_pb2.CustomConfig()),
            'output_config':
            proto_utils.proto_to_json(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ]))),
        }

        # Run executor.
        presto_example_gen = executor.Executor()
        presto_example_gen.Do({}, output_dict, exec_properties)

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         examples.split_names)

        # Check Presto example gen outputs.
        train_output_file = os.path.join(examples.uri, 'Split-train',
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(examples.uri, 'Split-eval',
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(fileio.exists(train_output_file))
        self.assertTrue(fileio.exists(eval_output_file))
        self.assertGreater(
            fileio.open(train_output_file).size(),
            fileio.open(eval_output_file).size())
示例#14
0
 def testConstructWithInputConfig(self):
     input_base = standard_artifacts.ExternalArtifact()
     example_gen = TestFileBasedExampleGenComponent(
         input=channel_utils.as_channel([input_base]),
         input_config=example_gen_pb2.Input(splits=[
             example_gen_pb2.Input.Split(name='train', pattern='train/*'),
             example_gen_pb2.Input.Split(name='eval', pattern='eval/*'),
             example_gen_pb2.Input.Split(name='test', pattern='test/*')
         ]))
     self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                      example_gen.outputs['examples'].type_name)
     artifact_collection = example_gen.outputs['examples'].get()
     self.assertEqual(1, len(artifact_collection))
     self.assertEqual(['train', 'eval', 'test'],
                      artifact_utils.decode_split_names(
                          artifact_collection[0].split_names))
示例#15
0
 def testConstructWithInputConfig(self):
   big_query_elwc_example_gen = component.BigQueryElwcExampleGen(
       elwc_config=example_gen_pb2.ElwcConfig(
           context_feature_fields=['query_id', 'query_content']),
       input_config=example_gen_pb2.Input(splits=[
           example_gen_pb2.Input.Split(name='train', pattern='query1'),
           example_gen_pb2.Input.Split(name='eval', pattern='query2'),
           example_gen_pb2.Input.Split(name='test', pattern='query3')
       ]))
   self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                    big_query_elwc_example_gen.outputs['examples'].type_name)
   artifact_collection = big_query_elwc_example_gen.outputs['examples'].get()
   self.assertEqual(1, len(artifact_collection))
   self.assertEqual(['train', 'eval', 'test'],
                    artifact_utils.decode_split_names(
                        artifact_collection[0].split_names))
示例#16
0
  def _make_example_gen(self) -> base_component.BaseComponent:
    """Returns a TFX ExampleGen which produces the desired split."""

    splits = []
    for name, value in self._dataset_builder.info.splits.items():
      # Assume there is only one file per split.
      # Filename will be like `'fashion_mnist-test.tfrecord-00000-of-00001'`.
      assert len(value.filenames) == 1
      pattern = value.filenames[0]
      splits.append(example_gen_pb2.Input.Split(name=name, pattern=pattern))

    logging.info('Splits: %s', splits)
    input_config = example_gen_pb2.Input(splits=splits)
    return tfx.ImportExampleGen(
        input=external_input(self._dataset_builder.data_dir),
        input_config=input_config)
示例#17
0
    def resolve_exec_properties(
        self,
        exec_properties: Dict[Text, Any],
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> Dict[Text, Any]:
        """Overrides BaseDriver.resolve_exec_properties()."""
        del pipeline_info, component_info

        input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            input_config)

        input_base = exec_properties[standard_component_specs.INPUT_BASE_KEY]
        logging.debug('Processing input %s.', input_base)

        range_config = None
        range_config_entry = exec_properties.get(
            standard_component_specs.RANGE_CONFIG_KEY)
        if range_config_entry:
            range_config = range_config_pb2.RangeConfig()
            proto_utils.json_to_proto(range_config_entry, range_config)

            if range_config.HasField('static_range'):
                # For ExampleGen, StaticRange must specify an exact span to look for,
                # since only one span is processed at a time.
                start_span_number = range_config.static_range.start_span_number
                end_span_number = range_config.static_range.end_span_number
                if start_span_number != end_span_number:
                    raise ValueError(
                        'Start and end span numbers for RangeConfig.static_range must '
                        'be equal: (%s, %s)' %
                        (start_span_number, end_span_number))

        # Note that this function updates the input_config.splits.pattern.
        fingerprint, span, version = utils.calculate_splits_fingerprint_span_and_version(
            input_base, input_config.splits, range_config)

        exec_properties[standard_component_specs.
                        INPUT_CONFIG_KEY] = proto_utils.proto_to_json(
                            input_config)
        exec_properties[utils.SPAN_PROPERTY_NAME] = span
        exec_properties[utils.VERSION_PROPERTY_NAME] = version
        exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint

        return exec_properties
示例#18
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        train_examples = types.TfxArtifact(type_name='ExamplesPath',
                                           split='train')
        train_examples.uri = os.path.join(output_data_dir, 'train')
        eval_examples = types.TfxArtifact(type_name='ExamplesPath',
                                          split='eval')
        eval_examples.uri = os.path.join(output_data_dir, 'eval')
        output_dict = {'examples': [train_examples, eval_examples]}

        # Create exec proterties.
        exec_properties = {
            'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='parquet',
                                                pattern='parquet/*'),
                ])),
            'output_config':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ])))
        }

        # Run executor.
        parquet_example_gen = parquet_executor.Executor()
        parquet_example_gen.Do(self._input_dict, output_dict, exec_properties)

        # Check Parquet example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())
示例#19
0
 def testConstructSubclassQueryBased(self):
   example_gen = TestQueryBasedExampleGenComponent(
       input_config=example_gen_pb2.Input(splits=[
           example_gen_pb2.Input.Split(name='single', pattern='query'),
       ]))
   self.assertEqual({}, example_gen.inputs.get_all())
   self.assertEqual(driver.QueryBasedDriver, example_gen.driver_class)
   self.assertEqual(
       standard_artifacts.Examples.TYPE_NAME,
       example_gen.outputs[standard_component_specs.EXAMPLES_KEY].type_name)
   self.assertEqual(
       example_gen.exec_properties[
           standard_component_specs.OUTPUT_DATA_FORMAT_KEY],
       example_gen_pb2.FORMAT_TF_EXAMPLE)
   self.assertIsNone(
       example_gen.exec_properties.get(
           standard_component_specs.CUSTOM_CONFIG_KEY))
示例#20
0
  def setUp(self):
    super(ExecutorTest, self).setUp()
    self._input_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata',
        'external')

    # Create values in exec_properties
    self._input_config = proto_utils.proto_to_json(
        example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='tfrecord', pattern='tfrecord/*'),
        ]))
    self._output_config = proto_utils.proto_to_json(
        example_gen_pb2.Output(
            split_config=example_gen_pb2.SplitConfig(splits=[
                example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2),
                example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
            ])))
示例#21
0
    def testExecutionParameterTypeCheck(self):
        int_parameter = ExecutionParameter(type=int)
        int_parameter.type_check('int_parameter', 8)
        with self.assertRaisesRegex(
                TypeError, "Expected type <(class|type) 'int'>"
                " for parameter u?'int_parameter'"):
            int_parameter.type_check('int_parameter', 'string')

        list_parameter = ExecutionParameter(type=List[int])
        list_parameter.type_check('list_parameter', [])
        list_parameter.type_check('list_parameter', [42])
        with self.assertRaisesRegex(TypeError,
                                    'Expecting a list for parameter'):
            list_parameter.type_check('list_parameter', 42)

        with self.assertRaisesRegex(
                TypeError, "Expecting item type <(class|type) "
                "'int'> for parameter u?'list_parameter'"):
            list_parameter.type_check('list_parameter', [42, 'wrong item'])

        dict_parameter = ExecutionParameter(type=Dict[str, int])
        dict_parameter.type_check('dict_parameter', {})
        dict_parameter.type_check('dict_parameter', {'key1': 1, 'key2': 2})
        with self.assertRaisesRegex(TypeError,
                                    'Expecting a dict for parameter'):
            dict_parameter.type_check('dict_parameter', 'simple string')

        with self.assertRaisesRegex(
                TypeError, "Expecting value type "
                "<(class|type) 'int'>"):
            dict_parameter.type_check('dict_parameter', {'key1': '1'})

        proto_parameter = ExecutionParameter(type=example_gen_pb2.Input)
        proto_parameter.type_check('proto_parameter', example_gen_pb2.Input())
        proto_parameter.type_check('proto_parameter',
                                   {'splits': [{
                                       'name': 'hello'
                                   }]})
        proto_parameter.type_check('proto_parameter', {'wrong_field': 42})
        with self.assertRaisesRegex(
                TypeError,
                "Expected type <class 'tfx.proto.example_gen_pb2.Input'>"):
            proto_parameter.type_check('proto_parameter', 42)
        with self.assertRaises(json_format.ParseError):
            proto_parameter.type_check('proto_parameter', {'splits': 42})
示例#22
0
    def resolve_exec_properties(
        self,
        exec_properties: Dict[Text, Any],
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> Dict[Text, Any]:
        """Overrides BaseDriver.resolve_exec_properties()."""
        del pipeline_info, component_info

        input_config = example_gen_pb2.Input()
        proto_utils.json_to_proto(
            exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
            input_config)

        input_base = exec_properties.get(
            standard_component_specs.INPUT_BASE_KEY)
        logging.debug('Processing input %s.', input_base)

        range_config = None
        range_config_entry = exec_properties.get(
            standard_component_specs.RANGE_CONFIG_KEY)
        if range_config_entry:
            range_config = range_config_pb2.RangeConfig()
            proto_utils.json_to_proto(range_config_entry, range_config)

        processor = self.get_input_processor(splits=input_config.splits,
                                             range_config=range_config,
                                             input_base_uri=input_base)

        span, version = processor.resolve_span_and_version()
        fingerprint = processor.get_input_fingerprint(span, version)

        # Updates the input_config.splits.pattern.
        for split in input_config.splits:
            split.pattern = processor.get_pattern_for_span_version(
                split.pattern, span, version)

        exec_properties[standard_component_specs.
                        INPUT_CONFIG_KEY] = proto_utils.proto_to_json(
                            input_config)
        exec_properties[utils.SPAN_PROPERTY_NAME] = span
        exec_properties[utils.VERSION_PROPERTY_NAME] = version
        exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint

        return exec_properties
示例#23
0
 def testConstructWithInputConfig(self):
   presto_example_gen = component.PrestoExampleGen(
       self.conn_config,
       input_config=example_gen_pb2.Input(splits=[
           example_gen_pb2.Input.Split(name='train', pattern='query1'),
           example_gen_pb2.Input.Split(name='eval', pattern='query2'),
           example_gen_pb2.Input.Split(name='test', pattern='query3')
       ]))
   self.assertEqual(
       self.conn_config,
       self._extract_conn_config(
           presto_example_gen.exec_properties['custom_config']))
   self.assertEqual('ExamplesPath',
                    presto_example_gen.outputs['examples'].type_name)
   artifact_collection = presto_example_gen.outputs['examples'].get()
   self.assertEqual('train', artifact_collection[0].split)
   self.assertEqual('eval', artifact_collection[1].split)
   self.assertEqual('test', artifact_collection[2].split)
示例#24
0
  def setUp(self):
    super(ExampleGenComponentWithAvroExecutorTest, self).setUp()
    # Create input_base.
    input_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata')
    self.avro_dir_path = os.path.join(input_data_dir, 'external')

    # Create input_config.
    self.input_config = example_gen_pb2.Input(splits=[
        example_gen_pb2.Input.Split(name='avro', pattern='avro/*.avro'),
    ])

    # Create output_config.
    self.output_config = example_gen_pb2.Output(
        split_config=example_gen_pb2.SplitConfig(splits=[
            example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2),
            example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
        ]))
示例#25
0
  def testResolveInputArtifactsWithSpan(self):
    # Test align of span number.
    span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1',
                                'data')
    io_utils.write_string_file(span1_split1, 'testing11')
    span1_split2 = os.path.join(self._input_base_path, 'span01', 'split2',
                                'data')
    io_utils.write_string_file(span1_split2, 'testing12')
    span2_split1 = os.path.join(self._input_base_path, 'span02', 'split1',
                                'data')
    io_utils.write_string_file(span2_split1, 'testing21')

    with self.assertRaisesRegexp(
        ValueError, 'Latest span should be the same for each split'):
      self._example_gen_driver.resolve_input_artifacts(self._input_channels,
                                                       self._exec_properties,
                                                       None, None)

    # Test if latest span is selected when span aligns for each split.
    span2_split2 = os.path.join(self._input_base_path, 'span02', 'split2',
                                'data')
    io_utils.write_string_file(span2_split2, 'testing22')

    self._mock_metadata.get_artifacts_by_uri.return_value = []
    self._mock_metadata.publish_artifacts.return_value = [
        metadata_store_pb2.Artifact()
    ]
    self._example_gen_driver.resolve_input_artifacts(self._input_channels,
                                                     self._exec_properties,
                                                     None, None)
    updated_input_config = example_gen_pb2.Input()
    json_format.Parse(self._exec_properties['input_config'],
                      updated_input_config)
    # Check if latest span is selected.
    self.assertProtoEquals(
        """
        splits {
          name: "s1"
          pattern: "span02/split1/*"
        }
        splits {
          name: "s2"
          pattern: "span02/split2/*"
        }""", updated_input_config)
示例#26
0
    def testDriverWithSpan(self):
        # Test align of span number.
        span1_split1 = os.path.join(_TEST_INPUT_DIR, 'span01', 'split1',
                                    'data')
        io_utils.write_string_file(span1_split1, 'testing11')
        span1_split2 = os.path.join(_TEST_INPUT_DIR, 'span01', 'split2',
                                    'data')
        io_utils.write_string_file(span1_split2, 'testing12')
        span2_split1 = os.path.join(_TEST_INPUT_DIR, 'span02', 'split1',
                                    'data')
        io_utils.write_string_file(span2_split1, 'testing21')

        serialized_args = [
            'driver.py', '--json_serialized_invocation_args',
            json_format.MessageToJson(message=self._executor_invocation)
        ]
        with self.assertRaisesRegexp(
                ValueError, 'Latest span should be the same for each split'):
            driver.main(serialized_args)

        # Test if latest span is selected when span aligns for each split.
        span2_split2 = os.path.join(_TEST_INPUT_DIR, 'span02', 'split2',
                                    'data')
        io_utils.write_string_file(span2_split2, 'testing22')

        driver.main(serialized_args)

        # Check the output metadata file for the expected outputs
        with open(_TEST_OUTPUT_METADATA_JSON) as output_meta_json:
            output_metadata = pipeline_pb2.ExecutorOutput()
            json_format.Parse(output_meta_json.read(),
                              output_metadata,
                              ignore_unknown_fields=True)
            self.assertEqual(output_metadata.parameters['span'].string_value,
                             '2')
            self.assertEqual(
                output_metadata.parameters['input_config'].string_value,
                json_format.MessageToJson(
                    example_gen_pb2.Input(splits=[
                        example_gen_pb2.Input.Split(name='s1',
                                                    pattern='span02/split1/*'),
                        example_gen_pb2.Input.Split(name='s2',
                                                    pattern='span02/split2/*')
                    ])))
示例#27
0
  def setUp(self):
    super().setUp()
    # Create input_base.
    input_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata')
    self.parquet_dir_path = os.path.join(input_data_dir, 'external')

    # Create input_config.
    self.input_config = example_gen_pb2.Input(splits=[
        example_gen_pb2.Input.Split(name='parquet',
                                    pattern='parquet/*.parquet'),
    ])

    # Create output_config.
    self.output_config = example_gen_pb2.Output(
        split_config=example_gen_pb2.SplitConfig(splits=[
            example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2),
            example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
        ]))
示例#28
0
 def testConstructWithInputConfig(self):
   presto_example_gen = component.PrestoExampleGen(
       self.conn_config,
       input_config=example_gen_pb2.Input(splits=[
           example_gen_pb2.Input.Split(name='train', pattern='query1'),
           example_gen_pb2.Input.Split(name='eval', pattern='query2'),
           example_gen_pb2.Input.Split(name='test', pattern='query3')
       ]))
   self.assertEqual(
       self.conn_config,
       self._extract_conn_config(
           presto_example_gen.exec_properties['custom_config']))
   self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                    presto_example_gen.outputs['examples'].type_name)
   artifact_collection = presto_example_gen.outputs['examples'].get()
   self.assertEqual(1, len(artifact_collection))
   self.assertEqual(['train', 'eval', 'test'],
                    artifact_utils.decode_split_names(
                        artifact_collection[0].split_names))
示例#29
0
 def testBuildFileBasedExampleGenWithInputConfig(self):
   input_config = example_gen_pb2.Input(splits=[
       example_gen_pb2.Input.Split(name='train', pattern='*train.tfr'),
       example_gen_pb2.Input.Split(name='eval', pattern='*test.tfr')
   ])
   example_gen = components.ImportExampleGen(
       input_base='path/to/data/root', input_config=input_config)
   deployment_config = pipeline_pb2.PipelineDeploymentConfig()
   my_builder = step_builder.StepBuilder(
       node=example_gen,
       image='gcr.io/tensorflow/tfx:latest',
       deployment_config=deployment_config)
   actual_step_spec = self._sole(my_builder.build())
   self.assertProtoEquals(
       test_utils.get_proto_from_test_data(
           'expected_import_example_gen.pbtxt',
           pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
   self.assertProtoEquals(
       test_utils.get_proto_from_test_data(
           'expected_import_example_gen_executor.pbtxt',
           pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
示例#30
0
    def testPrestoToExample(self):
        with beam.Pipeline() as pipeline:
            examples = (pipeline | 'ToTFExample' >> executor._PrestoToExample(
                exec_properties={
                    'input_config':
                    proto_utils.proto_to_json(example_gen_pb2.Input()),
                    'custom_config':
                    proto_utils.proto_to_json(example_gen_pb2.CustomConfig())
                },
                split_pattern='SELECT i, f, s FROM `fake`'))

            feature = {}
            feature['i'] = tf.train.Feature(int64_list=tf.train.Int64List(
                value=[1]))
            feature['f'] = tf.train.Feature(float_list=tf.train.FloatList(
                value=[2.0]))
            feature['s'] = tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[tf.compat.as_bytes('abc')]))
            example_proto = tf.train.Example(features=tf.train.Features(
                feature=feature))
            util.assert_that(examples, util.equal_to([example_proto]))