Exemplo n.º 1
0
  def testDo(self):
    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create output dict.
    examples = standard_artifacts.Examples()
    examples.uri = output_data_dir
    output_dict = {'examples': [examples]}

    # Create exe properties.
    exec_properties = {
        'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, f, s FROM `fake`'),
                ]),
                preserving_proto_field_name=True),
        'custom_config':
            json_format.MessageToJson(example_gen_pb2.CustomConfig()),
        'output_config':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(
                            name='train', hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(
                            name='eval', hash_buckets=1)
                    ]))),
    }

    # Run executor.
    presto_example_gen = executor.Executor()
    presto_example_gen.Do({}, output_dict, exec_properties)

    self.assertEqual(
        artifact_utils.encode_split_names(['train', 'eval']),
        examples.split_names)

    # Check Presto example gen outputs.
    train_output_file = os.path.join(examples.uri, 'train',
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(examples.uri, 'eval',
                                    'data_tfrecord-00000-of-00001.gz')
    self.assertTrue(tf.io.gfile.exists(train_output_file))
    self.assertTrue(tf.io.gfile.exists(eval_output_file))
    self.assertGreater(
        tf.io.gfile.GFile(train_output_file).size(),
        tf.io.gfile.GFile(eval_output_file).size())
Exemplo n.º 2
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        train_examples = types.TfxArtifact(type_name='ExamplesPath',
                                           split='train')
        train_examples.uri = os.path.join(output_data_dir, 'train')
        eval_examples = types.TfxArtifact(type_name='ExamplesPath',
                                          split='eval')
        eval_examples.uri = os.path.join(output_data_dir, 'eval')
        output_dict = {'examples': [train_examples, eval_examples]}

        # Create exe properties.
        exec_properties = {
            'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, f, s FROM `fake`'),
                ])),
            'custom_config':
            json_format.MessageToJson(example_gen_pb2.CustomConfig()),
            'output_config':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ]))),
        }

        # Run executor.
        presto_example_gen = executor.Executor()
        presto_example_gen.Do({}, output_dict, exec_properties)

        # Check Presto example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())