def test_get_visual_block_single_estimator(): est = LogisticRegression(C=10.0) est_html_info = _get_visual_block(est) assert est_html_info.kind == "single" assert est_html_info.estimators == est assert est_html_info.names == est.__class__.__name__ assert est_html_info.name_details == str(est)
def test_get_visual_block_single_str_none(est): # Test estimators that are represented by strings est_html_info = _get_visual_block(est) assert est_html_info.kind == "single" assert est_html_info.estimators == est assert est_html_info.names == str(est) assert est_html_info.name_details == str(est)
def test_get_visual_block_feature_union(): f_union = FeatureUnion([("pca", PCA()), ("svd", TruncatedSVD())]) est_html_info = _get_visual_block(f_union) assert est_html_info.kind == "parallel" assert est_html_info.names == ("pca", "svd") assert est_html_info.estimators == tuple( trans[1] for trans in f_union.transformer_list) assert est_html_info.name_details == (None, None)
def test_get_visual_block_voting(): clf = VotingClassifier([("log_reg", LogisticRegression()), ("mlp", MLPClassifier())]) est_html_info = _get_visual_block(clf) assert est_html_info.kind == "parallel" assert est_html_info.estimators == tuple(trans[1] for trans in clf.estimators) assert est_html_info.names == ("log_reg", "mlp") assert est_html_info.name_details == (None, None)
def test_get_visual_block_column_transformer(): ct = ColumnTransformer([("pca", PCA(), ["num1", "num2"]), ("svd", TruncatedSVD, [0, 3])]) est_html_info = _get_visual_block(ct) assert est_html_info.kind == "parallel" assert est_html_info.estimators == tuple(trans[1] for trans in ct.transformers) assert est_html_info.names == ("pca", "svd") assert est_html_info.name_details == (["num1", "num2"], [0, 3])
def test_get_visual_block_column_transformer(): ct = ColumnTransformer([ ('pca', PCA(), ['num1', 'num2']), ('svd', TruncatedSVD, [0, 3]) ]) est_html_info = _get_visual_block(ct) assert est_html_info.kind == 'parallel' assert est_html_info.estimators == tuple( trans[1] for trans in ct.transformers) assert est_html_info.names == ('pca', 'svd') assert est_html_info.name_details == (['num1', 'num2'], [0, 3])
def test_get_visual_block_pipeline(): pipe = Pipeline([ ('imputer', SimpleImputer()), ('do_nothing', 'passthrough'), ('do_nothing_more', None), ('classifier', LogisticRegression()) ]) est_html_info = _get_visual_block(pipe) assert est_html_info.kind == 'serial' assert est_html_info.estimators == tuple(step[1] for step in pipe.steps) assert est_html_info.names == ['imputer: SimpleImputer', 'do_nothing: passthrough', 'do_nothing_more: passthrough', 'classifier: LogisticRegression'] assert est_html_info.name_details == [str(est) for _, est in pipe.steps]
def test_get_visual_block_pipeline(): pipe = Pipeline([ ("imputer", SimpleImputer()), ("do_nothing", "passthrough"), ("do_nothing_more", None), ("classifier", LogisticRegression()), ]) est_html_info = _get_visual_block(pipe) assert est_html_info.kind == "serial" assert est_html_info.estimators == tuple(step[1] for step in pipe.steps) assert est_html_info.names == [ "imputer: SimpleImputer", "do_nothing: passthrough", "do_nothing_more: passthrough", "classifier: LogisticRegression", ] assert est_html_info.name_details == [str(est) for _, est in pipe.steps]