예제 #1
0
def test_bin_class_sigmoid_output_transform():
    estimator = lightgbm.LGBMClassifier(n_estimators=1,
                                        random_state=1,
                                        max_depth=1,
                                        sigmoid=0.5)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.LightGBMModelAssembler(estimator)
    actual = assembler.assemble()

    sigmoid = ast.BinNumExpr(
        ast.NumVal(1),
        ast.BinNumExpr(
            ast.NumVal(1),
            ast.ExpExpr(
                ast.BinNumExpr(
                    ast.NumVal(0),
                    ast.BinNumExpr(
                        ast.NumVal(0.5),
                        ast.IfExpr(
                            ast.CompExpr(ast.FeatureRef(20),
                                         ast.NumVal(16.795),
                                         ast.CompOpType.GT),
                            ast.NumVal(0.5500419366076967),
                            ast.NumVal(1.2782342253678096)),
                        ast.BinNumOpType.MUL), ast.BinNumOpType.SUB)),
            ast.BinNumOpType.ADD),
        ast.BinNumOpType.DIV,
        to_reuse=True)

    expected = ast.VectorVal([
        ast.BinNumExpr(ast.NumVal(1), sigmoid, ast.BinNumOpType.SUB), sigmoid
    ])

    assert utils.cmp_exprs(actual, expected)
예제 #2
0
def test_binary_classification():
    estimator = lightgbm.LGBMClassifier(n_estimators=2,
                                        random_state=1,
                                        max_depth=1)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.LightGBMModelAssembler(estimator)
    actual = assembler.assemble()

    sigmoid = ast.BinNumExpr(
        ast.NumVal(1),
        ast.BinNumExpr(
            ast.NumVal(1),
            ast.ExpExpr(
                ast.BinNumExpr(
                    ast.NumVal(0),
                    ast.SubroutineExpr(
                        ast.BinNumExpr(
                            ast.BinNumExpr(
                                ast.NumVal(0),
                                ast.SubroutineExpr(
                                    ast.IfExpr(
                                        ast.CompExpr(
                                            ast.FeatureRef(23),
                                            ast.NumVal(868.2000000000002),
                                            ast.CompOpType.GT),
                                        ast.NumVal(0.25986931215073095),
                                        ast.NumVal(0.6237178414050242))),
                                ast.BinNumOpType.ADD),
                            ast.SubroutineExpr(
                                ast.IfExpr(
                                    ast.CompExpr(ast.FeatureRef(7),
                                                 ast.NumVal(0.05142),
                                                 ast.CompOpType.GT),
                                    ast.NumVal(-0.1909605544006228),
                                    ast.NumVal(0.1293965108676673))),
                            ast.BinNumOpType.ADD)), ast.BinNumOpType.SUB)),
            ast.BinNumOpType.ADD),
        ast.BinNumOpType.DIV,
        to_reuse=True)

    expected = ast.VectorVal([
        ast.BinNumExpr(ast.NumVal(1), sigmoid, ast.BinNumOpType.SUB), sigmoid
    ])

    assert utils.cmp_exprs(actual, expected)
예제 #3
0
def test_binary_classification():
    estimator = xgboost.XGBClassifier(n_estimators=2,
                                      random_state=1,
                                      max_depth=1)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.XGBoostModelAssemblerSelector(estimator)
    actual = assembler.assemble()

    sigmoid = ast.BinNumExpr(
        ast.NumVal(1),
        ast.BinNumExpr(
            ast.NumVal(1),
            ast.ExpExpr(
                ast.BinNumExpr(
                    ast.NumVal(0),
                    ast.SubroutineExpr(
                        ast.BinNumExpr(
                            ast.BinNumExpr(
                                ast.NumVal(-0.0),
                                ast.SubroutineExpr(
                                    ast.IfExpr(
                                        ast.CompExpr(ast.FeatureRef(20),
                                                     ast.NumVal(16.7950001),
                                                     ast.CompOpType.GTE),
                                        ast.NumVal(-0.173057005),
                                        ast.NumVal(0.163440868))),
                                ast.BinNumOpType.ADD),
                            ast.SubroutineExpr(
                                ast.IfExpr(
                                    ast.CompExpr(ast.FeatureRef(27),
                                                 ast.NumVal(0.142349988),
                                                 ast.CompOpType.GTE),
                                    ast.NumVal(-0.161026895),
                                    ast.NumVal(0.149405137))),
                            ast.BinNumOpType.ADD)), ast.BinNumOpType.SUB)),
            ast.BinNumOpType.ADD),
        ast.BinNumOpType.DIV,
        to_reuse=True)

    expected = ast.VectorVal([
        ast.BinNumExpr(ast.NumVal(1), sigmoid, ast.BinNumOpType.SUB), sigmoid
    ])

    assert utils.cmp_exprs(actual, expected)
