def test_data_prep_beam_params(self, tfds, input_format): if tfds: flags.FLAGS.tfds_dataset = 'savee' else: flags.FLAGS.train_input_glob = os.path.join( absltest.get_default_test_srcdir(), TESTDIR, 'test.tfrecord*') flags.FLAGS.validation_input_glob = os.path.join( absltest.get_default_test_srcdir(), TESTDIR, 'test.tfrecord*') flags.FLAGS.test_input_glob = os.path.join( absltest.get_default_test_srcdir(), TESTDIR, 'test.tfrecord*') flags.FLAGS.skip_existing_error = False flags.FLAGS.output_filename = os.path.join( absltest.get_default_test_tmpdir(), f'data_prep_test_{tfds}') flags.FLAGS.embedding_modules = ['mod1', 'mod2'] flags.FLAGS.embedding_names = ['emb1', 'emb2'] flags.FLAGS.module_output_keys = ['k1', 'k2'] prep_params, input_filenames_list, output_filenames, run_data_prep = data_prep_and_eval_beam_main._get_data_prep_params_from_flags( ) self.assertTrue(run_data_prep) self.assertLen(input_filenames_list, 3) self.assertLen(output_filenames, 3) self.assertTrue(output_filenames[0].endswith( f'{flags.FLAGS.output_filename}.train'), output_filenames[0]) self.assertTrue(output_filenames[1].endswith( f'{flags.FLAGS.output_filename}.validation'), output_filenames[1]) self.assertTrue(output_filenames[2].endswith( f'{flags.FLAGS.output_filename}.test'), output_filenames[2]) self.assertIsInstance(prep_params, dict)
def test_read_flags_and_create_pipeline(self, data_prep_behavior): """Test that the read-from-flags and pipeline creation are synced.""" FLAGS.input_glob = os.path.join(absltest.get_default_test_srcdir(), TEST_DIR, '*') FLAGS.output_filename = os.path.join( absltest.get_default_test_tmpdir(), f'{data_prep_behavior}.tfrecord') FLAGS.data_prep_behavior = data_prep_behavior FLAGS.embedding_names = ['em1', 'em2'] FLAGS.embedding_modules = ['dummy_mod_loc'] FLAGS.module_output_keys = ['k1', 'k2'] FLAGS.sample_rate = 5 FLAGS.audio_key = 'audio_key' FLAGS.label_key = 'label_key' input_filenames_list, output_filenames, beam_params = audio_to_embeddings_beam_utils.get_beam_params_from_flags( ) # Use the defaults, unless we are using TFLite models. self.assertNotIn('module_call_fn', beam_params) self.assertNotIn('setup_fn', beam_params) # Check that the arguments run through. audio_to_embeddings_beam_utils.data_prep_pipeline( root=beam.Pipeline(), input_filenames_or_glob=input_filenames_list[0], output_filename=output_filenames[0], data_prep_behavior=FLAGS.data_prep_behavior, beam_params=beam_params, suffix='s')
def setUp(self): super(BundleToSeqexTest, self).setUp() with open( os.path.join(absltest.get_default_test_srcdir(), _VERSION_CONFIG_PATH)) as f: self._version_config = text_format.Parse( f.read(), version_config_pb2.VersionConfig())
def test_overwrite_b_factors(self): """tbd.""" testdir = os.path.join( absltest.get_default_test_srcdir(), 'alphafold/relax/testdata/' 'multiple_disulfides_target.pdb') with open(testdir) as f: test_pdb = f.read() n_residues = 191 bfactors = np.stack([np.arange(0, n_residues)] * 37, axis=-1) output_pdb = utils.overwrite_b_factors(test_pdb, bfactors) # Check that the atom lines are unchanged apart from the B-factors. atom_lines_original = [l for l in test_pdb.split('\n') if l[:4] == ('ATOM')] atom_lines_new = [l for l in output_pdb.split('\n') if l[:4] == ('ATOM')] for line_original, line_new in zip(atom_lines_original, atom_lines_new): self.assertEqual(line_original[:60].strip(), line_new[:60].strip()) self.assertEqual(line_original[66:].strip(), line_new[66:].strip()) # Check B-factors are correctly set for all atoms present. as_protein = protein.from_pdb_string(output_pdb) np.testing.assert_almost_equal( np.where(as_protein.atom_mask > 0, as_protein.b_factors, 0), np.where(as_protein.atom_mask > 0, bfactors, 0))
def testConfig(self, config_path): tmp_dir = self.create_tempdir() config_params = train_eval.get_gin_override_params(tmp_dir) test_srcdir = absltest.get_default_test_srcdir() config_path = os.path.join(test_srcdir, config_path) gin.parse_config_files_and_bindings([config_path], config_params) train_eval.train_eval_loop()
def setUp(self): self._test_data_dir = os.path.join(absltest.get_default_test_srcdir(), _TESTDATA_PATH) self._enc = resources_pb2.Encounter() with open(os.path.join(self._test_data_dir, 'encounter_1.pbtxt')) as f: text_format.Parse(f.read(), self._enc) self._bundle = resources_pb2.Bundle() self._bundle.entry.add().resource.encounter.CopyFrom(self._enc)
def test_file_equivalence_after_rewrite(self): wrong_modules = [] for module in FLAGS.modules_to_check: static_file = os.path.join( absltest.get_default_test_srcdir(), '{}/internal/backend/numpy/gen/{}.py'.format( TFP_PYTHON_DIR, module)) gen_file = os.path.join( absltest.get_default_test_srcdir(), '{}/internal/backend/numpy/{}_gen.py'.format( TFP_PYTHON_DIR, module)) try: with open(static_file, 'r') as f: static_content = f.read() except IOError: static_content = None try: with open(gen_file, 'r') as f: gen_content = f.read() except IOError: gen_content = None if gen_content is None and static_content is None: raise ValueError( 'Could not load content for {}'.format(static_file)) if gen_content != static_content: if FLAGS.update: to_update = static_file.split('runfiles/')[-1] to_update = '/'.join(to_update.split('/')[1:]) with open(to_update, 'w') as f: f.write(gen_content) logging.info('Updating file %s', to_update) else: wrong_modules.append(module) if wrong_modules: msg = '\n'.join([ 'Modules `{}` require updates. To update them, run'.format( repr(wrong_modules)), 'bazel build -c opt :rewrite_equivalence_test', 'bazel-py3/bin/.../rewrite_equivalence_test --update ' '--modules_to_check={}'.format(','.join(wrong_modules)), 'It may be necessary to adapt the generator programs.' ]) raise AssertionError(msg)
def test_from_pdb_str(self, pdb_file, chain_id, num_res): pdb_file = os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, pdb_file) with open(pdb_file) as f: pdb_string = f.read() prot = protein.from_pdb_string(pdb_string, chain_id) self._check_shapes(prot, num_res) self.assertGreaterEqual(prot.aatype.min(), 0) # Allow equal since unknown restypes have index equal to restype_num. self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num)
def setUp(self): super(TestInputFn, self).setUp() directory = os.path.join(absltest.get_default_test_srcdir(), TESTDATA_DIR) seqex_list = [ self.read_seqex_ascii(filename, directory) for filename in ['example1.ascii', 'example2.ascii'] ] self.input_data_dir = self.create_input_tfrecord( seqex_list, tempfile.mkdtemp(), 'input') self.log_dir = tempfile.mkdtemp()
def setUp(self): super().setUp() self.test_dir = os.path.join( absltest.get_default_test_srcdir(), 'alphafold/relax/testdata/') self.test_config = { 'max_iterations': 1, 'tolerance': 2.39, 'stiffness': 10.0, 'exclude_residues': [], 'max_outer_iterations': 1}
def setUp(self): super(ModelTest, self).setUp() filedir = os.path.join(absltest.get_default_test_srcdir(), TESTDATA_DIR) seqex_list = [ test_utils.read_seqex_ascii(filename, filedir) for filename in ['example1.ascii', 'example2.ascii'] ] self.input_data_dir = test_utils.create_input_tfrecord( seqex_list, tempfile.mkdtemp(), 'input') self.log_dir = tempfile.mkdtemp()
def setUp(self): self._test_data_dir = os.path.join(absltest.get_default_test_srcdir(), _TESTDATA_PATH) self._enc = resources_pb2.Encounter() with open(os.path.join(self._test_data_dir, 'encounter_1.pbtxt')) as f: text_format.Parse(f.read(), self._enc) self._patient = resources_pb2.Patient() self._patient.id.value = 'Patient/1' self._expected_label = google_extensions_pb2.EventLabel() with open(os.path.join(self._test_data_dir, 'label_1.pbtxt')) as f: text_format.Parse(f.read(), self._expected_label)
def test_saved_model_from_disk(self): test_srcdir = absltest.get_default_test_srcdir() relative_testdata_path = os.path.join( TFP_PYTHON_DIR, 'internal/testdata/auto_composite_tensor') absolute_testdata_path = os.path.join(test_srcdir, relative_testdata_path) m = tf.saved_model.load(absolute_testdata_path) self.evaluate(m.scale.initializer) b = tfb.Scale([5., 9.], validate_args=True) self.assertAllClose(self.evaluate(m.make_bij(b).forward(2.)), [10., 20.]) self.evaluate(m.scale.assign(m.scale + [1., 2.])) self.assertAllClose(self.evaluate(m.make_bij(b).forward(2.)), [12., 24.])
def test_full_flow(self, _): flags.FLAGS.xids = ['12321'] flags.FLAGS.base_experiment_dir = os.path.join( absltest.get_default_test_srcdir(), TESTDIR) flags.FLAGS.output_dir = os.path.join( absltest.get_default_test_tmpdir(), 'dummy_out') # Frontend args. flags.FLAGS.frame_hop = 5 flags.FLAGS.frame_width = 5 flags.FLAGS.num_mel_bins = 80 flags.FLAGS.n_required = 8000 model_conversion_beam_main.main(None)
def setUp(self): super(ExperimentTest, self).setUp() self.seqex_list = [ test_utils.read_seqex_ascii( filename, os.path.join(absltest.get_default_test_srcdir(), TESTDATA_DIR)) for filename in ['example1.ascii', 'example2.ascii'] ] self.input_data_dir = tempfile.mkdtemp() test_utils.create_input_tfrecord(self.seqex_list, self.input_data_dir, 'train') test_utils.create_input_tfrecord(self.seqex_list, self.input_data_dir, 'validation') self.log_dir = tempfile.mkdtemp()
def testParseConfigFromPythonSystemPath(self): """Load config only found in python system path.""" test_srcdir = absltest.get_default_test_srcdir() relative_testdata_path = 'gin/testdata' absolute_testdata_path = os.path.join(test_srcdir, relative_testdata_path, 'fake_package') sys.path.append(absolute_testdata_path) config_file = ('fake_gin_package/config/foo.gin') result = config.parse_config_file(config_file, print_includes_and_imports=True, skip_unknown=True) sys.path.remove(absolute_testdata_path) self.assertEqual(result.includes[0].filename, 'fake_gin_package/parent.gin')
def testTask(self, task): data_dir = os.path.join(absltest.get_default_test_srcdir(), DATA_DIR) suffix = 'json' if task in GROUP2TASK['qa'] else 'tsv' score = collections.defaultdict(dict) for lg in TASK2LANGS[task]: pred_file = os.path.join(data_dir, 'predictions', task, f'test-{lg}.{suffix}') label_file = os.path.join(data_dir, 'labels', task, f'test-{lg}.{suffix}') score_lg = evaluate_one_task(pred_file, label_file, task, language=lg) for metric in score_lg: score[metric][lg] = score_lg[metric] avg_score = {} for m in score: avg_score[f'avg_{m}'] = sum(score[m].values()) / len(score[m]) self.assertEqual(avg_score, TASK2AVG_SCORES[task])
def setUp(self): self._test_data_dir = os.path.join(absltest.get_default_test_srcdir(), _TESTDATA_PATH) self._bundle = resources_pb2.Bundle() with open(os.path.join(self._test_data_dir, 'bundle_1.pbtxt')) as f: text_format.Parse(f.read(), self._bundle) enc = self._bundle.entry[0].resource.encounter patient = self._bundle.entry[1].resource.patient self._expected_label = label.ComposeLabel( patient, enc, label.LOS_RANGE_LABEL, 'above_14', # 24 hours after admission datetime.datetime(2009, 2, 14, 23, 31, 30))
def test_ideal_atom_mask(self): with open( os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, '2rbg.pdb')) as f: pdb_string = f.read() prot = protein.from_pdb_string(pdb_string, chain_id='A') ideal_mask = protein.ideal_atom_mask(prot) non_ideal_residues = set([102] + list(range(127, 285))) for i, (res, atom_mask) in enumerate(zip(prot.residue_index, prot.atom_mask)): if res in non_ideal_residues: self.assertFalse(np.all(atom_mask == ideal_mask[i]), msg=f'{res}') else: self.assertTrue(np.all(atom_mask == ideal_mask[i]), msg=f'{res}')
def test_to_pdb(self): with open( os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, '2rbg.pdb')) as f: pdb_string = f.read() prot = protein.from_pdb_string(pdb_string, chain_id='A') pdb_string_reconstr = protein.to_pdb(prot) prot_reconstr = protein.from_pdb_string(pdb_string_reconstr) np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype) np.testing.assert_array_almost_equal(prot_reconstr.atom_positions, prot.atom_positions) np.testing.assert_array_almost_equal(prot_reconstr.atom_mask, prot.atom_mask) np.testing.assert_array_equal(prot_reconstr.residue_index, prot.residue_index) np.testing.assert_array_almost_equal(prot_reconstr.b_factors, prot.b_factors)
def setUp(self, **kwargs): super(Bit9TestCase, self).setUp(**kwargs) # Set up a fake Bit9ApiAuth entity in Datastore. os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.join( absltest.get_default_test_srcdir(), 'upvote/gae/lib/bit9', 'fake_credentials.json') self.Patch(bit9_utils.bit9.kms_ndb.EncryptedBlobProperty, '_Encrypt', return_value='blah') self.Patch(bit9_utils.bit9.kms_ndb.EncryptedBlobProperty, '_Decrypt', return_value='blah') bit9_models.Bit9ApiAuth.SetInstance(api_key='blah') self.mock_ctx = mock.Mock(spec=bit9_utils.api.Context) self.Patch(bit9_utils.api, 'Context', return_value=self.mock_ctx)
def setUp(self): self._test_data_dir = os.path.join(absltest.get_default_test_srcdir(), _TESTDATA_PATH) enc = resources_pb2.Encounter() with open(os.path.join(self._test_data_dir, 'encounter_1.pbtxt')) as f: text_format.Parse(f.read(), enc) patient = resources_pb2.Patient() patient.id.value = 'Patient/1' self._bundle = resources_pb2.Bundle() self._bundle.entry.add().resource.encounter.CopyFrom(enc) self._bundle.entry.add().resource.patient.CopyFrom(patient) self._expected_label = label.ComposeLabel( patient, enc, label.LOS_RANGE_LABEL, 'above_14', # 24 hours after admission datetime.datetime(2009, 2, 14, 23, 31, 30))
def test_validate_inputs(self, input_glob): file_glob = os.path.join(absltest.get_default_test_srcdir(), TEST_DIR, '*') if input_glob: input_filenames_list = [[file_glob]] else: filenames = tf.io.gfile.glob(file_glob) input_filenames_list = [filenames] output_filenames = [ os.path.join(absltest.get_default_test_tmpdir(), 'fake1')] embedding_modules = ['m1', 'm2'] embedding_names = ['n1', 'n2'] module_output_keys = ['k1', 'k2'] # Check that inputs and flags are formatted correctly. audio_to_embeddings_beam_utils.validate_inputs( input_filenames_list=input_filenames_list, output_filenames=output_filenames, embedding_modules=embedding_modules, embedding_names=embedding_names, module_output_keys=module_output_keys)
def test_tflite_inference(self, feature_inputs): test_dir = 'non_semantic_speech_benchmark/data_prep/testdata' if feature_inputs: test_file = 'model1_woutfrontend.tflite' else: test_file = 'model1_wfrontend.tflite' tflite_model_path = os.path.join(absltest.get_default_test_srcdir(), test_dir, test_file) output_key = '0' interpreter = audio_to_embeddings_beam_utils._build_tflite_interpreter( tflite_model_path=tflite_model_path) model_input = np.zeros([32000], dtype=np.float32) sample_rate = 16000 if feature_inputs: model_input = audio_to_embeddings_beam_utils._default_feature_fn( model_input, sample_rate) audio_to_embeddings_beam_utils._samples_to_embedding_tflite( model_input, sample_rate, interpreter, output_key)
def test_full_flow(self, include_frontend): flags.FLAGS.experiment_dir = os.path.join( absltest.get_default_test_srcdir(), TESTDIR) flags.FLAGS.checkpoint_number = '1000' flags.FLAGS.output_dir = absltest.get_default_test_tmpdir() flags.FLAGS.include_frontend = include_frontend tflite_conversion.main(None) tflite_model = os.path.join(flags.FLAGS.output_dir, 'model_1.tflite') self.assertTrue(tf.io.gfile.exists(tflite_model)) # Check that input signature is as expected. with tf.io.gfile.GFile(tflite_model, 'rb') as model_file: model_content = model_file.read() interpreter = tf.lite.Interpreter(model_content=model_content) interpreter.allocate_tensors() expected_input_shape = (1, 1) if include_frontend else (1, 96, 64, 1) np.testing.assert_array_equal( interpreter.get_input_details()[0]['shape'], expected_input_shape)
def test_tflite_inference(self, feature_inputs): if feature_inputs: test_file = 'model1_woutfrontend.tflite' else: test_file = 'model1_wfrontend.tflite' tflite_model_path = os.path.join(absltest.get_default_test_srcdir(), TEST_DIR, test_file) output_key = '0' interpreter = data_prep_utils.build_tflite_interpreter( tflite_model_path=tflite_model_path) model_input = np.zeros([32000], dtype=np.float32) sample_rate = 16000 if feature_inputs: model_input = data_prep_utils.default_feature_fn( model_input, sample_rate) else: model_input = np.expand_dims(model_input, axis=0) data_prep_utils.samples_to_embedding_tflite( model_input, sample_rate, interpreter, output_key, 'name')
def test_restore_text(self): testdata_dir = (pathlib.Path(absltest.get_default_test_srcdir()) / TESTDATA_DIR) input_file = testdata_dir / "example_mentions_no_text.jsonl" expected_file = testdata_dir / "example_mentions.jsonl" output_dir = self.create_tempdir() output_file = pathlib.Path(output_dir.full_path) / "restored.jsonl" with flagsaver.flagsaver(input=str(input_file), index_dir=str(testdata_dir), output=str(output_file)): restore_text.main([]) self.assertTrue(output_file.exists(), msg=str(output_file)) # Compare the dataclass representations of the output and expected files # rather than their string contents to be robust against fluctuations in the # serialization (e.g. due to dictionary order, etc). got = schema.load_jsonl(output_file, schema.ContextualMentions) expected = schema.load_jsonl(expected_file, schema.ContextualMentions) self.assertEqual(got, expected)
def test_main_flow(self, mock_ds_dict): flags.FLAGS.output_file = os.path.join( absltest.get_default_test_tmpdir(), 'dummy_out.txt') dummy_fn = os.path.join(absltest.get_default_test_srcdir(), TESTDIR, 'test.tfrecord') mock_ds_dict.return_value = {'dummy': ([[dummy_fn]], 'tfrecord')} # Run the beam pipeline, which writes to the output. count_duration_beam.main(None) # Do a check on the output file. ret = tf.io.gfile.glob(f'{flags.FLAGS.output_file}*') self.assertLen(ret, 1) out_file = ret[0] with tf.io.gfile.GFile(out_file) as f: lines = f.read().split('\n')[:-1] outs = [l.split(',') for l in lines] self.assertLen(outs, 1) savee_out = outs[0] self.assertEqual(savee_out[0], 'dummy') self.assertGreater(float(savee_out[1]), 0) self.assertEqual(int(savee_out[2]), 2)
def setUp(self): super(CommitBlockableChangeSetTest, self).setUp() # Set up a fake Bit9ApiAuth entity in Datastore. os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.join( absltest.get_default_test_srcdir(), 'upvote/gae/modules/bit9_api', 'fake_credentials.json') self.Patch(change_set.bit9_utils.bit9.kms_ndb.EncryptedBlobProperty, '_Encrypt', return_value='blah') self.Patch(change_set.bit9_utils.bit9.kms_ndb.EncryptedBlobProperty, '_Decrypt', return_value='blah') bit9.Bit9ApiAuth.SetInstance(api_key='blah') self.mock_ctx = mock.Mock(spec=change_set.bit9_utils.api.Context) self.Patch(change_set.bit9_utils.api, 'Context', return_value=self.mock_ctx) self.binary = test_utils.CreateBit9Binary(file_catalog_id='1234') self.local_rule = test_utils.CreateBit9Rule(self.binary.key, host_id='5678') self.global_rule = test_utils.CreateBit9Rule(self.binary.key)
def _get_test_data_path(): return os.path.join(absltest.get_default_test_srcdir(), _TEST_FILE_PATH)