예제 #1
0
파일: test_maui.py 프로젝트: hayat221/maui
def test_maui_updates_neural_weight_product_when_training():
    maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)

    z_before = maui_model.fit_transform({"d1": df1, "d2": df2})
    nwp_before_fine_tuning = maui_model.get_neural_weight_product()

    maui_model.fine_tune({"d1": df1, "d2": df2})
    z_after = maui_model.transform({"d1": df1, "d2": df2})
    nwp_after_fine_tuning = maui_model.get_neural_weight_product()

    assert not np.allclose(z_before, z_after)
    assert not np.allclose(nwp_before_fine_tuning, nwp_after_fine_tuning)
예제 #2
0
def test_maui_complains_if_fine_tune_with_wrong_features():
    maui_model = Maui(n_hidden=[], n_latent=2, epochs=1)
    maui_model.fit({"d1": df1, "d2": df2})

    df1_wrong_features = df1.reindex(df1.index[:len(df1.index) - 1])
    with pytest.raises(ValueError):
        z = maui_model.fine_tune({"df1": df1_wrong_features, "df2": df2})
예제 #3
0
def test_maui_can_fine_tune():
    maui_model = Maui(n_hidden=[], n_latent=2, epochs=1)
    maui_model = maui_model.fit({"d1": df1, "d2": df2})
    maui_model.fine_tune({"d1": df1, "d2": df2}, epochs=1)