def __init__(self, device='cpu', jit=False): self.device = device self.jit = jit # Download and the load default data. WORD = torchtext.data.Field(include_lengths=True) UD_TAG = torchtext.data.Field( init_token="<bos>", eos_token="<eos>", include_lengths=True ) # Download and the load default data. train, val, test = torchtext.datasets.UDPOS.splits( fields=(("word", WORD), ("udtag", UD_TAG), (None, None)), filter_pred=lambda ex: 5 < len(ex.word) < 30, ) WORD.build_vocab(train.word, min_freq=3) UD_TAG.build_vocab(train.udtag) self.train_iter = torch_struct.data.TokenBucket(train, batch_size=100, device=device) H = 256 T = 30 NT = 30 self.model = NeuralCFG(len(WORD.vocab), T, NT, H) if jit: self.model = torch.jit.script(self.model) self.model.to(device=device) self.opt = torch.optim.Adam(self.model.parameters(), lr=0.001, betas=[0.75, 0.999]) for i, ex in enumerate(self.train_iter): words, lengths = ex.word self.words = words.long().to(device).transpose(0, 1) self.lengths = lengths.to(device) break
class Model: def __init__(self, device='cpu', jit=False): self.device = device self.jit = jit # Download and the load default data. WORD = torchtext.data.Field(include_lengths=True) UD_TAG = torchtext.data.Field(init_token="<bos>", eos_token="<eos>", include_lengths=True) # Download and the load default data. train, val, test = torchtext.datasets.UDPOS.splits( fields=(("word", WORD), ("udtag", UD_TAG), (None, None)), filter_pred=lambda ex: 5 < len(ex.word) < 30, ) WORD.build_vocab(train.word, min_freq=3) UD_TAG.build_vocab(train.udtag) self.train_iter = torch_struct.data.TokenBucket(train, batch_size=100, device=device) H = 256 T = 30 NT = 30 self.model = NeuralCFG(len(WORD.vocab), T, NT, H) if jit: self.model = torch.jit.script(self.model) self.model.to(device=device) self.opt = torch.optim.Adam(self.model.parameters(), lr=0.001, betas=[0.75, 0.999]) def get_module(self): for ex in self.train_iter: words, _ = ex.word words = words.long() return self.model, (words.to(device=self.device).transpose(0, 1), ) def train(self, niter=1): losses = [] for i, ex in enumerate(self.train_iter): if i == niter: break self.opt.zero_grad() words, lengths = ex.word words = words.long() params = self.model(words.to(device=self.device).transpose(0, 1)) dist = SentCFG(params, lengths=lengths) loss = dist.partition.mean() (-loss).backward() losses.append(loss.detach()) torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3.0) self.opt.step() def eval(self, niter=1): pass
# Download and the load default data. train, val, test = UDPOS.splits( fields=(("word", WORD), ("udtag", UD_TAG), (None, None)), filter_pred=lambda ex: 5 < len(ex.word) < 30, ) WORD.build_vocab(train.word, min_freq=3) UD_TAG.build_vocab(train.udtag) train_iter = torch_struct.data.TokenBucket(train, batch_size=100, device="cuda:0") H = 256 T = 30 NT = 30 model = NeuralCFG(len(WORD.vocab), T, NT, H) if args.script: print("scripting...") model = torch.jit.script(model) model.cuda() opt = torch.optim.Adam(model.parameters(), lr=0.001, betas=[0.75, 0.999]) def train(): # model.train() losses = [] for epoch in range(2): for i, ex in enumerate(train_iter): opt.zero_grad() words, lengths = ex.word N, batch = words.shape