Ejemplo n.º 1
0
def test_transform_after_fit_no_labels(solver):
    rng = np.random.mtrand.RandomState(36)
    X = rng.randn(7, 5)
    Y = rng.randn(5, 3)
    X_new = rng.randn(15, 5)

    model = CMF(n_components=2, solver=solver, x_init='svd', y_init='svd',
                U_non_negative=False, V_non_negative=False, Z_non_negative=False,
                random_state=0, max_iter=100)
    U_ft, V_ft, Z_ft = model.fit_transform(X, Y)

    U_t, V_t, Z_t = model.transform(X_new, None)
    assert_array_equal(V_t, V_ft)
Ejemplo n.º 2
0
def test_transform_after_fit(solver):
    rng = np.random.mtrand.RandomState(36)
    X = rng.randn(7, 5)
    Y = rng.randn(5, 3)

    fit_model = CMF(n_components=2, solver=solver, x_init='random', y_init='random',
                    U_non_negative=False, V_non_negative=False, Z_non_negative=False,
                    random_state=0, max_iter=100)
    fit_transform_model = clone(fit_model)

    fit_model.fit(X, Y)
    U_f, V_f, Z_f = fit_model.transform(X, Y)
    U_ft, V_ft, Z_ft = fit_transform_model.fit_transform(X, Y)

    # the initalizations will differ, so the results may also differ slightly
    assert_array_almost_equal(U_f, U_ft, decimal=2)
    assert_array_almost_equal(V_f, V_ft, decimal=2)
    assert_array_almost_equal(Z_f, Z_ft, decimal=2)