示例#1
0
 def testConstructWithOutputConfig(self):
     big_query_example_gen = component.BigQueryExampleGen(
         query='query',
         output_config=example_gen_pb2.Output(
             split_config=example_gen_pb2.SplitConfig(splits=[
                 example_gen_pb2.SplitConfig.Split(name='train',
                                                   hash_buckets=2),
                 example_gen_pb2.SplitConfig.Split(name='eval',
                                                   hash_buckets=1),
             ])))
     self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                      big_query_example_gen.outputs['examples'].type_name)
 def _testFeatureBasedPartition(self, partition_feature_name):
   self._exec_properties[
       standard_component_specs.OUTPUT_CONFIG_KEY] = proto_utils.proto_to_json(
           example_gen_pb2.Output(
               split_config=example_gen_pb2.SplitConfig(
                   splits=[
                       example_gen_pb2.SplitConfig.Split(
                           name='train', hash_buckets=2),
                       example_gen_pb2.SplitConfig.Split(
                           name='eval', hash_buckets=1)
                   ],
                   partition_feature_name=partition_feature_name)))
示例#3
0
    def testDo(self, mock_client):
        # Mock query result schema for _BigQueryConverter.
        mock_client.return_value.query.return_value.result.return_value.schema = self._schema

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        examples = standard_artifacts.Examples()
        examples.uri = output_data_dir
        output_dict = {'examples': [examples]}

        # Create exe properties.
        exec_properties = {
            'input_config':
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, b, f, s FROM `fake`'),
                ])),
            'output_config':
            proto_utils.proto_to_json(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ])))
        }

        # Run executor.
        big_query_example_gen = executor.Executor(
            base_executor.BaseExecutor.Context(
                beam_pipeline_args=['--project=test-project']))
        big_query_example_gen.Do({}, output_dict, exec_properties)

        mock_client.assert_called_with(project='test-project')

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         examples.split_names)

        # Check BigQuery example gen outputs.
        train_output_file = os.path.join(examples.uri, 'Split-train',
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(examples.uri, 'Split-eval',
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(fileio.exists(train_output_file))
        self.assertTrue(fileio.exists(eval_output_file))
        self.assertGreater(
            fileio.open(train_output_file).size(),
            fileio.open(eval_output_file).size())
示例#4
0
 def _testFeatureBasedPartition(self, partition_feature_name):
     self._exec_properties[
         utils.OUTPUT_CONFIG_KEY] = json_format.MessageToJson(
             example_gen_pb2.Output(
                 split_config=example_gen_pb2.SplitConfig(
                     splits=[
                         example_gen_pb2.SplitConfig.Split(name='train',
                                                           hash_buckets=2),
                         example_gen_pb2.SplitConfig.Split(name='eval',
                                                           hash_buckets=1)
                     ],
                     partition_feature_name=partition_feature_name)))
示例#5
0
  def testDo(self, mock_client):
    # Mock query result schema for _BigQueryConverter.
    mock_client.return_value.query.return_value.result.return_value.schema = self._schema

    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create output dict.
    examples = standard_artifacts.Examples()
    examples.uri = output_data_dir
    output_dict = {'examples': [examples]}

    # Create exe properties.
    exec_properties = {
        'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, b, f, s FROM `fake`'),
                ]),
                preserving_proto_field_name=True),
        'output_config':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(
                            name='train', hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(
                            name='eval', hash_buckets=1)
                    ])),
                preserving_proto_field_name=True)
    }

    # Run executor.
    big_query_example_gen = executor.Executor()
    big_query_example_gen.Do({}, output_dict, exec_properties)

    self.assertEqual(
        artifact_utils.encode_split_names(['train', 'eval']),
        examples.split_names)

    # Check BigQuery example gen outputs.
    train_output_file = os.path.join(examples.uri, 'train',
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(examples.uri, 'eval',
                                    'data_tfrecord-00000-of-00001.gz')
    self.assertTrue(tf.io.gfile.exists(train_output_file))
    self.assertTrue(tf.io.gfile.exists(eval_output_file))
    self.assertGreater(
        tf.io.gfile.GFile(train_output_file).size(),
        tf.io.gfile.GFile(eval_output_file).size())
示例#6
0
def make_default_output_config(input_config: example_gen_pb2.Input
                              ) -> example_gen_pb2.Output:
  """Returns default output config based on input config."""
  if len(input_config.splits) > 1:
    # Returns empty output split config as output split will be same as input.
    return example_gen_pb2.Output()
  else:
    # Returns 'train' and 'eval' splits with size 2:1.
    return example_gen_pb2.Output(
        split_config=example_gen_pb2.SplitConfig(splits=[
            example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2),
            example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
        ]))
