def __init__(self):
     super(Container, self).__init__()
     self.basis = Basis(self.X.shape[1])
     self.parent = ConstantBasisFunction()
     self.bf1 = HingeBasisFunction(self.parent, 1.0, 10, 1, False)
     self.bf2 = HingeBasisFunction(self.parent, 1.0, 4, 2, True)
     self.bf3 = HingeBasisFunction(self.bf2, 1.0, 4, 3, True)
     self.bf4 = LinearBasisFunction(self.parent, 2)
     self.bf5 = HingeBasisFunction(self.parent, 1.5, 8, 2, True)
     self.basis.append(self.parent)
     self.basis.append(self.bf1)
     self.basis.append(self.bf2)
     self.basis.append(self.bf3)
     self.basis.append(self.bf4)
     self.basis.append(self.bf5)
Beispiel #2
0
import numpy
from scipy.sparse import csr_matrix
from pyearth._types import BOOL
from pyearth._basis import (Basis, ConstantBasisFunction, HingeBasisFunction,
                            LinearBasisFunction)
from pyearth import Earth
import pyearth
from numpy.testing.utils import assert_array_almost_equal

regenerate_target_files = False

numpy.random.seed(1)
basis = Basis(10)
constant = ConstantBasisFunction()
basis.append(constant)
bf1 = HingeBasisFunction(constant, 0.1, 10, 1, False, 'x1')
bf2 = HingeBasisFunction(constant, 0.1, 10, 1, True, 'x1')
bf3 = LinearBasisFunction(bf1, 2, 'x2')
basis.append(bf1)
basis.append(bf2)
basis.append(bf3)
X = numpy.random.normal(size=(1000, 10))
missing = numpy.zeros_like(X, dtype=BOOL)
B = numpy.empty(shape=(1000, 4), dtype=numpy.float64)
basis.transform(X, missing, B)
beta = numpy.random.normal(size=4)
y = numpy.empty(shape=1000, dtype=numpy.float64)
y[:] = numpy.dot(B, beta) + numpy.random.normal(size=1000)
default_params = {"penalty": 1}

Beispiel #3
0
 def __init__(self):
     super(Container, self).__init__()
     self.parent = ConstantBasisFunction()
     self.bf = HingeBasisFunction(self.parent, 1.0, 10, 1, False)
Beispiel #4
0
 def __init__(self):
     super(Container, self).__init__()
     self.parent = ConstantBasisFunction()
     self.bf = MissingnessBasisFunction(self.parent, 1, True)
     self.child = HingeBasisFunction(self.bf, 1.0, 10, 1, False)
Beispiel #5
0
 def __init__(self):
     super(self.__class__, self).__init__()
     self.parent = ConstantBasisFunction()
     self.bf = HingeBasisFunction(self.parent, 1.0, 10, 1, False)
Beispiel #6
0
class TestHingeBasisFunction(BaseTestClass):

    def __init__(self):
        super(self.__class__, self).__init__()
        self.parent = ConstantBasisFunction()
        self.bf = HingeBasisFunction(self.parent, 1.0, 10, 1, False)

    def test_getters(self):
        assert not self.bf.get_reverse()
        assert self.bf.get_knot() == 1.0
        assert self.bf.get_variable() == 1
        assert self.bf.get_knot_idx() == 10
        assert self.bf.get_parent() == self.parent

    def test_apply(self):
        m, _ = self.X.shape
        B = numpy.ones(shape=(m, 10))
        self.bf.apply(self.X, B[:, 0])
        numpy.testing.assert_almost_equal(
            B[:, 0],
            (self.X[:, 1] - 1.0) * (self.X[:, 1] > 1.0)
        )

    def test_apply_deriv(self):
        m, _ = self.X.shape
        b = numpy.empty(shape=m)
        j = numpy.empty(shape=m)
        self.bf.apply_deriv(self.X, b, j, 1)
        numpy.testing.assert_almost_equal(
            (self.X[:, 1] - 1.0) * (self.X[:, 1] > 1.0),
            b
        )
        numpy.testing.assert_almost_equal(1.0 * (self.X[:, 1] > 1.0),
                                          j)

    def test_degree(self):
        assert_equal(self.bf.degree(), 1)

    def test_pickle_compatibility(self):
        bf_copy = pickle.loads(pickle.dumps(self.bf))
        assert_true(self.bf == bf_copy)

    def test_smoothed_version(self):
        knot_dict = {self.bf: (.5, 1.5)}
        translation = {self.parent: self.parent._smoothed_version(None, {}, {})}
        smoothed = self.bf._smoothed_version(self.parent,
                                             knot_dict, translation)
        assert_true(type(smoothed) is SmoothedHingeBasisFunction)
        assert_true(translation[self.parent] is smoothed.get_parent())
        assert_equal(smoothed.get_knot_minus(), 0.5)
        assert_equal(smoothed.get_knot_plus(), 1.5)
        assert_equal(smoothed.get_knot(), self.bf.get_knot())
        assert_equal(smoothed.get_variable(), self.bf.get_variable())