def test_multi_class(): estimator = linear_model.LogisticRegression() estimator.coef_ = np.array([[1, 2], [3, 4], [5, 6]]) estimator.intercept_ = np.array([7, 8, 9]) assembler = assemblers.SklearnLinearModelAssembler(estimator) actual = assembler.assemble() expected = ast.VectorVal([ ast.BinNumExpr( ast.BinNumExpr( ast.NumVal(7), ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(1), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(2), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.BinNumExpr( ast.NumVal(8), ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(3), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(4), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.BinNumExpr( ast.NumVal(9), ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(5), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(6), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD) ]) assert utils.cmp_exprs(actual, expected)
def test_single_feature(): estimator = linear_model.LinearRegression() estimator.coef_ = np.array([1]) estimator.intercept_ = np.array([3]) assembler = assemblers.SklearnLinearModelAssembler(estimator) actual = assembler.assemble() expected = ast.BinNumExpr( ast.NumVal(3), ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(1), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD) assert utils.cmp_exprs(actual, expected)
def test_lightning_regression(): estimator = AdaGradRegressor(random_state=1) utils.get_regression_model_trainer()(estimator) assembler = assemblers.SklearnLinearModelAssembler(estimator) actual = assembler.assemble() feature_weight_mul = [ ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-0.0610645819), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.0856563713), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(-0.0562044566), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.2804204925), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(4), ast.NumVal(0.1359261760), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(5), ast.NumVal(1.6307305501), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(6), ast.NumVal(0.0866147265), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(7), ast.NumVal(-0.0726894150), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(8), ast.NumVal(0.0435440193), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(9), ast.NumVal(-0.0077364839), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(10), ast.NumVal(0.2902775116), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(11), ast.NumVal(0.0229879957), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(12), ast.NumVal(-0.7614706871), ast.BinNumOpType.MUL), ] expected = assemblers.utils.apply_op_to_expressions( ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul) assert utils.cmp_exprs(actual, expected)
def test_lightning_regression(): estimator = AdaGradRegressor(random_state=1) utils.get_regression_model_trainer()(estimator) assembler = assemblers.SklearnLinearModelAssembler(estimator) actual = assembler.assemble() feature_weight_mul = [ ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-0.08558826944690746), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.0803724696787377), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(-0.03516743076774846), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.26469178947134087), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(4), ast.NumVal(0.15651985221012488), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(5), ast.NumVal(1.5186399078028587), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(6), ast.NumVal(0.10089874009193693), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(7), ast.NumVal(-0.011426237067026246), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(8), ast.NumVal(0.0861987777487293), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(9), ast.NumVal(-0.0057791506839322574), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(10), ast.NumVal(0.3357752757550913), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(11), ast.NumVal(0.020189965076849486), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(12), ast.NumVal(-0.7390647599317609), ast.BinNumOpType.MUL), ] expected = assemblers.utils.apply_op_to_expressions( ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul) assert utils.cmp_exprs(actual, expected)
def test_lightning_regression(): estimator = AdaGradRegressor(random_state=1) utils.get_regression_model_trainer()(estimator) assembler = assemblers.SklearnLinearModelAssembler(estimator) actual = assembler.assemble() feature_weight_mul = [ ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-0.0961163452), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.1574398180), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(-0.0251799219), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.1975142192), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(4), ast.NumVal(0.1189621635), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(5), ast.NumVal(1.2977018274), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(6), ast.NumVal(0.1192977978), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(7), ast.NumVal(0.0331955333), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(8), ast.NumVal(0.1433964513), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(9), ast.NumVal(0.0014943531), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(10), ast.NumVal(0.3116036672), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(11), ast.NumVal(0.0258421936), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(12), ast.NumVal(-0.7386996349), ast.BinNumOpType.MUL), ] expected = assemblers.utils.apply_op_to_expressions( ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul) assert utils.cmp_exprs(actual, expected)
def test_lightning_multi_class(): estimator = AdaGradClassifier(random_state=1) utils.get_classification_model_trainer()(estimator) assembler = assemblers.SklearnLinearModelAssembler(estimator) actual = assembler.assemble() expected = ast.VectorVal([ ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.NumVal(0.0), ast.BinNumExpr( ast.FeatureRef(0), ast.NumVal(0.09686334892116512), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(1), ast.NumVal(0.32572202133211947), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(2), ast.NumVal(-0.48444233646554424), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(3), ast.NumVal(-0.219719145605816), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.NumVal(0.0), ast.BinNumExpr( ast.FeatureRef(0), ast.NumVal(-0.1089136473832088), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(1), ast.NumVal(-0.16956003333433572), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(2), ast.NumVal(0.0365531256007199), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(3), ast.NumVal(-0.01016100116780896), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.NumVal(0.0), ast.BinNumExpr( ast.FeatureRef(0), ast.NumVal(-0.16690339219780817), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(1), ast.NumVal(-0.19466284646233858), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(2), ast.NumVal(0.2953585236360389), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(3), ast.NumVal(0.21288203082531384), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD)]) assert utils.cmp_exprs(actual, expected)
def test_lightning_binary_class(): estimator = AdaGradClassifier(random_state=1) utils.get_binary_classification_model_trainer()(estimator) assembler = assemblers.SklearnLinearModelAssembler(estimator) actual = assembler.assemble() feature_weight_mul = [ ast.BinNumExpr( ast.FeatureRef(0), ast.NumVal(0.16218889967390476), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(1), ast.NumVal(0.10012761963766906), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(2), ast.NumVal(0.6289276652681673), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(3), ast.NumVal(0.17618420156072845), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(4), ast.NumVal(0.0010492096607182045), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(5), ast.NumVal(-0.0029135563693806913), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(6), ast.NumVal(-0.005923882409142498), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(7), ast.NumVal(-0.0023293599172479755), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(8), ast.NumVal(0.0020808828960210517), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(9), ast.NumVal(0.0009846430705550103), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(10), ast.NumVal(0.0010399810925427265), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(11), ast.NumVal(0.011203056917272093), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(12), ast.NumVal(-0.007271351370867731), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(13), ast.NumVal(-0.26333437096804224), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(14), ast.NumVal(1.8533543368532444e-05), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(15), ast.NumVal(-0.0008266341686278445), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(16), ast.NumVal(-0.0011090316301215724), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(17), ast.NumVal(-0.0001910857095336291), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(18), ast.NumVal(0.00010735116208006556), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(19), ast.NumVal(-4.076097659514017e-05), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(20), ast.NumVal(0.15300712110146406), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(21), ast.NumVal(0.06316277258339074), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(22), ast.NumVal(0.495291178977687), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(23), ast.NumVal(-0.29589136204657845), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(24), ast.NumVal(0.000771932729567487), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(25), ast.NumVal(-0.011877978242492428), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(26), ast.NumVal(-0.01678004536869617), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(27), ast.NumVal(-0.004070431062579625), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(28), ast.NumVal(0.001158641497209262), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(29), ast.NumVal(0.00010737287732588742), ast.BinNumOpType.MUL), ] expected = assemblers.utils.apply_op_to_expressions( ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul) assert utils.cmp_exprs(actual, expected)
def test_lightning_multi_class(): estimator = AdaGradClassifier(random_state=1) utils.get_classification_model_trainer()(estimator) assembler = assemblers.SklearnLinearModelAssembler(estimator) actual = assembler.assemble() expected = ast.VectorVal([ ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.NumVal(0.0), ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(0.0895848274), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.3258329434), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(-0.4900856238), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(-0.2214482506), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.NumVal(0.0), ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-0.1074247041), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(-0.1693225196), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.0357417324), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(-0.0161614171), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.NumVal(0.0), ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-0.1825063678), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(-0.2185655665), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.3053017646), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.2175198459), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD) ]) assert utils.cmp_exprs(actual, expected)
def test_lightning_binary_class(): estimator = AdaGradClassifier(random_state=1) utils.get_binary_classification_model_trainer()(estimator) assembler = assemblers.SklearnLinearModelAssembler(estimator) actual = assembler.assemble() feature_weight_mul = [ ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(0.1605265174), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.1045225083), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.6237391536), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.1680225811), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(4), ast.NumVal(0.0011013688), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(5), ast.NumVal(-0.0027528486), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(6), ast.NumVal(-0.0058878714), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(7), ast.NumVal(-0.0023719811), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(8), ast.NumVal(0.0019944105), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(9), ast.NumVal(0.0009924456), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(10), ast.NumVal(0.0003994860), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(11), ast.NumVal(0.0124697033), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(12), ast.NumVal(-0.0123674096), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(13), ast.NumVal(-0.2844204905), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(14), ast.NumVal(0.0000273704), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(15), ast.NumVal(-0.0007498013), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(16), ast.NumVal(-0.0010784399), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(17), ast.NumVal(-0.0001848694), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(18), ast.NumVal(0.0000632254), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(19), ast.NumVal(-0.0000369618), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(20), ast.NumVal(0.1520223021), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(21), ast.NumVal(0.0925348635), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(22), ast.NumVal(0.4861047372), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(23), ast.NumVal(-0.2798670185), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(24), ast.NumVal(0.0009925506), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(25), ast.NumVal(-0.0103414976), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(26), ast.NumVal(-0.0155024577), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(27), ast.NumVal(-0.0038881538), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(28), ast.NumVal(0.0010126166), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(29), ast.NumVal(0.0002312558), ast.BinNumOpType.MUL), ] expected = assemblers.utils.apply_op_to_expressions( ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul) assert utils.cmp_exprs(actual, expected)
def test_lightning_multi_class(): estimator = AdaGradClassifier(random_state=1) utils.get_classification_model_trainer()(estimator) assembler = assemblers.SklearnLinearModelAssembler(estimator) actual = assembler.assemble() expected = ast.VectorVal([ ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.NumVal(0.0), ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(0.0935146297), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.3213921354), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(-0.4855914264), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(-0.2214295302), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.NumVal(0.0), ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-0.1103262586), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(-0.1662457692), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.0379823341), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(-0.0128634938), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.BinNumExpr( ast.NumVal(0.0), ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-0.1685751402), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(-0.2045901693), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.2932121798), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.2138148665), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD) ]) assert utils.cmp_exprs(actual, expected)
def test_lightning_binary_class(): estimator = AdaGradClassifier(random_state=1) utils.get_binary_classification_model_trainer()(estimator) assembler = assemblers.SklearnLinearModelAssembler(estimator) actual = assembler.assemble() feature_weight_mul = [ ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(0.1617602138), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(1), ast.NumVal(0.0931034793), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(2), ast.NumVal(0.6279180888), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(3), ast.NumVal(0.1856722189), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(4), ast.NumVal(0.0009999878), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(5), ast.NumVal(-0.0028974470), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(6), ast.NumVal(-0.0059948515), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(7), ast.NumVal(-0.0024173728), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(8), ast.NumVal(0.0020429247), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(9), ast.NumVal(0.0009604400), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(10), ast.NumVal(0.0010933747), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(11), ast.NumVal(0.0078588761), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(12), ast.NumVal(-0.0069150246), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(13), ast.NumVal(-0.2583249885), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(14), ast.NumVal(0.0000097479), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(15), ast.NumVal(-0.0007210600), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(16), ast.NumVal(-0.0011295195), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(17), ast.NumVal(-0.0001966115), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(18), ast.NumVal(0.0001358314), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(19), ast.NumVal(-0.0000378118), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(20), ast.NumVal(0.1555921773), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(21), ast.NumVal(0.0621307817), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(22), ast.NumVal(0.5138354949), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(23), ast.NumVal(-0.2418579612), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(24), ast.NumVal(0.0007953821), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(25), ast.NumVal(-0.0110760214), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(26), ast.NumVal(-0.0162178044), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(27), ast.NumVal(-0.0040277699), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(28), ast.NumVal(0.0015067033), ast.BinNumOpType.MUL), ast.BinNumExpr(ast.FeatureRef(29), ast.NumVal(0.0001536614), ast.BinNumOpType.MUL), ] expected = assemblers.utils.apply_op_to_expressions( ast.BinNumOpType.ADD, ast.NumVal(0.0), *feature_weight_mul) assert utils.cmp_exprs(actual, expected)