Example #1
0
def test_adagrad_hinge_multiclass():
    clf = AdaGradClassifier(alpha=1e-2,
                            n_iter=100,
                            loss="hinge",
                            random_state=0)
    clf.fit(X, y)
    assert_almost_equal(clf.score(X, y), 0.960, 3)
Example #2
0
def test_adagrad_elastic_hinge():
    clf = AdaGradClassifier(alpha=0.5,
                            l1_ratio=0.85,
                            n_iter=10,
                            random_state=0)
    clf.fit(X_bin, y_bin)
    assert_equal(clf.score(X_bin, y_bin), 1.0)
Example #3
0
def test_adagrad_elastic_hinge():
    clf = AdaGradClassifier(alpha=0.5,
                            l1_ratio=0.85,
                            n_iter=10,
                            random_state=0)
    clf.fit(X_bin, y_bin)
    assert not hasattr(clf, "predict_proba")
    assert clf.score(X_bin, y_bin) == 1.0
Example #4
0
def test_adagrad_hinge_multiclass():
    clf = AdaGradClassifier(alpha=1e-2,
                            n_iter=100,
                            loss="hinge",
                            random_state=0)
    clf.fit(X, y)
    assert not hasattr(clf, "predict_proba")
    np.testing.assert_almost_equal(clf.score(X, y), 0.940, 3)
Example #5
0
def test_adagrad_elastic_log():
    clf = AdaGradClassifier(alpha=0.1,
                            l1_ratio=0.85,
                            loss="log",
                            n_iter=10,
                            random_state=0)
    clf.fit(X_bin, y_bin)
    assert clf.score(X_bin, y_bin) == 1.0
    check_predict_proba(clf, X_bin)
Example #6
0
def test_adagrad_elastic_smooth_hinge(bin_train_data):
    X_bin, y_bin = bin_train_data
    clf = AdaGradClassifier(alpha=0.5,
                            l1_ratio=0.85,
                            loss="smooth_hinge",
                            n_iter=10,
                            random_state=0)
    clf.fit(X_bin, y_bin)
    assert not hasattr(clf, "predict_proba")
    assert clf.score(X_bin, y_bin) == 1.0
def test_adagrad_callback():
    class Callback(object):
        def __init__(self, X, y):
            self.X = X
            self.y = y
            self.acc = []

        def __call__(self, clf, t):
            alpha1 = clf.l1_ratio * clf.alpha
            alpha2 = (1 - clf.l1_ratio) * clf.alpha
            _proj_elastic_all(clf.eta, t, clf.g_sum_[0], clf.g_norms_[0], alpha1, alpha2, 0, clf.coef_[0])
            score = clf.score(self.X, self.y)
            self.acc.append(score)

    cb = Callback(X_bin, y_bin)
    clf = AdaGradClassifier(alpha=0.5, l1_ratio=0.85, n_iter=10, callback=cb, random_state=0)
    clf.fit(X_bin, y_bin)
    assert_equal(cb.acc[-1], 1.0)
Example #8
0
def test_adagrad_callback():
    class Callback(object):
        def __init__(self, X, y):
            self.X = X
            self.y = y
            self.acc = []

        def __call__(self, clf, t):
            alpha1 = clf.l1_ratio * clf.alpha
            alpha2 = (1 - clf.l1_ratio) * clf.alpha
            _proj_elastic_all(clf.eta, t, clf.g_sum_[0], clf.g_norms_[0],
                              alpha1, alpha2, 0, clf.coef_[0])
            score = clf.score(self.X, self.y)
            self.acc.append(score)

    cb = Callback(X_bin, y_bin)
    clf = AdaGradClassifier(alpha=0.5,
                            l1_ratio=0.85,
                            n_iter=10,
                            callback=cb,
                            random_state=0)
    clf.fit(X_bin, y_bin)
    assert_equal(cb.acc[-1], 1.0)
