def test_get_visual_block_single_estimator():
    est = LogisticRegression(C=10.0)
    est_html_info = _get_visual_block(est)
    assert est_html_info.kind == "single"
    assert est_html_info.estimators == est
    assert est_html_info.names == est.__class__.__name__
    assert est_html_info.name_details == str(est)
def test_get_visual_block_single_str_none(est):
    # Test estimators that are represented by strings
    est_html_info = _get_visual_block(est)
    assert est_html_info.kind == "single"
    assert est_html_info.estimators == est
    assert est_html_info.names == str(est)
    assert est_html_info.name_details == str(est)
def test_get_visual_block_feature_union():
    f_union = FeatureUnion([("pca", PCA()), ("svd", TruncatedSVD())])
    est_html_info = _get_visual_block(f_union)
    assert est_html_info.kind == "parallel"
    assert est_html_info.names == ("pca", "svd")
    assert est_html_info.estimators == tuple(
        trans[1] for trans in f_union.transformer_list)
    assert est_html_info.name_details == (None, None)
def test_get_visual_block_voting():
    clf = VotingClassifier([("log_reg", LogisticRegression()),
                            ("mlp", MLPClassifier())])
    est_html_info = _get_visual_block(clf)
    assert est_html_info.kind == "parallel"
    assert est_html_info.estimators == tuple(trans[1]
                                             for trans in clf.estimators)
    assert est_html_info.names == ("log_reg", "mlp")
    assert est_html_info.name_details == (None, None)
def test_get_visual_block_column_transformer():
    ct = ColumnTransformer([("pca", PCA(), ["num1", "num2"]),
                            ("svd", TruncatedSVD, [0, 3])])
    est_html_info = _get_visual_block(ct)
    assert est_html_info.kind == "parallel"
    assert est_html_info.estimators == tuple(trans[1]
                                             for trans in ct.transformers)
    assert est_html_info.names == ("pca", "svd")
    assert est_html_info.name_details == (["num1", "num2"], [0, 3])
def test_get_visual_block_column_transformer():
    ct = ColumnTransformer([
        ('pca', PCA(), ['num1', 'num2']),
        ('svd', TruncatedSVD, [0, 3])
    ])
    est_html_info = _get_visual_block(ct)
    assert est_html_info.kind == 'parallel'
    assert est_html_info.estimators == tuple(
        trans[1] for trans in ct.transformers)
    assert est_html_info.names == ('pca', 'svd')
    assert est_html_info.name_details == (['num1', 'num2'], [0, 3])
def test_get_visual_block_pipeline():
    pipe = Pipeline([
        ('imputer', SimpleImputer()),
        ('do_nothing', 'passthrough'),
        ('do_nothing_more', None),
        ('classifier', LogisticRegression())
    ])
    est_html_info = _get_visual_block(pipe)
    assert est_html_info.kind == 'serial'
    assert est_html_info.estimators == tuple(step[1] for step in pipe.steps)
    assert est_html_info.names == ['imputer: SimpleImputer',
                                   'do_nothing: passthrough',
                                   'do_nothing_more: passthrough',
                                   'classifier: LogisticRegression']
    assert est_html_info.name_details == [str(est) for _, est in pipe.steps]
def test_get_visual_block_pipeline():
    pipe = Pipeline([
        ("imputer", SimpleImputer()),
        ("do_nothing", "passthrough"),
        ("do_nothing_more", None),
        ("classifier", LogisticRegression()),
    ])
    est_html_info = _get_visual_block(pipe)
    assert est_html_info.kind == "serial"
    assert est_html_info.estimators == tuple(step[1] for step in pipe.steps)
    assert est_html_info.names == [
        "imputer: SimpleImputer",
        "do_nothing: passthrough",
        "do_nothing_more: passthrough",
        "classifier: LogisticRegression",
    ]
    assert est_html_info.name_details == [str(est) for _, est in pipe.steps]