示例#7
0
  def testDo(self):
    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create output dict.
    examples = standard_artifacts.Examples()
    examples.uri = output_data_dir
    output_dict = {utils.EXAMPLES_KEY: [examples]}

    # Create exec proterties.
    exec_properties = {
        utils.INPUT_BASE_KEY:
            self._input_data_dir,
        utils.INPUT_CONFIG_KEY:
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='avro', pattern='avro/*.avro'),
                ]),
                preserving_proto_field_name=True),
        utils.OUTPUT_CONFIG_KEY:
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(
                            name='train', hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(
                            name='eval', hash_buckets=1)
                    ])),
                preserving_proto_field_name=True)
    }

    # Run executor.
    avro_example_gen = avro_executor.Executor()
    avro_example_gen.Do({}, output_dict, exec_properties)

    self.assertEqual(
        artifact_utils.encode_split_names(['train', 'eval']),
        examples.split_names)

    # Check Avro example gen outputs.
    train_output_file = os.path.join(examples.uri, 'train',
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(examples.uri, 'eval',
                                    'data_tfrecord-00000-of-00001.gz')
    self.assertTrue(fileio.exists(train_output_file))
    self.assertTrue(fileio.exists(eval_output_file))
    self.assertGreater(
        fileio.open(train_output_file).size(),
        fileio.open(eval_output_file).size())
示例#8
0
    def testDo(self, mock_client):
        # Mock query result schema for _BigQueryConverter.
        mock_client.return_value.query.return_value.result.return_value.schema = self._schema

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        train_examples = types.TfxArtifact(type_name='ExamplesPath',
                                           split='train')
        train_examples.uri = os.path.join(output_data_dir, 'train')
        eval_examples = types.TfxArtifact(type_name='ExamplesPath',
                                          split='eval')
        eval_examples.uri = os.path.join(output_data_dir, 'eval')
        output_dict = {'examples': [train_examples, eval_examples]}

        # Create exe properties.
        exec_properties = {
            'input':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, f, s FROM `fake`'),
                ])),
            'output':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ])))
        }

        # Run executor.
        big_query_example_gen = executor.Executor()
        big_query_example_gen.Do({}, output_dict, exec_properties)

        # Check BigQuery example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())
示例#9
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        train_examples = types.TfxArtifact(type_name='ExamplesPath',
                                           split='train')
        train_examples.uri = os.path.join(output_data_dir, 'train')
        eval_examples = types.TfxArtifact(type_name='ExamplesPath',
                                          split='eval')
        eval_examples.uri = os.path.join(output_data_dir, 'eval')
        output_dict = {'examples': [train_examples, eval_examples]}

        # Create exe properties.
        exec_properties = {
            'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, f, s FROM `fake`'),
                ])),
            'custom_config':
            json_format.MessageToJson(example_gen_pb2.CustomConfig()),
            'output_config':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ]))),
        }

        # Run executor.
        presto_example_gen = executor.Executor()
        presto_example_gen.Do({}, output_dict, exec_properties)

        # Check Presto example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())
