예제 #1
0
class TestBasis(BaseTestClass):

    def __init__(self):
        super(self.__class__, self).__init__()
        self.basis = Basis(self.X.shape[1])
        self.parent = ConstantBasisFunction()
        self.bf1 = HingeBasisFunction(self.parent, 1.0, 10, 1, False)
        self.bf2 = HingeBasisFunction(self.parent, 1.0, 4, 2, True)
        self.bf3 = HingeBasisFunction(self.bf2, 1.0, 4, 3, True)
        self.bf4 = LinearBasisFunction(self.parent, 2)
        self.bf5 = HingeBasisFunction(self.parent, 1.5, 8, 2, True)
        self.basis.append(self.parent)
        self.basis.append(self.bf1)
        self.basis.append(self.bf2)
        self.basis.append(self.bf3)
        self.basis.append(self.bf4)
        self.basis.append(self.bf5)

    def test_anova_decomp(self):
        anova = self.basis.anova_decomp()
        assert_equal(set(anova[frozenset([1])]), set([self.bf1]))
        assert_equal(set(anova[frozenset([2])]), set([self.bf2, self.bf4,
                                                      self.bf5]))
        assert_equal(set(anova[frozenset([2, 3])]), set([self.bf3]))
        assert_equal(set(anova[frozenset()]), set([self.parent]))
        assert_equal(len(anova), 4)

    def test_smooth_knots(self):
        mins = [0.0, -1.0, 0.1, 0.2]
        maxes = [2.5, 3.5, 3.0, 2.0]
        knots = self.basis.smooth_knots(mins, maxes)
        assert_equal(knots[self.bf1], (0.0, 2.25))
        assert_equal(knots[self.bf2], (0.55, 1.25))
        assert_equal(knots[self.bf3], (0.6,  1.5))
        assert_true(self.bf4 not in knots)
        assert_equal(knots[self.bf5], (1.25, 2.25))

    def test_smooth(self):
        X = numpy.random.uniform(-2.0, 4.0, size=(20, 4))
        smooth_basis = self.basis.smooth(X)
        for bf, smooth_bf in zip(self.basis, smooth_basis):
            if type(bf) is HingeBasisFunction:
                assert_true(type(smooth_bf) is SmoothedHingeBasisFunction)
            elif type(bf) is ConstantBasisFunction:
                assert_true(type(smooth_bf) is ConstantBasisFunction)
            elif type(bf) is LinearBasisFunction:
                assert_true(type(smooth_bf) is LinearBasisFunction)
            else:
                raise AssertionError('Basis function is of an unexpected type.')
            assert_true(type(smooth_bf) in {SmoothedHingeBasisFunction,
                                            ConstantBasisFunction,
                                            LinearBasisFunction})
            if bf.has_knot():
                assert_equal(bf.get_knot(), smooth_bf.get_knot())

    def test_add(self):
        assert_equal(len(self.basis), 6)

    def test_pickle_compat(self):
        basis_copy = pickle.loads(pickle.dumps(self.basis))
        assert_true(self.basis == basis_copy)