示例#1
0
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                batch = Batch(minibatch, self.dataset, self.device)
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1

                seq_len, time = self.__series_2_bptt(batch)
                dataset = Dataset(examples=self.dataset.examples, fields=[
                    ('time', self.dataset.fields['time'])])

                yield (Batch.fromvars(
                        dataset, self.batch_size,
                        time=(ti, l)) for ti, l in zip(time, seq_len))

                if not self.repeat:
                    return
示例#2
0
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                batch = Batch(minibatch, self.dataset, self.device)
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                # should we have many batches or we should have one long batch with many windows
                seq_len, text, time = self.__series_2_bptt(batch)
                dataset = Dataset(examples=self.dataset.examples, fields=[
                    ('time', self.dataset.fields['time']), ('text', self.dataset.fields['text']),
                    ('target_time', Field(use_vocab=False)), ('target_text', Field(use_vocab=False))])
                yield (Batch.fromvars(
                        dataset, self.batch_size,
                        time=(ti[:, :, :2], l),
                        text=te[:, :, 0],
                        target_time=ti[:, :, -1],
                        target_text=te[:, :, 1]) for ti, te, l in zip(time, text, seq_len))

            if not self.repeat:
                return
def get_interactive_batch(args, vid, question, iters):
    batch = []
    for key, it in iters.items():
        for ex in it.it.dataset.examples:
            if ex.vid == vid:
                batch.append((key, ex))
    tokenizer = get_tokenizer(args)
    ex = batch[check_que_sim(tokenizer, question, batch)]
    key, ex = ex
    ex = Batch([ex], iters[key].it.dataset)
    ex.images = iters[key].get_image(ex.vid)

    return ex
示例#4
0
    def next_example(self, example=None):
        # the mini-batch only contains one example due to batch_size == 1
        if example is None:
            self.raw_example_batch = next(self.env_data_iter)
        else:
            self.raw_example_batch = Batch(data=[example],
                                           dataset=self.train_set,
                                           device=self.device)

        # reserve example ids to track back to data sources
        self.raw_example_ids = self.raw_example_batch.Id.tolist()

        # get raw prediction probability
        _, input_batch, input_lengths, _ = self.prepare_relation_mini_batch(
            self.device, self.raw_example_batch)
        self.net.eval()
        with torch.no_grad():
            input_emb_seq, pred_logp = self.net.get_init_state_info(
                input_batch, input_lengths)

        # set input embedding (word embedding + position embedding)
        self.input_emb_seq = input_emb_seq
        self.input_seq_lens = input_lengths

        # set raw prediction log-probability for all classes
        self.raw_tgt_logp = pred_logp[:, 1]
        self.raw_pred_logps = pred_logp.tolist()
示例#5
0
    def __iter__(self):
        text = self.dataset[0].src
        TEXT = self.dataset.fields['src']
        TEXT.eos_token = None
        text = text + ([TEXT.pad_token] * int(
            math.ceil(len(text) / self.batch_size) * self.batch_size -
            len(text)))
        data = TEXT.numericalize([text], device=self.device)
        data = data.view(self.batch_size, -1).t().contiguous()
        dataset = Dataset(examples=self.dataset.examples,
                          fields=[('src', TEXT), ('tgt', TEXT)])

        while True:
            for i in range(0, len(self) * self.bptt_len, self.bptt_len):
                self.iterations += 1
                seq_len = min(self.bptt_len, len(data) - i - 1)
                batch_text = data[i:i + seq_len]
                batch_target = data[i + 1:i + 1 + seq_len]
                if TEXT.batch_first:
                    batch_text = batch_text.t().contiguous()
                    batch_target = batch_target.t().contiguous()
                yield Batch.fromvars(dataset,
                                     self.batch_size,
                                     src=batch_text,
                                     tgt=batch_target)
            if not self.repeat:
                return
