def generate(self, input_list, state_list, k=1, feed_all_timesteps=False, get_attention=False):
        # assert isinstance(input_list, list) or isinstance(input_list, tuple)
        # assert isinstance(input_list[0], list) or isinstance(
            # input_list[0], tuple)

        view_shape = (-1, 1) if self.decoder.batch_first else (1, -1)
        time_dim = 1 if self.decoder.batch_first else 0

        # For recurrent models, the last input frame is all we care about,
        # use feed_all_timesteps whenever the whole input needs to be fed
        if feed_all_timesteps:
            inputs = [torch.LongTensor(inp) for inp in input_list]
            inputs = batch_sequences(
                inputs, batch_first=self.decoder.batch_first)[0]
        else:
            inputs = torch.LongTensor(
                [inputs[-1] for inputs in input_list]).view(*view_shape)

        inputs_var = Variable(inputs, volatile=True)
        if next(self.decoder.parameters()).is_cuda:
            inputs_var = inputs_var.cuda()

        states = State().from_list(state_list)
        logits, new_states = self.decode(
            inputs_var, states, get_attention=get_attention)
        # use only last prediction
        logits = logits.select(time_dim, -1).contiguous()
        logprobs = log_softmax(logits.view(-1, logits.size(-1)))
        logprobs, words = logprobs.data.topk(k, 1)
        new_states_list = [new_states[i] for i in range(len(input_list))]
        return words, logprobs, new_states_list
示例#2
0
    def _decode_step(self, input_list, state_list, k=1,
                     feed_all_timesteps=False,
                     remove_unknown=False,
                     get_attention=False,
                     device_ids=None):

        view_shape = (-1, 1) if self.decoder.batch_first else (1, -1)
        time_dim = 1 if self.decoder.batch_first else 0
        device = next(self.decoder.parameters()).device

        # For recurrent models, the last input frame is all we care about,
        # use feed_all_timesteps whenever the whole input needs to be fed
        if feed_all_timesteps:
            inputs = [torch.tensor(inp, device=device, dtype=torch.long)
                      for inp in input_list]
            inputs = batch_sequences(
                inputs, device=device, batch_first=self.decoder.batch_first)[0]

        else:
            last_tokens = [inputs[-1] for inputs in input_list]
            inputs = torch.stack(last_tokens).view(*view_shape)

        states = State().from_list(state_list)
        logits, new_states = self.decode(inputs, states,
                                         get_attention=get_attention,
                                         device_ids=device_ids)
        # use only last prediction
        logits = logits.select(time_dim, -1).contiguous()
        if remove_unknown:
            # Remove possibility of unknown
            logits[:, UNK].fill_(-float('inf'))
        logprobs = log_softmax(logits, dim=1)
        logprobs, words = logprobs.topk(k, 1)
        new_states_list = [new_states[i] for i in range(len(input_list))]
        return words, logprobs, new_states_list
示例#3
0
    def _decode_step(self, input_list, state_list, k=1,
                     feed_all_timesteps=False,
                     remove_unknown=False,
                     get_attention=False):

        view_shape = (-1, 1) if self.decoder.batch_first else (1, -1)
        time_dim = 1 if self.decoder.batch_first else 0
        device = next(self.decoder.parameters()).device

        # For recurrent models, the last input frame is all we care about,
        # use feed_all_timesteps whenever the whole input needs to be fed
        if feed_all_timesteps:
            inputs = [torch.tensor(inp, device=device, dtype=torch.long)
                      for inp in input_list]
            inputs = batch_sequences(
                inputs, device=device, batch_first=self.decoder.batch_first)[0]

        else:
            last_tokens = [inputs[-1] for inputs in input_list]
            inputs = torch.stack(last_tokens).view(*view_shape)

        states = State().from_list(state_list)
        logits, new_states = self.decode(
            inputs, states, get_attention=get_attention)
        # use only last prediction
        logits = logits.select(time_dim, -1).contiguous()
        if remove_unknown:
            # Remove possibility of unknown
            logits[:, UNK].fill_(-float('inf'))
        logprobs = log_softmax(logits, dim=1)
        logprobs, words = logprobs.topk(k, 1)
        new_states_list = [new_states[i] for i in range(len(input_list))]
        return words, logprobs, new_states_list
 def collate(img_seq_tuple):
     if sort or pack:  # packing requires a sorted batch by length
         img_seq_tuple.sort(key=lambda p: len(p[1]), reverse=True)
     imgs, seqs = zip(*img_seq_tuple)
     imgs = torch.cat([img.unsqueeze(0) for img in imgs], 0)
     seq_tensor = batch_sequences(seqs, max_length=max_length,
                                  batch_first=batch_first,
                                  sort=False, pack=pack)
     return (imgs, seq_tensor)
