Ejemplo n.º 1
0
 def test_split(self):
     raw_data_list = [
         {
             'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'
         },
         {
             'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'
         },
         {
             'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'
         },
         {
             'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'
         },
         {
             'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'
         },
         {
             'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'
         },
         {
             'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'
         },
         {
             'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'
         },
         {
             'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'
         },
         {
             'smiles': 'CCCCCCCCCCOCC(O)CN'
         },
         {
             'smiles': 'CCCCCCCCCCOCC(O)CN'
         },
         {
             'smiles': 'CCCCCCCCCCOCC(O)CN'
         },
         {
             'smiles': 'CCCCCCCCCCOCC(O)CN'
         },
     ]
     dataset = InMemoryDataset(raw_data_list)
     splitter = ScaffoldSplitter()
     train_dataset, valid_dataset, test_dataset = splitter.split(
         dataset, frac_train=0.34, frac_valid=0.33, frac_test=0.33)
     n = len(train_dataset) + len(valid_dataset) + len(test_dataset)
     self.assertEqual(n, len(dataset))
Ejemplo n.º 2
0
def create_splitter(split_type):
    """tbd"""
    if split_type == 'random':
        splitter = RandomSplitter()
    elif split_type == 'index':
        splitter = IndexSplitter()
    elif split_type == 'scaffold':
        splitter = ScaffoldSplitter()
    elif split_type == 'random_scaffold':
        splitter = RandomScaffoldSplitter()
    else:
        raise ValueError('%s not supported' % split_type)
    return splitter
def load_chembl_filtered_dataset(data_path):
    """Load chembl_filtered dataset ,process the classification labels and the input information.

    Introduction:

        Note that, in order to load this dataset, you should have other datasets (bace, bbbp, clintox,
        esol, freesolv, hiv, lipophilicity, muv, sider, tox21, toxcast) downloaded. Since the chembl
        dataset may overlap with the above listed dataset, the overlapped smiles for test will be filtered
        for a fair evaluation.

    Description:

        The data file contains a csv table, in which columns below are used:
            
            It contains the ID, SMILES/CTAB, InChI and InChIKey compound information
            
            smiles: SMILES representation of the molecular structure

    Args:
        data_path(str): the path to the cached npz path
    
    Returns:
        an InMemoryDataset instance.
    
    Example:
        .. code-block:: python

            dataset = load_bbbp_dataset('./bace')
            print(len(dataset))

    References:
    
    [1] Gaulton, A; et al. (2011). “ChEMBL: a large-scale bioactivity database for drug discovery”. Nucleic Acids Research. 40 (Database issue): D1100-7.
    
    """
    downstream_datasets = [
        load_bace_dataset(join(dirname(data_path), 'bace')),
        load_bbbp_dataset(join(dirname(data_path), 'bbbp')),
        load_clintox_dataset(join(dirname(data_path), 'clintox')),
        load_esol_dataset(join(dirname(data_path), 'esol')),
        load_freesolv_dataset(join(dirname(data_path), 'freesolv')),
        load_hiv_dataset(join(dirname(data_path), 'hiv')),
        load_lipophilicity_dataset(join(dirname(data_path), 'lipophilicity')),
        load_muv_dataset(join(dirname(data_path), 'muv')),
        load_sider_dataset(join(dirname(data_path), 'sider')),
        load_tox21_dataset(join(dirname(data_path), 'tox21')),
        load_toxcast_dataset(join(dirname(data_path), 'toxcast')),
    ]
    downstream_inchi_set = set()
    splitter = ScaffoldSplitter()
    for c_dataset in downstream_datasets:
        train_dataset, valid_dataset, test_dataset = splitter.split(
            c_dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1)
        ### remove both test and validation molecules
        # remove_smiles = test_smiles + valid_smiles
        remove_smiles = [d['smiles'] for d in valid_dataset
                         ] + [d['smiles'] for d in test_dataset]

        downstream_inchis = []
        for smiles in remove_smiles:
            species_list = smiles.split('.')
            for s in species_list:  # record inchi for all species, not just
                # largest (by default in create_standardized_mol_id if input has
                # multiple species)
                inchi = create_standardized_mol_id(s)
                downstream_inchis.append(inchi)
        downstream_inchi_set.update(downstream_inchis)

    smiles_list, rdkit_mol_objs, folds, labels = \
            _load_chembl_filtered_dataset(data_path)
    # print(smiles_list, rdkit_mol_objs, folds, labels)
    data_list = []
    for i in range(len(rdkit_mol_objs)):
        rdkit_mol = rdkit_mol_objs[i]
        if not rdkit_mol is None:
            mw = Descriptors.MolWt(rdkit_mol)
            if 50 <= mw <= 900:
                inchi = create_standardized_mol_id(smiles_list[i])
                if not inchi is None and inchi not in downstream_inchi_set:
                    data = {
                        'smiles': smiles_list[i],
                        'label': labels[i].reshape([-1]),
                    }
                    data_list.append(data)

    dataset = InMemoryDataset(data_list)
    return dataset
