Example #1
0
def test_multi_class():
    estimator = linear_model.LogisticRegression()
    estimator.coef_ = np.array([[1, 2], [3, 4], [5, 6]])
    estimator.intercept_ = np.array([7, 8, 9])

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    expected = ast.VectorVal([
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.NumVal(7),
                ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(1),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(2),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.NumVal(8),
                ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(3),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(4),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.NumVal(9),
                ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(5),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(6),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD)
    ])

    assert utils.cmp_exprs(actual, expected)
Example #2
0
def test_single_feature():
    estimator = linear_model.LinearRegression()
    estimator.coef_ = np.array([1])
    estimator.intercept_ = np.array([3])

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    expected = ast.BinNumExpr(
        ast.NumVal(3),
        ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(1), ast.BinNumOpType.MUL),
        ast.BinNumOpType.ADD)

    assert utils.cmp_exprs(actual, expected)
Example #3
0
def test_lightning_regression():
    estimator = AdaGradRegressor(random_state=1)
    utils.get_regression_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    feature_weight_mul = [
        ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-0.0610645819),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.0856563713),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(-0.0562044566),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.2804204925),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(4), ast.NumVal(0.1359261760),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(5), ast.NumVal(1.6307305501),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(6), ast.NumVal(0.0866147265),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(7), ast.NumVal(-0.0726894150),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(8), ast.NumVal(0.0435440193),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(9), ast.NumVal(-0.0077364839),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(10), ast.NumVal(0.2902775116),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(11), ast.NumVal(0.0229879957),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(12), ast.NumVal(-0.7614706871),
                       ast.BinNumOpType.MUL),
    ]

    expected = assemblers.utils.apply_op_to_expressions(
        ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul)

    assert utils.cmp_exprs(actual, expected)
Example #4
0
def test_lightning_regression():
    estimator = AdaGradRegressor(random_state=1)
    utils.get_regression_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    feature_weight_mul = [
        ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-0.08558826944690746),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.0803724696787377),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(-0.03516743076774846),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.26469178947134087),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(4), ast.NumVal(0.15651985221012488),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(5), ast.NumVal(1.5186399078028587),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(6), ast.NumVal(0.10089874009193693),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(7), ast.NumVal(-0.011426237067026246),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(8), ast.NumVal(0.0861987777487293),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(9), ast.NumVal(-0.0057791506839322574),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(10), ast.NumVal(0.3357752757550913),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(11), ast.NumVal(0.020189965076849486),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(12), ast.NumVal(-0.7390647599317609),
                       ast.BinNumOpType.MUL),
    ]

    expected = assemblers.utils.apply_op_to_expressions(
        ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul)

    assert utils.cmp_exprs(actual, expected)
Example #5
0
def test_lightning_regression():
    estimator = AdaGradRegressor(random_state=1)
    utils.get_regression_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    feature_weight_mul = [
        ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-0.0961163452),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.1574398180),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(-0.0251799219),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.1975142192),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(4), ast.NumVal(0.1189621635),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(5), ast.NumVal(1.2977018274),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(6), ast.NumVal(0.1192977978),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(7), ast.NumVal(0.0331955333),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(8), ast.NumVal(0.1433964513),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(9), ast.NumVal(0.0014943531),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(10), ast.NumVal(0.3116036672),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(11), ast.NumVal(0.0258421936),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(12), ast.NumVal(-0.7386996349),
                       ast.BinNumOpType.MUL),
    ]

    expected = assemblers.utils.apply_op_to_expressions(
        ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul)

    assert utils.cmp_exprs(actual, expected)