예제 #4
0
def test_binary_classification():
    estimator = lightgbm.LGBMClassifier(n_estimators=2,
                                        random_state=1,
                                        max_depth=1)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.LightGBMModelAssembler(estimator)
    actual = assembler.assemble()

    sigmoid = ast.BinNumExpr(
        ast.NumVal(1),
        ast.BinNumExpr(
            ast.NumVal(1),
            ast.ExpExpr(
                ast.BinNumExpr(
                    ast.NumVal(0),
                    ast.BinNumExpr(
                        ast.IfExpr(
                            ast.CompExpr(ast.FeatureRef(23),
                                         ast.NumVal(868.2000000000002),
                                         ast.CompOpType.GT),
                            ast.NumVal(0.26400127816506497),
                            ast.NumVal(0.633133056485969)),
                        ast.IfExpr(
                            ast.CompExpr(ast.FeatureRef(22),
                                         ast.NumVal(105.95000000000002),
                                         ast.CompOpType.GT),
                            ast.NumVal(-0.18744882409486507),
                            ast.NumVal(0.13458899352064668)),
                        ast.BinNumOpType.ADD), ast.BinNumOpType.SUB)),
            ast.BinNumOpType.ADD),
        ast.BinNumOpType.DIV,
        to_reuse=True)

    expected = ast.VectorVal([
        ast.BinNumExpr(ast.NumVal(1), sigmoid, ast.BinNumOpType.SUB), sigmoid
    ])

    assert utils.cmp_exprs(actual, expected)
예제 #5
0
def test_binary_classification():
    estimator = lightgbm.LGBMClassifier(n_estimators=2,
                                        random_state=1,
                                        max_depth=1)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.LightGBMModelAssembler(estimator)
    actual = assembler.assemble()

    sigmoid = ast.BinNumExpr(
        ast.NumVal(1),
        ast.BinNumExpr(
            ast.NumVal(1),
            ast.ExpExpr(
                ast.BinNumExpr(
                    ast.NumVal(0),
                    ast.BinNumExpr(
                        ast.IfExpr(
                            ast.CompExpr(ast.FeatureRef(20),
                                         ast.NumVal(16.795),
                                         ast.CompOpType.GT),
                            ast.NumVal(0.27502096830384837),
                            ast.NumVal(0.6391171126839048)),
                        ast.IfExpr(
                            ast.CompExpr(ast.FeatureRef(27),
                                         ast.NumVal(0.14205),
                                         ast.CompOpType.GT),
                            ast.NumVal(-0.21340153096570616),
                            ast.NumVal(0.11583109256834748)),
                        ast.BinNumOpType.ADD), ast.BinNumOpType.SUB)),
            ast.BinNumOpType.ADD),
        ast.BinNumOpType.DIV,
        to_reuse=True)

    expected = ast.VectorVal([
        ast.BinNumExpr(ast.NumVal(1), sigmoid, ast.BinNumOpType.SUB), sigmoid
    ])

    assert utils.cmp_exprs(actual, expected)
예제 #6
0
def test_binary_classification():
    estimator = xgboost.XGBClassifier(n_estimators=2,
                                      random_state=1,
                                      max_depth=1)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.XGBoostModelAssemblerSelector(estimator)
    actual = assembler.assemble()

    sigmoid = ast.BinNumExpr(
        ast.NumVal(1),
        ast.BinNumExpr(
            ast.NumVal(1),
            ast.ExpExpr(
                ast.BinNumExpr(
                    ast.NumVal(0),
                    ast.BinNumExpr(
                        ast.IfExpr(
                            ast.CompExpr(ast.FeatureRef(20),
                                         ast.NumVal(16.795),
                                         ast.CompOpType.GTE),
                            ast.NumVal(-0.5178947448730469),
                            ast.NumVal(0.4880000054836273)),
                        ast.IfExpr(
                            ast.CompExpr(ast.FeatureRef(27),
                                         ast.NumVal(0.142349988),
                                         ast.CompOpType.GTE),
                            ast.NumVal(-0.4447747468948364),
                            ast.NumVal(0.39517202973365784)),
                        ast.BinNumOpType.ADD), ast.BinNumOpType.SUB)),
            ast.BinNumOpType.ADD),
        ast.BinNumOpType.DIV,
        to_reuse=True)

    expected = ast.VectorVal([
        ast.BinNumExpr(ast.NumVal(1), sigmoid, ast.BinNumOpType.SUB), sigmoid
    ])

    assert utils.cmp_exprs(actual, expected)