示例#5
0
 def collate(img_seq_tuple):
     if sort or pack:  # packing requires a sorted batch by length
         img_seq_tuple.sort(key=lambda p: len(p[1]), reverse=True)
     imgs, seqs = zip(*img_seq_tuple)
     imgs = torch.cat([img.unsqueeze(0) for img in imgs], 0)
     seq_tensor = batch_sequences(seqs, max_length=max_length,
                                  max_tokens=max_tokens,
                                  batch_first=batch_first,
                                  sort=False, pack=pack, augment=augment)
     return (imgs, seq_tensor)
 def collate(seqs, sort=sort, pack=pack):
     if not torch.is_tensor(seqs[0]):
         if sort or pack:  # packing requires a sorted batch by length
             # sort by the first set
             seqs.sort(key=lambda x: len(x[0]), reverse=True)
         # TODO: for now, just the first input will be packed
         return tuple([collate(s, sort=False, pack=pack and (i == 0))
                       for i, s in enumerate(zip(*seqs))])
     return batch_sequences(seqs, max_length=max_length,
                            max_tokens=max_tokens,
                            batch_first=batch_first,
                            sort=False, pack=pack, augment=augment)
示例#7
0
 def collate(img_seq_tuple):
     if sort or pack:  # packing requires a sorted batch by length
         img_seq_tuple.sort(key=lambda p: len(p[1]), reverse=True)
     imgs, seqs = zip(*img_seq_tuple)
     imgs = torch.stack(imgs, 0)
     seq_tensor = batch_sequences(seqs,
                                  max_length=max_length,
                                  max_tokens=max_tokens,
                                  batch_first=batch_first,
                                  sort=False,
                                  pack=pack,
                                  augment=augment)
     return (imgs, seq_tensor)
示例#8
0
 def get_loader(self, sort=False, pack=False,
                batch_size=1, shuffle=False, sampler=None, num_workers=0,
                max_length=None, batch_first=False, pin_memory=False, drop_last=False):
     collate_fn = lambda seqs: batch_sequences(seqs, max_length=max_length,
                                               batch_first=batch_first,
                                               sort=sort, pack=pack)
     return torch.utils.data.DataLoader(self,
                                        batch_size=batch_size,
                                        collate_fn=collate_fn,
                                        sampler=sampler,
                                        shuffle=shuffle,
                                        num_workers=num_workers,
                                        pin_memory=pin_memory,
                                        drop_last=drop_last)
示例#9
0
 def get_loader(self, sort=False, pack=False,
                batch_size=1, shuffle=False, sampler=None, num_workers=0,
                max_length=None, batch_first=False, pin_memory=False, drop_last=False):
     collate_fn = lambda seqs: batch_sequences(seqs, max_length=max_length,
                                               batch_first=batch_first,
                                               sort=sort, pack=pack)
     return torch.utils.data.DataLoader(self,
                                        batch_size=batch_size,
                                        collate_fn=collate_fn,
                                        sampler=sampler,
                                        shuffle=shuffle,
                                        num_workers=num_workers,
                                        pin_memory=pin_memory,
                                        drop_last=drop_last)
