예제 #1
0
def compare_anomalies(output_uri: Text, expected_uri: Text) -> bool:
    """Compares anomalies files in output uri and recorded uri.

  Args:
    output_uri: pipeline output artifact uri.
    expected_uri: recorded pipeline output artifact uri.

  Returns:
     boolean whether anomalies are same.
  """
    for dir_name, _, leaf_files in fileio.walk(expected_uri):
        for leaf_file in leaf_files:
            expected_file_name = os.path.join(dir_name, leaf_file)
            file_name = os.path.join(
                dir_name.replace(expected_uri, output_uri, 1), leaf_file)
            anomalies = anomalies_pb2.Anomalies()
            anomalies.ParseFromString(
                io_utils.read_bytes_file(os.path.join(output_uri, file_name)))
            expected_anomalies = anomalies_pb2.Anomalies()
            expected_anomalies.ParseFromString(
                io_utils.read_bytes_file(
                    os.path.join(expected_uri, expected_file_name)))
            if expected_anomalies.anomaly_info != anomalies.anomaly_info:
                return False
    return True
예제 #2
0
    def testDo(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        eval_stats_artifact = standard_artifacts.ExampleStatistics()
        eval_stats_artifact.uri = os.path.join(source_data_dir,
                                               'statistics_gen')
        eval_stats_artifact.split_names = artifact_utils.encode_split_names(
            ['train', 'eval', 'test'])

        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        validation_output = standard_artifacts.ExampleAnomalies()
        validation_output.uri = os.path.join(output_data_dir, 'output')

        input_dict = {
            STATISTICS_KEY: [eval_stats_artifact],
            SCHEMA_KEY: [schema_artifact],
        }

        exec_properties = {
            # List needs to be serialized before being passed into Do function.
            EXCLUDE_SPLITS_KEY: json_utils.dumps(['test'])
        }

        output_dict = {
            ANOMALIES_KEY: [validation_output],
        }

        example_validator_executor = executor.Executor()
        example_validator_executor.Do(input_dict, output_dict, exec_properties)

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         validation_output.split_names)

        # Check example_validator outputs.
        train_anomalies_path = os.path.join(validation_output.uri,
                                            'Split-train', 'SchemaDiff.pb')
        eval_anomalies_path = os.path.join(validation_output.uri, 'Split-eval',
                                           'SchemaDiff.pb')
        self.assertTrue(fileio.exists(train_anomalies_path))
        self.assertTrue(fileio.exists(eval_anomalies_path))
        train_anomalies_bytes = io_utils.read_bytes_file(train_anomalies_path)
        train_anomalies = anomalies_pb2.Anomalies()
        train_anomalies.ParseFromString(train_anomalies_bytes)
        eval_anomalies_bytes = io_utils.read_bytes_file(eval_anomalies_path)
        eval_anomalies = anomalies_pb2.Anomalies()
        eval_anomalies.ParseFromString(eval_anomalies_bytes)
        self.assertEqual(0, len(train_anomalies.anomaly_info))
        self.assertEqual(0, len(eval_anomalies.anomaly_info))

        # Assert 'test' split is excluded.
        train_file_path = os.path.join(validation_output.uri, 'Split-test',
                                       'SchemaDiff.pb')
        self.assertFalse(fileio.exists(train_file_path))
예제 #3
0
 def display(self, artifact: types.Artifact):
   from IPython.core.display import display  # pylint: disable=g-import-not-at-top
   from IPython.core.display import HTML  # pylint: disable=g-import-not-at-top
   for split in artifact_utils.decode_split_names(artifact.split_names):
     display(HTML('<div><b>%r split:</b></div><br/>' % split))
     anomalies_path = io_utils.get_only_uri_in_dir(
         artifact_utils.get_split_uri([artifact], split))
     if artifact_utils.is_artifact_version_older_than(
         artifact, artifact_utils._ARTIFACT_VERSION_FOR_ANOMALIES_UPDATE):  # pylint: disable=protected-access
       anomalies = tfdv.load_anomalies_text(anomalies_path)
     else:
       anomalies = anomalies_pb2.Anomalies()
       anomalies_bytes = io_utils.read_bytes_file(anomalies_path)
       anomalies.ParseFromString(anomalies_bytes)
     tfdv.display_anomalies(anomalies)
예제 #4
0
 def testReadWriteBytes(self):
     file_path = os.path.join(self._base_dir, 'test_file')
     content = b'testing read/write'
     io_utils.write_bytes_file(file_path, content)
     read_content = io_utils.read_bytes_file(file_path)
     self.assertEqual(content, read_content)