예제 #7
0
def test_lightning_binary_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    feature_weight_mul = [
        ast.BinNumExpr(
            ast.FeatureRef(0),
            ast.NumVal(0.16218889967390476),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(1),
            ast.NumVal(0.10012761963766906),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(2),
            ast.NumVal(0.6289276652681673),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(3),
            ast.NumVal(0.17618420156072845),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(4),
            ast.NumVal(0.0010492096607182045),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(5),
            ast.NumVal(-0.0029135563693806913),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(6),
            ast.NumVal(-0.005923882409142498),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(7),
            ast.NumVal(-0.0023293599172479755),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(8),
            ast.NumVal(0.0020808828960210517),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(9),
            ast.NumVal(0.0009846430705550103),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(10),
            ast.NumVal(0.0010399810925427265),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(11),
            ast.NumVal(0.011203056917272093),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(12),
            ast.NumVal(-0.007271351370867731),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(13),
            ast.NumVal(-0.26333437096804224),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(14),
            ast.NumVal(1.8533543368532444e-05),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(15),
            ast.NumVal(-0.0008266341686278445),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(16),
            ast.NumVal(-0.0011090316301215724),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(17),
            ast.NumVal(-0.0001910857095336291),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(18),
            ast.NumVal(0.00010735116208006556),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(19),
            ast.NumVal(-4.076097659514017e-05),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(20),
            ast.NumVal(0.15300712110146406),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(21),
            ast.NumVal(0.06316277258339074),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(22),
            ast.NumVal(0.495291178977687),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(23),
            ast.NumVal(-0.29589136204657845),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(24),
            ast.NumVal(0.000771932729567487),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(25),
            ast.NumVal(-0.011877978242492428),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(26),
            ast.NumVal(-0.01678004536869617),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(27),
            ast.NumVal(-0.004070431062579625),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(28),
            ast.NumVal(0.001158641497209262),
            ast.BinNumOpType.MUL),
        ast.BinNumExpr(
            ast.FeatureRef(29),
            ast.NumVal(0.00010737287732588742),
            ast.BinNumOpType.MUL),
    ]

    expected = assemblers.utils.apply_op_to_expressions(
        ast.BinNumOpType.ADD,
        ast.NumVal(0.0),
        *feature_weight_mul)

    assert utils.cmp_exprs(actual, expected)
예제 #8
0
파일: test_e2e.py 프로젝트: goldv/m2cgen
def classification_binary(model, test_fraction=0.02):
    return (
        model,
        utils.get_binary_classification_model_trainer(test_fraction),
        CLASSIFICATION,
    )
