Esempio n. 1
0
  def test_data_prep_beam_params(self, tfds, input_format):
    if tfds:
      flags.FLAGS.tfds_dataset = 'savee'
    else:
      flags.FLAGS.train_input_glob = os.path.join(
          absltest.get_default_test_srcdir(), TESTDIR, 'test.tfrecord*')
      flags.FLAGS.validation_input_glob = os.path.join(
          absltest.get_default_test_srcdir(), TESTDIR, 'test.tfrecord*')
      flags.FLAGS.test_input_glob = os.path.join(
          absltest.get_default_test_srcdir(), TESTDIR, 'test.tfrecord*')
    flags.FLAGS.skip_existing_error = False
    flags.FLAGS.output_filename = os.path.join(
        absltest.get_default_test_tmpdir(), f'data_prep_test_{tfds}')

    flags.FLAGS.embedding_modules = ['mod1', 'mod2']
    flags.FLAGS.embedding_names = ['emb1', 'emb2']
    flags.FLAGS.module_output_keys = ['k1', 'k2']
    prep_params, input_filenames_list, output_filenames, run_data_prep = data_prep_and_eval_beam_main._get_data_prep_params_from_flags(
    )
    self.assertTrue(run_data_prep)
    self.assertLen(input_filenames_list, 3)
    self.assertLen(output_filenames, 3)
    self.assertTrue(output_filenames[0].endswith(
        f'{flags.FLAGS.output_filename}.train'), output_filenames[0])
    self.assertTrue(output_filenames[1].endswith(
        f'{flags.FLAGS.output_filename}.validation'), output_filenames[1])
    self.assertTrue(output_filenames[2].endswith(
        f'{flags.FLAGS.output_filename}.test'), output_filenames[2])
    self.assertIsInstance(prep_params, dict)
Esempio n. 2
0
    def test_read_flags_and_create_pipeline(self, data_prep_behavior):
        """Test that the read-from-flags and pipeline creation are synced."""
        FLAGS.input_glob = os.path.join(absltest.get_default_test_srcdir(),
                                        TEST_DIR, '*')
        FLAGS.output_filename = os.path.join(
            absltest.get_default_test_tmpdir(),
            f'{data_prep_behavior}.tfrecord')
        FLAGS.data_prep_behavior = data_prep_behavior
        FLAGS.embedding_names = ['em1', 'em2']
        FLAGS.embedding_modules = ['dummy_mod_loc']
        FLAGS.module_output_keys = ['k1', 'k2']
        FLAGS.sample_rate = 5
        FLAGS.audio_key = 'audio_key'
        FLAGS.label_key = 'label_key'
        input_filenames_list, output_filenames, beam_params = audio_to_embeddings_beam_utils.get_beam_params_from_flags(
        )
        # Use the defaults, unless we are using TFLite models.
        self.assertNotIn('module_call_fn', beam_params)
        self.assertNotIn('setup_fn', beam_params)

        # Check that the arguments run through.
        audio_to_embeddings_beam_utils.data_prep_pipeline(
            root=beam.Pipeline(),
            input_filenames_or_glob=input_filenames_list[0],
            output_filename=output_filenames[0],
            data_prep_behavior=FLAGS.data_prep_behavior,
            beam_params=beam_params,
            suffix='s')
Esempio n. 3
0
 def setUp(self):
   super(BundleToSeqexTest, self).setUp()
   with open(
       os.path.join(absltest.get_default_test_srcdir(),
                    _VERSION_CONFIG_PATH)) as f:
     self._version_config = text_format.Parse(
         f.read(), version_config_pb2.VersionConfig())
