示例#1
0
    def train_batch(self, batch):
        self.model.zero_grad()  # right thing to do?
        src, lengths = batch.src
        tgt_outer = batch.tgt[0] if isinstance(batch.tgt, tuple) else batch.tgt

        side_info = side_information(batch)

        target_size = tgt_outer.size(0)

        trunc_size = self.trunc_size if self.trunc_size else target_size
        batch_stats = Counter()

        for j in range(0, target_size - 1, trunc_size):
            # 1. Create truncated target.
            tgt = tgt_outer[j:j + trunc_size]

            self.model.zero_grad()
            outputs, attns = self.model(src, tgt, lengths=lengths, **side_info)

            logits = self.model.generator(outputs, **side_info)
            logits = logits.view(-1, logits.size(-1))
            gold = tgt[1:].view(-1)

            loss = self.train_loss(logits, gold)
            if isinstance(loss, tuple):
                loss, p_star = loss
            else:
                p_star = None
            loss.div(batch.batch_size).backward()

            # this stuff should be refactored
            pred = logits.max(1)[1]
            non_pad = gold.ne(self.train_loss.ignore_index)
            n_correct = pred.eq(gold).masked_select(non_pad).sum().item()
            n_words = non_pad.sum().item()
            batch_loss = loss.clone().item()
            if self.log_sparse_attn:
                for k, v in attns.items():
                    bottled_attn = v.view(v.size(0) * v.size(1), -1)
                    attn_support = bottled_attn.gt(0).sum(dim=1)
                    n_attended = attn_support.masked_select(non_pad)
                    batch_stats[k + "_attended"] += n_attended.sum().item()
            if self.log_gate:
                seq_len, batch_size, classes = attns["gate"].size()
                bottled_gate = attns["gate"].view(seq_len * batch_size, -1)
                gate_support = bottled_gate.gt(0)
                lemma_gate = gate_support[:, 0].masked_select(non_pad)
                infl_gate = gate_support[:, 1].masked_select(non_pad)
                both_gate = lemma_gate & infl_gate
                batch_stats["lemma_gate"] += lemma_gate.sum().item()
                batch_stats["inflection_gate"] += infl_gate.sum().item()
                batch_stats["both_gate"] += both_gate.sum().item()
            if p_star is not None:
                support = p_star.gt(0).sum(dim=1)
                support = support.masked_select(non_pad).sum().item()
                batch_stats["support"] += support
            else:
                batch_stats["support"] += n_words * logits.size(-1)

            batch_stats["loss"] += batch_loss
            batch_stats["n_words"] += n_words
            batch_stats["n_correct"] += n_correct

            for group in self.optim.param_groups:
                if self.max_grad_norm:
                    clip_grad_norm_(group['params'], self.max_grad_norm)
            self.optim.step()

            # If truncated, don't backprop fully.
            if self.model.decoder.state is not None:
                self.model.decoder.detach_state()
        return batch_stats