示例#10
0
 def test_construct_with_output_config(self):
   input_base = types.TfxArtifact(type_name='ExternalPath')
   example_gen = TestFileBasedExampleGenComponent(
       input_base=channel.as_channel([input_base]),
       output_config=example_gen_pb2.Output(
           split_config=example_gen_pb2.SplitConfig(splits=[
               example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2),
               example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1),
               example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1)
           ])))
   self.assertEqual('ExamplesPath', example_gen.outputs.examples.type_name)
   artifact_collection = example_gen.outputs.examples.get()
   self.assertEqual('train', artifact_collection[0].split)
   self.assertEqual('eval', artifact_collection[1].split)
   self.assertEqual('test', artifact_collection[2].split)
示例#11
0
 def testConstructWithOutputConfig(self):
   input_base = standard_artifacts.ExternalArtifact()
   example_gen = TestFileBasedExampleGenComponent(
       input_base=channel_utils.as_channel([input_base]),
       output_config=example_gen_pb2.Output(
           split_config=example_gen_pb2.SplitConfig(splits=[
               example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2),
               example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1),
               example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1)
           ])))
   self.assertEqual('ExamplesPath', example_gen.outputs['examples'].type_name)
   artifact_collection = example_gen.outputs['examples'].get()
   self.assertEqual('train', artifact_collection[0].split)
   self.assertEqual('eval', artifact_collection[1].split)
   self.assertEqual('test', artifact_collection[2].split)
示例#12
0
 def test_construct_with_output_config(self):
   big_query_example_gen = component.BigQueryExampleGen(
       query='',
       output_config=example_gen_pb2.Output(
           split_config=example_gen_pb2.SplitConfig(splits=[
               example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2),
               example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1),
               example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1)
           ])))
   self.assertEqual('ExamplesPath',
                    big_query_example_gen.outputs.examples.type_name)
   artifact_collection = big_query_example_gen.outputs.examples.get()
   self.assertEqual('train', artifact_collection[0].split)
   self.assertEqual('eval', artifact_collection[1].split)
   self.assertEqual('test', artifact_collection[2].split)
示例#13
0
 def testConstructWithOutputConfig(self):
     big_query_to_elwc_example_gen = component.BigQueryToElwcExampleGen(
         query='query',
         elwc_config=elwc_config_pb2.ElwcConfig(
             context_feature_fields=['query_id', 'query_content']),
         output_config=example_gen_pb2.
         Output(split_config=example_gen_pb2.SplitConfig(splits=[
             example_gen_pb2.SplitConfig.Split(name='train',
                                               hash_buckets=2),
             example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1),
             example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1)
         ])))
     self.assertEqual(
         standard_artifacts.Examples.TYPE_NAME,
         big_query_to_elwc_example_gen.outputs['examples'].type_name)