示例#6
0
 def __text_bptt_2_minibatches(self, batch, batch_size, dataset):
     time, time_len, text, text_len = self.__text_series_2_bptt(batch, batch_size)
     yield (Batch.fromvars(dataset, batch_size,
                           time=(ti[:, :, :2], ti_l),
                           text=(te, te_l),
                           target_time=ti[:, :, -1])
            for ti, ti_l, te, te_l in zip(time, time_len, text, text_len))
示例#7
0
 def __bow_bptt_2_minibatches(self, batch, batch_size, dataset):
     time, time_len, bow = self.__bow_series_2_bptt(batch, batch_size)
     yield (Batch.fromvars(dataset, batch_size,
                           time=(t[:, :, :2], l),
                           bow=b[:, :, 0],
                           target_time=t[:, :, -1],
                           target_bow=b[:, :, -1])
            for t, l, b in zip(time, time_len, bow))
示例#8
0
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                batch = Batch(minibatch, self.dataset, self.device)
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                # should we have many batches or we should have one long batch with many windows
                batch_size: int = len(batch)

                dataset = Dataset(examples=self.dataset.examples, fields=self.fields)
                if self.train:
                    yield from self.__bptt_2_minibatch(batch, batch_size, dataset)
                else:
                    if 'bow' in dataset.fields.keys():
                        batch.target_bow = batch.bow[:, :, 1]
                        batch.bow = batch.bow[:, :, 0]
                    else:
                        batch.text = (batch.text[0], batch.text[2])
                    batch.target_time = batch.time[0][:, :, -1]
                    batch.time = (batch.time[0][:, :, :2], batch.time[1])
                    yield batch

            if not self.repeat:
                return
示例#9
0
 def __iter__(self):
     text_data, valid_noise, dataset = self.__generate_text_data()
     for b in range(len(self)):
         start, end, batch_noised, batch_clean = self.__generate_batch(
             b, len(self.dataset[0].text), text_data, valid_noise)
         yield Batch.fromvars(dataset,
                              end - start,
                              noised=batch_noised,
                              clean=batch_clean)
示例#10
0
def create_homogenous_batches(dataset: Dataset, max_batch_size: int,
                              key_fn=lambda ex: -len(ex.src),
                              filter_fn=lambda ex: len(ex.src) > 0,
                              sort: bool = True) -> List[Batch]:
    """
    Creates a list of batches such that for each batch b it holds that:
        - b contains at least one and at most max_batch_size examples
        - for any two examples e1, e2 in b: key_fn(e1) == key_fn(e2)
        - b does not contain any example e for which filter_fn(e) == False
    In addition, the batches b are sorted by (increasing) key_fn(e) for any e in b.

    Args:
        dataset: the dataset to take the batches from
        max_batch_size: how many examples one batch may contain at the most
        key_fn: function of type (Example) -> int, that is used to sort the batches. Each batch will only have examples
            that all have exactly the same key (e.g. source sentence length).
        filter_fn: function of type (Example) -> bool, that is used to filter the examples. No example e with
            filter_fn(e) == False will be contained in any batch.
        sort: whether or not to sort the examples (by the given key_fn)

    Returns: a list of batches with the above properties
    """
    sorted_examples = sorted(dataset.examples, key=key_fn) if sort else dataset.examples

    same_key_blocks = []
    previous_key = -1
    current_block = []

    for example in sorted_examples:
        if not filter_fn(example):
            continue

        key = key_fn(example)
        if previous_key == -1 or key != previous_key:
            previous_key = key
            # start a new block
            if len(current_block) > 0:
                same_key_blocks.append(current_block)
                current_block = []
        # append current example to corresponding block
        current_block.append(example)

    # append last block
    if len(current_block) > 0:
        same_key_blocks.append(current_block)

    # split up blocks in batches of size at most max_batch_size
    batches = []
    for block in same_key_blocks:
        i = 0
        while i < len(block):
            # take the next at most max_batch_size examples from this block
            data = block[i:i + max_batch_size]
            batches.append(Batch(data=data, dataset=dataset))
            i += len(data)

    return batches