示例#2
0
    def translate_batch(self, batch):
        beam_size = self.beam_size
        tgt_field = self.fields['tgt'][0][1]
        vocab = tgt_field.vocab

        pad = vocab.stoi[tgt_field.pad_token]
        eos = vocab.stoi[tgt_field.eos_token]
        bos = vocab.stoi[tgt_field.init_token]
        b = Beam(beam_size,
                 n_best=self.n_best,
                 cuda=self.cuda,
                 pad=pad,
                 eos=eos,
                 bos=bos)

        src, src_lengths = batch.src
        # why doesn't this contain inflection source lengths when ensembling?
        side_info = side_information(batch)

        encoder_out = self.model.encode(src, lengths=src_lengths, **side_info)
        enc_states = encoder_out["enc_state"]
        memory_bank = encoder_out["memory_bank"]
        infl_memory_bank = encoder_out.get("inflection_memory_bank", None)

        self.model.init_decoder_state(enc_states)

        results = dict()

        if "tgt" in batch.__dict__:
            results["gold_score"] = self._score_target(
                batch,
                memory_bank,
                src_lengths,
                inflection_memory_bank=infl_memory_bank,
                **side_info)
            self.model.init_decoder_state(enc_states)
        else:
            results["gold_score"] = 0

        # (2) Repeat src objects `beam_size` times.
        self.model.map_decoder_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
        memory_lengths = tile(src_lengths, beam_size)

        if infl_memory_bank is not None:
            if isinstance(infl_memory_bank, tuple):
                infl_memory_bank = tuple(
                    tile(x, beam_size, dim=1) for x in infl_memory_bank)
            else:
                infl_memory_bank = tile(infl_memory_bank, beam_size, dim=1)
            tiled_infl_len = tile(side_info["inflection_lengths"], beam_size)
            side_info["inflection_lengths"] = tiled_infl_len

        if "language" in side_info:
            side_info["language"] = tile(side_info["language"], beam_size)

        for i in range(self.max_length):
            if b.done():
                break

            inp = b.current_state.unsqueeze(0)

            # the decoder expects an input of tgt_len x batch
            dec_out, dec_attn = self.model.decode(
                inp,
                memory_bank,
                memory_lengths=memory_lengths,
                inflection_memory_bank=infl_memory_bank,
                **side_info)
            attn = dec_attn["lemma"].squeeze(0)
            out = self.model.generator(dec_out.squeeze(0),
                                       transform=True,
                                       **side_info)

            # b.advance will take attn (beam size x src length)
            b.advance(out, dec_attn)
            select_indices = b.current_origin

            self.model.map_decoder_state(
                lambda state, dim: state.index_select(dim, select_indices))

        scores, ks = b.sort_finished()
        hyps, attn, out_probs = [], [], []
        for i, (times, k) in enumerate(ks[:self.n_best]):
            hyp, att, out_p = b.get_hyp(times, k)
            hyps.append(hyp)
            attn.append(att)
            out_probs.append(out_p)

        results["preds"] = hyps
        results["scores"] = scores
        results["attn"] = attn

        if self.beam_accum is not None:
            parent_ids = [t.tolist() for t in b.prev_ks]
            self.beam_accum["beam_parent_ids"].append(parent_ids)
            scores = [["%4f" % s for s in t.tolist()]
                      for t in b.all_scores][1:]
            self.beam_accum["scores"].append(scores)
            pred_ids = [[vocab.itos[i] for i in t.tolist()]
                        for t in b.next_ys][1:]
            self.beam_accum["predicted_ids"].append(pred_ids)

        if self.attn_path is not None:
            save_attn = {k: v.cpu() for k, v in attn[0].items()}
            src_seq = self.itos(src, "src")
            pred_seq = self.itos(hyps[0], "tgt")
            attn_dict = {"src": src_seq, "pred": pred_seq, "attn": save_attn}
            if "inflection" in save_attn:
                inflection_seq = self.itos(batch.inflection[0], "inflection")
                attn_dict["inflection"] = inflection_seq
            self.attns.append(attn_dict)

        if self.probs_path is not None:
            save_probs = out_probs[0].cpu()
            self.probs.append(save_probs)

        return results
示例#3
0
    def validate(self, valid_iter):
        with torch.no_grad():
            self.model.eval()

            stats = Counter()

            for batch in valid_iter:
                src, lengths = batch.src
                tgt = batch.tgt[0] if isinstance(batch.tgt, tuple) \
                    else batch.tgt
                side_info = side_information(batch)

                outputs, attns = self.model(src,
                                            tgt,
                                            lengths=lengths,
                                            **side_info)
                logits = self.model.generator(outputs, **side_info)
                logits = logits.view(-1, logits.size(-1))
                gold = tgt[1:].view(-1)

                loss = self.valid_loss(logits, gold)
                if isinstance(loss, tuple):
                    loss, p_star = loss
                else:
                    p_star = None

                pred = logits.max(1)[1]
                non_pad = gold.ne(self.valid_loss.ignore_index)
                n_correct = pred.eq(gold).masked_select(non_pad).sum().item()
                n_words = non_pad.sum().item()
                batch_loss = loss.clone().item()
                if self.log_sparse_attn:
                    for k, v in attns.items():
                        bottled_attn = v.view(v.size(0) * v.size(1), -1)
                        attn_supp = bottled_attn.gt(0).sum(dim=1)
                        n_attended = attn_supp.masked_select(non_pad)
                        stats[k + "_attended"] += n_attended.sum().item()
                if self.log_gate:
                    seq_len, batch_size, classes = attns["gate"].size()
                    bottled_gate = attns["gate"].view(seq_len * batch_size, -1)
                    gate_support = bottled_gate.gt(0)
                    lemma_gate = gate_support[:, 0].masked_select(non_pad)
                    infl_gate = gate_support[:, 1].masked_select(non_pad)
                    both_gate = lemma_gate & infl_gate
                    stats["lemma_gate"] += lemma_gate.sum().item()
                    stats["inflection_gate"] += infl_gate.sum().item()
                    stats["both_gate"] += both_gate.sum().item()
                if p_star is not None:
                    support = p_star.gt(0).sum(dim=1)
                    support = support.masked_select(non_pad).sum().item()
                    stats["support"] += support
                else:
                    stats["support"] += n_words * logits.size(-1)

                stats["loss"] += batch_loss
                stats["n_words"] += n_words
                stats["n_correct"] += n_correct

            self.model.train()

            return stats
