Exemplo n.º 1
0
    def test_make_train_val_test_split_mol_lists(self):
        main_train_test_split = train_test_split_utils.TrainValTestFractions(
            0.5, 0.25, 0.25)

        inchikey_list_of_lists = (
            train_test_split_utils.make_train_val_test_split_inchikey_lists(
                self.inchikey_list_large, self.inchikey_dict_large,
                main_train_test_split))

        expected_lengths_of_inchikey_lists = [5, 2, 4]

        for expected_length, inchikey_list in zip(
                expected_lengths_of_inchikey_lists, inchikey_list_of_lists):
            self.assertLen(inchikey_list, expected_length)

        train_test_split_utils.assert_all_lists_mutally_exclusive(
            inchikey_list_of_lists)

        trunc_inchikey_list_large = self.inchikey_list_large[:6]
        inchikey_list_of_lists = [
            (train_test_split_utils.make_train_val_test_split_inchikey_lists(
                trunc_inchikey_list_large, self.inchikey_dict_large,
                main_train_test_split))
        ]

        expected_lengths_of_inchikey_lists = [3, 1, 2]
        for expected_length, inchikey_list in zip(
                expected_lengths_of_inchikey_lists, inchikey_list_of_lists):
            self.assertLen(inchikey_list, expected_length)

        train_test_split_utils.assert_all_lists_mutally_exclusive(
            inchikey_list_of_lists)
Exemplo n.º 2
0
 def test_all_lists_mutually_exclusive(self):
     list1 = ['1', '2', '3']
     list2 = ['2', '3', '4']
     try:
         train_test_split_utils.assert_all_lists_mutally_exclusive(
             [list1, list2])
         raise ValueError(
             'Sets with overlapping elements should have failed.')
     except ValueError:
         pass
Exemplo n.º 3
0
    def test_make_train_val_test_split_mol_lists_holdout(self):
        main_train_test_split = train_test_split_utils.TrainValTestFractions(
            0.5, 0.25, 0.25)
        holdout_inchikey_list_of_lists = (
            train_test_split_utils.make_train_val_test_split_inchikey_lists(
                self.inchikey_list_large,
                self.inchikey_dict_large,
                main_train_test_split,
                holdout_inchikey_list=self.inchikey_list_small))

        expected_lengths_of_inchikey_lists = [4, 2, 3]
        for expected_length, inchikey_list in zip(
                expected_lengths_of_inchikey_lists,
                holdout_inchikey_list_of_lists):
            self.assertLen(inchikey_list, expected_length)

        train_test_split_utils.assert_all_lists_mutally_exclusive(
            holdout_inchikey_list_of_lists)
Exemplo n.º 4
0
def make_mainlib_replicates_train_test_split(
        mainlib_mol_list,
        replicates_mol_list,
        splitting_type,
        mainlib_fractions,
        replicates_fractions,
        mainlib_maximum_num_molecules_to_use=None,
        replicates_maximum_num_molecules_to_use=None,
        rseed=42):
    """Makes train/validation/test inchikey lists from two lists of rdkit.Mol.

    Args:
      mainlib_mol_list : list of molecules from main library
      replicates_mol_list : list of molecules from replicates library
      splitting_type : type of splitting to use for validation splits.
      mainlib_fractions : TrainValTestFractions namedtuple
          holding desired fractions for train/val/test split of mainlib
      replicates_fractions : TrainValTestFractions namedtuple
          holding desired fractions for train/val/test split of replicates.
          For the replicates set, the train fraction should be set to 0.
      mainlib_maximum_num_molecules_to_use : Largest number of molecules to use
         when making datasets from mainlib
      replicates_maximum_num_molecules_to_use : Largest number of molecules to use
         when making datasets from replicates
      rseed : random seed for shuffling

    Returns:
      main_inchikey_dict : Dict that is keyed by inchikey, containing a list of
          rdkit.Mol objects corresponding to that inchikey from the mainlib
      replicates_inchikey_dict : Dict that is keyed by inchikey, containing a list
          of rdkit.Mol objects corresponding to that inchikey from the replicates
          library
      main_replicates_split_inchikey_lists_dict : dict with keys :
        'mainlib_train', 'mainlib_validation', 'mainlib_test',
        'replicates_train', 'replicates_validation', 'replicates_test'
        Values are lists of inchikeys corresponding to each dataset.

    """
    random.seed(rseed)
    main_inchikey_dict = train_test_split_utils.make_inchikey_dict(
        mainlib_mol_list)
    main_inchikey_list = main_inchikey_dict.keys()

    if mainlib_maximum_num_molecules_to_use is not None:
        main_inchikey_list = random.sample(
            main_inchikey_list, mainlib_maximum_num_molecules_to_use)

    replicates_inchikey_dict = train_test_split_utils.make_inchikey_dict(
        replicates_mol_list)
    replicates_inchikey_list = replicates_inchikey_dict.keys()

    if replicates_maximum_num_molecules_to_use is not None:
        replicates_inchikey_list = random.sample(
            replicates_inchikey_list, replicates_maximum_num_molecules_to_use)

    # Make train/val/test splits for main dataset.
    main_train_validation_test_inchikeys = (
        train_test_split_utils.make_train_val_test_split_inchikey_lists(
            main_inchikey_list,
            main_inchikey_dict,
            mainlib_fractions,
            holdout_inchikey_list=replicates_inchikey_list,
            splitting_type=splitting_type))

    # Make train/val/test splits for replicates dataset.
    replicates_validation_test_inchikeys = (
        train_test_split_utils.make_train_val_test_split_inchikey_lists(
            replicates_inchikey_list,
            replicates_inchikey_dict,
            replicates_fractions,
            splitting_type=splitting_type))

    component_inchikey_dict = {
        ds_constants.MAINLIB_TRAIN_BASENAME:
        main_train_validation_test_inchikeys.train,
        ds_constants.MAINLIB_VALIDATION_BASENAME:
        main_train_validation_test_inchikeys.validation,
        ds_constants.MAINLIB_TEST_BASENAME:
        main_train_validation_test_inchikeys.test,
        ds_constants.REPLICATES_TRAIN_BASENAME:
        replicates_validation_test_inchikeys.train,
        ds_constants.REPLICATES_VALIDATION_BASENAME:
        replicates_validation_test_inchikeys.validation,
        ds_constants.REPLICATES_TEST_BASENAME:
        replicates_validation_test_inchikeys.test
    }

    train_test_split_utils.assert_all_lists_mutally_exclusive(
        component_inchikey_dict.values())
    # Test that the set of the 5 component inchikey lists is equal to the set of
    #   inchikeys in the main library.
    all_inchikeys_in_components = [
        ikey for ikey_list in component_inchikey_dict.values()
        for ikey in ikey_list
    ]

    assert set(main_inchikey_list + replicates_inchikey_list) == set(
        all_inchikeys_in_components
    ), ('The inchikeys in the original inchikey dictionary are not all included'
        ' in the train/val/test component libraries')

    return (main_inchikey_dict, replicates_inchikey_dict,
            component_inchikey_dict)