Beispiel #1
0
    def _input_fn(record_fnames,
                  all_data_in_one_batch,
                  load_training_spectrum_library=False):
        """Reads TFRecord from a list of record file names. Modified by Zhongsheng Chen ([email protected])
        on Sept. 30, 2019. A set function was used to remove repeated keys of fields in features.
        """
        if not record_fnames:
            return None

        record_fnames = [
            os.path.join(data_dir, r_name) for r_name in record_fnames
        ]
        dataset = parse_sdf_utils.get_dataset_from_record(
            record_fnames,
            hparams,
            mode=mode,
            features_to_load=set(features_to_load + hparams.label_names),
            all_data_in_one_batch=all_data_in_one_batch)
        dict_to_return = parse_sdf_utils.make_features_and_labels(
            dataset, features_to_load, hparams.label_names, mode=mode)[0]

        if load_training_spectrum_library:
            library_file = os.path.join(
                '/readahead/128M/',
                filenames[ds_constants.TRAINING_SPECTRA_ARRAY_KEY])
            train_library = parse_sdf_utils.load_training_spectra_array(
                library_file)
            train_library = tf.convert_to_tensor(train_library,
                                                 dtype=tf.float32)

            dict_to_return['SPECTRUM_PREDICTION_LIBRARY'] = train_library

        return dict_to_return
Beispiel #2
0
  def test_dict_tfexample(self):
    """Check if the contents of tf.Records is the same as input molecule info.

       Writes tf.example as tf.record to disk, then reads from disk.
    """
    mol_list = parse_sdf_utils.get_sdf_to_mol(self.test_file_short)

    fd, fpath = tempfile.mkstemp(dir=self.temp_dir)
    os.close(fd)

    parse_sdf_utils.write_dicts_to_example(mol_list, fpath,
                                           self.hparams.max_atoms,
                                           self.hparams.max_mass_spec_peak_loc)
    parse_sdf_utils.write_info_file(mol_list, fpath)
    self._validate_info_file(mol_list, fpath)

    dataset = parse_sdf_utils.get_dataset_from_record(
        [fpath], self.hparams, mode=tf.estimator.ModeKeys.EVAL)

    feature_names = [
        fmap_constants.ATOM_WEIGHTS,
        fmap_constants.MOLECULE_WEIGHT,
        fmap_constants.DENSE_MASS_SPEC,
        fmap_constants.INCHIKEY, fmap_constants.NAME,
        fmap_constants.MOLECULAR_FORMULA,
        fmap_constants.ADJACENCY_MATRIX,
        fmap_constants.ATOM_IDS, fmap_constants.SMILES
    ]
    label_names = [fmap_constants.INCHIKEY]

    features, _ = parse_sdf_utils.make_features_and_labels(
        dataset, feature_names, label_names, mode=tf.estimator.ModeKeys.EVAL)

    with tf.Session() as sess:
      feature_values = sess.run(features)

      # Check that the dataset was consumed
      try:
        sess.run(features)
        raise ValueError('Dataset parsing using batch size of length of the'
                         'dataset resulted in more than one batch.')
      except tf.errors.OutOfRangeError:  # expected behavior
        pass

    for i in range(len(self.expected_mol_dicts)):
      self.assertAlmostEqual(
          feature_values[fmap_constants.MOLECULE_WEIGHT][i],
          self.expected_mol_dicts[i][fmap_constants.MOLECULE_WEIGHT])
      self.assertSequenceAlmostEqual(
          feature_values[fmap_constants.ADJACENCY_MATRIX][i]
          .flatten(),
          self.expected_mol_dicts[i][fmap_constants.ADJACENCY_MATRIX],
          delta=0.0001)
      self.assertSequenceAlmostEqual(
          feature_values[fmap_constants.DENSE_MASS_SPEC][i],
          self.expected_mol_dicts[i][fmap_constants.DENSE_MASS_SPEC],
          delta=0.0001)
      self.assertSequenceAlmostEqual(
          feature_values[fmap_constants.ATOM_WEIGHTS][i],
          self.expected_mol_dicts[i][fmap_constants.ATOM_WEIGHTS],
          delta=0.0001)
      self.assertSequenceAlmostEqual(
          feature_values[fmap_constants.ATOM_IDS][i],
          self.expected_mol_dicts[i][fmap_constants.ATOM_IDS],
          delta=0.0001)
      self.assertEqual(
          feature_values[fmap_constants.NAME][i],
          self.encode(self.expected_mol_dicts[i][fmap_constants.NAME]))
      self.assertEqual(
          feature_values[fmap_constants.INCHIKEY][i],
          self.encode(
              self.expected_mol_dicts[i][fmap_constants.INCHIKEY]))
      self.assertEqual(
          feature_values[fmap_constants.MOLECULAR_FORMULA][i],
          self.encode(
              self.expected_mol_dicts[i][fmap_constants.MOLECULAR_FORMULA]))
      self.assertAllEqual(feature_values[fmap_constants.SMILES][i],
                          self.expected_mol_dicts[i]['parsed_smiles'])
      self.assertAllEqual(
          feature_values[fmap_constants.SMILES_TOKEN_LIST_LENGTH][i],
          self.expected_mol_dicts[i][fmap_constants.SMILES_TOKEN_LIST_LENGTH])