Esempio n. 4
0
  def test_overwrite_b_factors(self):
    """tbd."""
    testdir = os.path.join(
        absltest.get_default_test_srcdir(),
        'alphafold/relax/testdata/'
        'multiple_disulfides_target.pdb')
    with open(testdir) as f:
      test_pdb = f.read()
    n_residues = 191
    bfactors = np.stack([np.arange(0, n_residues)] * 37, axis=-1)

    output_pdb = utils.overwrite_b_factors(test_pdb, bfactors)

    # Check that the atom lines are unchanged apart from the B-factors.
    atom_lines_original = [l for l in test_pdb.split('\n') if l[:4] == ('ATOM')]
    atom_lines_new = [l for l in output_pdb.split('\n') if l[:4] == ('ATOM')]
    for line_original, line_new in zip(atom_lines_original, atom_lines_new):
      self.assertEqual(line_original[:60].strip(), line_new[:60].strip())
      self.assertEqual(line_original[66:].strip(), line_new[66:].strip())

    # Check B-factors are correctly set for all atoms present.
    as_protein = protein.from_pdb_string(output_pdb)
    np.testing.assert_almost_equal(
        np.where(as_protein.atom_mask > 0, as_protein.b_factors, 0),
        np.where(as_protein.atom_mask > 0, bfactors, 0))
Esempio n. 5
0
 def testConfig(self, config_path):
     tmp_dir = self.create_tempdir()
     config_params = train_eval.get_gin_override_params(tmp_dir)
     test_srcdir = absltest.get_default_test_srcdir()
     config_path = os.path.join(test_srcdir, config_path)
     gin.parse_config_files_and_bindings([config_path], config_params)
     train_eval.train_eval_loop()
Esempio n. 6
0
 def setUp(self):
     self._test_data_dir = os.path.join(absltest.get_default_test_srcdir(),
                                        _TESTDATA_PATH)
     self._enc = resources_pb2.Encounter()
     with open(os.path.join(self._test_data_dir, 'encounter_1.pbtxt')) as f:
         text_format.Parse(f.read(), self._enc)
     self._bundle = resources_pb2.Bundle()
     self._bundle.entry.add().resource.encounter.CopyFrom(self._enc)
 def test_file_equivalence_after_rewrite(self):
     wrong_modules = []
     for module in FLAGS.modules_to_check:
         static_file = os.path.join(
             absltest.get_default_test_srcdir(),
             '{}/internal/backend/numpy/gen/{}.py'.format(
                 TFP_PYTHON_DIR, module))
         gen_file = os.path.join(
             absltest.get_default_test_srcdir(),
             '{}/internal/backend/numpy/{}_gen.py'.format(
                 TFP_PYTHON_DIR, module))
         try:
             with open(static_file, 'r') as f:
                 static_content = f.read()
         except IOError:
             static_content = None
         try:
             with open(gen_file, 'r') as f:
                 gen_content = f.read()
         except IOError:
             gen_content = None
         if gen_content is None and static_content is None:
             raise ValueError(
                 'Could not load content for {}'.format(static_file))
         if gen_content != static_content:
             if FLAGS.update:
                 to_update = static_file.split('runfiles/')[-1]
                 to_update = '/'.join(to_update.split('/')[1:])
                 with open(to_update, 'w') as f:
                     f.write(gen_content)
                 logging.info('Updating file %s', to_update)
             else:
                 wrong_modules.append(module)
     if wrong_modules:
         msg = '\n'.join([
             'Modules `{}` require updates.  To update them, run'.format(
                 repr(wrong_modules)),
             'bazel build -c opt :rewrite_equivalence_test',
             'bazel-py3/bin/.../rewrite_equivalence_test --update '
             '--modules_to_check={}'.format(','.join(wrong_modules)),
             'It may be necessary to adapt the generator programs.'
         ])
         raise AssertionError(msg)
Esempio n. 8
0
 def test_from_pdb_str(self, pdb_file, chain_id, num_res):
     pdb_file = os.path.join(absltest.get_default_test_srcdir(),
                             TEST_DATA_DIR, pdb_file)
     with open(pdb_file) as f:
         pdb_string = f.read()
     prot = protein.from_pdb_string(pdb_string, chain_id)
     self._check_shapes(prot, num_res)
     self.assertGreaterEqual(prot.aatype.min(), 0)
     # Allow equal since unknown restypes have index equal to restype_num.
     self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num)
 def setUp(self):
     super(TestInputFn, self).setUp()
     directory = os.path.join(absltest.get_default_test_srcdir(),
                              TESTDATA_DIR)
     seqex_list = [
         self.read_seqex_ascii(filename, directory)
         for filename in ['example1.ascii', 'example2.ascii']
     ]
     self.input_data_dir = self.create_input_tfrecord(
         seqex_list, tempfile.mkdtemp(), 'input')
     self.log_dir = tempfile.mkdtemp()