Ejemplo n.º 4
0
def load_chembl_filtered_dataset(data_path, featurizer=None):
    """tbd"""
    downstream_datasets = [
        load_bace_dataset(join(dirname(dirname(data_path)), 'bace/raw')),
        load_bbbp_dataset(join(dirname(dirname(data_path)), 'bbbp/raw')),
        load_clintox_dataset(join(dirname(dirname(data_path)), 'clintox/raw')),
        load_esol_dataset(join(dirname(dirname(data_path)), 'esol/raw')),
        load_freesolv_dataset(join(dirname(dirname(data_path)),
                                   'freesolv/raw')),
        load_hiv_dataset(join(dirname(dirname(data_path)), 'hiv/raw')),
        load_lipophilicity_dataset(
            join(dirname(dirname(data_path)), 'lipophilicity/raw')),
        load_muv_dataset(join(dirname(dirname(data_path)), 'muv/raw')),
        load_sider_dataset(join(dirname(dirname(data_path)), 'sider/raw')),
        load_tox21_dataset(join(dirname(dirname(data_path)), 'tox21/raw')),
        load_toxcast_dataset(join(dirname(dirname(data_path)), 'toxcast/raw')),
    ]
    downstream_inchi_set = set()
    splitter = ScaffoldSplitter()
    for c_dataset in downstream_datasets:
        train_dataset, valid_dataset, test_dataset = splitter.split(
            c_dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1)

        # downstream_dataset = MoleculeDataset(self.root, dataset_name=dataset_name)
        # downstream_smiles = pd.read_csv(os.path.join(d_path,
        #                                              'processed', 'smiles.csv'),
        #                                 header=None)[0].tolist()
        # downstream_data_list = downstream_dataset.get_data_list()
        # downstream_smiles = downstream_dataset.get_smiles_list()
        # assert len(downstream_data_list) == len(downstream_smiles)
        # _, _, _, (train_smiles, valid_smiles, test_smiles) = scaffold_split(
        #         downstream_data_list, downstream_smiles, task_idx=None, null_value=0,
        #         frac_train=0.8, frac_valid=0.1, frac_test=0.1,
        #         return_smiles=True)

        ### remove both test and validation molecules
        # remove_smiles = test_smiles + valid_smiles
        remove_smiles = [d['smiles'] for d in valid_dataset
                         ] + [d['smiles'] for d in test_dataset]

        downstream_inchis = []
        for smiles in remove_smiles:
            species_list = smiles.split('.')
            for s in species_list:  # record inchi for all species, not just
                # largest (by default in create_standardized_mol_id if input has
                # multiple species)
                inchi = create_standardized_mol_id(s)
                downstream_inchis.append(inchi)
        downstream_inchi_set.update(downstream_inchis)

    smiles_list, rdkit_mol_objs, folds, labels = \
            _load_chembl_filtered_dataset(data_path)
    # print(smiles_list, rdkit_mol_objs, folds, labels)
    data_list = []
    for i in range(len(rdkit_mol_objs)):
        rdkit_mol = rdkit_mol_objs[i]
        if not rdkit_mol is None:
            mw = Descriptors.MolWt(rdkit_mol)
            if 50 <= mw <= 900:
                inchi = create_standardized_mol_id(smiles_list[i])
                if not inchi is None and inchi not in downstream_inchi_set:
                    raw_data = {
                        'smiles': smiles_list[i],
                        'label': labels[i].reshape([-1]),
                    }

                    if not featurizer is None:
                        data = featurizer.gen_features(raw_data)
                    else:
                        data = raw_data

                    if not data is None:
                        data_list.append(data)

    dataset = InMemoryDataset(data_list)
    return dataset