Beispiel #3
0
  def test_record_contents(self):
    """Test the contents of the stored record file to ensure features match."""
    mol_list = parse_sdf_utils.get_sdf_to_mol(self.test_file_long)

    mol_dicts = [parse_sdf_utils.make_mol_dict(mol) for mol in mol_list]
    parsed_smiles_tokens = [
        feature_utils.tokenize_smiles(
            np.array([mol_dict[fmap_constants.SMILES]]))
        for mol_dict in mol_dicts
    ]

    token_lengths = [
        np.shape(token_arr)[0] for token_arr in parsed_smiles_tokens
    ]
    parsed_smiles_tokens = [
        np.pad(token_arr,
               (0, ms_constants.MAX_TOKEN_LIST_LENGTH - token_length),
               'constant')
        for token_arr, token_length in zip(parsed_smiles_tokens, token_lengths)
    ]

    hparams_main = tf.contrib.training.HParams(
        max_atoms=ms_constants.MAX_ATOMS,
        max_mass_spec_peak_loc=ms_constants.MAX_PEAK_LOC,
        eval_batch_size=len(mol_list),
        intensity_power=1.0)

    dataset = parse_sdf_utils.get_dataset_from_record(
        [os.path.join(self.test_data_directory, 'test_14_record.gz')],
        hparams_main,
        mode=tf.estimator.ModeKeys.EVAL)

    feature_names = [
        fmap_constants.ATOM_WEIGHTS,
        fmap_constants.MOLECULE_WEIGHT,
        fmap_constants.DENSE_MASS_SPEC,
        fmap_constants.INCHIKEY, fmap_constants.NAME,
        fmap_constants.MOLECULAR_FORMULA,
        fmap_constants.ADJACENCY_MATRIX,
        fmap_constants.ATOM_IDS, fmap_constants.SMILES
    ]
    for fp_len in ms_constants.NUM_CIRCULAR_FP_BITS_LIST:
      for rad in ms_constants.CIRCULAR_FP_RADII_LIST:
        for fp_type in fmap_constants.FP_TYPE_LIST:
          feature_names.append(
              str(ms_constants.CircularFingerprintKey(fp_type, fp_len, rad)))
    label_names = [fmap_constants.INCHIKEY]

    features, _ = parse_sdf_utils.make_features_and_labels(
        dataset, feature_names, label_names, mode=tf.estimator.ModeKeys.EVAL)

    with tf.Session() as sess:
      feature_values = sess.run(features)

      # Check that the dataset was consumed
      try:
        sess.run(features)
        raise ValueError('Dataset parsing using batch size of length of the'
                         ' dataset resulted in more than one batch.')
      except tf.errors.OutOfRangeError:  # expected behavior
        pass

    for i in range(len(mol_list)):
      self.assertAlmostEqual(
          feature_values[fmap_constants.MOLECULE_WEIGHT][i],
          mol_dicts[i][fmap_constants.MOLECULE_WEIGHT])
      self.assertSequenceAlmostEqual(
          feature_values[fmap_constants.ADJACENCY_MATRIX][i]
          .flatten(),
          mol_dicts[i][fmap_constants.ADJACENCY_MATRIX],
          delta=0.0001)
      self.assertEqual(feature_values[fmap_constants.NAME][i],
                       self.encode(mol_dicts[i][fmap_constants.NAME]))
      self.assertEqual(feature_values[fmap_constants.INCHIKEY][i],
                       self.encode(mol_dicts[i][fmap_constants.INCHIKEY]))
      self.assertEqual(
          feature_values[fmap_constants.MOLECULAR_FORMULA][i],
          self.encode(mol_dicts[i][fmap_constants.MOLECULAR_FORMULA]))
      self.assertSequenceAlmostEqual(
          feature_values[fmap_constants.DENSE_MASS_SPEC][i],
          mol_dicts[i][fmap_constants.DENSE_MASS_SPEC],
          delta=0.0001)
      self.assertSequenceAlmostEqual(
          feature_values[fmap_constants.ATOM_WEIGHTS][i],
          mol_dicts[i][fmap_constants.ATOM_WEIGHTS],
          delta=0.0001)
      self.assertSequenceAlmostEqual(
          feature_values[fmap_constants.ATOM_IDS][i],
          mol_dicts[i][fmap_constants.ATOM_IDS],
          delta=0.0001)
      self.assertAllEqual(feature_values[fmap_constants.SMILES][i],
                          parsed_smiles_tokens[i])
      self.assertAllEqual(
          feature_values[fmap_constants.SMILES_TOKEN_LIST_LENGTH][i],
          token_lengths[i])
      for fp_len in ms_constants.NUM_CIRCULAR_FP_BITS_LIST:
        for rad in ms_constants.CIRCULAR_FP_RADII_LIST:
          for fp_type in fmap_constants.FP_TYPE_LIST:
            fp_key = ms_constants.CircularFingerprintKey(fp_type, fp_len, rad)
            self.assertSequenceAlmostEqual(
                feature_values[str(fp_key)][i],
                mol_dicts[i][fp_key],
                delta=0.0001)
