Ejemplo n.º 1
0
 def test_split(self):
     raw_data_list = [
         {
             'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'
         },
         {
             'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'
         },
         {
             'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'
         },
         {
             'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1'
         },
         {
             'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'
         },
         {
             'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'
         },
         {
             'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'
         },
         {
             'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'
         },
         {
             'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1'
         },
         {
             'smiles': 'CCCCCCCCCCOCC(O)CN'
         },
         {
             'smiles': 'CCCCCCCCCCOCC(O)CN'
         },
         {
             'smiles': 'CCCCCCCCCCOCC(O)CN'
         },
         {
             'smiles': 'CCCCCCCCCCOCC(O)CN'
         },
     ]
     dataset = InMemoryDataset(raw_data_list)
     splitter = RandomScaffoldSplitter()
     train_dataset, valid_dataset, test_dataset = splitter.split(
         dataset, frac_train=0.34, frac_valid=0.33, frac_test=0.33)
     n = len(train_dataset) + len(valid_dataset) + len(test_dataset)
     self.assertEqual(n, len(dataset))
Ejemplo n.º 2
0
def create_splitter(split_type):
    """tbd"""
    if split_type == 'random':
        splitter = RandomSplitter()
    elif split_type == 'index':
        splitter = IndexSplitter()
    elif split_type == 'scaffold':
        splitter = ScaffoldSplitter()
    elif split_type == 'random_scaffold':
        splitter = RandomScaffoldSplitter()
    else:
        raise ValueError('%s not supported' % split_type)
    return splitter