Example #9
0
def test_lightning_binary_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    feature_weight_mul = [
        ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(0.1605265174),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.1045225083),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.6237391536),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.1680225811),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(4), ast.NumVal(0.0011013688),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(5), ast.NumVal(-0.0027528486),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(6), ast.NumVal(-0.0058878714),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(7), ast.NumVal(-0.0023719811),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(8), ast.NumVal(0.0019944105),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(9), ast.NumVal(0.0009924456),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(10), ast.NumVal(0.0003994860),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(11), ast.NumVal(0.0124697033),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(12), ast.NumVal(-0.0123674096),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(13), ast.NumVal(-0.2844204905),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(14), ast.NumVal(0.0000273704),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(15), ast.NumVal(-0.0007498013),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(16), ast.NumVal(-0.0010784399),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(17), ast.NumVal(-0.0001848694),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(18), ast.NumVal(0.0000632254),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(19), ast.NumVal(-0.0000369618),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(20), ast.NumVal(0.1520223021),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(21), ast.NumVal(0.0925348635),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(22), ast.NumVal(0.4861047372),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(23), ast.NumVal(-0.2798670185),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(24), ast.NumVal(0.0009925506),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(25), ast.NumVal(-0.0103414976),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(26), ast.NumVal(-0.0155024577),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(27), ast.NumVal(-0.0038881538),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(28), ast.NumVal(0.0010126166),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(29), ast.NumVal(0.0002312558),
                       ast.BinNumOpType.MUL),
    ]

    expected = assemblers.utils.apply_op_to_expressions(
        ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul)

    assert utils.cmp_exprs(actual, expected)
Example #10
0
def test_adagrad_elastic_hinge():
    clf = AdaGradClassifier(alpha=0.5, l1_ratio=0.85, n_iter=10, random_state=0)
    clf.fit(X_bin, y_bin)
    assert_equal(clf.score(X_bin, y_bin), 1.0)
Example #11
0
def test_adagrad_classes_multiclass(train_data):
    X, y = train_data
    clf = AdaGradClassifier()
    assert not hasattr(clf, 'classes_')
    clf.fit(X, y)
    assert list(clf.classes_) == [0, 1, 2]
Example #12
0
def test_lightning_multi_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    expected = ast.VectorVal([
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(ast.FeatureRef(0),
                                       ast.NumVal(0.0895848274),
                                       ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.3258329434),
                                   ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(-0.4900856238),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(-0.2214482506),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(ast.FeatureRef(0),
                                       ast.NumVal(-0.1074247041),
                                       ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(ast.FeatureRef(1),
                                   ast.NumVal(-0.1693225196),
                                   ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.0357417324),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(-0.0161614171),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(ast.FeatureRef(0),
                                       ast.NumVal(-0.1825063678),
                                       ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(ast.FeatureRef(1),
                                   ast.NumVal(-0.2185655665),
                                   ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.3053017646),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.2175198459),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD)
    ])

    assert utils.cmp_exprs(actual, expected)
def test_adagrad_elastic_smooth_hinge():
    clf = AdaGradClassifier(alpha=0.5, l1_ratio=0.85, loss="smooth_hinge", n_iter=10, random_state=0)
    clf.fit(X_bin, y_bin)
    assert not hasattr(clf, "predict_proba")
    assert_equal(clf.score(X_bin, y_bin), 1.0)
Example #14
0
def test_adagrad_classes_binary(bin_train_data):
    X_bin, y_bin = bin_train_data
    clf = AdaGradClassifier()
    assert not hasattr(clf, 'classes_')
    clf.fit(X_bin, y_bin)
    assert list(clf.classes_) == [-1, 1]
Example #15
0
def test_lightning_multi_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    expected = ast.VectorVal([
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(ast.FeatureRef(0),
                                       ast.NumVal(0.0935146297),
                                       ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.3213921354),
                                   ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(-0.4855914264),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(-0.2214295302),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(ast.FeatureRef(0),
                                       ast.NumVal(-0.1103262586),
                                       ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(ast.FeatureRef(1),
                                   ast.NumVal(-0.1662457692),
                                   ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.0379823341),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(-0.0128634938),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(ast.FeatureRef(0),
                                       ast.NumVal(-0.1685751402),
                                       ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(ast.FeatureRef(1),
                                   ast.NumVal(-0.2045901693),
                                   ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.2932121798),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.2138148665),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD)
    ])

    assert utils.cmp_exprs(actual, expected)
