예제 #1
0
def test_predict_functions_honor_device(dataset, func):
    X_train, X_test, y_train, y_test, vocab = dataset
    mod = TorchTreeNN(vocab, max_iter=2)
    mod.fit(X_train, y_train)
    prediction_func = getattr(mod, func)
    with pytest.raises(RuntimeError):
        prediction_func(X_test, device="FAKE_DEVICE")
예제 #2
0
def test_model_graph_dimensions(dataset, mod_attr, dim, graph_attr):
    X_train, X_test, y_train, y_test, vocab = dataset
    mod = TorchTreeNN(vocab, max_iter=1)
    mod.fit(X_train, y_train)
    mod_attr_val = getattr(mod, mod_attr)
    graph_attr_val = getattr(mod.model, graph_attr).weight.shape[dim]
    assert mod_attr_val == graph_attr_val
예제 #3
0
def test_predict_restores_device(dataset, func):
    X_train, X_test, y_train, y_test, vocab = dataset
    mod = TorchTreeNN(vocab, max_iter=2)
    mod.fit(X_train, y_train)
    current_device = mod.device
    assert current_device != torch.device("cpu:0")
    prediction_func = getattr(mod, func)
    prediction_func(X_test, device="cpu:0")
    assert mod.device == current_device
예제 #4
0
def test_torch_tree_nn_save_load(dataset):
    X_train, X_test, y_train, y_test, vocab = dataset
    mod = TorchTreeNN(vocab, embed_dim=50, max_iter=100, embedding=None)
    mod.fit(X_train, y_train)
    mod.predict(X_test)
    with tempfile.NamedTemporaryFile(mode='wb') as f:
        name = f.name
        mod.to_pickle(name)
        mod2 = TorchTreeNN.from_pickle(name)
        mod2.predict(X_test)
        mod2.fit(X_test, y_test)
예제 #5
0
def test_pretrained_embedding(dataset):
    X_train, X_test, y_train, y_test, vocab = dataset
    embed_dim = 5
    embedding = np.ones((len(vocab), embed_dim))
    mod = TorchTreeNN(vocab,
                      max_iter=1,
                      embedding=embedding,
                      freeze_embedding=True)
    mod.fit(X_train, y_train)
    graph_emb = mod.model.embedding.weight.detach().cpu().numpy()
    assert np.array_equal(embedding, graph_emb)
예제 #6
0
def test_simple_example_params(dataset, param, expected):
    X_train, X_test, y_train, y_test, vocab = dataset
    model = TorchTreeNN(vocab, **{param: expected})
    model.fit(X_train, y_train)
    preds = model.predict(X_test)
    assert accuracy_score(y_test, preds)