コード例 #1
0
  def testDo(self):
    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create output dict.
    examples = standard_artifacts.Examples()
    examples.uri = output_data_dir
    output_dict = {utils.EXAMPLES_KEY: [examples]}

    # Create exec proterties.
    exec_properties = {
        utils.INPUT_BASE_KEY:
            self._input_data_dir,
        utils.INPUT_CONFIG_KEY:
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='avro', pattern='avro/*.avro'),
                ]),
                preserving_proto_field_name=True),
        utils.OUTPUT_CONFIG_KEY:
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(
                            name='train', hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(
                            name='eval', hash_buckets=1)
                    ])),
                preserving_proto_field_name=True)
    }

    # Run executor.
    avro_example_gen = avro_executor.Executor()
    avro_example_gen.Do({}, output_dict, exec_properties)

    self.assertEqual(
        artifact_utils.encode_split_names(['train', 'eval']),
        examples.split_names)

    # Check Avro example gen outputs.
    train_output_file = os.path.join(examples.uri, 'train',
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(examples.uri, 'eval',
                                    'data_tfrecord-00000-of-00001.gz')
    self.assertTrue(fileio.exists(train_output_file))
    self.assertTrue(fileio.exists(eval_output_file))
    self.assertGreater(
        fileio.open(train_output_file).size(),
        fileio.open(eval_output_file).size())
コード例 #2
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        train_examples = standard_artifacts.Examples(split='train')
        train_examples.uri = os.path.join(output_data_dir, 'train')
        eval_examples = standard_artifacts.Examples(split='eval')
        eval_examples.uri = os.path.join(output_data_dir, 'eval')
        output_dict = {'examples': [train_examples, eval_examples]}

        # Create exec proterties.
        exec_properties = {
            'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='avro',
                                                pattern='avro/*.avro'),
                ])),
            'output_config':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ])))
        }

        # Run executor.
        avro_example_gen = avro_executor.Executor()
        avro_example_gen.Do(self._input_dict, output_dict, exec_properties)

        # Check Avro example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())