示例#11
0
 def get_single_rnd_batch(self):
     text_data, valid_noise, dataset = self.__generate_text_data()
     b = rnd.randint(0, len(self))
     start, end, batch_noised, batch_clean = self.__generate_batch(
         b, len(self.dataset[0].text), text_data, valid_noise)
     return Batch.fromvars(dataset,
                           end - start,
                           noised=batch_noised,
                           clean=batch_clean)
示例#12
0
 def get_single_rnd_batch(self):
     b = self.noised_win_iter.get_single_rnd_batch()
     middle_start, middle_end = self.__compute_middle_range()
     holed_noised, clean_middle = self.__hole_batch(b, middle_start,
                                                    middle_end)
     return Batch.fromvars(b.dataset,
                           b.batch_size,
                           noised=holed_noised,
                           clean=clean_middle)
示例#13
0
 def __iter__(self):
     middle_start, middle_end = self.__compute_middle_range()
     for b in self.noised_win_iter:
         holed_noised, clean_middle = self.__hole_batch(
             b, middle_start, middle_end)
         yield Batch.fromvars(b.dataset,
                              b.batch_size,
                              noised=holed_noised,
                              clean=clean_middle)
示例#14
0
    def __iter__(self):
        text = self.dataset[0].text
        TEXT = self.dataset.fields['text']
        CHARS = self.dataset.fields['chars']
        TEXT.eos_token = None
        text = text + ([TEXT.pad_token] * int(
            math.ceil(len(text) / self.batch_size) * self.batch_size -
            len(text)))
        data = TEXT.numericalize([text], device=self.device, train=self.train)

        char_data = [
            CHARS.numericalize([t], device=self.device, train=self.train)
            for t in text
        ]

        max_chars = max([c.shape[0] for c in char_data])

        pad_idx = CHARS.vocab.stoi['<pad>']

        for i, c in enumerate(char_data):
            if c.shape[0] < max_chars:
                padding = Variable(
                    torch.LongTensor(max_chars - c.shape[0],
                                     1).fill_(pad_idx).cuda())
                char_data[i] = torch.cat((c, padding), dim=0)

        char_data = torch.stack(char_data)

        print(char_data.shape)

        char_data = char_data.view(self.batch_size, -1,
                                   max_chars).permute(1, 0, 2).contiguous()

        print(char_data.shape)

        data = data.view(self.batch_size, -1).t().contiguous()

        dataset = Dataset(examples=self.dataset.examples,
                          fields=[('text', TEXT), ('chars', CHARS),
                                  ('target', TEXT)])

        while True:
            for i in range(0, len(self) * self.bptt_len, self.bptt_len):
                self.iterations += 1
                seq_len = min(self.bptt_len, len(data) - i - 1)
                yield Batch.fromvars(dataset,
                                     self.batch_size,
                                     train=self.train,
                                     text=data[i:i + seq_len],
                                     chars=char_data[i:i + seq_len],
                                     target=data[i + 1:i + 1 + seq_len])
            if not self.repeat:
                return
示例#15
0
 def __iter__(self):
     while True:
         self.init_epoch()
         for idx, minibatch in enumerate(self.batches):
             # fast-forward if loaded from state
             if self._iterations_this_epoch > idx:
                 continue
             self.iterations += 1
             self._iterations_this_epoch += 1
             if self.sort_within_batch:
                 # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                 # be sorted by decreasing order, which requires reversing
                 # relative to typical sort keys
                 if self.sort:
                     minibatch.reverse()
                 else:
                     minibatch.sort(key=self.sort_key, reverse=True)
             yield minibatch, Batch(minibatch, self.dataset, self.device)
         if not self.repeat:
             return
