예제 #1
0
from torch.utils.data import DataLoader
from practical_2.TreeDataset import TreeDataset, prepare_example, pad_batch
from practical_2.models.BOW import *
from torch.optim import *
from practical_2.callbacks.callbacks import *
from practical_2.models.CBOW import create_cbow_model
from practical_2.prepare import prepare

from practical_2.utils import *
from practical_2.train import train_model

### For reproducibility.
prepare()

train_dataset = TreeDataset("trees/train.txt")
eval_testset = TreeDataset("trees/dev.txt")
### Now we need to set the tranformation function
transform = lambda example: prepare_example(example, train_dataset.v)

train_dataset.transform = transform
eval_testset.transform = transform

collate_fn = lambda x: pad_batch(x, v)
train_dataloader = DataLoader(train_dataset,  batch_size=128, collate_fn=collate_fn)
eval_dataloader = DataLoader(eval_testset,  batch_size=128, collate_fn=collate_fn)
v = train_dataset.v

model = create_cbow_model(v)

optimizer = Adam(model.parameters(), lr=0.0005)
예제 #2
0
from torch.utils.data import DataLoader
from practical_2.TreeDataset import TreeDataset, prepare_treelstm_minibatch
from practical_2.models.BOW import *
from torch.optim import *
from practical_2.callbacks.callbacks import *
from practical_2.models.TreeLSTM import create_tree_lstm
from practical_2.prepare import prepare
from practical_2.utils import *
from practical_2.train import train_model

### For reproducibility.
prepare()

train_dataset = TreeDataset("trees/train.txt")
eval_testset = TreeDataset("trees/dev.txt")
### Now we need to set the tranformation function

model = create_tree_lstm()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)
v = model.vocab

collate_fn = lambda batch: prepare_treelstm_minibatch(batch, model.vocab)

train_dataloader = DataLoader(train_dataset,
                              batch_size=512,
                              collate_fn=collate_fn)
eval_dataloader = DataLoader(eval_testset,
                             batch_size=512,
                             collate_fn=collate_fn)
예제 #3
0
from practical_2.TreeDataset import TreeDataset

train_dataset = TreeDataset("trees/train.txt")

v = train_dataset.v
print(v.w2i['century'])