示例#10
0
    def translate(self, input_sentences, target_priming=None):
        """input_sentences is either a string or list of strings"""
        if isinstance(input_sentences, list):
            flatten = False
        else:
            input_sentences = [input_sentences]
            flatten = True
        batch = len(input_sentences)
        src_tok = [self.src_tok.tokenize(sentence,
                                         insert_start=self.insert_src_start,
                                         insert_end=self.insert_src_end)
                   for sentence in input_sentences]

        order = range(batch)
        if self.pack_encoder_inputs:
            # sort by the first set
            sorted_idx, src_tok = zip(*sorted(
                enumerate(src_tok), key=lambda x: x[1].numel(), reverse=True))
            order = [sorted_idx.index(i) for i in order]

        if target_priming is None:
            bos = [self.insert_target_start] * batch
        else:
            if isinstance(target_priming, list):
                bos = [list(self.target_tok.tokenize(target_priming[i],
                                                     insert_start=self.insert_target_start))
                       for i in order]
            else:
                bos = self.target_tok.tokenize(target_priming,
                                               insert_start=self.insert_target_start)
                bos = [list(bos)] * batch

        src = batch_sequences(src_tok,
                              sort=False,
                              pack=self.pack_encoder_inputs,
                              device=self.device,
                              batch_first=self.model.encoder.batch_first)[0]

        with torch.no_grad():
            seqs = self.model.generate(src, bos,
                                       beam_size=self.beam_size,
                                       max_sequence_length=self.max_sequence_length,
                                       length_normalization_factor=self.length_normalization_factor,
                                       get_attention=self.get_attention, devices=None)
        # remove forced  tokens
        preds = [s.output[len(self.insert_target_start):] for s in seqs]
        output = [self.target_tok.detokenize(p[:-1]) for p in preds]

        output = output[0] if flatten else output
        if self.get_attention:
            attentions = [s.attention for s in seqs]
            # if target_priming is not None:
            # preds = [preds[b][-len(attentions[b]):] for b in range(batch)]
            attentions = attentions[0] if flatten else attentions

            preds = [[self.target_tok.idx2word(
                idx) for idx in p] for p in preds]
            preds = preds[0] if flatten else preds
            src = [[self.src_tok.idx2word(idx)
                    for idx in list(s)] for s in src_tok]
            src = src[0] if flatten else src
            return output, (attentions, src, preds)
        else:
            return output
示例#11
0
import torch
from seq2seq.tools import batch_sequences

s1 = torch.LongTensor([1,2,3,4,5,6])
s2 = torch.LongTensor([10,20,30])

seqs = [s1,s2]
batch = batch_sequences(seqs, max_length=4, augment=True)
print(batch)
示例#12
0
    def translate(self, input_sentences, target_priming=None):
        """input_sentences is either a string or list of strings"""
        if isinstance(input_sentences, list):
            flatten = False
        else:
            input_sentences = [input_sentences]
            flatten = True
        batch = len(input_sentences)
        src_tok = [
            self.src_tok.tokenize(sentence,
                                  insert_start=self.insert_src_start,
                                  insert_end=self.insert_src_end)
            for sentence in input_sentences
        ]
        if target_priming is None:
            bos = [self.insert_target_start] * batch
        else:
            if isinstance(target_priming, list):
                bos = [
                    list(
                        self.target_tok.tokenize(
                            priming, insert_start=self.insert_target_start))
                    for priming in target_priming
                ]
            else:
                bos = self.target_tok.tokenize(
                    target_priming, insert_start=self.insert_target_start)
                bos = [list(bos)] * batch

        order = range(batch)
        if self.pack_encoder_inputs:
            # sort by the first set
            sorted_idx, src_tok = zip(*sorted(
                enumerate(src_tok), key=lambda x: x[1].numel(), reverse=True))
            order = [sorted_idx.index(i) for i in order]

        src = batch_sequences(src_tok,
                              sort=False,
                              pack=self.pack_encoder_inputs,
                              batch_first=self.model.encoder.batch_first)[0]

        # Allow packed source sequences - for cudnn rnns
        if isinstance(src, PackedSequence):
            src_var = Variable(src[0].cuda() if self.cuda else src[0],
                               volatile=True)
            src = PackedSequence(src_var, src[1])
        elif self.cuda:
            src = Variable(src.cuda() if self.cuda else src, volatile=True)
        context = self.model.encode(src)

        if hasattr(self.model, 'bridge'):
            state = self.model.bridge(context)

        state_list = [state[idx] for idx in order]
        seqs = self.generator.beam_search(bos, state_list)
        # remove forced  tokens
        preds = [s.sentence[len(self.insert_target_start):] for s in seqs]
        output = [self.target_tok.detokenize(p[:-1]) for p in preds]

        output = output[0] if flatten else output
        if self.get_attention:
            attentions = [s.attention for s in seqs]
            # if target_priming is not None:
            # preds = [preds[b][-len(attentions[b]):] for b in range(batch)]
            attentions = attentions[0] if flatten else attentions

            preds = [[self.target_tok.idx2word(idx) for idx in p]
                     for p in preds]
            preds = preds[0] if flatten else preds
            src = [[self.src_tok.idx2word(idx) for idx in list(s)]
                   for s in src_tok]
            src = src[0] if flatten else src
            return output, (attentions, src, preds)
        else:
            return output
示例#13
0
 def collate_fn(seqs):
     return batch_sequences(seqs,
                            max_length=max_length,
                            batch_first=batch_first,
                            sort=sort,
                            pack=pack)
