示例#1
0
class TestCrossValidation(unittest.TestCase):
    def setUp(self):
        self.data = Reader().read('./tests/data.txt')
        self.n_folds = 5
        self.cv = CrossValidation(data=self.data,
                                  n_folds=self.n_folds,
                                  exclude_unknowns=False)

    def test_partition_data(self):
        ref_set = set(range(self.n_folds))
        res_set = set(self.cv._partition)
        fold_sizes = np.unique(self.cv._partition, return_counts=True)[1]

        self.assertEqual(len(self.data), len(self.cv._partition))
        self.assertEqual(res_set, ref_set)
        np.testing.assert_array_equal(fold_sizes, 2)

    def test_validate_partition(self):
        try:
            self.cv._validate_partition([0, 0, 1, 1])
        except ValueError:
            assert True

        try:
            self.cv._validate_partition([0, 0, 1, 1, 2, 2, 2, 2, 3, 3])
        except ValueError:
            assert True

    def test_get_train_test_sets_next_fold(self):
        for n in range(self.cv.n_folds):
            self.cv._get_train_test()
            self.assertEqual(self.cv.current_fold, n)
            self.assertSequenceEqual(self.cv.train_set.matrix.shape, (8, 8))
            self.cv._next_fold()
示例#2
0
 def test_with_cross_validation(self):
     exp = Experiment(eval_method=CrossValidation(self.data),
                      models=[PMF(1, 0)],
                      metrics=[MAE(), RMSE(),
                               Recall(1),
                               FMeasure(1)],
                      verbose=True)
     exp.run()
示例#3
0
 def test_with_cross_validation(self):
     Experiment(eval_method=CrossValidation(
         self.data + [(self.data[0][0], self.data[1][1], 5.0)],
         exclude_unknowns=False,
         verbose=True),
                models=[PMF(1, 0)],
                metrics=[Recall(1), FMeasure(1)],
                verbose=True).run()
示例#4
0
def test_with_cross_validation():
    data_file = './tests/data.txt'
    data = reader.read_uir(data_file)
    exp = Experiment(eval_method=CrossValidation(data),
                     models=[PMF(1, 0)],
                     metrics=[MAE(), RMSE(),
                              Recall(1), FMeasure(1)],
                     verbose=True)
    exp.run()
示例#5
0
def test_partition_data():
    data = reader.read_uir('./tests/data.txt')

    nfolds = 5
    cv = CrossValidation(data=data, n_folds=nfolds)

    ref_set = set(range(nfolds))
    res_set = set(cv.partition)
    fold_sizes = np.unique(cv.partition, return_counts=True)[1]

    assert len(data) == len(cv.partition)
    assert res_set == ref_set
    assert np.all(fold_sizes == 2)
示例#6
0
def test_get_train_test_sets_next_fold():
    data = reader.read_uir('./tests/data.txt')

    nfolds = 5
    cv = CrossValidation(data=data, n_folds=nfolds)
    
    for n in range(cv.n_folds):
        cv._get_train_test()
        assert cv.current_fold == n
        assert cv.train_set.matrix.shape == (8, 8)
        cv._next_fold()
        
示例#7
0
def test_validate_partition():
    data = reader.read_uir('./tests/data.txt')

    nfolds = 5
    cv = CrossValidation(data=data, n_folds=nfolds)

    try:
        cv._validate_partition([0, 0, 1, 1])
    except:
        assert True

    try:
        cv._validate_partition([0, 0, 1, 1, 2, 2, 2, 2, 3, 3])
    except:
        assert True
示例#8
0
 def setUp(self):
     self.data = Reader().read('./tests/data.txt')
     self.n_folds = 5
     self.cv = CrossValidation(data=self.data,
                               n_folds=self.n_folds,
                               exclude_unknowns=False)
示例#9
0
 def setUp(self):
     self.data = Reader().read('./tests/data.txt')
     self.n_folds = 5
     self.cv = CrossValidation(data=self.data, n_folds=self.n_folds)
示例#10
0
def select_eval(method, dataset):
    if method == "ratio_split":
        return RatioSplit(data=dataset, test_size=0.2, 
                        rating_threshold=4.0, exclude_unknowns=False)
    elif method == "cross_validation":
        return CrossValidation(data=dataset, rating_threshold=1.0, exclude_unknowns=False)