示例#14
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        examples = standard_artifacts.Examples()
        examples.uri = output_data_dir
        output_dict = {'examples': [examples]}

        # Create exe properties.
        exec_properties = {
            'input_config':
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, f, s FROM `fake`'),
                ])),
            'custom_config':
            proto_utils.proto_to_json(example_gen_pb2.CustomConfig()),
            'output_config':
            proto_utils.proto_to_json(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ]))),
        }

        # Run executor.
        presto_example_gen = executor.Executor()
        presto_example_gen.Do({}, output_dict, exec_properties)

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         examples.split_names)

        # Check Presto example gen outputs.
        train_output_file = os.path.join(examples.uri, 'Split-train',
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(examples.uri, 'Split-eval',
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(fileio.exists(train_output_file))
        self.assertTrue(fileio.exists(eval_output_file))
        self.assertGreater(
            fileio.open(train_output_file).size(),
            fileio.open(eval_output_file).size())
示例#15
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        examples = standard_artifacts.Examples()
        examples.uri = output_data_dir
        output_dict = {standard_component_specs.EXAMPLES_KEY: [examples]}

        # Create exec proterties.
        exec_properties = {
            standard_component_specs.INPUT_BASE_KEY:
            self._input_data_dir,
            standard_component_specs.INPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='parquet',
                                                pattern='parquet/*'),
                ])),
            standard_component_specs.OUTPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ])))
        }

        # Run executor.
        parquet_example_gen = parquet_executor.Executor()
        parquet_example_gen.Do({}, output_dict, exec_properties)

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         examples.split_names)

        # Check Parquet example gen outputs.
        train_output_file = os.path.join(examples.uri, 'Split-train',
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(examples.uri, 'Split-eval',
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(fileio.exists(train_output_file))
        self.assertTrue(fileio.exists(eval_output_file))
        self.assertGreater(
            fileio.open(train_output_file).size(),
            fileio.open(eval_output_file).size())
示例#16
0
 def testConstructWithOutputConfig(self):
   big_query_example_gen = component.BigQueryExampleGen(
       query='query',
       output_config=example_gen_pb2.Output(
           split_config=example_gen_pb2.SplitConfig(splits=[
               example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2),
               example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1),
               example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1)
           ])))
   self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                    big_query_example_gen.outputs['examples'].type_name)
   artifact_collection = big_query_example_gen.outputs['examples'].get()
   self.assertEqual(1, len(artifact_collection))
   self.assertEqual(['train', 'eval', 'test'],
                    artifact_utils.decode_split_names(
                        artifact_collection[0].split_names))
示例#17
0
 def testConstructWithOutputConfig(self):
     example_gen = TestFileBasedExampleGenComponent(
         input_base='path',
         output_config=example_gen_pb2.
         Output(split_config=example_gen_pb2.SplitConfig(splits=[
             example_gen_pb2.SplitConfig.Split(name='train',
                                               hash_buckets=2),
             example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1),
             example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1)
         ])))
     self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                      example_gen.outputs['examples'].type_name)
     artifact_collection = example_gen.outputs['examples'].get()
     self.assertEqual(1, len(artifact_collection))
     self.assertEqual(['train', 'eval', 'test'],
                      artifact_utils.decode_split_names(
                          artifact_collection[0].split_names))
示例#18
0
  def setUp(self):
    super(ExecutorTest, self).setUp()
    self._input_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata',
        'external')

    # Create values in exec_properties
    self._input_config = proto_utils.proto_to_json(
        example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='tfrecord', pattern='tfrecord/*'),
        ]))
    self._output_config = proto_utils.proto_to_json(
        example_gen_pb2.Output(
            split_config=example_gen_pb2.SplitConfig(splits=[
                example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2),
                example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
            ])))
示例#19
0
    def testConstructWithOutputConfig(self):
        output_config = example_gen_pb2.Output(
            split_config=example_gen_pb2.SplitConfig(splits=[
                example_gen_pb2.SplitConfig.Split(name='train',
                                                  hash_buckets=2),
                example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1),
                example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1)
            ]))
        example_gen = TestFileBasedExampleGenComponent(
            input_base='path', output_config=output_config)
        self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                         example_gen.outputs['examples'].type_name)

        stored_output_config = example_gen_pb2.Output()
        json_format.Parse(example_gen.exec_properties['output_config'],
                          stored_output_config)
        self.assertEqual(output_config, stored_output_config)
示例#20
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        train_examples = standard_artifacts.Examples(split='train')
        train_examples.uri = os.path.join(output_data_dir, 'train')
        eval_examples = standard_artifacts.Examples(split='eval')
        eval_examples.uri = os.path.join(output_data_dir, 'eval')
        output_dict = {'examples': [train_examples, eval_examples]}

        # Create exec proterties.
        exec_properties = {
            'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='avro',
                                                pattern='avro/*.avro'),
                ])),
            'output_config':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ])))
        }

        # Run executor.
        avro_example_gen = avro_executor.Executor()
        avro_example_gen.Do(self._input_dict, output_dict, exec_properties)

        # Check Avro example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())
