def compare_anomalies(output_uri: Text, expected_uri: Text) -> bool: """Compares anomalies files in output uri and recorded uri. Args: output_uri: pipeline output artifact uri. expected_uri: recorded pipeline output artifact uri. Returns: boolean whether anomalies are same. """ for dir_name, _, leaf_files in fileio.walk(expected_uri): for leaf_file in leaf_files: expected_file_name = os.path.join(dir_name, leaf_file) file_name = os.path.join( dir_name.replace(expected_uri, output_uri, 1), leaf_file) anomalies = anomalies_pb2.Anomalies() anomalies.ParseFromString( io_utils.read_bytes_file(os.path.join(output_uri, file_name))) expected_anomalies = anomalies_pb2.Anomalies() expected_anomalies.ParseFromString( io_utils.read_bytes_file( os.path.join(expected_uri, expected_file_name))) if expected_anomalies.anomaly_info != anomalies.anomaly_info: return False return True
def testDo(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') eval_stats_artifact = standard_artifacts.ExampleStatistics() eval_stats_artifact.uri = os.path.join(source_data_dir, 'statistics_gen') eval_stats_artifact.split_names = artifact_utils.encode_split_names( ['train', 'eval', 'test']) schema_artifact = standard_artifacts.Schema() schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen') output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) validation_output = standard_artifacts.ExampleAnomalies() validation_output.uri = os.path.join(output_data_dir, 'output') input_dict = { STATISTICS_KEY: [eval_stats_artifact], SCHEMA_KEY: [schema_artifact], } exec_properties = { # List needs to be serialized before being passed into Do function. EXCLUDE_SPLITS_KEY: json_utils.dumps(['test']) } output_dict = { ANOMALIES_KEY: [validation_output], } example_validator_executor = executor.Executor() example_validator_executor.Do(input_dict, output_dict, exec_properties) self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']), validation_output.split_names) # Check example_validator outputs. train_anomalies_path = os.path.join(validation_output.uri, 'Split-train', 'SchemaDiff.pb') eval_anomalies_path = os.path.join(validation_output.uri, 'Split-eval', 'SchemaDiff.pb') self.assertTrue(fileio.exists(train_anomalies_path)) self.assertTrue(fileio.exists(eval_anomalies_path)) train_anomalies_bytes = io_utils.read_bytes_file(train_anomalies_path) train_anomalies = anomalies_pb2.Anomalies() train_anomalies.ParseFromString(train_anomalies_bytes) eval_anomalies_bytes = io_utils.read_bytes_file(eval_anomalies_path) eval_anomalies = anomalies_pb2.Anomalies() eval_anomalies.ParseFromString(eval_anomalies_bytes) self.assertEqual(0, len(train_anomalies.anomaly_info)) self.assertEqual(0, len(eval_anomalies.anomaly_info)) # Assert 'test' split is excluded. train_file_path = os.path.join(validation_output.uri, 'Split-test', 'SchemaDiff.pb') self.assertFalse(fileio.exists(train_file_path))
def display(self, artifact: types.Artifact): from IPython.core.display import display # pylint: disable=g-import-not-at-top from IPython.core.display import HTML # pylint: disable=g-import-not-at-top for split in artifact_utils.decode_split_names(artifact.split_names): display(HTML('<div><b>%r split:</b></div><br/>' % split)) anomalies_path = io_utils.get_only_uri_in_dir( artifact_utils.get_split_uri([artifact], split)) if artifact_utils.is_artifact_version_older_than( artifact, artifact_utils._ARTIFACT_VERSION_FOR_ANOMALIES_UPDATE): # pylint: disable=protected-access anomalies = tfdv.load_anomalies_text(anomalies_path) else: anomalies = anomalies_pb2.Anomalies() anomalies_bytes = io_utils.read_bytes_file(anomalies_path) anomalies.ParseFromString(anomalies_bytes) tfdv.display_anomalies(anomalies)
def testReadWriteBytes(self): file_path = os.path.join(self._base_dir, 'test_file') content = b'testing read/write' io_utils.write_bytes_file(file_path, content) read_content = io_utils.read_bytes_file(file_path) self.assertEqual(content, read_content)