예제 #1
0
 def test_multitask_index_split(self):
     """
 Test multitask IndexSplitter class.
 """
     multitask_dataset = self.load_multitask_data()
     index_splitter = IndexSplitter()
     train_data, valid_data, test_data = \
         index_splitter.train_valid_test_split(
             multitask_dataset,
             self.train_dir, self.valid_dir, self.test_dir,
             frac_train=0.8, frac_valid=0.1, frac_test=0.1)
     assert len(train_data) == 8
     assert len(valid_data) == 1
     assert len(test_data) == 1
예제 #2
0
    def test_singletask_index_split(self):
        """
    Test singletask RandomSplitter class.
    """
        solubility_dataset = self.load_solubility_data()
        random_splitter = IndexSplitter()
        train_data, valid_data, test_data = \
            random_splitter.train_valid_test_split(
                solubility_dataset,
                self.train_dir, self.valid_dir, self.test_dir,
                frac_train=0.8, frac_valid=0.1, frac_test=0.1)
        assert len(train_data) == 8
        assert len(valid_data) == 1
        assert len(test_data) == 1

        merge_dir = tempfile.mkdtemp()
        merged_dataset = DiskDataset.merge(merge_dir,
                                           [train_data, valid_data, test_data])
        assert sorted(merged_dataset.ids) == (sorted(solubility_dataset.ids))