Example #6
0
def test_lightning_multi_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    expected = ast.VectorVal([
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(
                            ast.FeatureRef(0),
                            ast.NumVal(0.09686334892116512),
                            ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(
                        ast.FeatureRef(1),
                        ast.NumVal(0.32572202133211947),
                        ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(
                    ast.FeatureRef(2),
                    ast.NumVal(-0.48444233646554424),
                    ast.BinNumOpType.MUL),
                ast.BinNumOpType.ADD),
            ast.BinNumExpr(
                ast.FeatureRef(3),
                ast.NumVal(-0.219719145605816),
                ast.BinNumOpType.MUL),
            ast.BinNumOpType.ADD),
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(
                            ast.FeatureRef(0),
                            ast.NumVal(-0.1089136473832088),
                            ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(
                        ast.FeatureRef(1),
                        ast.NumVal(-0.16956003333433572),
                        ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(
                    ast.FeatureRef(2),
                    ast.NumVal(0.0365531256007199),
                    ast.BinNumOpType.MUL),
                ast.BinNumOpType.ADD),
            ast.BinNumExpr(
                ast.FeatureRef(3),
                ast.NumVal(-0.01016100116780896),
                ast.BinNumOpType.MUL),
            ast.BinNumOpType.ADD),
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(
                            ast.FeatureRef(0),
                            ast.NumVal(-0.16690339219780817),
                            ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(
                        ast.FeatureRef(1),
                        ast.NumVal(-0.19466284646233858),
                        ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(
                    ast.FeatureRef(2),
                    ast.NumVal(0.2953585236360389),
                    ast.BinNumOpType.MUL),
                ast.BinNumOpType.ADD),
            ast.BinNumExpr(
                ast.FeatureRef(3),
                ast.NumVal(0.21288203082531384),
                ast.BinNumOpType.MUL),
            ast.BinNumOpType.ADD)])

    assert utils.cmp_exprs(actual, expected)
Example #7
0
def test_lightning_binary_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    feature_weight_mul = [
        ast.BinNumExpr(
            ast.FeatureRef(0),
            ast.NumVal(0.16218889967390476),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(1),
            ast.NumVal(0.10012761963766906),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(2),
            ast.NumVal(0.6289276652681673),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(3),
            ast.NumVal(0.17618420156072845),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(4),
            ast.NumVal(0.0010492096607182045),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(5),
            ast.NumVal(-0.0029135563693806913),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(6),
            ast.NumVal(-0.005923882409142498),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(7),
            ast.NumVal(-0.0023293599172479755),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(8),
            ast.NumVal(0.0020808828960210517),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(9),
            ast.NumVal(0.0009846430705550103),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(10),
            ast.NumVal(0.0010399810925427265),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(11),
            ast.NumVal(0.011203056917272093),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(12),
            ast.NumVal(-0.007271351370867731),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(13),
            ast.NumVal(-0.26333437096804224),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(14),
            ast.NumVal(1.8533543368532444e-05),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(15),
            ast.NumVal(-0.0008266341686278445),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(16),
            ast.NumVal(-0.0011090316301215724),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(17),
            ast.NumVal(-0.0001910857095336291),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(18),
            ast.NumVal(0.00010735116208006556),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(19),
            ast.NumVal(-4.076097659514017e-05),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(20),
            ast.NumVal(0.15300712110146406),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(21),
            ast.NumVal(0.06316277258339074),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(22),
            ast.NumVal(0.495291178977687),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(23),
            ast.NumVal(-0.29589136204657845),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(24),
            ast.NumVal(0.000771932729567487),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(25),
            ast.NumVal(-0.011877978242492428),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(26),
            ast.NumVal(-0.01678004536869617),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(27),
            ast.NumVal(-0.004070431062579625),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(28),
            ast.NumVal(0.001158641497209262),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(29),
            ast.NumVal(0.00010737287732588742),
            ast.BinNumOpType.MUL),
    ]

    expected = assemblers.utils.apply_op_to_expressions(
        ast.BinNumOpType.ADD,
        ast.NumVal(0.0),
        *feature_weight_mul)

    assert utils.cmp_exprs(actual, expected)
Example #8
0
def test_lightning_multi_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    expected = ast.VectorVal([
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(ast.FeatureRef(0),
                                       ast.NumVal(0.0895848274),
                                       ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.3258329434),
                                   ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(-0.4900856238),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(-0.2214482506),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(ast.FeatureRef(0),
                                       ast.NumVal(-0.1074247041),
                                       ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(ast.FeatureRef(1),
                                   ast.NumVal(-0.1693225196),
                                   ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.0357417324),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(-0.0161614171),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(ast.FeatureRef(0),
                                       ast.NumVal(-0.1825063678),
                                       ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(ast.FeatureRef(1),
                                   ast.NumVal(-0.2185655665),
                                   ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.3053017646),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.2175198459),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD)
    ])

    assert utils.cmp_exprs(actual, expected)
