Ejemplo n.º 1
0
    def test_singletask_random_k_fold_split(self):
        """
    Test singletask RandomSplitter class.
    """
        solubility_dataset = self.load_solubility_data()
        random_splitter = RandomSplitter()
        ids_set = set(solubility_dataset.ids)

        K = 5
        fold_dirs = [tempfile.mkdtemp() for i in range(K)]
        fold_datasets = random_splitter.k_fold_split(solubility_dataset,
                                                     fold_dirs)
        for fold in range(K):
            fold_dataset = fold_datasets[fold]
            # Verify lengths is 10/k == 2
            assert len(fold_dataset) == 2
            # Verify that compounds in this fold are subset of original compounds
            fold_ids_set = set(fold_dataset.ids)
            assert fold_ids_set.issubset(ids_set)
            # Verify that no two folds have overlapping compounds.
            for other_fold in range(K):
                if fold == other_fold:
                    continue
                other_fold_dataset = fold_datasets[other_fold]
                other_fold_ids_set = set(other_fold_dataset.ids)
                assert fold_ids_set.isdisjoint(other_fold_ids_set)

        merge_dir = tempfile.mkdtemp()
        merged_dataset = DiskDataset.merge(merge_dir, fold_datasets)
        assert len(merged_dataset) == len(solubility_dataset)
        assert sorted(merged_dataset.ids) == (sorted(solubility_dataset.ids))
Ejemplo n.º 2
0
    def test_singletask_stratified_k_fold_split(self):
        """
    Test RandomStratifiedSplitter k-fold class.
    """
        n_samples = 100
        n_positives = 20
        n_features = 10
        n_tasks = 1

        X = np.random.rand(n_samples, n_features)
        y = np.zeros(n_samples)
        y[:n_positives] = 1
        w = np.ones(n_samples)
        ids = np.arange(n_samples)

        data_dir = tempfile.mkdtemp()
        dataset = DiskDataset.from_numpy(data_dir, X, y, w, ids)

        stratified_splitter = RandomStratifiedSplitter()
        ids_set = set(dataset.ids)

        K = 5
        fold_dirs = [tempfile.mkdtemp() for i in range(K)]
        fold_datasets = stratified_splitter.k_fold_split(dataset, fold_dirs)

        for fold in range(K):
            fold_dataset = fold_datasets[fold]
            # Verify lengths is 100/k == 20
            # Note: This wouldn't work for multitask str
            # assert len(fold_dataset) == n_samples/K
            fold_labels = fold_dataset.y
            # Verify that each fold has n_positives/K = 4 positive examples.
            assert np.count_nonzero(fold_labels == 1) == n_positives / K
            # Verify that compounds in this fold are subset of original compounds
            fold_ids_set = set(fold_dataset.ids)
            assert fold_ids_set.issubset(ids_set)
            # Verify that no two folds have overlapping compounds.
            for other_fold in range(K):
                if fold == other_fold:
                    continue
                other_fold_dataset = fold_datasets[other_fold]
                other_fold_ids_set = set(other_fold_dataset.ids)
                assert fold_ids_set.isdisjoint(other_fold_ids_set)

        merge_dir = tempfile.mkdtemp()
        merged_dataset = DiskDataset.merge(merge_dir, fold_datasets)
        assert len(merged_dataset) == len(dataset)
        assert sorted(merged_dataset.ids) == (sorted(dataset.ids))
Ejemplo n.º 3
0
    def test_singletask_index_split(self):
        """
    Test singletask RandomSplitter class.
    """
        solubility_dataset = self.load_solubility_data()
        random_splitter = IndexSplitter()
        train_data, valid_data, test_data = \
            random_splitter.train_valid_test_split(
                solubility_dataset,
                self.train_dir, self.valid_dir, self.test_dir,
                frac_train=0.8, frac_valid=0.1, frac_test=0.1)
        assert len(train_data) == 8
        assert len(valid_data) == 1
        assert len(test_data) == 1

        merge_dir = tempfile.mkdtemp()
        merged_dataset = DiskDataset.merge(merge_dir,
                                           [train_data, valid_data, test_data])
        assert sorted(merged_dataset.ids) == (sorted(solubility_dataset.ids))
Ejemplo n.º 4
0
    def test_merge(self):
        """Test that datasets can be merged."""
        verbosity = "high"
        current_dir = os.path.dirname(os.path.realpath(__file__))
        first_data_dir = os.path.join(self.base_dir, "first_dataset")
        second_data_dir = os.path.join(self.base_dir, "second_dataset")
        merged_data_dir = os.path.join(self.base_dir, "merged_data")

        dataset_file = os.path.join(current_dir,
                                    "../../models/tests/example.csv")

        featurizer = CircularFingerprint(size=1024)
        tasks = ["log-solubility"]
        loader = DataLoader(tasks=tasks,
                            smiles_field="smiles",
                            featurizer=featurizer,
                            verbosity=verbosity)
        first_dataset = loader.featurize(dataset_file, first_data_dir)
        second_dataset = loader.featurize(dataset_file, second_data_dir)

        merged_dataset = DiskDataset.merge(merged_data_dir,
                                           [first_dataset, second_dataset])

        assert len(merged_dataset) == len(first_dataset) + len(second_dataset)