Ejemplo n.º 1
0
 def test_circular_tuning_model(self):
     data = Series(self.sc.parallelize([(1, array([1.5, 2.3, 6.2, 5.1, 3.4, 2.1]))]))
     s = array([-pi/2, -pi/3, -pi/4, pi/4, pi/3, pi/2])
     model = TuningModel.load(s, "circular")
     params = model.fit(data)
     tol = 1E-4  # to handle rounding errors
     assert(allclose(params.select('center').values().collect()[0], array([0.10692]), atol=tol))
     assert(allclose(params.select('spread').values().collect()[0], array([1.61944]), atol=tol))
Ejemplo n.º 2
0
 def test_gaussian_tuning_model(self):
     data = Series(self.sc.parallelize([(1, array([1.5, 2.3, 6.2, 5.1, 3.4, 2.1]))]))
     s = array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
     model = TuningModel.load(s, "gaussian")
     params = model.fit(data)
     tol = 1E-4  # to handle rounding errors
     assert(allclose(params.select('center').values().collect()[0], array([0.36262]), atol=tol))
     assert(allclose(params.select('spread').values().collect()[0], array([0.01836]), atol=tol))
Ejemplo n.º 3
0
 def test_circularTuningModel(self):
     data = Series(
         self.sc.parallelize([(1, array([1.5, 2.3, 6.2, 5.1, 3.4, 2.1]))]))
     s = array([-pi / 2, -pi / 3, -pi / 4, pi / 4, pi / 3, pi / 2])
     model = TuningModel.load(s, "circular")
     params = model.fit(data)
     tol = 1E-4  # to handle rounding errors
     assert (allclose(params.select('center').values().collect()[0],
                      array([0.10692]),
                      atol=tol))
     assert (allclose(params.select('spread').values().collect()[0],
                      array([1.61944]),
                      atol=tol))
Ejemplo n.º 4
0
 def test_gaussianTuningModel(self):
     data = Series(
         self.sc.parallelize([(1, array([1.5, 2.3, 6.2, 5.1, 3.4, 2.1]))]))
     s = array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
     model = TuningModel.load(s, "gaussian")
     params = model.fit(data)
     tol = 1E-4  # to handle rounding errors
     assert (allclose(params.select('center').values().collect()[0],
                      array([0.36262]),
                      atol=tol))
     assert (allclose(params.select('spread').values().collect()[0],
                      array([0.01836]),
                      atol=tol))