Example #9
0
def test_lightning_binary_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    feature_weight_mul = [
        ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(0.1605265174),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.1045225083),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.6237391536),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.1680225811),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(4), ast.NumVal(0.0011013688),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(5), ast.NumVal(-0.0027528486),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(6), ast.NumVal(-0.0058878714),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(7), ast.NumVal(-0.0023719811),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(8), ast.NumVal(0.0019944105),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(9), ast.NumVal(0.0009924456),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(10), ast.NumVal(0.0003994860),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(11), ast.NumVal(0.0124697033),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(12), ast.NumVal(-0.0123674096),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(13), ast.NumVal(-0.2844204905),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(14), ast.NumVal(0.0000273704),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(15), ast.NumVal(-0.0007498013),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(16), ast.NumVal(-0.0010784399),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(17), ast.NumVal(-0.0001848694),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(18), ast.NumVal(0.0000632254),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(19), ast.NumVal(-0.0000369618),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(20), ast.NumVal(0.1520223021),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(21), ast.NumVal(0.0925348635),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(22), ast.NumVal(0.4861047372),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(23), ast.NumVal(-0.2798670185),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(24), ast.NumVal(0.0009925506),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(25), ast.NumVal(-0.0103414976),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(26), ast.NumVal(-0.0155024577),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(27), ast.NumVal(-0.0038881538),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(28), ast.NumVal(0.0010126166),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(29), ast.NumVal(0.0002312558),
                       ast.BinNumOpType.MUL),
    ]

    expected = assemblers.utils.apply_op_to_expressions(
        ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul)

    assert utils.cmp_exprs(actual, expected)
Example #10
0
def test_lightning_multi_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    expected = ast.VectorVal([
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(ast.FeatureRef(0),
                                       ast.NumVal(0.0935146297),
                                       ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.3213921354),
                                   ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(-0.4855914264),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(-0.2214295302),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(ast.FeatureRef(0),
                                       ast.NumVal(-0.1103262586),
                                       ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(ast.FeatureRef(1),
                                   ast.NumVal(-0.1662457692),
                                   ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.0379823341),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(-0.0128634938),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
        ast.BinNumExpr(
            ast.BinNumExpr(
                ast.BinNumExpr(
                    ast.BinNumExpr(
                        ast.NumVal(0.0),
                        ast.BinNumExpr(ast.FeatureRef(0),
                                       ast.NumVal(-0.1685751402),
                                       ast.BinNumOpType.MUL),
                        ast.BinNumOpType.ADD),
                    ast.BinNumExpr(ast.FeatureRef(1),
                                   ast.NumVal(-0.2045901693),
                                   ast.BinNumOpType.MUL),
                    ast.BinNumOpType.ADD),
                ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.2932121798),
                               ast.BinNumOpType.MUL), ast.BinNumOpType.ADD),
            ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.2138148665),
                           ast.BinNumOpType.MUL), ast.BinNumOpType.ADD)
    ])

    assert utils.cmp_exprs(actual, expected)
Example #11
0
def test_lightning_binary_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    feature_weight_mul = [
        ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(0.1617602138),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.0931034793),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.6279180888),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.1856722189),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(4), ast.NumVal(0.0009999878),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(5), ast.NumVal(-0.0028974470),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(6), ast.NumVal(-0.0059948515),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(7), ast.NumVal(-0.0024173728),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(8), ast.NumVal(0.0020429247),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(9), ast.NumVal(0.0009604400),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(10), ast.NumVal(0.0010933747),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(11), ast.NumVal(0.0078588761),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(12), ast.NumVal(-0.0069150246),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(13), ast.NumVal(-0.2583249885),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(14), ast.NumVal(0.0000097479),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(15), ast.NumVal(-0.0007210600),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(16), ast.NumVal(-0.0011295195),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(17), ast.NumVal(-0.0001966115),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(18), ast.NumVal(0.0001358314),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(19), ast.NumVal(-0.0000378118),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(20), ast.NumVal(0.1555921773),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(21), ast.NumVal(0.0621307817),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(22), ast.NumVal(0.5138354949),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(23), ast.NumVal(-0.2418579612),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(24), ast.NumVal(0.0007953821),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(25), ast.NumVal(-0.0110760214),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(26), ast.NumVal(-0.0162178044),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(27), ast.NumVal(-0.0040277699),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(28), ast.NumVal(0.0015067033),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(29), ast.NumVal(0.0001536614),
                       ast.BinNumOpType.MUL),
    ]

    expected = assemblers.utils.apply_op_to_expressions(
        ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul)

    assert utils.cmp_exprs(actual, expected)