def setUp(self): super(DriverTest, self).setUp() # Create input splits. test_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) self._input_base_path = os.path.join(test_dir, 'input_base') tf.io.gfile.makedirs(self._input_base_path) # Mock metadata. self._mock_metadata = tf.test.mock.Mock() self._example_gen_driver = driver.Driver(self._mock_metadata) # Create input dict. input_base = standard_artifacts.ExternalArtifact() input_base.uri = self._input_base_path self._input_channels = { 'input_base': channel_utils.as_channel([input_base]) } # Create exec proterties. self._exec_properties = { 'input_config': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern='span{SPAN}/split1/*'), example_gen_pb2.Input.Split( name='s2', pattern='span{SPAN}/split2/*') ])), }
def setUp(self): super(DriverTest, self).setUp() self._test_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Mock metadata and create driver. self._mock_metadata = tf.compat.v1.test.mock.Mock() self._example_gen_driver = driver.Driver(self._mock_metadata)
def test_prepare_input_for_processing(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) self._logger_config = logging_utils.LoggerConfig( log_root=os.path.join(output_data_dir, 'log_dir')) # Mock metadata. mock_metadata = tf.test.mock.Mock() example_gen_driver = driver.Driver(self._logger_config, mock_metadata) # Mock artifact. artifacts = [] for i in [4, 3, 2, 1]: artifact = metadata_store_pb2.Artifact() artifact.id = i # Only odd ids will be matched to input_base.uri. artifact.uri = 'path-{}'.format(i % 2) artifacts.append(artifact) # Create input dict. input_base = types.TfxType(type_name='ExternalPath') input_base.uri = 'path-1' input_dict = {'input-base': [input_base]} # Cache not hit. mock_metadata.get_all_artifacts.return_value = [] mock_metadata.publish_artifacts.return_value = [artifacts[3]] updated_input_dict = example_gen_driver._prepare_input_for_processing( copy.deepcopy(input_dict)) self.assertEqual(1, len(updated_input_dict)) self.assertEqual(1, len(updated_input_dict['input-base'])) updated_input_base = updated_input_dict['input-base'][0] self.assertEqual(1, updated_input_base.id) self.assertEqual('path-1', updated_input_base.uri) # Cache hit. mock_metadata.get_all_artifacts.return_value = artifacts mock_metadata.publish_artifacts.return_value = [] updated_input_dict = example_gen_driver._prepare_input_for_processing( copy.deepcopy(input_dict)) self.assertEqual(1, len(updated_input_dict)) self.assertEqual(1, len(updated_input_dict['input-base'])) updated_input_base = updated_input_dict['input-base'][0] self.assertEqual(3, updated_input_base.id) self.assertEqual('path-1', updated_input_base.uri)
def test_prepare_input_for_processing(self): # Mock metadata. mock_metadata = tf.test.mock.Mock() example_gen_driver = driver.Driver(mock_metadata) # Mock artifact. artifacts = [] for i in [4, 3, 2, 1]: artifact = metadata_store_pb2.Artifact() artifact.id = i # Only odd ids will be matched to input_base.uri. artifact.uri = 'path-{}'.format(i % 2) artifacts.append(artifact) # Create input dict. input_base = types.TfxArtifact(type_name='ExternalPath') input_base.uri = 'path-1' input_dict = {'input_base': [input_base]} # Cache not hit. mock_metadata.get_all_artifacts.return_value = [] mock_metadata.publish_artifacts.return_value = [artifacts[3]] updated_input_dict = example_gen_driver._prepare_input_for_processing( copy.deepcopy(input_dict)) self.assertEqual(1, len(updated_input_dict)) self.assertEqual(1, len(updated_input_dict['input_base'])) updated_input_base = updated_input_dict['input_base'][0] self.assertEqual(1, updated_input_base.id) self.assertEqual('path-1', updated_input_base.uri) # Cache hit. mock_metadata.get_all_artifacts.return_value = artifacts mock_metadata.publish_artifacts.return_value = [] updated_input_dict = example_gen_driver._prepare_input_for_processing( copy.deepcopy(input_dict)) self.assertEqual(1, len(updated_input_dict)) self.assertEqual(1, len(updated_input_dict['input_base'])) updated_input_base = updated_input_dict['input_base'][0] self.assertEqual(3, updated_input_base.id) self.assertEqual('path-1', updated_input_base.uri)
def testRun(self): # Create input dir. self._input_base_path = os.path.join(self._test_dir, 'input_base') tf.io.gfile.makedirs(self._input_base_path) # Create PipelineInfo and PipelineNode pipeline_info = pipeline_pb2.PipelineInfo() pipeline_node = pipeline_pb2.PipelineNode() # Fake previous outputs span1_v1_split1 = os.path.join(self._input_base_path, 'span01', 'version01', 'split1', 'data') io_utils.write_string_file(span1_v1_split1, 'testing11') span1_v1_split2 = os.path.join(self._input_base_path, 'span01', 'version01', 'split2', 'data') io_utils.write_string_file(span1_v1_split2, 'testing12') ir_driver = driver.Driver(self._mock_metadata, pipeline_info, pipeline_node) example = standard_artifacts.Examples() # Prepare output_dic example.uri = 'my_uri' # Will verify that this uri is not changed. output_dic = {utils.EXAMPLES_KEY: [example]} # Prepare output_dic exec_proterties. exec_properties = { utils.INPUT_BASE_KEY: self._input_base_path, utils.INPUT_CONFIG_KEY: json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern='span{SPAN}/version{VERSION}/split1/*'), example_gen_pb2.Input.Split( name='s2', pattern='span{SPAN}/version{VERSION}/split2/*') ]), preserving_proto_field_name=True), } result = ir_driver.run(None, output_dic, exec_properties) print(result) # Assert exec_properties' values exec_properties = result.exec_properties self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME].int_value, 1) self.assertEqual(exec_properties[utils.VERSION_PROPERTY_NAME].int_value, 1) updated_input_config = example_gen_pb2.Input() json_format.Parse(exec_properties[utils.INPUT_CONFIG_KEY].string_value, updated_input_config) self.assertProtoEquals( """ splits { name: "s1" pattern: "span01/version01/split1/*" } splits { name: "s2" pattern: "span01/version01/split2/*" }""", updated_input_config) self.assertRegex( exec_properties[utils.FINGERPRINT_PROPERTY_NAME].string_value, r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' ) # Assert output_artifacts' values self.assertLen(result.output_artifacts[utils.EXAMPLES_KEY].artifacts, 1) output_example = result.output_artifacts[utils.EXAMPLES_KEY].artifacts[0] self.assertEqual(output_example.uri, example.uri) self.assertEqual( output_example.custom_properties[utils.SPAN_PROPERTY_NAME].string_value, '1') self.assertEqual( output_example.custom_properties[ utils.VERSION_PROPERTY_NAME].string_value, '1') self.assertRegex( output_example.custom_properties[ utils.FINGERPRINT_PROPERTY_NAME].string_value, r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' )
def test_prepare_input_for_processing(self): # Create input splits. test_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) input_base_path = os.path.join(test_dir, 'input_base') split1 = os.path.join(input_base_path, 'split1', 'data') io_utils.write_string_file(split1, 'testing') os.utime(split1, (0, 1)) split2 = os.path.join(input_base_path, 'split2', 'data') io_utils.write_string_file(split2, 'testing2') os.utime(split2, (0, 3)) # Mock metadata. mock_metadata = tf.test.mock.Mock() example_gen_driver = driver.Driver(mock_metadata) # Mock artifact. artifacts = [] for i in [4, 3, 2, 1]: artifact = metadata_store_pb2.Artifact() artifact.id = i artifact.uri = input_base_path # Only odd ids will be matched if i % 2 == 1: artifact.custom_properties[ 'input_fingerprint'].string_value = 'split:s1,num_files:1,total_bytes:7,xor_checksum:1,sum_checksum:1\nsplit:s2,num_files:1,total_bytes:8,xor_checksum:3,sum_checksum:3' else: artifact.custom_properties[ 'input_fingerprint'].string_value = 'not_match' artifacts.append(artifact) # Create input dict. input_base = types.TfxArtifact(type_name='ExternalPath') input_base.uri = input_base_path input_dict = {'input_base': [input_base]} # Create exec proterties. exec_properties = { 'input_config': json_format.MessageToJson( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s1', pattern='split1/*'), example_gen_pb2.Input.Split(name='s2', pattern='split2/*') ])), } # Cache not hit. mock_metadata.get_artifacts_by_uri.return_value = [artifacts[0]] mock_metadata.publish_artifacts.return_value = [artifacts[3]] updated_input_dict = example_gen_driver._prepare_input_for_processing( copy.deepcopy(input_dict), exec_properties) self.assertEqual(1, len(updated_input_dict)) self.assertEqual(1, len(updated_input_dict['input_base'])) updated_input_base = updated_input_dict['input_base'][0] self.assertEqual(1, updated_input_base.id) self.assertEqual(input_base_path, updated_input_base.uri) # Cache hit. mock_metadata.get_artifacts_by_uri.return_value = artifacts mock_metadata.publish_artifacts.return_value = [] updated_input_dict = example_gen_driver._prepare_input_for_processing( copy.deepcopy(input_dict), exec_properties) self.assertEqual(1, len(updated_input_dict)) self.assertEqual(1, len(updated_input_dict['input_base'])) updated_input_base = updated_input_dict['input_base'][0] self.assertEqual(3, updated_input_base.id) self.assertEqual(input_base_path, updated_input_base.uri)
def testDriverRunFn(self): # Create input dir. self._input_base_path = os.path.join(self._test_dir, 'input_base') fileio.makedirs(self._input_base_path) # Fake previous outputs span1_v1_split1 = os.path.join(self._input_base_path, 'span01', 'split1', 'data') io_utils.write_string_file(span1_v1_split1, 'testing11') span1_v1_split2 = os.path.join(self._input_base_path, 'span01', 'split2', 'data') io_utils.write_string_file(span1_v1_split2, 'testing12') ir_driver = driver.Driver(self._mock_metadata) example = standard_artifacts.Examples() # Prepare output_dic example.uri = 'my_uri' # Will verify that this uri is not changed. output_dic = {standard_component_specs.EXAMPLES_KEY: [example]} # Prepare output_dic exec_proterties. exec_properties = { standard_component_specs.INPUT_BASE_KEY: self._input_base_path, standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN}/split1/*'), example_gen_pb2.Input.Split(name='s2', pattern='span{SPAN}/split2/*') ])), } result = ir_driver.run( portable_data_types.ExecutionInfo(output_dict=output_dic, exec_properties=exec_properties)) # Assert exec_properties' values exec_properties = result.exec_properties self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME].int_value, 1) updated_input_config = example_gen_pb2.Input() proto_utils.json_to_proto( exec_properties[ standard_component_specs.INPUT_CONFIG_KEY].string_value, updated_input_config) self.assertProtoEquals( """ splits { name: "s1" pattern: "span01/split1/*" } splits { name: "s2" pattern: "span01/split2/*" }""", updated_input_config) self.assertRegex( exec_properties[utils.FINGERPRINT_PROPERTY_NAME].string_value, r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' ) # Assert output_artifacts' values self.assertLen( result.output_artifacts[ standard_component_specs.EXAMPLES_KEY].artifacts, 1) output_example = result.output_artifacts[ standard_component_specs.EXAMPLES_KEY].artifacts[0] self.assertEqual(output_example.uri, example.uri) self.assertEqual( output_example.custom_properties[ utils.SPAN_PROPERTY_NAME].string_value, '1') self.assertRegex( output_example.custom_properties[ utils.FINGERPRINT_PROPERTY_NAME].string_value, r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' )