Example #1
0
    def forward(self,seqs_tensor:PackedSequence)->tuple:
        unsort_idx = seqs_tensor.unsorted_indices 
        if self.cuda:
            seqs_tensor.cuda()

        # lstm_output : (batch, seq_len, num_directions * hidden_size)
        # hn : (num_layers * num_directions, batch, hidden_size)
        # cn : (num_layers * num_directions, batch, hidden_size)
        lstm_output,(hn,cn) = self.rnn(seqs_tensor)
        # with pad_packed_sequence, no need to unsort lstm_output
        hn = hn.index_select(1,unsort_idx)
        cn = cn.index_select(1,unsort_idx)
        hn = hn.view(hn.size()[1],-1)
        cn = cn.view(cn.size()[1],-1)
        hn = self.dropout(hn)
        cn = self.dropout(cn)
        lstm_output = PackedSequence(self.dropout(lstm_output.data),lstm_output.batch_sizes,lstm_output.sorted_indices,lstm_output.unsorted_indices)# dropout on the .data
        out_tensor,lengths = pad_packed_sequence(lstm_output,batch_first = True)
        return (out_tensor,lengths),hn,cn
Example #2
0
    def translate(self, input_sentences, target_priming=None):
        """input_sentences is either a string or list of strings"""
        if isinstance(input_sentences, list):
            flatten = False
        else:
            input_sentences = [input_sentences]
            flatten = True
        batch = len(input_sentences)
        src_tok = [
            self.src_tok.tokenize(sentence,
                                  insert_start=self.insert_src_start,
                                  insert_end=self.insert_src_end)
            for sentence in input_sentences
        ]
        if target_priming is None:
            bos = [self.insert_target_start] * batch
        else:
            if isinstance(target_priming, list):
                bos = [
                    list(
                        self.target_tok.tokenize(
                            priming, insert_start=self.insert_target_start))
                    for priming in target_priming
                ]
            else:
                bos = self.target_tok.tokenize(
                    target_priming, insert_start=self.insert_target_start)
                bos = [list(bos)] * batch

        order = range(batch)
        if self.pack_encoder_inputs:
            # sort by the first set
            sorted_idx, src_tok = zip(*sorted(
                enumerate(src_tok), key=lambda x: x[1].numel(), reverse=True))
            order = [sorted_idx.index(i) for i in order]

        src = batch_sequences(src_tok,
                              sort=False,
                              pack=self.pack_encoder_inputs,
                              batch_first=self.model.encoder.batch_first)[0]

        # Allow packed source sequences - for cudnn rnns
        if isinstance(src, PackedSequence):
            src_var = Variable(src[0].cuda() if self.cuda else src[0],
                               volatile=True)
            src = PackedSequence(src_var, src[1])
        elif self.cuda:
            src = Variable(src.cuda() if self.cuda else src, volatile=True)
        context = self.model.encode(src)

        if hasattr(self.model, 'bridge'):
            state = self.model.bridge(context)

        state_list = [state[idx] for idx in order]
        seqs = self.generator.beam_search(bos, state_list)
        # remove forced  tokens
        preds = [s.sentence[len(self.insert_target_start):] for s in seqs]
        output = [self.target_tok.detokenize(p[:-1]) for p in preds]

        output = output[0] if flatten else output
        if self.get_attention:
            attentions = [s.attention for s in seqs]
            # if target_priming is not None:
            # preds = [preds[b][-len(attentions[b]):] for b in range(batch)]
            attentions = attentions[0] if flatten else attentions

            preds = [[self.target_tok.idx2word(idx) for idx in p]
                     for p in preds]
            preds = preds[0] if flatten else preds
            src = [[self.src_tok.idx2word(idx) for idx in list(s)]
                   for s in src_tok]
            src = src[0] if flatten else src
            return output, (attentions, src, preds)
        else:
            return output