Ejemplo n.º 1
0
    def test_singletask_index_k_fold_split(self):
        """
    Test singletask IndexSplitter class.
    """
        solubility_dataset = self.load_solubility_data()
        index_splitter = IndexSplitter()
        ids_set = set(solubility_dataset.ids)

        K = 5
        fold_dirs = [tempfile.mkdtemp() for i in range(K)]
        fold_datasets = index_splitter.k_fold_split(solubility_dataset,
                                                    fold_dirs)

        for fold in range(K):
            fold_dataset = fold_datasets[fold]
            # Verify lengths is 10/k == 2
            assert len(fold_dataset) == 2
            # Verify that compounds in this fold are subset of original compounds
            fold_ids_set = set(fold_dataset.ids)
            assert fold_ids_set.issubset(ids_set)
            # Verify that no two folds have overlapping compounds.
            for other_fold in range(K):
                if fold == other_fold:
                    continue
                other_fold_dataset = fold_datasets[other_fold]
                other_fold_ids_set = set(other_fold_dataset.ids)
                assert fold_ids_set.isdisjoint(other_fold_ids_set)

        merge_dir = tempfile.mkdtemp()
        merged_dataset = DiskDataset.merge(merge_dir, fold_datasets)
        assert len(merged_dataset) == len(solubility_dataset)
        assert sorted(merged_dataset.ids) == (sorted(solubility_dataset.ids))
Ejemplo n.º 2
0
 def test_multitask_index_split(self):
     """
 Test multitask IndexSplitter class.
 """
     multitask_dataset = self.load_multitask_data()
     index_splitter = IndexSplitter()
     train_data, valid_data, test_data = \
         index_splitter.train_valid_test_split(
             multitask_dataset,
             self.train_dir, self.valid_dir, self.test_dir,
             frac_train=0.8, frac_valid=0.1, frac_test=0.1)
     assert len(train_data) == 8
     assert len(valid_data) == 1
     assert len(test_data) == 1
Ejemplo n.º 3
0
  def test_k_fold_splitter(self):
    """
    Test that we can 5 fold index wise over 5 points
    """
    ds = NumpyDataset(np.array(range(5)), np.array(range(5)))
    index_splitter = IndexSplitter()

    K = 5
    fold_datasets = index_splitter.k_fold_split(ds, K)

    for fold in range(K):
      self.assertTrue(fold_datasets[fold][1].X[0] == fold)
      train_data = set(list(fold_datasets[fold][0].X))
      self.assertFalse(fold in train_data)
      self.assertEqual(K - 1, len(train_data))
Ejemplo n.º 4
0
  def test_k_fold_splitter(self):
    """
    Test that we can 5 fold index wise over 5 points
    """
    ds = NumpyDataset(np.array(range(5)), np.array(range(5)))
    index_splitter = IndexSplitter()

    K = 5
    fold_datasets = index_splitter.k_fold_split(ds, K)

    for fold in range(K):
      train, cv = fold_datasets[fold][0], fold_datasets[fold][1]
      self.assertTrue(cv.X[0] == fold)
      train_data = set(list(train.X))
      self.assertFalse(fold in train_data)
      self.assertEqual(K - 1, len(train))
      self.assertEqual(1, len(cv))
Ejemplo n.º 5
0
    def test_singletask_index_split(self):
        """
    Test singletask RandomSplitter class.
    """
        solubility_dataset = self.load_solubility_data()
        random_splitter = IndexSplitter()
        train_data, valid_data, test_data = \
            random_splitter.train_valid_test_split(
                solubility_dataset,
                self.train_dir, self.valid_dir, self.test_dir,
                frac_train=0.8, frac_valid=0.1, frac_test=0.1)
        assert len(train_data) == 8
        assert len(valid_data) == 1
        assert len(test_data) == 1

        merge_dir = tempfile.mkdtemp()
        merged_dataset = DiskDataset.merge(merge_dir,
                                           [train_data, valid_data, test_data])
        assert sorted(merged_dataset.ids) == (sorted(solubility_dataset.ids))
Ejemplo n.º 6
0
def _load_mol_dataset(dataset_file,
                      tasks,
                      split="stratified",
                      test_size=0.1,
                      valid_size=0.1,
                      min_size=0,
                      max_size=None,
                      **kwargs):

    train_size = 1.0 - (test_size + valid_size)
    featurizer = RawFeaturizer()
    loader = CSVLoader(tasks=tasks,
                       smiles_field="smiles",
                       featurizer=featurizer,
                       verbose=False,
                       log_every_n=10000)
    dataset = loader.featurize(dataset_file)

    splitters = {
        'index': IndexSplitter(),
        'random': RandomSplitter(),
        'scaffold': ScaffoldSplitter(),
        'butina': ButinaSplitter(),
        'stratified': RandomStratifiedSplitter()
    }

    splitter = splitters[split]
    train, valid, test = splitter.train_valid_test_split(dataset,
                                                         frac_train=train_size,
                                                         frac_valid=valid_size,
                                                         frac_test=test_size)

    # compute data balance information on train
    balancer = BalancingTransformer(transform_w=True, dataset=train)
    train = balancer.transform(train)
    valid = balancer.transform(valid)
    test = balancer.transform(test)
    transformer = GraphTransformer(mol_size=[min_size, max_size], **kwargs)
    datasets = []
    for dt in (train, valid, test):
        X, ids = transformer(dt.ids, dtype=np.float32, ignore_errors=False)
        y = dt.y[ids, :]
        w = dt.w[ids, :]
        raw_mols = dt.X[ids]
        datasets.append(MolDataset(X, y, raw_mols, w=w, pad_to=max_size))

    in_size = X[0][-1].shape[-1]
    out_size = 1 if len(y.shape) == 1 else y.shape[-1]
    return datasets, in_size, out_size