示例#21
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        examples = standard_artifacts.Examples()
        examples.uri = output_data_dir
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        output_dict = {'examples': [examples]}

        # Create exec proterties.
        exec_properties = {
            'input_config':
            json_format.MessageToJson(example_gen_pb2.Input(splits=[
                example_gen_pb2.Input.Split(name='parquet',
                                            pattern='parquet/*'),
            ]),
                                      preserving_proto_field_name=True),
            'output_config':
            json_format.MessageToJson(example_gen_pb2.Output(
                split_config=example_gen_pb2.SplitConfig(splits=[
                    example_gen_pb2.SplitConfig.Split(name='train',
                                                      hash_buckets=2),
                    example_gen_pb2.SplitConfig.Split(name='eval',
                                                      hash_buckets=1)
                ])),
                                      preserving_proto_field_name=True)
        }

        # Run executor.
        parquet_example_gen = parquet_executor.Executor()
        parquet_example_gen.Do(self._input_dict, output_dict, exec_properties)

        # Check Parquet example gen outputs.
        train_output_file = os.path.join(examples.uri, 'train',
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(examples.uri, 'eval',
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.io.gfile.exists(train_output_file))
        self.assertTrue(tf.io.gfile.exists(eval_output_file))
        self.assertGreater(
            tf.io.gfile.GFile(train_output_file).size(),
            tf.io.gfile.GFile(eval_output_file).size())
示例#22
0
  def setUp(self):
    super(ExampleGenComponentWithAvroExecutorTest, self).setUp()
    # Create input_base.
    input_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata')
    self.avro_dir_path = os.path.join(input_data_dir, 'external')

    # Create input_config.
    self.input_config = example_gen_pb2.Input(splits=[
        example_gen_pb2.Input.Split(name='avro', pattern='avro/*.avro'),
    ])

    # Create output_config.
    self.output_config = example_gen_pb2.Output(
        split_config=example_gen_pb2.SplitConfig(splits=[
            example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2),
            example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
        ]))
示例#23
0
  def testEmptyFeature(self):
    # Add output config to exec proterties.
    self._exec_properties['output_config'] = json_format.MessageToJson(
        example_gen_pb2.Output(
            split_config=example_gen_pb2.SplitConfig(
                splits=[
                    example_gen_pb2.SplitConfig.Split(
                        name='train', hash_buckets=2),
                    example_gen_pb2.SplitConfig.Split(
                        name='eval', hash_buckets=1)
                ],
                partition_feature_name='i')))

    # Run executor.
    example_gen = TestExampleGenExecutor()
    with self.assertRaisesRegexp(
        RuntimeError, 'Partition feature does not contain any value.'):
      example_gen.Do({}, self._output_dict, self._exec_properties)
示例#24
0
  def testConstructWithOutputConfig(self):
    output_config = example_gen_pb2.Output(
        split_config=example_gen_pb2.SplitConfig(splits=[
            example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2),
            example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1),
            example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1)
        ]))
    example_gen = TestFileBasedExampleGenComponent(
        input_base='path', output_config=output_config)
    self.assertEqual(
        standard_artifacts.Examples.TYPE_NAME,
        example_gen.outputs[standard_component_specs.EXAMPLES_KEY].type_name)

    stored_output_config = example_gen_pb2.Output()
    proto_utils.json_to_proto(
        example_gen.exec_properties[standard_component_specs.OUTPUT_CONFIG_KEY],
        stored_output_config)
    self.assertEqual(output_config, stored_output_config)
示例#25
0
    def testInvalidFeatureName(self):
        # Add output config to exec proterties.
        self._exec_properties[
            utils.OUTPUT_CONFIG_KEY] = json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(
                        splits=[
                            example_gen_pb2.SplitConfig.Split(name='train',
                                                              hash_buckets=2),
                            example_gen_pb2.SplitConfig.Split(name='eval',
                                                              hash_buckets=1)
                        ],
                        partition_feature_name='invalid')))

        # Run executor.
        example_gen = TestExampleGenExecutor()
        with self.assertRaisesRegexp(RuntimeError,
                                     'Feature name `.*` does not exist.'):
            example_gen.Do({}, self._output_dict, self._exec_properties)
