class TestBasis(BaseTestClass): def __init__(self): super(self.__class__, self).__init__() self.basis = Basis(self.X.shape[1]) self.parent = ConstantBasisFunction() self.bf1 = HingeBasisFunction(self.parent, 1.0, 10, 1, False) self.bf2 = HingeBasisFunction(self.parent, 1.0, 4, 2, True) self.bf3 = HingeBasisFunction(self.bf2, 1.0, 4, 3, True) self.bf4 = LinearBasisFunction(self.parent, 2) self.bf5 = HingeBasisFunction(self.parent, 1.5, 8, 2, True) self.basis.append(self.parent) self.basis.append(self.bf1) self.basis.append(self.bf2) self.basis.append(self.bf3) self.basis.append(self.bf4) self.basis.append(self.bf5) def test_anova_decomp(self): anova = self.basis.anova_decomp() assert_equal(set(anova[frozenset([1])]), set([self.bf1])) assert_equal(set(anova[frozenset([2])]), set([self.bf2, self.bf4, self.bf5])) assert_equal(set(anova[frozenset([2, 3])]), set([self.bf3])) assert_equal(set(anova[frozenset()]), set([self.parent])) assert_equal(len(anova), 4) def test_smooth_knots(self): mins = [0.0, -1.0, 0.1, 0.2] maxes = [2.5, 3.5, 3.0, 2.0] knots = self.basis.smooth_knots(mins, maxes) assert_equal(knots[self.bf1], (0.0, 2.25)) assert_equal(knots[self.bf2], (0.55, 1.25)) assert_equal(knots[self.bf3], (0.6, 1.5)) assert_true(self.bf4 not in knots) assert_equal(knots[self.bf5], (1.25, 2.25)) def test_smooth(self): X = numpy.random.uniform(-2.0, 4.0, size=(20, 4)) smooth_basis = self.basis.smooth(X) for bf, smooth_bf in zip(self.basis, smooth_basis): if type(bf) is HingeBasisFunction: assert_true(type(smooth_bf) is SmoothedHingeBasisFunction) elif type(bf) is ConstantBasisFunction: assert_true(type(smooth_bf) is ConstantBasisFunction) elif type(bf) is LinearBasisFunction: assert_true(type(smooth_bf) is LinearBasisFunction) else: raise AssertionError('Basis function is of an unexpected type.') assert_true(type(smooth_bf) in {SmoothedHingeBasisFunction, ConstantBasisFunction, LinearBasisFunction}) if bf.has_knot(): assert_equal(bf.get_knot(), smooth_bf.get_knot()) def test_add(self): assert_equal(len(self.basis), 6) def test_pickle_compat(self): basis_copy = pickle.loads(pickle.dumps(self.basis)) assert_true(self.basis == basis_copy)