def train_batch(self, batch): self.model.zero_grad() # right thing to do? src, lengths = batch.src tgt_outer = batch.tgt[0] if isinstance(batch.tgt, tuple) else batch.tgt side_info = side_information(batch) target_size = tgt_outer.size(0) trunc_size = self.trunc_size if self.trunc_size else target_size batch_stats = Counter() for j in range(0, target_size - 1, trunc_size): # 1. Create truncated target. tgt = tgt_outer[j:j + trunc_size] self.model.zero_grad() outputs, attns = self.model(src, tgt, lengths=lengths, **side_info) logits = self.model.generator(outputs, **side_info) logits = logits.view(-1, logits.size(-1)) gold = tgt[1:].view(-1) loss = self.train_loss(logits, gold) if isinstance(loss, tuple): loss, p_star = loss else: p_star = None loss.div(batch.batch_size).backward() # this stuff should be refactored pred = logits.max(1)[1] non_pad = gold.ne(self.train_loss.ignore_index) n_correct = pred.eq(gold).masked_select(non_pad).sum().item() n_words = non_pad.sum().item() batch_loss = loss.clone().item() if self.log_sparse_attn: for k, v in attns.items(): bottled_attn = v.view(v.size(0) * v.size(1), -1) attn_support = bottled_attn.gt(0).sum(dim=1) n_attended = attn_support.masked_select(non_pad) batch_stats[k + "_attended"] += n_attended.sum().item() if self.log_gate: seq_len, batch_size, classes = attns["gate"].size() bottled_gate = attns["gate"].view(seq_len * batch_size, -1) gate_support = bottled_gate.gt(0) lemma_gate = gate_support[:, 0].masked_select(non_pad) infl_gate = gate_support[:, 1].masked_select(non_pad) both_gate = lemma_gate & infl_gate batch_stats["lemma_gate"] += lemma_gate.sum().item() batch_stats["inflection_gate"] += infl_gate.sum().item() batch_stats["both_gate"] += both_gate.sum().item() if p_star is not None: support = p_star.gt(0).sum(dim=1) support = support.masked_select(non_pad).sum().item() batch_stats["support"] += support else: batch_stats["support"] += n_words * logits.size(-1) batch_stats["loss"] += batch_loss batch_stats["n_words"] += n_words batch_stats["n_correct"] += n_correct for group in self.optim.param_groups: if self.max_grad_norm: clip_grad_norm_(group['params'], self.max_grad_norm) self.optim.step() # If truncated, don't backprop fully. if self.model.decoder.state is not None: self.model.decoder.detach_state() return batch_stats
def translate_batch(self, batch): beam_size = self.beam_size tgt_field = self.fields['tgt'][0][1] vocab = tgt_field.vocab pad = vocab.stoi[tgt_field.pad_token] eos = vocab.stoi[tgt_field.eos_token] bos = vocab.stoi[tgt_field.init_token] b = Beam(beam_size, n_best=self.n_best, cuda=self.cuda, pad=pad, eos=eos, bos=bos) src, src_lengths = batch.src # why doesn't this contain inflection source lengths when ensembling? side_info = side_information(batch) encoder_out = self.model.encode(src, lengths=src_lengths, **side_info) enc_states = encoder_out["enc_state"] memory_bank = encoder_out["memory_bank"] infl_memory_bank = encoder_out.get("inflection_memory_bank", None) self.model.init_decoder_state(enc_states) results = dict() if "tgt" in batch.__dict__: results["gold_score"] = self._score_target( batch, memory_bank, src_lengths, inflection_memory_bank=infl_memory_bank, **side_info) self.model.init_decoder_state(enc_states) else: results["gold_score"] = 0 # (2) Repeat src objects `beam_size` times. self.model.map_decoder_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) else: memory_bank = tile(memory_bank, beam_size, dim=1) memory_lengths = tile(src_lengths, beam_size) if infl_memory_bank is not None: if isinstance(infl_memory_bank, tuple): infl_memory_bank = tuple( tile(x, beam_size, dim=1) for x in infl_memory_bank) else: infl_memory_bank = tile(infl_memory_bank, beam_size, dim=1) tiled_infl_len = tile(side_info["inflection_lengths"], beam_size) side_info["inflection_lengths"] = tiled_infl_len if "language" in side_info: side_info["language"] = tile(side_info["language"], beam_size) for i in range(self.max_length): if b.done(): break inp = b.current_state.unsqueeze(0) # the decoder expects an input of tgt_len x batch dec_out, dec_attn = self.model.decode( inp, memory_bank, memory_lengths=memory_lengths, inflection_memory_bank=infl_memory_bank, **side_info) attn = dec_attn["lemma"].squeeze(0) out = self.model.generator(dec_out.squeeze(0), transform=True, **side_info) # b.advance will take attn (beam size x src length) b.advance(out, dec_attn) select_indices = b.current_origin self.model.map_decoder_state( lambda state, dim: state.index_select(dim, select_indices)) scores, ks = b.sort_finished() hyps, attn, out_probs = [], [], [] for i, (times, k) in enumerate(ks[:self.n_best]): hyp, att, out_p = b.get_hyp(times, k) hyps.append(hyp) attn.append(att) out_probs.append(out_p) results["preds"] = hyps results["scores"] = scores results["attn"] = attn if self.beam_accum is not None: parent_ids = [t.tolist() for t in b.prev_ks] self.beam_accum["beam_parent_ids"].append(parent_ids) scores = [["%4f" % s for s in t.tolist()] for t in b.all_scores][1:] self.beam_accum["scores"].append(scores) pred_ids = [[vocab.itos[i] for i in t.tolist()] for t in b.next_ys][1:] self.beam_accum["predicted_ids"].append(pred_ids) if self.attn_path is not None: save_attn = {k: v.cpu() for k, v in attn[0].items()} src_seq = self.itos(src, "src") pred_seq = self.itos(hyps[0], "tgt") attn_dict = {"src": src_seq, "pred": pred_seq, "attn": save_attn} if "inflection" in save_attn: inflection_seq = self.itos(batch.inflection[0], "inflection") attn_dict["inflection"] = inflection_seq self.attns.append(attn_dict) if self.probs_path is not None: save_probs = out_probs[0].cpu() self.probs.append(save_probs) return results
def validate(self, valid_iter): with torch.no_grad(): self.model.eval() stats = Counter() for batch in valid_iter: src, lengths = batch.src tgt = batch.tgt[0] if isinstance(batch.tgt, tuple) \ else batch.tgt side_info = side_information(batch) outputs, attns = self.model(src, tgt, lengths=lengths, **side_info) logits = self.model.generator(outputs, **side_info) logits = logits.view(-1, logits.size(-1)) gold = tgt[1:].view(-1) loss = self.valid_loss(logits, gold) if isinstance(loss, tuple): loss, p_star = loss else: p_star = None pred = logits.max(1)[1] non_pad = gold.ne(self.valid_loss.ignore_index) n_correct = pred.eq(gold).masked_select(non_pad).sum().item() n_words = non_pad.sum().item() batch_loss = loss.clone().item() if self.log_sparse_attn: for k, v in attns.items(): bottled_attn = v.view(v.size(0) * v.size(1), -1) attn_supp = bottled_attn.gt(0).sum(dim=1) n_attended = attn_supp.masked_select(non_pad) stats[k + "_attended"] += n_attended.sum().item() if self.log_gate: seq_len, batch_size, classes = attns["gate"].size() bottled_gate = attns["gate"].view(seq_len * batch_size, -1) gate_support = bottled_gate.gt(0) lemma_gate = gate_support[:, 0].masked_select(non_pad) infl_gate = gate_support[:, 1].masked_select(non_pad) both_gate = lemma_gate & infl_gate stats["lemma_gate"] += lemma_gate.sum().item() stats["inflection_gate"] += infl_gate.sum().item() stats["both_gate"] += both_gate.sum().item() if p_star is not None: support = p_star.gt(0).sum(dim=1) support = support.masked_select(non_pad).sum().item() stats["support"] += support else: stats["support"] += n_words * logits.size(-1) stats["loss"] += batch_loss stats["n_words"] += n_words stats["n_correct"] += n_correct self.model.train() return stats
def validate(self, valid_iter): with torch.no_grad(): self.model.eval() stats = Counter() for batch in valid_iter: src, lengths = batch.src tgt = batch.tgt[0] if isinstance(batch.tgt, tuple) \ else batch.tgt side_info = side_information(batch) outputs, attns = self.model(src, tgt, lengths=lengths, **side_info) logits = self.model.generator(outputs, **side_info) logits = logits.view(-1, logits.size(-1)) gold = tgt[1:].view(-1) loss = self.valid_loss(logits, gold) if isinstance(loss, tuple): loss, p_star = loss else: p_star = None pred = logits.max(1)[1] non_pad = gold.ne(self.valid_loss.ignore_index) n_correct = pred.eq(gold).masked_select(non_pad).sum().item() n_words = non_pad.sum().item() #import pdb; pdb.set_trace() batch_loss = loss.clone().item() # if self.log_global_gate_attn_mix: # attn_types = [] # for n in range(self.model_opt.global_gate_heads_number): # attn_types.append("gate_lemma_subw_global_"+str(n)) # attn_types.append("gate_lemma_char_global_"+str(n)) # for k in attn_types: # v = attns[k] # bottled_attn = v.view(v.size(0) * v.size(1), -1) # attn_supp = bottled_attn.gt(0).sum(dim=1) # n_attended = attn_supp # stats["n_words_global"] += tgt.size()[1] # batch size # stats[k + "_attended"] += n_attended.sum().item() if self.log_global_gate_attn: #for k in ["gate_lemma_global", "gate_infl_global"]: for k in range(self.model_opt.global_gate_heads_number): k = "gate_lemma_global_" + str(k) v = attns[k] bottled_attn = v.view(v.size(0) * v.size(1), -1) attn_supp = bottled_attn.gt(0).sum(dim=1) n_attended = attn_supp stats["n_words_global"] += tgt.size()[1] # batch size stats[k + "_attended"] += n_attended.sum().item() if self.log_sparse_attn: #for k, v in attns.items(): for k in ["lemma", "inflection"]: v = attns[k] bottled_attn = v.view(v.size(0) * v.size(1), -1) attn_supp = bottled_attn.gt(0).sum(dim=1) n_attended = attn_supp.masked_select(non_pad) stats[k + "_attended"] += n_attended.sum().item() if self.log_gate: seq_len, batch_size, classes = attns["gate"].size() bottled_gate = attns["gate"].view(seq_len * batch_size, -1) gate_support = bottled_gate.gt(0) lemma_gate = gate_support[:, 0].masked_select(non_pad) infl_gate = gate_support[:, 1].masked_select(non_pad) both_gate = lemma_gate & infl_gate stats["lemma_gate"] += lemma_gate.sum().item() stats["inflection_gate"] += infl_gate.sum().item() stats["both_gate"] += both_gate.sum().item() if p_star is not None: support = p_star.gt(0).sum(dim=1) support = support.masked_select(non_pad).sum().item() stats["support"] += support else: stats["support"] += n_words * logits.size(-1) stats["loss"] += batch_loss stats["n_words"] += n_words stats["n_correct"] += n_correct self.model.train() return stats