Example #16
0
def test_lightning_multi_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    expected = ast.VectorVal([
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(
                            ast.FeatureRef(0),
                            ast.NumVal(0.09686334892116512),
                            ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(
                        ast.FeatureRef(1),
                        ast.NumVal(0.32572202133211947),
                        ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(
                    ast.FeatureRef(2),
                    ast.NumVal(-0.48444233646554424),
                    ast.BinNumOpType.MUL),
                ast.BinNumOpType.ADD),
            ast.BinNumExpr(
                ast.FeatureRef(3),
                ast.NumVal(-0.219719145605816),
                ast.BinNumOpType.MUL),
            ast.BinNumOpType.ADD),
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(
                            ast.FeatureRef(0),
                            ast.NumVal(-0.1089136473832088),
                            ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(
                        ast.FeatureRef(1),
                        ast.NumVal(-0.16956003333433572),
                        ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(
                    ast.FeatureRef(2),
                    ast.NumVal(0.0365531256007199),
                    ast.BinNumOpType.MUL),
                ast.BinNumOpType.ADD),
            ast.BinNumExpr(
                ast.FeatureRef(3),
                ast.NumVal(-0.01016100116780896),
                ast.BinNumOpType.MUL),
            ast.BinNumOpType.ADD),
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(
                            ast.FeatureRef(0),
                            ast.NumVal(-0.16690339219780817),
                            ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(
                        ast.FeatureRef(1),
                        ast.NumVal(-0.19466284646233858),
                        ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(
                    ast.FeatureRef(2),
                    ast.NumVal(0.2953585236360389),
                    ast.BinNumOpType.MUL),
                ast.BinNumOpType.ADD),
            ast.BinNumExpr(
                ast.FeatureRef(3),
                ast.NumVal(0.21288203082531384),
                ast.BinNumOpType.MUL),
            ast.BinNumOpType.ADD)])

    assert utils.cmp_exprs(actual, expected)
Example #17
0
def test_adagrad_classes_multiclass():
    clf = AdaGradClassifier()
    assert not hasattr(clf, 'classes_')
    clf.fit(X, y)
    assert_equal(list(clf.classes_), [0, 1, 2])
def test_adagrad_elastic_log():
    clf = AdaGradClassifier(alpha=0.1, l1_ratio=0.85, loss="log", n_iter=10, random_state=0)
    clf.fit(X_bin, y_bin)
    assert_equal(clf.score(X_bin, y_bin), 1.0)
    check_predict_proba(clf, X_bin)
def test_adagrad_classes_multiclass():
    clf = AdaGradClassifier()
    assert not hasattr(clf, "classes_")
    clf.fit(X, y)
    assert_equal(list(clf.classes_), [0, 1, 2])
def test_adagrad_classes_binary():
    clf = AdaGradClassifier()
    assert not hasattr(clf, "classes_")
    clf.fit(X_bin, y_bin)
    assert_equal(list(clf.classes_), [-1, 1])
def test_adagrad_hinge_multiclass():
    clf = AdaGradClassifier(alpha=1e-2, n_iter=100, loss="hinge", random_state=0)
    clf.fit(X, y)
    assert not hasattr(clf, "predict_proba")
    assert_almost_equal(clf.score(X, y), 0.960, 3)
Example #22
0
def test_adagrad_classes_binary():
    clf = AdaGradClassifier()
    assert not hasattr(clf, 'classes_')
    clf.fit(X_bin, y_bin)
    assert_equal(list(clf.classes_), [-1, 1])
Example #23
0
clf2 = SDCAClassifier(loss="squared_hinge",
                      alpha=alpha,
                      max_iter=100,
                      n_calls=X.shape[0] / 2,
                      random_state=0,
                      tol=tol)
clf3 = CDClassifier(loss="squared_hinge",
                    alpha=alpha,
                    C=1.0 / X.shape[0],
                    max_iter=50,
                    n_calls=X.shape[1] / 3,
                    random_state=0,
                    tol=tol)
clf4 = AdaGradClassifier(loss="squared_hinge",
                         alpha=alpha,
                         eta=eta_adagrad,
                         n_iter=100,
                         n_calls=X.shape[0] / 2,
                         random_state=0)
