def _input_fn(record_fnames, all_data_in_one_batch, load_training_spectrum_library=False): """Reads TFRecord from a list of record file names. Modified by Zhongsheng Chen ([email protected]) on Sept. 30, 2019. A set function was used to remove repeated keys of fields in features. """ if not record_fnames: return None record_fnames = [ os.path.join(data_dir, r_name) for r_name in record_fnames ] dataset = parse_sdf_utils.get_dataset_from_record( record_fnames, hparams, mode=mode, features_to_load=set(features_to_load + hparams.label_names), all_data_in_one_batch=all_data_in_one_batch) dict_to_return = parse_sdf_utils.make_features_and_labels( dataset, features_to_load, hparams.label_names, mode=mode)[0] if load_training_spectrum_library: library_file = os.path.join( '/readahead/128M/', filenames[ds_constants.TRAINING_SPECTRA_ARRAY_KEY]) train_library = parse_sdf_utils.load_training_spectra_array( library_file) train_library = tf.convert_to_tensor(train_library, dtype=tf.float32) dict_to_return['SPECTRUM_PREDICTION_LIBRARY'] = train_library return dict_to_return
def test_dict_tfexample(self): """Check if the contents of tf.Records is the same as input molecule info. Writes tf.example as tf.record to disk, then reads from disk. """ mol_list = parse_sdf_utils.get_sdf_to_mol(self.test_file_short) fd, fpath = tempfile.mkstemp(dir=self.temp_dir) os.close(fd) parse_sdf_utils.write_dicts_to_example(mol_list, fpath, self.hparams.max_atoms, self.hparams.max_mass_spec_peak_loc) parse_sdf_utils.write_info_file(mol_list, fpath) self._validate_info_file(mol_list, fpath) dataset = parse_sdf_utils.get_dataset_from_record( [fpath], self.hparams, mode=tf.estimator.ModeKeys.EVAL) feature_names = [ fmap_constants.ATOM_WEIGHTS, fmap_constants.MOLECULE_WEIGHT, fmap_constants.DENSE_MASS_SPEC, fmap_constants.INCHIKEY, fmap_constants.NAME, fmap_constants.MOLECULAR_FORMULA, fmap_constants.ADJACENCY_MATRIX, fmap_constants.ATOM_IDS, fmap_constants.SMILES ] label_names = [fmap_constants.INCHIKEY] features, _ = parse_sdf_utils.make_features_and_labels( dataset, feature_names, label_names, mode=tf.estimator.ModeKeys.EVAL) with tf.Session() as sess: feature_values = sess.run(features) # Check that the dataset was consumed try: sess.run(features) raise ValueError('Dataset parsing using batch size of length of the' 'dataset resulted in more than one batch.') except tf.errors.OutOfRangeError: # expected behavior pass for i in range(len(self.expected_mol_dicts)): self.assertAlmostEqual( feature_values[fmap_constants.MOLECULE_WEIGHT][i], self.expected_mol_dicts[i][fmap_constants.MOLECULE_WEIGHT]) self.assertSequenceAlmostEqual( feature_values[fmap_constants.ADJACENCY_MATRIX][i] .flatten(), self.expected_mol_dicts[i][fmap_constants.ADJACENCY_MATRIX], delta=0.0001) self.assertSequenceAlmostEqual( feature_values[fmap_constants.DENSE_MASS_SPEC][i], self.expected_mol_dicts[i][fmap_constants.DENSE_MASS_SPEC], delta=0.0001) self.assertSequenceAlmostEqual( feature_values[fmap_constants.ATOM_WEIGHTS][i], self.expected_mol_dicts[i][fmap_constants.ATOM_WEIGHTS], delta=0.0001) self.assertSequenceAlmostEqual( feature_values[fmap_constants.ATOM_IDS][i], self.expected_mol_dicts[i][fmap_constants.ATOM_IDS], delta=0.0001) self.assertEqual( feature_values[fmap_constants.NAME][i], self.encode(self.expected_mol_dicts[i][fmap_constants.NAME])) self.assertEqual( feature_values[fmap_constants.INCHIKEY][i], self.encode( self.expected_mol_dicts[i][fmap_constants.INCHIKEY])) self.assertEqual( feature_values[fmap_constants.MOLECULAR_FORMULA][i], self.encode( self.expected_mol_dicts[i][fmap_constants.MOLECULAR_FORMULA])) self.assertAllEqual(feature_values[fmap_constants.SMILES][i], self.expected_mol_dicts[i]['parsed_smiles']) self.assertAllEqual( feature_values[fmap_constants.SMILES_TOKEN_LIST_LENGTH][i], self.expected_mol_dicts[i][fmap_constants.SMILES_TOKEN_LIST_LENGTH])
def test_record_contents(self): """Test the contents of the stored record file to ensure features match.""" mol_list = parse_sdf_utils.get_sdf_to_mol(self.test_file_long) mol_dicts = [parse_sdf_utils.make_mol_dict(mol) for mol in mol_list] parsed_smiles_tokens = [ feature_utils.tokenize_smiles( np.array([mol_dict[fmap_constants.SMILES]])) for mol_dict in mol_dicts ] token_lengths = [ np.shape(token_arr)[0] for token_arr in parsed_smiles_tokens ] parsed_smiles_tokens = [ np.pad(token_arr, (0, ms_constants.MAX_TOKEN_LIST_LENGTH - token_length), 'constant') for token_arr, token_length in zip(parsed_smiles_tokens, token_lengths) ] hparams_main = tf.contrib.training.HParams( max_atoms=ms_constants.MAX_ATOMS, max_mass_spec_peak_loc=ms_constants.MAX_PEAK_LOC, eval_batch_size=len(mol_list), intensity_power=1.0) dataset = parse_sdf_utils.get_dataset_from_record( [os.path.join(self.test_data_directory, 'test_14_record.gz')], hparams_main, mode=tf.estimator.ModeKeys.EVAL) feature_names = [ fmap_constants.ATOM_WEIGHTS, fmap_constants.MOLECULE_WEIGHT, fmap_constants.DENSE_MASS_SPEC, fmap_constants.INCHIKEY, fmap_constants.NAME, fmap_constants.MOLECULAR_FORMULA, fmap_constants.ADJACENCY_MATRIX, fmap_constants.ATOM_IDS, fmap_constants.SMILES ] for fp_len in ms_constants.NUM_CIRCULAR_FP_BITS_LIST: for rad in ms_constants.CIRCULAR_FP_RADII_LIST: for fp_type in fmap_constants.FP_TYPE_LIST: feature_names.append( str(ms_constants.CircularFingerprintKey(fp_type, fp_len, rad))) label_names = [fmap_constants.INCHIKEY] features, _ = parse_sdf_utils.make_features_and_labels( dataset, feature_names, label_names, mode=tf.estimator.ModeKeys.EVAL) with tf.Session() as sess: feature_values = sess.run(features) # Check that the dataset was consumed try: sess.run(features) raise ValueError('Dataset parsing using batch size of length of the' ' dataset resulted in more than one batch.') except tf.errors.OutOfRangeError: # expected behavior pass for i in range(len(mol_list)): self.assertAlmostEqual( feature_values[fmap_constants.MOLECULE_WEIGHT][i], mol_dicts[i][fmap_constants.MOLECULE_WEIGHT]) self.assertSequenceAlmostEqual( feature_values[fmap_constants.ADJACENCY_MATRIX][i] .flatten(), mol_dicts[i][fmap_constants.ADJACENCY_MATRIX], delta=0.0001) self.assertEqual(feature_values[fmap_constants.NAME][i], self.encode(mol_dicts[i][fmap_constants.NAME])) self.assertEqual(feature_values[fmap_constants.INCHIKEY][i], self.encode(mol_dicts[i][fmap_constants.INCHIKEY])) self.assertEqual( feature_values[fmap_constants.MOLECULAR_FORMULA][i], self.encode(mol_dicts[i][fmap_constants.MOLECULAR_FORMULA])) self.assertSequenceAlmostEqual( feature_values[fmap_constants.DENSE_MASS_SPEC][i], mol_dicts[i][fmap_constants.DENSE_MASS_SPEC], delta=0.0001) self.assertSequenceAlmostEqual( feature_values[fmap_constants.ATOM_WEIGHTS][i], mol_dicts[i][fmap_constants.ATOM_WEIGHTS], delta=0.0001) self.assertSequenceAlmostEqual( feature_values[fmap_constants.ATOM_IDS][i], mol_dicts[i][fmap_constants.ATOM_IDS], delta=0.0001) self.assertAllEqual(feature_values[fmap_constants.SMILES][i], parsed_smiles_tokens[i]) self.assertAllEqual( feature_values[fmap_constants.SMILES_TOKEN_LIST_LENGTH][i], token_lengths[i]) for fp_len in ms_constants.NUM_CIRCULAR_FP_BITS_LIST: for rad in ms_constants.CIRCULAR_FP_RADII_LIST: for fp_type in fmap_constants.FP_TYPE_LIST: fp_key = ms_constants.CircularFingerprintKey(fp_type, fp_len, rad) self.assertSequenceAlmostEqual( feature_values[str(fp_key)][i], mol_dicts[i][fp_key], delta=0.0001)
def test_make_train_test_split(self, splitting_type): """An integration test on a small dataset.""" fpath = self.temp_dir # Create component datasets from two library files. main_train_val_test_fractions = ( train_test_split_utils.TrainValTestFractions(0.5, 0.25, 0.25)) replicates_val_test_fractions = ( train_test_split_utils.TrainValTestFractions(0.0, 0.5, 0.5)) (mainlib_inchikey_dict, replicates_inchikey_dict, component_inchikey_dict) = ( make_train_test_split.make_mainlib_replicates_train_test_split( self.mol_list_large, self.mol_list_small, splitting_type, main_train_val_test_fractions, replicates_val_test_fractions)) make_train_test_split.write_mainlib_split_datasets( component_inchikey_dict, mainlib_inchikey_dict, fpath, ms_constants.MAX_ATOMS, ms_constants.MAX_PEAK_LOC) make_train_test_split.write_replicates_split_datasets( component_inchikey_dict, replicates_inchikey_dict, fpath, ms_constants.MAX_ATOMS, ms_constants.MAX_PEAK_LOC) for experiment_setup in ds_constants.EXPERIMENT_SETUPS_LIST: # Create experiment json files tf.logging.info('Writing experiment setup for %s', experiment_setup.json_name) make_train_test_split.check_experiment_setup( experiment_setup.experiment_setup_dataset_dict, component_inchikey_dict) make_train_test_split.write_json_for_experiment( experiment_setup, fpath) # Check that physical files for library matching contain all inchikeys dict_from_json = json.load( tf.gfile.Open(os.path.join(fpath, experiment_setup.json_name))) tf.logging.info(dict_from_json) library_files = ( dict_from_json[ds_constants.LIBRARY_MATCHING_OBSERVED_KEY] + dict_from_json[ds_constants.LIBRARY_MATCHING_PREDICTED_KEY]) library_files = [ os.path.join(fpath, fname) for fname in library_files ] hparams = tf.contrib.training.HParams( max_atoms=ms_constants.MAX_ATOMS, max_mass_spec_peak_loc=ms_constants.MAX_PEAK_LOC, intensity_power=1.0, batch_size=5) parse_sdf_utils.validate_spectra_array_contents( os.path.join( fpath, dict_from_json[ ds_constants.SPECTRUM_PREDICTION_TRAIN_KEY][0]), hparams, os.path.join( fpath, dict_from_json[ds_constants.TRAINING_SPECTRA_ARRAY_KEY])) dataset = parse_sdf_utils.get_dataset_from_record( library_files, hparams, mode=tf.estimator.ModeKeys.EVAL, all_data_in_one_batch=True) feature_names = [fmap_constants.INCHIKEY] label_names = [fmap_constants.ATOM_WEIGHTS] features, labels = parse_sdf_utils.make_features_and_labels( dataset, feature_names, label_names, mode=tf.estimator.ModeKeys.EVAL) with tf.Session() as sess: feature_values, _ = sess.run([features, labels]) inchikeys_from_file = [ ikey[0] for ikey in feature_values[fmap_constants.INCHIKEY].tolist() ] length_from_info_file = sum([ parse_sdf_utils.parse_info_file(library_fname)['num_examples'] for library_fname in library_files ]) # Check that info file has the correct length for the file. self.assertLen(inchikeys_from_file, length_from_info_file) # Check that the TF.Record contains all of the inchikeys in our list. inchikey_list_large = [ self.encode(ikey) for ikey in self.inchikey_list_large ] self.assertSetEqual(set(inchikeys_from_file), set(inchikey_list_large))