Beispiel #4
0
    def test_make_train_test_split(self, splitting_type):
        """An integration test on a small dataset."""

        fpath = self.temp_dir

        # Create component datasets from two library files.
        main_train_val_test_fractions = (
            train_test_split_utils.TrainValTestFractions(0.5, 0.25, 0.25))
        replicates_val_test_fractions = (
            train_test_split_utils.TrainValTestFractions(0.0, 0.5, 0.5))

        (mainlib_inchikey_dict, replicates_inchikey_dict,
         component_inchikey_dict) = (
             make_train_test_split.make_mainlib_replicates_train_test_split(
                 self.mol_list_large, self.mol_list_small, splitting_type,
                 main_train_val_test_fractions, replicates_val_test_fractions))

        make_train_test_split.write_mainlib_split_datasets(
            component_inchikey_dict, mainlib_inchikey_dict, fpath,
            ms_constants.MAX_ATOMS, ms_constants.MAX_PEAK_LOC)

        make_train_test_split.write_replicates_split_datasets(
            component_inchikey_dict, replicates_inchikey_dict, fpath,
            ms_constants.MAX_ATOMS, ms_constants.MAX_PEAK_LOC)

        for experiment_setup in ds_constants.EXPERIMENT_SETUPS_LIST:
            # Create experiment json files
            tf.logging.info('Writing experiment setup for %s',
                            experiment_setup.json_name)
            make_train_test_split.check_experiment_setup(
                experiment_setup.experiment_setup_dataset_dict,
                component_inchikey_dict)

            make_train_test_split.write_json_for_experiment(
                experiment_setup, fpath)

            # Check that physical files for library matching contain all inchikeys
            dict_from_json = json.load(
                tf.gfile.Open(os.path.join(fpath, experiment_setup.json_name)))

            tf.logging.info(dict_from_json)
            library_files = (
                dict_from_json[ds_constants.LIBRARY_MATCHING_OBSERVED_KEY] +
                dict_from_json[ds_constants.LIBRARY_MATCHING_PREDICTED_KEY])
            library_files = [
                os.path.join(fpath, fname) for fname in library_files
            ]

            hparams = tf.contrib.training.HParams(
                max_atoms=ms_constants.MAX_ATOMS,
                max_mass_spec_peak_loc=ms_constants.MAX_PEAK_LOC,
                intensity_power=1.0,
                batch_size=5)

            parse_sdf_utils.validate_spectra_array_contents(
                os.path.join(
                    fpath, dict_from_json[
                        ds_constants.SPECTRUM_PREDICTION_TRAIN_KEY][0]),
                hparams,
                os.path.join(
                    fpath,
                    dict_from_json[ds_constants.TRAINING_SPECTRA_ARRAY_KEY]))

            dataset = parse_sdf_utils.get_dataset_from_record(
                library_files,
                hparams,
                mode=tf.estimator.ModeKeys.EVAL,
                all_data_in_one_batch=True)

            feature_names = [fmap_constants.INCHIKEY]
            label_names = [fmap_constants.ATOM_WEIGHTS]

            features, labels = parse_sdf_utils.make_features_and_labels(
                dataset,
                feature_names,
                label_names,
                mode=tf.estimator.ModeKeys.EVAL)

            with tf.Session() as sess:
                feature_values, _ = sess.run([features, labels])

            inchikeys_from_file = [
                ikey[0]
                for ikey in feature_values[fmap_constants.INCHIKEY].tolist()
            ]

            length_from_info_file = sum([
                parse_sdf_utils.parse_info_file(library_fname)['num_examples']
                for library_fname in library_files
            ])
            # Check that info file has the correct length for the file.
            self.assertLen(inchikeys_from_file, length_from_info_file)
            # Check that the TF.Record contains all of the inchikeys in our list.
            inchikey_list_large = [
                self.encode(ikey) for ikey in self.inchikey_list_large
            ]
            self.assertSetEqual(set(inchikeys_from_file),
                                set(inchikey_list_large))