コード例 #1
0
ファイル: saga_test.py プロジェクト: rafael-glima/tick
    def test_variance_reduction_setting(self):
        """...Test SAGA variance_reduction parameter is correctly set
        """
        svrg = SAGA()
        self.assertEqual(svrg.variance_reduction, 'last')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SAGA.VarianceReductionMethod_Last)

        svrg = SAGA(variance_reduction='rand')
        self.assertEqual(svrg.variance_reduction, 'rand')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SAGA.VarianceReductionMethod_Random)

        svrg.variance_reduction = 'avg'
        self.assertEqual(svrg.variance_reduction, 'avg')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SAGA.VarianceReductionMethod_Average)

        svrg.variance_reduction = 'rand'
        self.assertEqual(svrg.variance_reduction, 'rand')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SAGA.VarianceReductionMethod_Random)

        svrg.variance_reduction = 'last'
        self.assertEqual(svrg.variance_reduction, 'last')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SAGA.VarianceReductionMethod_Last)

        with self.assertRaises(ValueError):
            svrg.variance_reduction = 'wrong_name'
コード例 #2
0
    def test_variance_reduction_setting(self):
        """...SolverTest SAGA variance_reduction parameter is correctly set"""
        svrg = SAGA()

        coeffs0 = weights_sparse_gauss(20, nnz=5, dtype=self.dtype)
        interc0 = None

        X, y = SimuLogReg(coeffs0,
                          interc0,
                          n_samples=3000,
                          verbose=False,
                          seed=123,
                          dtype=self.dtype).simulate()

        model = ModelLogReg().fit(X, y)
        svrg.set_model(model)
        svrg.astype(self.dtype)
        self.assertEqual(svrg.variance_reduction, 'last')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         SAGA_VarianceReductionMethod_Last)

        svrg = SAGA(variance_reduction='rand')
        svrg.set_model(model)
        svrg.astype(self.dtype)
        self.assertEqual(svrg.variance_reduction, 'rand')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         SAGA_VarianceReductionMethod_Random)

        svrg.variance_reduction = 'avg'
        self.assertEqual(svrg.variance_reduction, 'avg')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         SAGA_VarianceReductionMethod_Average)

        svrg.variance_reduction = 'rand'
        self.assertEqual(svrg.variance_reduction, 'rand')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         SAGA_VarianceReductionMethod_Random)

        svrg.variance_reduction = 'last'
        self.assertEqual(svrg.variance_reduction, 'last')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         SAGA_VarianceReductionMethod_Last)

        with self.assertRaises(ValueError):
            svrg.variance_reduction = 'wrong_name'