示例#4
0
    def validate(self, valid_iter):
        with torch.no_grad():
            self.model.eval()

            stats = Counter()

            for batch in valid_iter:
                src, lengths = batch.src
                tgt = batch.tgt[0] if isinstance(batch.tgt, tuple) \
                    else batch.tgt
                side_info = side_information(batch)

                outputs, attns = self.model(src,
                                            tgt,
                                            lengths=lengths,
                                            **side_info)
                logits = self.model.generator(outputs, **side_info)
                logits = logits.view(-1, logits.size(-1))
                gold = tgt[1:].view(-1)

                loss = self.valid_loss(logits, gold)
                if isinstance(loss, tuple):
                    loss, p_star = loss
                else:
                    p_star = None

                pred = logits.max(1)[1]
                non_pad = gold.ne(self.valid_loss.ignore_index)
                n_correct = pred.eq(gold).masked_select(non_pad).sum().item()
                n_words = non_pad.sum().item()

                #import pdb; pdb.set_trace()
                batch_loss = loss.clone().item()
                # if self.log_global_gate_attn_mix:
                #     attn_types = []
                #     for n in range(self.model_opt.global_gate_heads_number):
                #         attn_types.append("gate_lemma_subw_global_"+str(n))
                #         attn_types.append("gate_lemma_char_global_"+str(n))
                #     for k in attn_types:
                #         v = attns[k]
                #         bottled_attn = v.view(v.size(0) * v.size(1), -1)
                #         attn_supp = bottled_attn.gt(0).sum(dim=1)
                #         n_attended = attn_supp
                #         stats["n_words_global"] += tgt.size()[1] # batch size
                #         stats[k + "_attended"] += n_attended.sum().item()
                if self.log_global_gate_attn:
                    #for k in ["gate_lemma_global", "gate_infl_global"]:
                    for k in range(self.model_opt.global_gate_heads_number):
                        k = "gate_lemma_global_" + str(k)
                        v = attns[k]
                        bottled_attn = v.view(v.size(0) * v.size(1), -1)
                        attn_supp = bottled_attn.gt(0).sum(dim=1)
                        n_attended = attn_supp
                        stats["n_words_global"] += tgt.size()[1]  # batch size
                        stats[k + "_attended"] += n_attended.sum().item()
                if self.log_sparse_attn:
                    #for k, v in attns.items():
                    for k in ["lemma", "inflection"]:
                        v = attns[k]
                        bottled_attn = v.view(v.size(0) * v.size(1), -1)
                        attn_supp = bottled_attn.gt(0).sum(dim=1)
                        n_attended = attn_supp.masked_select(non_pad)
                        stats[k + "_attended"] += n_attended.sum().item()
                if self.log_gate:
                    seq_len, batch_size, classes = attns["gate"].size()
                    bottled_gate = attns["gate"].view(seq_len * batch_size, -1)
                    gate_support = bottled_gate.gt(0)
                    lemma_gate = gate_support[:, 0].masked_select(non_pad)
                    infl_gate = gate_support[:, 1].masked_select(non_pad)
                    both_gate = lemma_gate & infl_gate
                    stats["lemma_gate"] += lemma_gate.sum().item()
                    stats["inflection_gate"] += infl_gate.sum().item()
                    stats["both_gate"] += both_gate.sum().item()
                if p_star is not None:
                    support = p_star.gt(0).sum(dim=1)
                    support = support.masked_select(non_pad).sum().item()
                    stats["support"] += support
                else:
                    stats["support"] += n_words * logits.size(-1)

                stats["loss"] += batch_loss
                stats["n_words"] += n_words
                stats["n_correct"] += n_correct

            self.model.train()

            return stats