예제 #1
0
 def test_make_mol_dict(self):
   """Test generation of molecule dictionaries."""
   mols = parse_sdf_utils.get_sdf_to_mol(self.test_file_short)
   mol_dicts = [
       parse_sdf_utils.make_mol_dict(mol, self.hparams.max_atoms,
                                     self.hparams.max_mass_spec_peak_loc)
       for mol in mols
   ]
   for i in range(len(self.expected_mol_dicts)):
     mol_dict_key_names = [
         fmap_constants.NAME, fmap_constants.INCHIKEY,
         fmap_constants.SMILES, fmap_constants.MOLECULAR_FORMULA
     ]
     for kwarg in mol_dict_key_names:
       self.assertEqual(self.expected_mol_dicts[i][kwarg], mol_dicts[i][kwarg])
     self.assertAlmostEqual(
         self.expected_mol_dicts[i][fmap_constants.MOLECULE_WEIGHT],
         mol_dicts[i][fmap_constants.MOLECULE_WEIGHT])
     self.assertSequenceAlmostEqual(
         self.expected_mol_dicts[i][fmap_constants.ATOM_WEIGHTS],
         mol_dicts[i][fmap_constants.ATOM_WEIGHTS])
     self.assertSequenceAlmostEqual(
         self.expected_mol_dicts[i][fmap_constants.ADJACENCY_MATRIX],
         mol_dicts[i][fmap_constants.ADJACENCY_MATRIX])
     self.assertSequenceAlmostEqual(
         self.expected_mol_dicts[i][fmap_constants.DENSE_MASS_SPEC],
         mol_dicts[i][fmap_constants.DENSE_MASS_SPEC])
예제 #2
0
  def test_record_contents(self):
    """Test the contents of the stored record file to ensure features match."""
    mol_list = parse_sdf_utils.get_sdf_to_mol(self.test_file_long)

    mol_dicts = [parse_sdf_utils.make_mol_dict(mol) for mol in mol_list]
    parsed_smiles_tokens = [
        feature_utils.tokenize_smiles(
            np.array([mol_dict[fmap_constants.SMILES]]))
        for mol_dict in mol_dicts
    ]

    token_lengths = [
        np.shape(token_arr)[0] for token_arr in parsed_smiles_tokens
    ]
    parsed_smiles_tokens = [
        np.pad(token_arr,
               (0, ms_constants.MAX_TOKEN_LIST_LENGTH - token_length),
               'constant')
        for token_arr, token_length in zip(parsed_smiles_tokens, token_lengths)
    ]

    hparams_main = tf.contrib.training.HParams(
        max_atoms=ms_constants.MAX_ATOMS,
        max_mass_spec_peak_loc=ms_constants.MAX_PEAK_LOC,
        eval_batch_size=len(mol_list),
        intensity_power=1.0)

    dataset = parse_sdf_utils.get_dataset_from_record(
        [os.path.join(self.test_data_directory, 'test_14_record.gz')],
        hparams_main,
        mode=tf.estimator.ModeKeys.EVAL)

    feature_names = [
        fmap_constants.ATOM_WEIGHTS,
        fmap_constants.MOLECULE_WEIGHT,
        fmap_constants.DENSE_MASS_SPEC,
        fmap_constants.INCHIKEY, fmap_constants.NAME,
        fmap_constants.MOLECULAR_FORMULA,
        fmap_constants.ADJACENCY_MATRIX,
        fmap_constants.ATOM_IDS, fmap_constants.SMILES
    ]
    for fp_len in ms_constants.NUM_CIRCULAR_FP_BITS_LIST:
      for rad in ms_constants.CIRCULAR_FP_RADII_LIST:
        for fp_type in fmap_constants.FP_TYPE_LIST:
          feature_names.append(
              str(ms_constants.CircularFingerprintKey(fp_type, fp_len, rad)))
    label_names = [fmap_constants.INCHIKEY]

    features, _ = parse_sdf_utils.make_features_and_labels(
        dataset, feature_names, label_names, mode=tf.estimator.ModeKeys.EVAL)

    with tf.Session() as sess:
      feature_values = sess.run(features)

      # Check that the dataset was consumed
      try:
        sess.run(features)
        raise ValueError('Dataset parsing using batch size of length of the'
                         ' dataset resulted in more than one batch.')
      except tf.errors.OutOfRangeError:  # expected behavior
        pass

    for i in range(len(mol_list)):
      self.assertAlmostEqual(
          feature_values[fmap_constants.MOLECULE_WEIGHT][i],
          mol_dicts[i][fmap_constants.MOLECULE_WEIGHT])
      self.assertSequenceAlmostEqual(
          feature_values[fmap_constants.ADJACENCY_MATRIX][i]
          .flatten(),
          mol_dicts[i][fmap_constants.ADJACENCY_MATRIX],
          delta=0.0001)
      self.assertEqual(feature_values[fmap_constants.NAME][i],
                       self.encode(mol_dicts[i][fmap_constants.NAME]))
      self.assertEqual(feature_values[fmap_constants.INCHIKEY][i],
                       self.encode(mol_dicts[i][fmap_constants.INCHIKEY]))
      self.assertEqual(
          feature_values[fmap_constants.MOLECULAR_FORMULA][i],
          self.encode(mol_dicts[i][fmap_constants.MOLECULAR_FORMULA]))
      self.assertSequenceAlmostEqual(
          feature_values[fmap_constants.DENSE_MASS_SPEC][i],
          mol_dicts[i][fmap_constants.DENSE_MASS_SPEC],
          delta=0.0001)
      self.assertSequenceAlmostEqual(
          feature_values[fmap_constants.ATOM_WEIGHTS][i],
          mol_dicts[i][fmap_constants.ATOM_WEIGHTS],
          delta=0.0001)
      self.assertSequenceAlmostEqual(
          feature_values[fmap_constants.ATOM_IDS][i],
          mol_dicts[i][fmap_constants.ATOM_IDS],
          delta=0.0001)
      self.assertAllEqual(feature_values[fmap_constants.SMILES][i],
                          parsed_smiles_tokens[i])
      self.assertAllEqual(
          feature_values[fmap_constants.SMILES_TOKEN_LIST_LENGTH][i],
          token_lengths[i])
      for fp_len in ms_constants.NUM_CIRCULAR_FP_BITS_LIST:
        for rad in ms_constants.CIRCULAR_FP_RADII_LIST:
          for fp_type in fmap_constants.FP_TYPE_LIST:
            fp_key = ms_constants.CircularFingerprintKey(fp_type, fp_len, rad)
            self.assertSequenceAlmostEqual(
                feature_values[str(fp_key)][i],
                mol_dicts[i][fp_key],
                delta=0.0001)