Esempio n. 10
0
 def setUp(self):
   super().setUp()
   self.test_dir = os.path.join(
       absltest.get_default_test_srcdir(),
       'alphafold/relax/testdata/')
   self.test_config = {
       'max_iterations': 1,
       'tolerance': 2.39,
       'stiffness': 10.0,
       'exclude_residues': [],
       'max_outer_iterations': 1}
Esempio n. 11
0
 def setUp(self):
     super(ModelTest, self).setUp()
     filedir = os.path.join(absltest.get_default_test_srcdir(),
                            TESTDATA_DIR)
     seqex_list = [
         test_utils.read_seqex_ascii(filename, filedir)
         for filename in ['example1.ascii', 'example2.ascii']
     ]
     self.input_data_dir = test_utils.create_input_tfrecord(
         seqex_list, tempfile.mkdtemp(), 'input')
     self.log_dir = tempfile.mkdtemp()
Esempio n. 12
0
    def setUp(self):
        self._test_data_dir = os.path.join(absltest.get_default_test_srcdir(),
                                           _TESTDATA_PATH)
        self._enc = resources_pb2.Encounter()
        with open(os.path.join(self._test_data_dir, 'encounter_1.pbtxt')) as f:
            text_format.Parse(f.read(), self._enc)
        self._patient = resources_pb2.Patient()
        self._patient.id.value = 'Patient/1'

        self._expected_label = google_extensions_pb2.EventLabel()
        with open(os.path.join(self._test_data_dir, 'label_1.pbtxt')) as f:
            text_format.Parse(f.read(), self._expected_label)
  def test_saved_model_from_disk(self):

    test_srcdir = absltest.get_default_test_srcdir()
    relative_testdata_path = os.path.join(
        TFP_PYTHON_DIR, 'internal/testdata/auto_composite_tensor')
    absolute_testdata_path = os.path.join(test_srcdir, relative_testdata_path)

    m = tf.saved_model.load(absolute_testdata_path)
    self.evaluate(m.scale.initializer)
    b = tfb.Scale([5., 9.], validate_args=True)
    self.assertAllClose(self.evaluate(m.make_bij(b).forward(2.)), [10., 20.])
    self.evaluate(m.scale.assign(m.scale + [1., 2.]))
    self.assertAllClose(self.evaluate(m.make_bij(b).forward(2.)), [12., 24.])
    def test_full_flow(self, _):
        flags.FLAGS.xids = ['12321']
        flags.FLAGS.base_experiment_dir = os.path.join(
            absltest.get_default_test_srcdir(), TESTDIR)
        flags.FLAGS.output_dir = os.path.join(
            absltest.get_default_test_tmpdir(), 'dummy_out')

        # Frontend args.
        flags.FLAGS.frame_hop = 5
        flags.FLAGS.frame_width = 5
        flags.FLAGS.num_mel_bins = 80
        flags.FLAGS.n_required = 8000

        model_conversion_beam_main.main(None)
Esempio n. 15
0
 def setUp(self):
     super(ExperimentTest, self).setUp()
     self.seqex_list = [
         test_utils.read_seqex_ascii(
             filename,
             os.path.join(absltest.get_default_test_srcdir(), TESTDATA_DIR))
         for filename in ['example1.ascii', 'example2.ascii']
     ]
     self.input_data_dir = tempfile.mkdtemp()
     test_utils.create_input_tfrecord(self.seqex_list, self.input_data_dir,
                                      'train')
     test_utils.create_input_tfrecord(self.seqex_list, self.input_data_dir,
                                      'validation')
     self.log_dir = tempfile.mkdtemp()
