예제 #1
0
 def test_from_dict_nofit(self):
     eta = 1
     for kernel in kernels_dict.values():
         k = kernel(2)
         for tail in tails_dict.values():
             t = tail()
             input_dict = {
                 "eta": eta,
                 "kernel": k.to_dict(),
                 "tail": t.to_dict()
             }
             if t.degree >= k.dmin:
                 model = RBF.from_dict(input_dict)
                 assert model._eta == eta
                 assert isinstance(model._kernel, kernel)
                 assert model._kernel.param == 2
                 assert isinstance(model._tail, tail)
             else:
                 with pytest.raises(ValueError):
                     model = RBF.from_dict(input_dict)
예제 #2
0
 def test_from_dict_fit_1d(self):
     x = np.linspace(0, 2 * np.pi, num=8)
     y = func_1d(x)
     eta = 1e-3
     p = 3
     for kernel in kernels_dict.values():
         k = kernel(2)
         for tail in tails_dict.values():
             t = tail()
             if t.degree >= k.dmin:
                 _model = RBF(k, t, eta)
                 _model.fit(x, y, p)
                 input_dict = {
                     "X": x.reshape((-1, 1)).tolist(),
                     "y": y.tolist(),
                     "eta": eta,
                     "kernel": k.to_dict(),
                     "lambda": _model._lambda.tolist(),
                     "LU": _model._LU.tolist(),
                     "piv": _model._piv.tolist(),
                     "loo": _model._loo_residuals.tolist(),
                     "tail": _model._tail.to_dict()
                 }
                 model = RBF.from_dict(input_dict)
                 assert model._X.shape == (8, 1)
                 assert np.allclose(model._X.ravel(), x)
                 assert model._y.shape == (8, )
                 assert np.allclose(model._y, y)
                 assert model._eta == eta
                 assert isinstance(model._kernel, kernel)
                 assert model._kernel.param == p
                 assert np.allclose(model._lambda, _model._lambda)
                 assert np.allclose(model._LU, _model._LU)
                 assert np.allclose(model._piv, _model._piv)
                 assert np.allclose(model._loo_residuals,
                                    _model._loo_residuals)
                 assert isinstance(model._tail, _model._tail.__class__)
                 assert np.all(model._tail.params == _model._tail.params)