def train_sample(self, sample): (article, context), targets = sample if self.hier: hidden_state = self.encoder.init_hidden() summ_hidden_state = self.encoder.init_hidden(n=self.opt.summLstmLayers, K=self.opt.K) encoder_out, hidden_state, _ = self.encoder(article, hidden_state, summ_hidden_state) err = 0 teacher_forcing = self.opt.useTeacherForcing if random.random() < 0.5 else False if teacher_forcing: for i in range(len(targets)): target = targets[i].unsqueeze(0) ctx = context[i].unsqueeze(0) out, hidden_state, _ = self.mlp(encoder_out, ctx, hidden_state) err += self.loss(out, target) else: ctx = apply_cuda(torch.tensor(self.dict["w2i"]["<s>"])) for i in range(len(targets)): target = targets[i].unsqueeze(0) ctx = ctx.unsqueeze(0).unsqueeze(0) out, hidden_state, _ = self.mlp(encoder_out, ctx, hidden_state) err += self.loss(out, target) topv, topi = out.topk(1) ctx = topi.squeeze().detach() else: out, attn = self.mlp(article, context) err = self.loss(out, targets) return err
def make_input(article, context, K): bucket = article.size(0) article_tensor = apply_cuda( article.view(bucket, 1).expand(bucket, K).t().contiguous()) return [ Variable(tensor.long()) for tensor in [article_tensor, context] ]
def torchify(arr, variable=False, revsort=False, opt=None): if variable: batch_size = len(arr) if revsort: arr = sorted(arr, key=len)[::-1] lengths = [len(batch) for batch in arr] largest_length = opt.maxWordLength out = torch.zeros(batch_size, largest_length).long() # HACK There must be a better way to do this for batch in range(batch_size): for j in range(lengths[batch]): out[batch][j] = arr[batch][j] return apply_cuda(Variable(out)), lengths else: return apply_cuda(Variable(torch.tensor(list(arr)).long()))
def __init__(self, opt, dict): super(Trainer, self).__init__() self.opt = opt self.dict = dict self.hier = opt.hier if opt.restore: if opt.hier: self.mlp, self.encoder = torch.load(opt.model) else: self.mlp = torch.load(opt.model) self.encoder = self.mlp.encoder self.mlp.epoch += 1 print("Restoring MLP {} with epoch {}".format( opt.model, self.mlp.epoch)) else: if opt.hier: glove_weights = build_glove(dict["w2i"]) if opt.glove else None self.encoder = apply_cuda(HierAttnEncoder(len(dict["i2w"]), opt.bowDim, opt.hiddenSize, opt, glove_weights)) self.mlp = apply_cuda(HierAttnDecoder(len(dict["i2w"]), opt.bowDim, opt.hiddenSize, opt, glove_weights)) else: self.mlp = apply_cuda(LanguageModel(self.dict, opt)) self.encoder = self.mlp.encoder self.mlp.epoch = 0 self.loss = apply_cuda(nn.NLLLoss(ignore_index=0)) self.decoder_embedding = self.mlp.context_embedding if opt.hier: c = 0.9 self.encoder_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, self.encoder.parameters()), self.opt.learningRate, momentum=c, weight_decay=c) self.optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, self.mlp.parameters()), self.opt.learningRate, momentum=c, weight_decay=c) else: self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.mlp.parameters()), self.opt.learningRate) # Half learning rate
def init_hidden(self, n=1, K=1): return (apply_cuda(torch.zeros(n, K, self.hidden_size)), apply_cuda(torch.zeros(n, K, self.hidden_size)))
def main(): state = torch.load(opt.model) if opt.hier: mlp, encoder = state else: mlp = state dict = data.load_dict(opt.dictionary) sent_file = open(opt.inputf).read().split("\n") length = opt.length if not opt.hier: W = mlp.window opt.window = mlp.window else: W = 1 w2i = dict["w2i"] i2w = dict["i2w"] K = opt.beamSize actual = open(opt.outputf).read().split('\n') sent_num = 0 with torch.no_grad(): for line in sent_file: if line.strip() == "": continue # Add padding if opt.hier: summaries = extractive(line).split("\t") print("\n> {}...".format(summaries[0])) encoded_summaries = [ encode("<s> {} </s>".format(normalize(summary)), w2i) for summary in summaries ] article = HierDataLoader.torchify(encoded_summaries, variable=True, revsort=True, opt=opt) hidden_state = encoder.init_hidden() summ_hidden_state = encoder.init_hidden(n=opt.summLstmLayers, K=opt.K) print(hidden_state[0].shape, summ_hidden_state[0].shape) print(article[0].shape) encoder_out, hidden_state, _ = encoder(article, hidden_state, summ_hidden_state) else: print("\n> {}".format(line)) true_line = "<s> <s> <s> {} </s> </s> </s>".format( normalize(line)) article = torch.tensor(encode(true_line, w2i)) n = opt.length hyps = apply_cuda(torch.zeros(K, W + n).long().fill_(w2i["<s>"])) scores = apply_cuda(torch.zeros(K).float()) if opt.hier: hidden_size = len(hidden_state[0][0][0]) hidden = apply_cuda(torch.zeros(K, hidden_size).float()) cell = apply_cuda(torch.zeros(K, hidden_size).float()) for k in range(K): hidden[k] = hidden_state[0][0] cell[k] = hidden_state[1][0] for step in range(n): new_candidates = [] start = step end = step + W context = hyps[:, start:end] # context if opt.hier: model_scores = torch.zeros(K, len(w2i)) for c in range(K): ctx = context[c].view(1, -1) ctx = article[0][0][step].view(1, -1) model_scores[c], new_hidden, attn = mlp( encoder_out, ctx, (hidden[c].view(1, 1, -1), cell[c].view(1, 1, -1))) hidden[c] = new_hidden[0] cell[c] = new_hidden[1] else: article_t, context_t = AbsDataLoader.make_input( article, context, K) model_scores, attn = mlp(article_t, context_t) out_scores = model_scores.data # Apply hard constraints finalized = (step == n - 1) and opt.fixedLength set_hard_constraints(out_scores, w2i, finalized) for sample in range(K): # Per certain context top_scores, top_indexes = torch.topk(out_scores[sample], K) for ix, score in zip(top_indexes, top_scores): repetition = opt.noRepeat and apply_cuda( ix) in apply_cuda(hyps[sample]) combined = torch.cat((hyps[sample][:end], apply_cuda(torch.tensor([ix])))) if opt.hier: candidate = [ combined, -INF if repetition else scores[sample] + apply_cuda(score), hidden[c], cell[c] ] else: candidate = [ combined, -INF if repetition else scores[sample] + apply_cuda(score), None, None ] new_candidates.append(candidate) ordered = list( reversed(sorted(new_candidates, key=lambda cand: cand[1]))) h, s, hidden_temp, cell_temp = zip(*ordered) for r in range(K): hyps[r][0:end + 1] = h[r] scores[r] = s[r] if opt.hier: hidden[r] = hidden_temp[r] cell[r] = cell_temp[r] s, top_ixs = torch.topk(scores, 1) final = hyps[int(top_ixs)][W:-1] print("= {}".format(actual[sent_num])) print("< {}".format(decode(final, i2w))) print("") sent_num += 1
def torchify(arr, variable=False): return apply_cuda(Variable(torch.tensor(list(arr)).long()))