Esempio n. 16
0
 def testParseConfigFromPythonSystemPath(self):
     """Load config only found in python system path."""
     test_srcdir = absltest.get_default_test_srcdir()
     relative_testdata_path = 'gin/testdata'
     absolute_testdata_path = os.path.join(test_srcdir,
                                           relative_testdata_path,
                                           'fake_package')
     sys.path.append(absolute_testdata_path)
     config_file = ('fake_gin_package/config/foo.gin')
     result = config.parse_config_file(config_file,
                                       print_includes_and_imports=True,
                                       skip_unknown=True)
     sys.path.remove(absolute_testdata_path)
     self.assertEqual(result.includes[0].filename,
                      'fake_gin_package/parent.gin')
Esempio n. 17
0
 def testTask(self, task):
   data_dir = os.path.join(absltest.get_default_test_srcdir(), DATA_DIR)
   suffix = 'json' if task in GROUP2TASK['qa'] else 'tsv'
   score = collections.defaultdict(dict)
   for lg in TASK2LANGS[task]:
     pred_file = os.path.join(data_dir, 'predictions', task,
                              f'test-{lg}.{suffix}')
     label_file = os.path.join(data_dir, 'labels', task, f'test-{lg}.{suffix}')
     score_lg = evaluate_one_task(pred_file, label_file, task, language=lg)
     for metric in score_lg:
       score[metric][lg] = score_lg[metric]
   avg_score = {}
   for m in score:
     avg_score[f'avg_{m}'] = sum(score[m].values()) / len(score[m])
   self.assertEqual(avg_score, TASK2AVG_SCORES[task])
Esempio n. 18
0
  def setUp(self):
    self._test_data_dir = os.path.join(absltest.get_default_test_srcdir(),
                                       _TESTDATA_PATH)
    self._bundle = resources_pb2.Bundle()
    with open(os.path.join(self._test_data_dir, 'bundle_1.pbtxt')) as f:
      text_format.Parse(f.read(), self._bundle)
    enc = self._bundle.entry[0].resource.encounter
    patient = self._bundle.entry[1].resource.patient

    self._expected_label = label.ComposeLabel(
        patient, enc,
        label.LOS_RANGE_LABEL,
        'above_14',
        # 24 hours after admission
        datetime.datetime(2009, 2, 14, 23, 31, 30))
Esempio n. 19
0
 def test_ideal_atom_mask(self):
     with open(
             os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
                          '2rbg.pdb')) as f:
         pdb_string = f.read()
     prot = protein.from_pdb_string(pdb_string, chain_id='A')
     ideal_mask = protein.ideal_atom_mask(prot)
     non_ideal_residues = set([102] + list(range(127, 285)))
     for i, (res,
             atom_mask) in enumerate(zip(prot.residue_index,
                                         prot.atom_mask)):
         if res in non_ideal_residues:
             self.assertFalse(np.all(atom_mask == ideal_mask[i]),
                              msg=f'{res}')
         else:
             self.assertTrue(np.all(atom_mask == ideal_mask[i]),
                             msg=f'{res}')
Esempio n. 20
0
    def test_to_pdb(self):
        with open(
                os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
                             '2rbg.pdb')) as f:
            pdb_string = f.read()
        prot = protein.from_pdb_string(pdb_string, chain_id='A')
        pdb_string_reconstr = protein.to_pdb(prot)
        prot_reconstr = protein.from_pdb_string(pdb_string_reconstr)

        np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype)
        np.testing.assert_array_almost_equal(prot_reconstr.atom_positions,
                                             prot.atom_positions)
        np.testing.assert_array_almost_equal(prot_reconstr.atom_mask,
                                             prot.atom_mask)
        np.testing.assert_array_equal(prot_reconstr.residue_index,
                                      prot.residue_index)
        np.testing.assert_array_almost_equal(prot_reconstr.b_factors,
                                             prot.b_factors)
Esempio n. 21
0
    def setUp(self, **kwargs):

        super(Bit9TestCase, self).setUp(**kwargs)

        # Set up a fake Bit9ApiAuth entity in Datastore.
        os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.join(
            absltest.get_default_test_srcdir(), 'upvote/gae/lib/bit9',
            'fake_credentials.json')
        self.Patch(bit9_utils.bit9.kms_ndb.EncryptedBlobProperty,
                   '_Encrypt',
                   return_value='blah')
        self.Patch(bit9_utils.bit9.kms_ndb.EncryptedBlobProperty,
                   '_Decrypt',
                   return_value='blah')
        bit9_models.Bit9ApiAuth.SetInstance(api_key='blah')

        self.mock_ctx = mock.Mock(spec=bit9_utils.api.Context)
        self.Patch(bit9_utils.api, 'Context', return_value=self.mock_ctx)
