def test_predict_functions_honor_device(dataset, func): X_train, X_test, y_train, y_test, vocab = dataset mod = TorchTreeNN(vocab, max_iter=2) mod.fit(X_train, y_train) prediction_func = getattr(mod, func) with pytest.raises(RuntimeError): prediction_func(X_test, device="FAKE_DEVICE")
def test_model_graph_dimensions(dataset, mod_attr, dim, graph_attr): X_train, X_test, y_train, y_test, vocab = dataset mod = TorchTreeNN(vocab, max_iter=1) mod.fit(X_train, y_train) mod_attr_val = getattr(mod, mod_attr) graph_attr_val = getattr(mod.model, graph_attr).weight.shape[dim] assert mod_attr_val == graph_attr_val
def test_build_tree_rep(tree, subtree_indices, emb_indices, n): tree = Tree.fromstring(tree) vocab = ["0", "1", "2", "3", "4", "5", "6", "$UNK"] mod = TorchTreeNN(vocab) result = mod._build_tree_rep(tree) assert result[0] == subtree_indices assert result[1] == emb_indices assert result[2] == n
def test_parameter_setting(param, expected): vocab = [] mod = TorchTreeNN(vocab) mod.set_params(**{param: expected}) result = getattr(mod, param) if param == "embedding": assert np.array_equal(result, expected) else: assert result == expected
def test_predict_restores_device(dataset, func): X_train, X_test, y_train, y_test, vocab = dataset mod = TorchTreeNN(vocab, max_iter=2) mod.fit(X_train, y_train) current_device = mod.device assert current_device != torch.device("cpu:0") prediction_func = getattr(mod, func) prediction_func(X_test, device="cpu:0") assert mod.device == current_device
def test_build_dataset(dataset, with_y, expected): X_train, X_test, y_train, y_test, vocab = dataset mod = TorchTreeNN(vocab) if with_y: dataset = mod.build_dataset(X_train, y_train) else: dataset = mod.build_dataset(X_train) result = next(iter(dataset)) assert len(result) == expected
def test_pretrained_embedding(dataset): X_train, X_test, y_train, y_test, vocab = dataset embed_dim = 5 embedding = np.ones((len(vocab), embed_dim)) mod = TorchTreeNN(vocab, max_iter=1, embedding=embedding, freeze_embedding=True) mod.fit(X_train, y_train) graph_emb = mod.model.embedding.weight.detach().cpu().numpy() assert np.array_equal(embedding, graph_emb)
def test_torch_tree_nn_save_load(dataset): X_train, X_test, y_train, y_test, vocab = dataset mod = TorchTreeNN(vocab, embed_dim=50, max_iter=100, embedding=None) mod.fit(X_train, y_train) mod.predict(X_test) with tempfile.NamedTemporaryFile(mode='wb') as f: name = f.name mod.to_pickle(name) mod2 = TorchTreeNN.from_pickle(name) mod2.predict(X_test) mod2.fit(X_test, y_test)
# In[41]: tree_glove_dev_predictions = tree_nn_glove.predict(X_tree_dev) # In[42]: print(classification_report(y_tree_dev, tree_glove_dev_predictions)) # ### PyTorch TreeNN implementation # In[43]: torch_tree_nn_glove = TorchTreeNN(sst_glove_vocab, embedding=glove_embedding, embed_dim=50, max_iter=10, eta=0.05) # In[44]: get_ipython().run_line_magic('time', '_ = torch_tree_nn_glove.fit(X_tree_train)') # As with `TreeNN` above, you have the option of specifying the labels separately, and this is required if you are cross-validating the model using scikit-learn methods. # In[45]: torch_tree_glove_dev_predictions = torch_tree_nn_glove.predict(X_tree_dev) # In[46]:
def test_simple_example_params(dataset, param, expected): X_train, X_test, y_train, y_test, vocab = dataset model = TorchTreeNN(vocab, **{param: expected}) model.fit(X_train, y_train) preds = model.predict(X_test) assert accuracy_score(y_test, preds)