def test_split(self): raw_data_list = [ { 'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' }, { 'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' }, { 'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' }, { 'smiles': 'CCOc1ccc2nc(S(N)(=O)=O)sc2c1' }, { 'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' }, { 'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' }, { 'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' }, { 'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' }, { 'smiles': 'CC(C)CCCCCCCOP(OCCCCCCCC(C)C)Oc1ccccc1' }, { 'smiles': 'CCCCCCCCCCOCC(O)CN' }, { 'smiles': 'CCCCCCCCCCOCC(O)CN' }, { 'smiles': 'CCCCCCCCCCOCC(O)CN' }, { 'smiles': 'CCCCCCCCCCOCC(O)CN' }, ] dataset = InMemoryDataset(raw_data_list) splitter = RandomScaffoldSplitter() train_dataset, valid_dataset, test_dataset = splitter.split( dataset, frac_train=0.34, frac_valid=0.33, frac_test=0.33) n = len(train_dataset) + len(valid_dataset) + len(test_dataset) self.assertEqual(n, len(dataset))
def create_splitter(split_type): """tbd""" if split_type == 'random': splitter = RandomSplitter() elif split_type == 'index': splitter = IndexSplitter() elif split_type == 'scaffold': splitter = ScaffoldSplitter() elif split_type == 'random_scaffold': splitter = RandomScaffoldSplitter() else: raise ValueError('%s not supported' % split_type) return splitter