Esempio n. 22
0
    def setUp(self):
        self._test_data_dir = os.path.join(absltest.get_default_test_srcdir(),
                                           _TESTDATA_PATH)
        enc = resources_pb2.Encounter()
        with open(os.path.join(self._test_data_dir, 'encounter_1.pbtxt')) as f:
            text_format.Parse(f.read(), enc)
        patient = resources_pb2.Patient()
        patient.id.value = 'Patient/1'
        self._bundle = resources_pb2.Bundle()
        self._bundle.entry.add().resource.encounter.CopyFrom(enc)
        self._bundle.entry.add().resource.patient.CopyFrom(patient)

        self._expected_label = label.ComposeLabel(
            patient,
            enc,
            label.LOS_RANGE_LABEL,
            'above_14',
            # 24 hours after admission
            datetime.datetime(2009, 2, 14, 23, 31, 30))
Esempio n. 23
0
 def test_validate_inputs(self, input_glob):
   file_glob = os.path.join(absltest.get_default_test_srcdir(), TEST_DIR, '*')
   if input_glob:
     input_filenames_list = [[file_glob]]
   else:
     filenames = tf.io.gfile.glob(file_glob)
     input_filenames_list = [filenames]
   output_filenames = [
       os.path.join(absltest.get_default_test_tmpdir(), 'fake1')]
   embedding_modules = ['m1', 'm2']
   embedding_names = ['n1', 'n2']
   module_output_keys = ['k1', 'k2']
   # Check that inputs and flags are formatted correctly.
   audio_to_embeddings_beam_utils.validate_inputs(
       input_filenames_list=input_filenames_list,
       output_filenames=output_filenames,
       embedding_modules=embedding_modules,
       embedding_names=embedding_names,
       module_output_keys=module_output_keys)
Esempio n. 24
0
    def test_tflite_inference(self, feature_inputs):
        test_dir = 'non_semantic_speech_benchmark/data_prep/testdata'
        if feature_inputs:
            test_file = 'model1_woutfrontend.tflite'
        else:
            test_file = 'model1_wfrontend.tflite'
        tflite_model_path = os.path.join(absltest.get_default_test_srcdir(),
                                         test_dir, test_file)
        output_key = '0'
        interpreter = audio_to_embeddings_beam_utils._build_tflite_interpreter(
            tflite_model_path=tflite_model_path)

        model_input = np.zeros([32000], dtype=np.float32)
        sample_rate = 16000
        if feature_inputs:
            model_input = audio_to_embeddings_beam_utils._default_feature_fn(
                model_input, sample_rate)

        audio_to_embeddings_beam_utils._samples_to_embedding_tflite(
            model_input, sample_rate, interpreter, output_key)
Esempio n. 25
0
    def test_full_flow(self, include_frontend):
        flags.FLAGS.experiment_dir = os.path.join(
            absltest.get_default_test_srcdir(), TESTDIR)
        flags.FLAGS.checkpoint_number = '1000'
        flags.FLAGS.output_dir = absltest.get_default_test_tmpdir()
        flags.FLAGS.include_frontend = include_frontend

        tflite_conversion.main(None)

        tflite_model = os.path.join(flags.FLAGS.output_dir, 'model_1.tflite')
        self.assertTrue(tf.io.gfile.exists(tflite_model))

        # Check that input signature is as expected.
        with tf.io.gfile.GFile(tflite_model, 'rb') as model_file:
            model_content = model_file.read()
        interpreter = tf.lite.Interpreter(model_content=model_content)
        interpreter.allocate_tensors()

        expected_input_shape = (1, 1) if include_frontend else (1, 96, 64, 1)
        np.testing.assert_array_equal(
            interpreter.get_input_details()[0]['shape'], expected_input_shape)
  def test_tflite_inference(self, feature_inputs):
    if feature_inputs:
      test_file = 'model1_woutfrontend.tflite'
    else:
      test_file = 'model1_wfrontend.tflite'
    tflite_model_path = os.path.join(absltest.get_default_test_srcdir(),
                                     TEST_DIR, test_file)
    output_key = '0'
    interpreter = data_prep_utils.build_tflite_interpreter(
        tflite_model_path=tflite_model_path)

    model_input = np.zeros([32000], dtype=np.float32)
    sample_rate = 16000
    if feature_inputs:
      model_input = data_prep_utils.default_feature_fn(
          model_input, sample_rate)
    else:
      model_input = np.expand_dims(model_input, axis=0)

    data_prep_utils.samples_to_embedding_tflite(
        model_input, sample_rate, interpreter, output_key, 'name')
