示例#1
0
def test_log_reg_sklearn_coherence():
    """Checks that the sklearn and creme implementations produce the same results."""

    ss = preprocessing.StandardScaler()
    cr = lm.LogisticRegression(optimizer=optim.SGD(.01))
    sk = sklm.SGDClassifier(learning_rate='constant', eta0=.01, alpha=.0, loss='log')

    for x, y in datasets.Bananas():
        x = ss.fit_one(x).transform_one(x)
        cr.fit_one(x, y)
        sk.partial_fit([list(x.values())], [y], classes=[False, True])

    for i, w in enumerate(cr.weights.values()):
        assert math.isclose(w, sk.coef_[0][i])

    assert math.isclose(cr.intercept, sk.intercept_[0])
示例#2
0
def test_perceptron_sklearn_coherence():
    """Checks that the sklearn and creme implementations produce the same results."""

    ss = preprocessing.StandardScaler()
    cr = lm.Perceptron()
    sk = sklm.Perceptron()

    for x, y in datasets.Bananas():
        x = ss.fit_one(x).transform_one(x)
        cr.fit_one(x, y)
        sk.partial_fit([list(x.values())], [y], classes=[False, True])

    for i, w in enumerate(cr.weights.values()):
        assert math.isclose(w, sk.coef_[0][i])

    assert math.isclose(cr.intercept, sk.intercept_[0])
示例#3
0
        for j in p:
            p[j] /= norm
        yield p


@pytest.mark.parametrize(
    'lm, dataset',
    [
        pytest.param(
            lm(optimizer=copy.deepcopy(optimizer), initializer=initializer, l2=0),
            dataset,
            id=f'{lm.__name__} - {optimizer} - {initializer}'
        )
        for lm, dataset in [
            (lm.LinearRegression, datasets.TrumpApproval()),
            (lm.LogisticRegression, datasets.Bananas())
        ]
        for optimizer, initializer in itertools.product(
            [
                optim.AdaBound(),
                optim.AdaDelta(),
                optim.AdaGrad(),
                optim.AdaMax(),
                optim.Adam(),
                optim.AMSGrad(),
                # TODO: check momentum optimizers
                # optim.Momentum(),
                # optim.NesterovMomentum(),
                optim.RMSProp(),
                optim.SGD()
            ],