示例#16
0
    def __iter__(self):
        TEXT = self.dataset.fields['text']
        TEXT.eos_token = None

        while True:
            for _ in range(
                    int(self.dataset.numberOfTokens / self.bsz /
                        self.bptt_len)):
                text = self.dataset.text.__next__()

                data = TEXT.numericalize([text], device=self.device)
                data = data.view(self.bsz, -1).t().contiguous()
                dataset = Dataset(examples=self.dataset.examples,
                                  fields=[('text', TEXT), ('target', TEXT)])
                yield Batch.fromvars(dataset,
                                     self.bsz,
                                     train=self.train,
                                     text=data[:, :-1],
                                     target=data[:, 1:])
            if not self.repeat:
                return
示例#17
0
def corrupt(batch, config, inputs, eos_id=2, pad_id=1, noise_level=0.5):
    #print(batch.hypothesis.size())
    #Create random tensor of ints from inputs.vocab
    #Sample random tensor of 0-1, with probability given by corruption factor
    #mix batch.hypothesis with this random tensor [Except at EOS and Pad Positions]
    #Return an updated "batch" object with batch.hypothesis replaced by this new hypothesis. batch.label remains the same
    is_eos_or_pad_mask = (batch.hypothesis == eos_id).long() + (
        batch.hypothesis == pad_id).long()
    bernoulli_mask = (torch.rand(batch.hypothesis.size(),
                                 device=batch.hypothesis.device) <
                      noise_level).long()
    if config.noise_type == "uniform":
        replacement_indices = torch.randint(
            low=3,
            high=len(inputs.vocab),
            size=tuple(batch.hypothesis.size()),
            dtype=batch.hypothesis.dtype,
            device=batch.hypothesis.device).long()
    elif config.noise_type == "unigram" or config.noise_type == "uniroot":
        replacement_indices = torch.multinomial(
            config.freq_list_tensor.unsqueeze(dim=0).expand(
                (batch.hypothesis.size()[0], -1)),
            batch.hypothesis.size()[1],
            replacement=True) + 3
        device = torch.device('cuda:{}'.format(config.gpu))
        replacement_indices = replacement_indices.to(device)
    corrupted_hypothesis = (
        1 - bernoulli_mask
    ) * batch.hypothesis + bernoulli_mask * replacement_indices
    corrupted_hypothesis = is_eos_or_pad_mask * batch.hypothesis + (
        1 - is_eos_or_pad_mask) * corrupted_hypothesis
    #print(sum(sum(corrupted_hypothesis!=batch.hypothesis)))
    #print(type(batch))
    corrupted_batch = Batch.fromvars(batch.dataset, batch.batch_size)
    #Create and return corrupted batch using the corrupted hypothesis
    corrupted_batch.label = batch.label
    corrupted_batch.hypothesis = corrupted_hypothesis
    return corrupted_batch
示例#18
0
 def __iter__(self):
     valid_noise = torch.tensor(list(
         range(2, len(self.dataset.fields['text'].vocab))),
                                dtype=torch.long,
                                device=self.device)
     text_field = self.dataset.fields['text']
     dataset = Dataset(examples=self.dataset.examples,
                       fields=[('noised', text_field),
                               ('clean', text_field)])
     if self.shuffle:
         shuffled_examples = [
             self.dataset.examples[i]
             for i in torch.randperm(len(self.dataset))
         ]
     else:
         shuffled_examples = self.dataset.examples
     for b in range(len(self)):
         actual_batch_size = min(self.batch_size,
                                 len(self.dataset) - b * self.batch_size)
         batch_clean = torch.zeros([actual_batch_size, self.window_size],
                                   dtype=torch.long,
                                   device=self.device)
         for i in range(actual_batch_size):
             example_index = b * self.batch_size + i
             batch_clean[i] = text_field.numericalize(
                 shuffled_examples[example_index].text).squeeze(1)
         if self.noise_ratio > 0:
             batch_noised = binary_noise_char_input(batch_clean,
                                                    valid_noise,
                                                    self.noise_ratio)
         else:
             batch_noised = batch_clean
         yield Batch.fromvars(dataset,
                              actual_batch_size,
                              noised=batch_noised,
                              clean=batch_clean)
