def decode_training(self, x, y, gru_dec, enc_hidden, enc_annotations):
        batch_size = x.size(0)
        # decoder first hidden is last encoder hidden
        dec_hidden = enc_hidden

        # decoder first output (needed for attention) init to zeros
        dec_output = self.decoder.init_first_output(batch_size).float()

        # init first decoder input <sos>
        last_char_index = decoder.init_first_input_index(batch_size, self.sos_token).double()

        p_gen = decoder.init_first_p_gen(batch_size)

        first_prob = torch.zeros((batch_size, 1, self.vocabulary_size)).to(device)
        first_prob[:, 0, self.sos_token] = 1
        probabilities = [first_prob]

        max_len = y.size(1)
        for pos in range(1, max_len):
            # DecoderAndPointer
            att, dec_output, dec_hidden, prob, _, p_gen = self.decoder(last_char_index,
                                                                       dec_output[:, :, :gru_dec.hidden_size],
                                                                       # only fw
                                                                       dec_hidden, enc_annotations, x, p_gen,
                                                                       self.embedding, gru_dec)
            probabilities.append(prob)

            # find next char [Bx1x1]
            target = [ref[pos] for ref in y]
            last_char_index = torch.tensor(target).double().view(batch_size, 1, 1).to(device)
        return torch.stack(probabilities, dim=1).squeeze()
    def decode_eval(self, x, beam_size, alpha, gru_dec, enc_hidden, enc_annotations, max_len):
        # beam element: <likelihood, probabilities, last_index, attentions, dec_output, dec_hidden, p_gens>
        beam = [(torch.tensor([0.]).to(device),
                 [torch.tensor([0. if i != self.sos_token else 1. for i in range(self.vocabulary_size)]).to(device)],
                 self.sos_token,
                 [torch.zeros((1, x.size(1))).to(device)],
                 self.decoder.init_first_output(1).float(),
                 # decoder first hidden is last encoder hidden
                 enc_hidden,
                 [decoder.init_first_p_gen(1)])]
        for pos in range(1, max_len):
            new_elements = []
            for beam_elem in beam:
                # Check if this beam element is complete
                if beam_elem[2] == self.eos_token:
                    new_elements.append(beam_elem)
                    continue
                likelihood, probs, last_index, attentions, dec_output, dec_hidden, p_gens = beam_elem
                # DecoderAndPointer
                att, dec_output, dec_hidden, prob, _, p_gen = self.decoder(
                    torch.tensor([[[last_index]]], dtype=torch.float64).to(device),
                    dec_output[:, :, :gru_dec.hidden_size],
                    # only fw
                    dec_hidden, enc_annotations, x, p_gens[-1],
                    self.embedding, gru_dec)

                probs = probs + [prob.squeeze()]
                attentions = attentions + [att]
                p_gens = p_gens + [p_gen]

                # Expand beam
                best_probs, top_indices = prob.topk(beam_size, 2)
                top_indices = top_indices.squeeze().tolist() if beam_size > 1 else [top_indices.item()]

                for i in range(beam_size):
                    # Update list of ended sentences
                    next_char = top_indices[i]
                    new_elements.append((likelihood + best_probs[:, :, i].squeeze(), probs, next_char, attentions,
                                         dec_output, dec_hidden, p_gens))

            new_elements.sort(key=lambda elem: elem[0][0] / length_penalty(len(elem[1]), alpha), reverse=True)
            beam = new_elements[:beam_size]
        _, probabilities, _, attentions, _, _, p_gens = beam[0]
        probabilities = torch.stack(probabilities)
        attentions = torch.stack(attentions)
        p_gens = torch.stack(p_gens)
        return probabilities, attentions, p_gens
Example #3
0
    def forward(self, x, y=None):
        assert y is not None if self.training else y is None
        # Unpack input
        x, lengths = x
        batch_size = x.size(0)

        # encoder pass
        enc_annotations, annotations_len, enc_hidden = self.encoder(x, lengths)

        # decoder first hidden is last encoder hidden (using both forward and backward pass)
        fw_to_bw = enc_hidden.size(0) // 2
        dec_hidden = 0.5 * (enc_hidden[:fw_to_bw] + enc_hidden[fw_to_bw:]
                            )  # [num_dir x B x H]
        # dec_hidden = decoder.init_first_hidden(batch_size).float()

        # decoder first output (needed for attention) init to zeros
        dec_output = self.decoder.init_first_output(batch_size).float()

        # init first decoder input <sos>
        last_char_index = decoder.init_first_input_index(
            batch_size, self.sos_token).double()

        p_gen = decoder.init_first_p_gen(batch_size)

        first_prob = torch.zeros(
            (batch_size, 1, self.vocabulary_size)).to(device)
        first_prob[:, 0, self.sos_token] = 1
        probabilities = [first_prob]
        if self.training:
            max_len = y.size(1)
        else:
            max_len = self.max_string_length
            sentence_end = [False for _ in range(batch_size)]
            attentions = [torch.zeros((batch_size, x.size(1))).to(device)]
            p_gens = [p_gen[0]]
        for pos in range(1, max_len):
            # DecoderAndPointer
            att, dec_output, dec_hidden, prob, _, p_gen = self.decoder(
                last_char_index, dec_output, dec_hidden, enc_annotations, x,
                p_gen)
            probabilities.append(prob)

            # find next char [Bx1x1]
            if self.training:
                target = [ref[pos] for ref in y]
                last_char_index = torch.tensor(target).double().view(
                    batch_size, 1, 1).to(device)
            else:
                attentions.append(att)
                p_gens.append(p_gen.squeeze(2))

                last_char_index = prob.data.argmax(2).unsqueeze(2).double()

                # Update list of ended sentences
                last_char_list = last_char_index.view(-1).tolist()
                sentence_end_iter = [
                    i for i, ch in enumerate(last_char_list)
                    if ch == self.eos_token
                ]
                for i in sentence_end_iter:
                    sentence_end[i] = True

                if all(sentence_end):
                    break
        if self.training:
            return torch.stack(probabilities, dim=1).squeeze()
        else:
            return torch.stack(
                probabilities,
                dim=1).squeeze(), torch.stack(attentions), torch.stack(p_gens)