Ejemplo n.º 5
0
def load_chembl_filtered_dataset(data_path, featurizer=None):
    """load chembl_filtered dataset ,process the classification labels and the input information.

    The data file contains a csv table, in which columns below are used:

    :It contains the ID, SMILES/CTAB, InChI and InChIKey compound information.
    :smiles:SMILES representation of the molecular structure

    Args:
        data_path(str): the path to the cached npz path.
        featurizer: the featurizer to use for processing the data.  
    
    Returns:
        dataset(InMemoryDataset): the data_list(list of dict of numpy ndarray).

    References:
    -- Gaulton, A; et al. (2011). “ChEMBL: a large-scale bioactivity database for drug discovery”. Nucleic Acids Research. 40 (Database issue): D1100-7.
    
    """
    downstream_datasets = [
        load_bace_dataset(join(dirname(dirname(data_path)), 'bace/raw')),
        load_bbbp_dataset(join(dirname(dirname(data_path)), 'bbbp/raw')),
        load_clintox_dataset(join(dirname(dirname(data_path)), 'clintox/raw')),
        load_esol_dataset(join(dirname(dirname(data_path)), 'esol/raw')),
        load_freesolv_dataset(join(dirname(dirname(data_path)),
                                   'freesolv/raw')),
        load_hiv_dataset(join(dirname(dirname(data_path)), 'hiv/raw')),
        load_lipophilicity_dataset(
            join(dirname(dirname(data_path)), 'lipophilicity/raw')),
        load_muv_dataset(join(dirname(dirname(data_path)), 'muv/raw')),
        load_sider_dataset(join(dirname(dirname(data_path)), 'sider/raw')),
        load_tox21_dataset(join(dirname(dirname(data_path)), 'tox21/raw')),
        load_toxcast_dataset(join(dirname(dirname(data_path)), 'toxcast/raw')),
    ]
    downstream_inchi_set = set()
    splitter = ScaffoldSplitter()
    for c_dataset in downstream_datasets:
        train_dataset, valid_dataset, test_dataset = splitter.split(
            c_dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1)
        ### remove both test and validation molecules
        # remove_smiles = test_smiles + valid_smiles
        remove_smiles = [d['smiles'] for d in valid_dataset
                         ] + [d['smiles'] for d in test_dataset]

        downstream_inchis = []
        for smiles in remove_smiles:
            species_list = smiles.split('.')
            for s in species_list:  # record inchi for all species, not just
                # largest (by default in create_standardized_mol_id if input has
                # multiple species)
                inchi = create_standardized_mol_id(s)
                downstream_inchis.append(inchi)
        downstream_inchi_set.update(downstream_inchis)

    smiles_list, rdkit_mol_objs, folds, labels = \
            _load_chembl_filtered_dataset(data_path)
    # print(smiles_list, rdkit_mol_objs, folds, labels)
    data_list = []
    for i in range(len(rdkit_mol_objs)):
        rdkit_mol = rdkit_mol_objs[i]
        if not rdkit_mol is None:
            mw = Descriptors.MolWt(rdkit_mol)
            if 50 <= mw <= 900:
                inchi = create_standardized_mol_id(smiles_list[i])
                if not inchi is None and inchi not in downstream_inchi_set:
                    raw_data = {
                        'smiles': smiles_list[i],
                        'label': labels[i].reshape([-1]),
                    }

                    if not featurizer is None:
                        data = featurizer.gen_features(raw_data)
                    else:
                        data = raw_data

                    if not data is None:
                        data_list.append(data)

    dataset = InMemoryDataset(data_list)
    return dataset