コード例 #1
0
ファイル: svrg_test.py プロジェクト: tozammel/tick
    def test_variance_reduction_setting(self):
        """...Test SVRG variance_reduction parameter is correctly set
        """
        svrg = SVRG()
        self.assertEqual(svrg.variance_reduction, 'last')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SVRG.VarianceReductionMethod_Last)

        svrg = SVRG(variance_reduction='rand')
        self.assertEqual(svrg.variance_reduction, 'rand')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SVRG.VarianceReductionMethod_Random)

        svrg.variance_reduction = 'avg'
        self.assertEqual(svrg.variance_reduction, 'avg')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SVRG.VarianceReductionMethod_Average)

        svrg.variance_reduction = 'rand'
        self.assertEqual(svrg.variance_reduction, 'rand')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SVRG.VarianceReductionMethod_Random)

        svrg.variance_reduction = 'last'
        self.assertEqual(svrg.variance_reduction, 'last')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SVRG.VarianceReductionMethod_Last)

        with self.assertRaises(ValueError):
            svrg.variance_reduction = 'wrong_name'
コード例 #2
0
    def test_variance_reduction_setting(self):
        """...Test that SVRG variance_reduction parameter behaves correctly
        """
        svrg = SVRG()
        self.assertEqual(svrg.variance_reduction, 'last')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SVRG.VarianceReductionMethod_Last)

        svrg = SVRG(variance_reduction='rand')
        self.assertEqual(svrg.variance_reduction, 'rand')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SVRG.VarianceReductionMethod_Random)

        svrg.variance_reduction = 'avg'
        self.assertEqual(svrg.variance_reduction, 'avg')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SVRG.VarianceReductionMethod_Average)

        svrg.variance_reduction = 'rand'
        self.assertEqual(svrg.variance_reduction, 'rand')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SVRG.VarianceReductionMethod_Random)

        svrg.variance_reduction = 'last'
        self.assertEqual(svrg.variance_reduction, 'last')
        self.assertEqual(svrg._solver.get_variance_reduction(),
                         _SVRG.VarianceReductionMethod_Last)

        msg = '^variance_reduction should be one of "avg, last, rand", ' \
              'got "stuff"$'
        with self.assertRaisesRegex(ValueError, msg):
            svrg = SVRG(variance_reduction='stuff')
        with self.assertRaisesRegex(ValueError, msg):
            svrg.variance_reduction = 'stuff'

        X, y = self.simu_linreg_data()
        model_dense, model_spars = self.get_dense_and_sparse_linreg_model(X, y)
        try:
            svrg.set_model(model_dense)
            svrg.variance_reduction = 'avg'
            svrg.variance_reduction = 'last'
            svrg.variance_reduction = 'rand'
            svrg.set_model(model_spars)
            svrg.variance_reduction = 'last'
            svrg.variance_reduction = 'rand'
        except Exception:
            self.fail('Setting variance_reduction in these cases should have '
                      'been ok')

        msg = "'avg' variance reduction cannot be used with sparse datasets"
        with catch_warnings(record=True) as w:
            simplefilter('always')
            svrg.set_model(model_spars)
            svrg.variance_reduction = 'avg'
            self.assertEqual(len(w), 1)
            self.assertTrue(issubclass(w[0].category, UserWarning))
            self.assertEqual(str(w[0].message), msg)