示例#26
0
  def testMakeOutputSplitNames(self):
    split_names = utils.generate_output_split_names(
        input_config=example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='train', pattern='train/*'),
            example_gen_pb2.Input.Split(name='eval', pattern='eval/*')
        ]),
        output_config=example_gen_pb2.Output())
    self.assertListEqual(['train', 'eval'], split_names)

    split_names = utils.generate_output_split_names(
        input_config=example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='single', pattern='single/*')
        ]),
        output_config=example_gen_pb2.Output(
            split_config=example_gen_pb2.SplitConfig(splits=[
                example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2),
                example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
            ])))
    self.assertListEqual(['train', 'eval'], split_names)
示例#27
0
def make_default_output_config(
    input_config: Union[example_gen_pb2.Input, Dict[Text, Any]]
) -> example_gen_pb2.Output:
    """Returns default output config based on input config."""
    if isinstance(input_config, example_gen_pb2.Input):
        input_config = json_format.MessageToDict(
            input_config, including_default_value_fields=True)

    if len(input_config['splits']) > 1:
        # Returns empty output split config as output split will be same as input.
        return example_gen_pb2.Output()
    else:
        # Returns 'train' and 'eval' splits with size 2:1.
        return example_gen_pb2.Output(split_config=example_gen_pb2.SplitConfig(
            splits=[
                example_gen_pb2.SplitConfig.Split(name='train',
                                                  hash_buckets=2),
                example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
            ]))
示例#28
0
  def setUp(self):
    super().setUp()
    # Create input_base.
    input_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata')
    self.parquet_dir_path = os.path.join(input_data_dir, 'external')

    # Create input_config.
    self.input_config = example_gen_pb2.Input(splits=[
        example_gen_pb2.Input.Split(name='parquet',
                                    pattern='parquet/*.parquet'),
    ])

    # Create output_config.
    self.output_config = example_gen_pb2.Output(
        split_config=example_gen_pb2.SplitConfig(splits=[
            example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2),
            example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
        ]))
示例#29
0
 def testConstructWithOutputConfig(self):
   presto_example_gen = component.PrestoExampleGen(
       self.conn_config,
       query='query',
       output_config=example_gen_pb2.Output(
           split_config=example_gen_pb2.SplitConfig(splits=[
               example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2),
               example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1),
               example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1)
           ])))
   self.assertEqual(
       self.conn_config,
       self._extract_conn_config(
           presto_example_gen.exec_properties['custom_config']))
   self.assertEqual('ExamplesPath',
                    presto_example_gen.outputs['examples'].type_name)
   artifact_collection = presto_example_gen.outputs['examples'].get()
   self.assertEqual('train', artifact_collection[0].split)
   self.assertEqual('eval', artifact_collection[1].split)
   self.assertEqual('test', artifact_collection[2].split)
示例#30
0
  def testInvalidFloatListFeature(self):
    # Add output config to exec proterties.
    self._exec_properties['output_config'] = json_format.MessageToJson(
        example_gen_pb2.Output(
            split_config=example_gen_pb2.SplitConfig(
                splits=[
                    example_gen_pb2.SplitConfig.Split(
                        name='train', hash_buckets=2),
                    example_gen_pb2.SplitConfig.Split(
                        name='eval', hash_buckets=1)
                ],
                partition_feature_name='f')))
    self._exec_properties['has_empty'] = False

    # Run executor.
    example_gen = TestExampleGenExecutor()
    with self.assertRaisesRegexp(
        RuntimeError,
        'Only `bytes_list` and `int64_list` features are supported for partition.'
    ):
      example_gen.Do({}, self._output_dict, self._exec_properties)