def test_torch_tree_nn_incremental(X_tree): X, vocab = X_tree model = torch_tree_nn.TorchTreeNN(vocab, embed_dim=50, hidden_dim=50, max_iter=100, embedding=None) model.fit(X, X_dev=X, dev_iter=20) epochs = list(model.dev_predictions.keys()) assert epochs == list(range(20, 101, 20)) assert all(len(v) == len(X) for v in model.dev_predictions.values())
def test_torch_tree_nn_save_load(X_tree): X, vocab = X_tree mod = torch_tree_nn.TorchTreeNN(vocab, embed_dim=50, hidden_dim=50, max_iter=100, embedding=None) mod.fit(X) mod.predict(X) with tempfile.NamedTemporaryFile(mode='wb') as f: name = f.name mod.to_pickle(name) mod2 = torch_tree_nn.TorchTreeNN.from_pickle(name) mod2.predict(X) mod2.fit(X)
{ 'hidden_dim': 10, 'eta': 1.0, 'max_iter': 100, 'l2_strength': 0.01, 'embed_dim': 100, 'bidirectional': False } ], [ np_tree_nn.TreeNN( vocab=[], max_iter=10, hidden_dim=5, eta=0.1), {'embed_dim': 5, 'hidden_dim': 10, 'eta': 1.0, 'max_iter': 100} ], [ torch_tree_nn.TorchTreeNN( vocab=[], max_iter=10, hidden_dim=5, eta=0.1), { 'embed_dim': 5, 'hidden_dim': 10, 'hidden_activation': nn.ReLU(), 'eta': 1.0, 'max_iter': 100, 'l2_strength': 0.01 } ], [ np_shallow_neural_classifier.ShallowNeuralClassifier( hidden_dim=5, max_iter=1, eta=1.0), { 'hidden_dim': 10, # Reset to ReLU: