def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir output_dict = {utils.EXAMPLES_KEY: [examples]} # Create exec proterties. exec_properties = { utils.INPUT_BASE_KEY: self._input_data_dir, utils.INPUT_CONFIG_KEY: json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='avro', pattern='avro/*.avro'), ]), preserving_proto_field_name=True), utils.OUTPUT_CONFIG_KEY: json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split( name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split( name='eval', hash_buckets=1) ])), preserving_proto_field_name=True) } # Run executor. avro_example_gen = avro_executor.Executor() avro_example_gen.Do({}, output_dict, exec_properties) self.assertEqual( artifact_utils.encode_split_names(['train', 'eval']), examples.split_names) # Check Avro example gen outputs. train_output_file = os.path.join(examples.uri, 'train', 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(examples.uri, 'eval', 'data_tfrecord-00000-of-00001.gz') self.assertTrue(fileio.exists(train_output_file)) self.assertTrue(fileio.exists(eval_output_file)) self.assertGreater( fileio.open(train_output_file).size(), fileio.open(eval_output_file).size())
def testDo(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create output dict. train_examples = standard_artifacts.Examples(split='train') train_examples.uri = os.path.join(output_data_dir, 'train') eval_examples = standard_artifacts.Examples(split='eval') eval_examples.uri = os.path.join(output_data_dir, 'eval') output_dict = {'examples': [train_examples, eval_examples]} # Create exec proterties. exec_properties = { 'input_config': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='avro', pattern='avro/*.avro'), ])), 'output_config': json_format.MessageToJson( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1) ]))) } # Run executor. avro_example_gen = avro_executor.Executor() avro_example_gen.Do(self._input_dict, output_dict, exec_properties) # Check Avro example gen outputs. train_output_file = os.path.join(train_examples.uri, 'data_tfrecord-00000-of-00001.gz') eval_output_file = os.path.join(eval_examples.uri, 'data_tfrecord-00000-of-00001.gz') self.assertTrue(tf.gfile.Exists(train_output_file)) self.assertTrue(tf.gfile.Exists(eval_output_file)) self.assertGreater( tf.gfile.GFile(train_output_file).size(), tf.gfile.GFile(eval_output_file).size())