def setUp(self): super(DriverTest, self).setUp() self._test_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Mock metadata and create driver. self._mock_metadata = tf.compat.v1.test.mock.Mock() self._file_based_driver = driver.FileBasedDriver(self._mock_metadata)
def testDriverRunFn(self): # Create input dir. self._input_base_path = os.path.join(self._test_dir, 'input_base') fileio.makedirs(self._input_base_path) # Fake previous outputs span1_v1_split1 = os.path.join(self._input_base_path, 'span01', 'split1', 'data') io_utils.write_string_file(span1_v1_split1, 'testing11') span1_v1_split2 = os.path.join(self._input_base_path, 'span01', 'split2', 'data') io_utils.write_string_file(span1_v1_split2, 'testing12') ir_driver = driver.FileBasedDriver(self._mock_metadata) example = standard_artifacts.Examples() # Prepare output_dic example.uri = 'my_uri' # Will verify that this uri is not changed. output_dic = {standard_component_specs.EXAMPLES_KEY: [example]} # Prepare output_dic exec_proterties. exec_properties = { standard_component_specs.INPUT_BASE_KEY: self._input_base_path, standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern='span{SPAN:2}/split1/*'), example_gen_pb2.Input.Split( name='s2', pattern='span{SPAN:2}/split2/*') ])), } result = ir_driver.run( portable_data_types.ExecutionInfo(output_dict=output_dic, exec_properties=exec_properties)) # Assert exec_properties' values exec_properties = result.exec_properties self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME].int_value, 1) updated_input_config = example_gen_pb2.Input() proto_utils.json_to_proto( exec_properties[ standard_component_specs.INPUT_CONFIG_KEY].string_value, updated_input_config) self.assertProtoEquals( """ splits { name: "s1" pattern: "span01/split1/*" } splits { name: "s2" pattern: "span01/split2/*" }""", updated_input_config) self.assertRegex( exec_properties[utils.FINGERPRINT_PROPERTY_NAME].string_value, r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' ) # Assert output_artifacts' values self.assertLen( result.output_artifacts[ standard_component_specs.EXAMPLES_KEY].artifacts, 1) output_example = result.output_artifacts[ standard_component_specs.EXAMPLES_KEY].artifacts[0] self.assertEqual(output_example.uri, example.uri) self.assertEqual( output_example.custom_properties[ utils.SPAN_PROPERTY_NAME].string_value, '1') self.assertRegex( output_example.custom_properties[ utils.FINGERPRINT_PROPERTY_NAME].string_value, r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' )