Esempio n. 1
0
def test_simple_example_params(count_matrix, param, expected):
    X = count_matrix
    mod = TorchGloVe(**{param: expected})
    G = mod.fit(X)
    corr = mod.score(X)
    if not (param == "max_iter" and expected == 0):
        assert corr > 0.40
Esempio n. 2
0
def test_model_graph_embed_dim(count_matrix, param):
    X = count_matrix
    mod = TorchGloVe(max_iter=1)
    mod.fit(X)
    mod_attr_val = mod.embed_dim
    graph_param = getattr(mod.model, param)
    graph_attr_val = graph_param.shape[1]
    assert mod_attr_val == graph_attr_val
Esempio n. 3
0
def test_build_dataset(count_matrix):
    X = count_matrix
    # We needn't do the actual calculation to test here:
    weights = X
    mod = TorchGloVe()
    dataset = mod.build_dataset(X, weights)
    result = next(iter(dataset))
    assert len(result) == 3
Esempio n. 4
0
def test_model(count_matrix, pandas):
    X = count_matrix
    if pandas:
        X = pd.DataFrame(X)
    glove = TorchGloVe()
    G = glove.fit(X)
    G_is_pandas = isinstance(G, pd.DataFrame)
    assert G_is_pandas == pandas
Esempio n. 5
0
def test_save_load(count_matrix):
    X = count_matrix
    mod = TorchGloVe(max_iter=2)
    mod.fit(X)
    with tempfile.NamedTemporaryFile(mode='wb') as f:
        name = f.name
        mod.to_pickle(name)
        mod2 = TorchGloVe.from_pickle(name)
        mod2.fit(X)
Esempio n. 6
0
def test_parameter_setting(param, expected):
    mod = TorchGloVe()
    mod.set_params(**{param: expected})
    result = getattr(mod, param)
    assert result == expected
Esempio n. 7
0
def test_params(param, expected):
    mod = TorchGloVe(**{param: expected})
    result = getattr(mod, param)
    assert result == expected