示例#14
0
    def translate(self, input_sentences, target_priming=None):
        """input_sentences is either a string or list of strings"""
        if isinstance(input_sentences, list):
            flatten = False
        else:
            input_sentences = [input_sentences]
            flatten = True
        batch = len(input_sentences)
        src_tok = [
            self.src_tok.tokenize(sentence,
                                  insert_start=self.insert_src_start,
                                  insert_end=self.insert_src_end)
            for sentence in input_sentences
        ]
        if target_priming is None:
            bos = [self.insert_target_start] * batch
        else:
            if isinstance(target_priming, list):
                bos = [
                    list(
                        self.target_tok.tokenize(
                            priming, insert_start=self.insert_target_start))
                    for priming in target_priming
                ]
            else:
                bos = self.target_tok.tokenize(
                    target_priming, insert_start=self.insert_target_start)
                bos = [list(bos)] * batch

        src = Variable(batch_sequences(
            src_tok, batch_first=self.model.encoder.batch_first)[0],
                       volatile=True)
        if self.cuda:
            src = src.cuda()

        context = self.model.encode(src)
        if hasattr(self.model, 'bridge'):
            state = self.model.bridge(context)
        state_list = [state[i] for i in range(batch)]

        seqs = self.generator.beam_search(bos, state_list)
        # remove forced  tokens
        preds = [s.sentence[len(self.insert_target_start):] for s in seqs]
        output = [self.target_tok.detokenize(p[:-1]) for p in preds]

        output = output[0] if flatten else output
        if self.get_attention:
            attentions = [s.attention for s in seqs]
            # if target_priming is not None:
            # preds = [preds[b][-len(attentions[b]):] for b in range(batch)]
            attentions = attentions[0] if flatten else attentions

            preds = [[self.target_tok.idx2word(idx) for idx in p]
                     for p in preds]
            preds = preds[0] if flatten else preds
            src = [[self.src_tok.idx2word(idx) for idx in list(s)]
                   for s in src_tok]
            src = src[0] if flatten else src
            return output, (attentions, src, preds)
        else:
            return output
示例#15
0
    def translate(self, input_sentences, target_priming=None):
        """input_sentences is either a string or list of strings"""
        if isinstance(input_sentences, list):
            flatten = False
        else:
            input_sentences = [input_sentences]
            flatten = True
        batch = len(input_sentences)
        src_tok = [self.src_tok.tokenize(sentence,
                                         insert_start=self.insert_src_start,
                                         insert_end=self.insert_src_end)
                   for sentence in input_sentences]

        order = range(batch)
        if self.pack_encoder_inputs:
            # sort by the first set
            sorted_idx, src_tok = zip(*sorted(
                enumerate(src_tok), key=lambda x: x[1].numel(), reverse=True))
            order = [sorted_idx.index(i) for i in order]

        if target_priming is None:
            bos = [self.insert_target_start] * batch
        else:
            if isinstance(target_priming, list):
                bos = [list(self.target_tok.tokenize(target_priming[i],
                                                     insert_start=self.insert_target_start))
                       for i in order]
            else:
                bos = self.target_tok.tokenize(target_priming,
                                               insert_start=self.insert_target_start)
                bos = [list(bos)] * batch

        src = batch_sequences(src_tok,
                              sort=False,
                              pack=self.pack_encoder_inputs,
                              device=self.device,
                              batch_first=self.model.encoder.batch_first)[0]

        with torch.no_grad():
            seqs = self.model.generate(src, bos,
                                       beam_size=self.beam_size,
                                       max_sequence_length=self.max_sequence_length,
                                       length_normalization_factor=self.length_normalization_factor,
                                       get_attention=self.get_attention, devices=None)
        # remove forced  tokens
        preds = [s.output[len(self.insert_target_start):] for s in seqs]
        output = [self.target_tok.detokenize(p[:-1]) for p in preds]

        output = output[0] if flatten else output
        if self.get_attention:
            attentions = [s.attention for s in seqs]
            # if target_priming is not None:
            # preds = [preds[b][-len(attentions[b]):] for b in range(batch)]
            attentions = attentions[0] if flatten else attentions

            preds = [[self.target_tok.idx2word(
                idx) for idx in p] for p in preds]
            preds = preds[0] if flatten else preds
            src = [[self.src_tok.idx2word(idx)
                    for idx in list(s)] for s in src_tok]
            src = src[0] if flatten else src
            return output, (attentions, src, preds)
        else:
            return output