def test_main_with_updates(self): output = os.path.join(self.test_subdirectory, 'output.pbtxt') with flagsaver.flagsaver(input_pattern=self.dataset1_filename, update=True, output=output): process_dataset.main(()) self.assertTrue(os.path.exists(output)) dataset = message_helpers.load_message(output, dataset_pb2.Dataset) self.assertLen(dataset.reactions, 1) self.assertStartsWith(dataset.reactions[0].reaction_id, 'ord-')
def test_bad_dataset_id(self): dataset = dataset_pb2.Dataset(reactions=[reaction_pb2.Reaction()], dataset_id='not-a-real-dataset-id') filename = os.path.join(self.test_subdirectory, 'test.pbtxt') message_helpers.write_message(dataset, filename) with flagsaver.flagsaver(root=self.test_subdirectory, input_pattern=filename, validate=False, update=True): with self.assertRaisesRegex(ValueError, 'malformed dataset ID'): process_dataset.main(())
def test_main_with_updates(self): output = os.path.join(self.test_subdirectory, 'output.pb') with flagsaver.flagsaver(input_pattern=self.dataset1_filename, update=True, output=output): process_dataset.main(()) self.assertTrue(os.path.exists(output)) with open(output, 'rb') as f: dataset = dataset_pb2.Dataset.FromString(f.read()) self.assertLen(dataset.reactions, 1) self.assertStartsWith(dataset.reactions[0].provenance.record_id, 'ord-')
def test_bad_dataset_id(self): dataset = dataset_pb2.Dataset(reactions=[reaction_pb2.Reaction()], dataset_id='not-a-real-dataset-id') filename = os.path.join(self.test_subdirectory, 'test.pb') with open(filename, 'wb') as f: f.write(dataset.SerializeToString()) with flagsaver.flagsaver(root=self.test_subdirectory, input_pattern=filename, validate=False, update=True): with self.assertRaisesRegex(ValueError, 'malformed dataset ID'): process_dataset.main(())
def test_main_with_validation_errors(self): with flagsaver.flagsaver(input_pattern=self.dataset2_filename, write_errors=True): with self.assertRaisesRegex(ValueError, 'validation encountered errors'): process_dataset.main(()) error_filename = f'{self.dataset2_filename}.error' self.assertTrue(os.path.exists(error_filename)) expected_output = [ 'Reactions should have at least 1 reaction input\n', 'Reactions should have at least 1 reaction outcome\n', ] with open(error_filename) as f: self.assertEqual(f.readlines(), expected_output)
def test_preserve_existing_dataset_id(self): dataset = dataset_pb2.Dataset( reactions=[reaction_pb2.Reaction()], dataset_id='64b14868c5cd46dd8e75560fd3589a6b') filename = os.path.join(self.test_subdirectory, 'test.pb') with open(filename, 'wb') as f: f.write(dataset.SerializeToString()) expected_filename = os.path.join( self.test_subdirectory, 'data', '64', '64b14868c5cd46dd8e75560fd3589a6b.pb') self.assertFalse(os.path.exists(expected_filename)) with flagsaver.flagsaver(root=self.test_subdirectory, input_pattern=filename, validate=False, update=True): process_dataset.main(()) self.assertTrue(os.path.exists(expected_filename))
def _run_main(self, **kwargs): subprocess.run(['git', 'add', '*.pbtxt', 'data/*/*.pbtxt'], check=True) changed = subprocess.run(['git', 'diff', '--name-status', '--staged'], check=True, capture_output=True) with open('changed.txt', 'wb') as f: f.write(changed.stdout) subprocess.run(['git', 'commit', '-m', 'Submission'], check=True) run_flags = { 'input_file': 'changed.txt', 'update': True, 'cleanup': True } run_flags.update(kwargs) with flagsaver.flagsaver(**run_flags): process_dataset.main(()) return glob.glob(os.path.join(self.test_subdirectory, '**/*.pbtxt'), recursive=True)
def test_main_with_too_many_flags(self): with flagsaver.flagsaver(input_pattern=self.dataset1_filename, input_file=self.dataset2_filename): with self.assertRaisesRegex(ValueError, 'not both'): process_dataset.main(())
def test_main_with_input_file(self): input_file = os.path.join(self.test_subdirectory, 'input_file.txt') with open(input_file, 'w') as f: f.write(f'{self.dataset1_filename}\n') with flagsaver.flagsaver(input_file=input_file): process_dataset.main(())
def test_main_with_input_pattern(self): with flagsaver.flagsaver(input_pattern=self.dataset1_filename): process_dataset.main(())