Ejemplo n.º 1
0
    def test_singletask_random_k_fold_split(self):
        """
    Test singletask RandomSplitter class.
    """
        solubility_dataset = self.load_solubility_data()
        random_splitter = RandomSplitter()
        ids_set = set(solubility_dataset.ids)

        K = 5
        fold_dirs = [tempfile.mkdtemp() for i in range(K)]
        fold_datasets = random_splitter.k_fold_split(solubility_dataset,
                                                     fold_dirs)
        for fold in range(K):
            fold_dataset = fold_datasets[fold]
            # Verify lengths is 10/k == 2
            assert len(fold_dataset) == 2
            # Verify that compounds in this fold are subset of original compounds
            fold_ids_set = set(fold_dataset.ids)
            assert fold_ids_set.issubset(ids_set)
            # Verify that no two folds have overlapping compounds.
            for other_fold in range(K):
                if fold == other_fold:
                    continue
                other_fold_dataset = fold_datasets[other_fold]
                other_fold_ids_set = set(other_fold_dataset.ids)
                assert fold_ids_set.isdisjoint(other_fold_ids_set)

        merge_dir = tempfile.mkdtemp()
        merged_dataset = DiskDataset.merge(merge_dir, fold_datasets)
        assert len(merged_dataset) == len(solubility_dataset)
        assert sorted(merged_dataset.ids) == (sorted(solubility_dataset.ids))
Ejemplo n.º 2
0
            logging.info(f"Scaffolds sets size: {len(scaffold_sets)}")
            logging.info(
                f"Scaffolds length: {[len(sfd) for sfd in scaffold_sets]}")
            logging.info(f"Raw scaffolds: {scaffold_sets}")
            scaffold_sets_filt = [
                sfd for sfd in scaffold_sets if len(sfd) >= 100
            ]

            for sfd_filt in scaffold_sets_filt:
                sfd_name = "scaffold_" + str(len(sfd_filt))
                results[splitter_name][model_name][sfd_name] = {}
                logging.info(f"Scaffold size: {len(sfd_filt)}")

                data_subset = wang_train.select(indices=sfd_filt)
                k_fold = splitter_rand.k_fold_split(data_subset, k=10)
                for i, fold in enumerate(k_fold):
                    model = get_model(model_name)
                    train, valid = fold
                    logging.info(
                        f"Train size: {len(train)}, Valid size: {len(valid)}")

                    model.fit(train)
                    train_scores = model.evaluate(train, [metric],
                                                  wang_transformers)
                    valid_scores = model.evaluate(valid, [metric],
                                                  wang_transformers)

                    fold_name = "fold_" + str(i)
                    results[splitter_name][model_name][sfd_name][
                        fold_name] = {}