示例#19
0
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                batch = Batch(minibatch, self.dataset, self.device)
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                batch_size: int = len(batch)
                dataset = Dataset(examples=self.dataset.examples, fields=[
                    ('time', self.dataset.fields['time']), ('mark', self.dataset.fields['mark']),
                    ('target_time', Field(use_vocab=False)), ('target_mark', Field(use_vocab=False))])
                if self.train:
                    seq_len, time, mark = self.__series_2_bptt(batch, batch_size)
                    yield (Batch.fromvars(
                            dataset, batch_size,
                            time=(ti[:, :, :2], l),
                            mark=m[:, :, 0],
                            target_time=ti[:, :, -1],
                            target_mark=m[:, :, -1]) for ti, l, m in zip(time, seq_len, mark))
                else:
                    batch.target_time = batch.time[0][:, :, 2]
                    batch.time = (batch.time[0][:, :, :2], batch.time[1])
                    batch.target_mark = batch.mark[:, :, 1]
                    batch.mark = batch.mark[:, :, 0]
                    yield batch

            if not self.repeat:
                return
示例#20
0
    def __iter__(self):
        with self._elastic.context():
            if self._elastic.skipdone():
                return

            self.batch_size = self._elastic._sync_local_bsz()

            text = self.dataset[0].text
            TEXT = self.dataset.fields['text']
            TEXT.eos_token = None
            text = text + ([TEXT.pad_token] * int(
                math.ceil(len(text) / self.batch_size) * self.batch_size -
                len(text)))
            data = TEXT.numericalize([text], device=self.device)
            data = data.view(self.batch_size, -1).t().contiguous()
            dataset = Dataset(examples=self.dataset.examples,
                              fields=[('text', TEXT), ('target', TEXT)])
            end = data.size(0)  # current length of dataset

            # Change in current batch size changes the dimensions of dataset
            # which changes the starting position in the reshaped dataset. The
            # local batch size is also a function of number of replicas, so
            # when we rescale we need to recalculate where to start the
            # iterations from for the new batch size.
            self._elastic.current_index = \
                self._recompute_start(self._elastic.current_index,
                                      self._elastic.end_index, end)
            self._elastic.end_index = end

            # Every replica reads data strided by bptt_len
            start = self._elastic.current_index + (self.bptt_len * self.rank)
            step = self.bptt_len * self.num_replicas

            # The starting index of the highest rank
            highest_start = self._elastic.current_index + \
                (self.bptt_len * (self.num_replicas - 1))

            # Number of steps we will take on the highest rank. We limit
            # iterations on all replicas by this number to prevent asymmetric
            # reduce ops which would result in a deadlock.
            min_steps_in_epoch = max(math.ceil((end - highest_start) / step),
                                     0)  # noqa: E501
            self.iterations = 0
            while True:
                for i in range(start, end, step):
                    self.iterations += 1
                    # Make sure that _elastic.profile is called equal number of
                    # times on all replicas
                    if self.iterations > min_steps_in_epoch:
                        break
                    with self._elastic.profile(self.training and i > 0):
                        seq_len = min(self.bptt_len, data.size(0) - i - 1)
                        assert seq_len > 0
                        batch_text = data[i:i + seq_len]
                        batch_target = data[i + 1:i + 1 + seq_len]
                        if TEXT.batch_first:
                            batch_text = batch_text.t().contiguous()
                            batch_target = batch_target.t().contiguous()
                        yield Batch.fromvars(dataset,
                                             self.batch_size,
                                             text=batch_text,
                                             target=batch_target)
                        self._elastic.current_index += step

                if not self.repeat:
                    break