clf5 = SAGAClassifier(loss="squared_hinge",
                      alpha=alpha,
                      max_iter=100,
                      random_state=0,
                      tol=tol)
clf6 = SAGClassifier(loss="squared_hinge",
                     alpha=alpha,
                     max_iter=100,
                     random_state=0,
                     tol=tol)

plt.figure()
Example #24
0
def test_lightning_binary_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    feature_weight_mul = [
        ast.BinNumExpr(
            ast.FeatureRef(0),
            ast.NumVal(0.16218889967390476),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(1),
            ast.NumVal(0.10012761963766906),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(2),
            ast.NumVal(0.6289276652681673),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(3),
            ast.NumVal(0.17618420156072845),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(4),
            ast.NumVal(0.0010492096607182045),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(5),
            ast.NumVal(-0.0029135563693806913),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(6),
            ast.NumVal(-0.005923882409142498),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(7),
            ast.NumVal(-0.0023293599172479755),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(8),
            ast.NumVal(0.0020808828960210517),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(9),
            ast.NumVal(0.0009846430705550103),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(10),
            ast.NumVal(0.0010399810925427265),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(11),
            ast.NumVal(0.011203056917272093),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(12),
            ast.NumVal(-0.007271351370867731),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(13),
            ast.NumVal(-0.26333437096804224),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(14),
            ast.NumVal(1.8533543368532444e-05),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(15),
            ast.NumVal(-0.0008266341686278445),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(16),
            ast.NumVal(-0.0011090316301215724),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(17),
            ast.NumVal(-0.0001910857095336291),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(18),
            ast.NumVal(0.00010735116208006556),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(19),
            ast.NumVal(-4.076097659514017e-05),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(20),
            ast.NumVal(0.15300712110146406),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(21),
            ast.NumVal(0.06316277258339074),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(22),
            ast.NumVal(0.495291178977687),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(23),
            ast.NumVal(-0.29589136204657845),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(24),
            ast.NumVal(0.000771932729567487),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(25),
            ast.NumVal(-0.011877978242492428),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(26),
            ast.NumVal(-0.01678004536869617),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(27),
            ast.NumVal(-0.004070431062579625),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(28),
            ast.NumVal(0.001158641497209262),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(29),
            ast.NumVal(0.00010737287732588742),
            ast.BinNumOpType.MUL),
    ]

    expected = assemblers.utils.apply_op_to_expressions(
        ast.BinNumOpType.ADD,
        ast.NumVal(0.0),
        *feature_weight_mul)

    assert utils.cmp_exprs(actual, expected)
Example #25
0
def test_lightning_binary_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    feature_weight_mul = [
        ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(0.1617602138),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.0931034793),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.6279180888),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.1856722189),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(4), ast.NumVal(0.0009999878),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(5), ast.NumVal(-0.0028974470),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(6), ast.NumVal(-0.0059948515),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(7), ast.NumVal(-0.0024173728),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(8), ast.NumVal(0.0020429247),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(9), ast.NumVal(0.0009604400),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(10), ast.NumVal(0.0010933747),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(11), ast.NumVal(0.0078588761),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(12), ast.NumVal(-0.0069150246),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(13), ast.NumVal(-0.2583249885),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(14), ast.NumVal(0.0000097479),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(15), ast.NumVal(-0.0007210600),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(16), ast.NumVal(-0.0011295195),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(17), ast.NumVal(-0.0001966115),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(18), ast.NumVal(0.0001358314),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(19), ast.NumVal(-0.0000378118),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(20), ast.NumVal(0.1555921773),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(21), ast.NumVal(0.0621307817),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(22), ast.NumVal(0.5138354949),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(23), ast.NumVal(-0.2418579612),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(24), ast.NumVal(0.0007953821),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(25), ast.NumVal(-0.0110760214),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(26), ast.NumVal(-0.0162178044),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(27), ast.NumVal(-0.0040277699),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(28), ast.NumVal(0.0015067033),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(29), ast.NumVal(0.0001536614),
                       ast.BinNumOpType.MUL),
    ]

    expected = assemblers.utils.apply_op_to_expressions(
        ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul)

    assert utils.cmp_exprs(actual, expected)