예제 #9
0
파일: test_e2e.py 프로젝트: goldv/m2cgen
        regression(linear_model.Ridge(random_state=RANDOM_SEED)),
        regression(linear_model.RidgeCV()),
        regression(linear_model.SGDRegressor(random_state=RANDOM_SEED)),
        regression(linear_model.TheilSenRegressor(random_state=RANDOM_SEED)),
        regression(linear_model.TweedieRegressor(power=0.0)),
        regression(linear_model.TweedieRegressor(power=1.0)),
        regression(linear_model.TweedieRegressor(power=1.5)),
        regression(linear_model.TweedieRegressor(power=2.0)),
        regression(linear_model.TweedieRegressor(power=3.0)),

        # Statsmodels Linear Regression
        classification_binary(
            utils.StatsmodelsSklearnLikeWrapper(
                sm.GLM,
                dict(fit_constrained=dict(constraints=(
                    np.eye(utils.get_binary_classification_model_trainer().
                           X_train.shape[-1])[0], [1]))))),
        classification_binary(
            utils.StatsmodelsSklearnLikeWrapper(
                sm.GLM,
                dict(fit_regularized=STATSMODELS_LINEAR_REGULARIZED_PARAMS))),
        classification_binary(
            utils.StatsmodelsSklearnLikeWrapper(
                sm.GLM,
                dict(init=dict(
                    family=sm.families.Binomial(sm.families.links.cauchy())),
                     fit=dict(maxiter=2)))),
        classification_binary(
            utils.StatsmodelsSklearnLikeWrapper(
                sm.GLM,
                dict(init=dict(
                    family=sm.families.Binomial(sm.families.links.cloglog())),
예제 #10
0
def test_lightning_binary_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    feature_weight_mul = [
        ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(0.1605265174),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.1045225083),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.6237391536),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.1680225811),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(4), ast.NumVal(0.0011013688),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(5), ast.NumVal(-0.0027528486),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(6), ast.NumVal(-0.0058878714),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(7), ast.NumVal(-0.0023719811),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(8), ast.NumVal(0.0019944105),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(9), ast.NumVal(0.0009924456),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(10), ast.NumVal(0.0003994860),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(11), ast.NumVal(0.0124697033),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(12), ast.NumVal(-0.0123674096),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(13), ast.NumVal(-0.2844204905),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(14), ast.NumVal(0.0000273704),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(15), ast.NumVal(-0.0007498013),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(16), ast.NumVal(-0.0010784399),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(17), ast.NumVal(-0.0001848694),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(18), ast.NumVal(0.0000632254),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(19), ast.NumVal(-0.0000369618),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(20), ast.NumVal(0.1520223021),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(21), ast.NumVal(0.0925348635),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(22), ast.NumVal(0.4861047372),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(23), ast.NumVal(-0.2798670185),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(24), ast.NumVal(0.0009925506),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(25), ast.NumVal(-0.0103414976),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(26), ast.NumVal(-0.0155024577),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(27), ast.NumVal(-0.0038881538),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(28), ast.NumVal(0.0010126166),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(29), ast.NumVal(0.0002312558),
                       ast.BinNumOpType.MUL),
    ]

    expected = assemblers.utils.apply_op_to_expressions(
        ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul)

    assert utils.cmp_exprs(actual, expected)
예제 #11
0
def test_lightning_binary_class():
    estimator = AdaGradClassifier(random_state=1)
    utils.get_binary_classification_model_trainer()(estimator)

    assembler = assemblers.SklearnLinearModelAssembler(estimator)
    actual = assembler.assemble()

    feature_weight_mul = [
        ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(0.1617602138),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.0931034793),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.6279180888),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.1856722189),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(4), ast.NumVal(0.0009999878),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(5), ast.NumVal(-0.0028974470),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(6), ast.NumVal(-0.0059948515),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(7), ast.NumVal(-0.0024173728),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(8), ast.NumVal(0.0020429247),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(9), ast.NumVal(0.0009604400),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(10), ast.NumVal(0.0010933747),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(11), ast.NumVal(0.0078588761),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(12), ast.NumVal(-0.0069150246),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(13), ast.NumVal(-0.2583249885),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(14), ast.NumVal(0.0000097479),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(15), ast.NumVal(-0.0007210600),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(16), ast.NumVal(-0.0011295195),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(17), ast.NumVal(-0.0001966115),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(18), ast.NumVal(0.0001358314),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(19), ast.NumVal(-0.0000378118),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(20), ast.NumVal(0.1555921773),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(21), ast.NumVal(0.0621307817),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(22), ast.NumVal(0.5138354949),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(23), ast.NumVal(-0.2418579612),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(24), ast.NumVal(0.0007953821),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(25), ast.NumVal(-0.0110760214),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(26), ast.NumVal(-0.0162178044),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(27), ast.NumVal(-0.0040277699),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(28), ast.NumVal(0.0015067033),
                       ast.BinNumOpType.MUL),
        ast.BinNumExpr(ast.FeatureRef(29), ast.NumVal(0.0001536614),
                       ast.BinNumOpType.MUL),
    ]

    expected = assemblers.utils.apply_op_to_expressions(
        ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul)

    assert utils.cmp_exprs(actual, expected)
예제 #12
0
def classification_binary(model):
    return (
        model,
        utils.get_binary_classification_model_trainer(),
        CLASSIFICATION,
    )