Esempio n. 27
0
    def test_restore_text(self):
        testdata_dir = (pathlib.Path(absltest.get_default_test_srcdir()) /
                        TESTDATA_DIR)
        input_file = testdata_dir / "example_mentions_no_text.jsonl"
        expected_file = testdata_dir / "example_mentions.jsonl"

        output_dir = self.create_tempdir()
        output_file = pathlib.Path(output_dir.full_path) / "restored.jsonl"
        with flagsaver.flagsaver(input=str(input_file),
                                 index_dir=str(testdata_dir),
                                 output=str(output_file)):
            restore_text.main([])

        self.assertTrue(output_file.exists(), msg=str(output_file))

        # Compare the dataclass representations of the output and expected files
        # rather than their string contents to be robust against fluctuations in the
        # serialization (e.g. due to dictionary order, etc).
        got = schema.load_jsonl(output_file, schema.ContextualMentions)
        expected = schema.load_jsonl(expected_file, schema.ContextualMentions)
        self.assertEqual(got, expected)
Esempio n. 28
0
    def test_main_flow(self, mock_ds_dict):
        flags.FLAGS.output_file = os.path.join(
            absltest.get_default_test_tmpdir(), 'dummy_out.txt')
        dummy_fn = os.path.join(absltest.get_default_test_srcdir(), TESTDIR,
                                'test.tfrecord')
        mock_ds_dict.return_value = {'dummy': ([[dummy_fn]], 'tfrecord')}

        # Run the beam pipeline, which writes to the output.
        count_duration_beam.main(None)

        # Do a check on the output file.
        ret = tf.io.gfile.glob(f'{flags.FLAGS.output_file}*')
        self.assertLen(ret, 1)
        out_file = ret[0]
        with tf.io.gfile.GFile(out_file) as f:
            lines = f.read().split('\n')[:-1]
        outs = [l.split(',') for l in lines]
        self.assertLen(outs, 1)
        savee_out = outs[0]
        self.assertEqual(savee_out[0], 'dummy')
        self.assertGreater(float(savee_out[1]), 0)
        self.assertEqual(int(savee_out[2]), 2)
Esempio n. 29
0
    def setUp(self):
        super(CommitBlockableChangeSetTest, self).setUp()

        # Set up a fake Bit9ApiAuth entity in Datastore.
        os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.join(
            absltest.get_default_test_srcdir(), 'upvote/gae/modules/bit9_api',
            'fake_credentials.json')
        self.Patch(change_set.bit9_utils.bit9.kms_ndb.EncryptedBlobProperty,
                   '_Encrypt',
                   return_value='blah')
        self.Patch(change_set.bit9_utils.bit9.kms_ndb.EncryptedBlobProperty,
                   '_Decrypt',
                   return_value='blah')
        bit9.Bit9ApiAuth.SetInstance(api_key='blah')

        self.mock_ctx = mock.Mock(spec=change_set.bit9_utils.api.Context)
        self.Patch(change_set.bit9_utils.api,
                   'Context',
                   return_value=self.mock_ctx)

        self.binary = test_utils.CreateBit9Binary(file_catalog_id='1234')
        self.local_rule = test_utils.CreateBit9Rule(self.binary.key,
                                                    host_id='5678')
        self.global_rule = test_utils.CreateBit9Rule(self.binary.key)
Esempio n. 30
0
def _get_test_data_path():
    return os.path.join(absltest.get_default_test_srcdir(), _TEST_FILE_PATH)