예제 #1
0
파일: iterator.py 프로젝트: cesarali/Tyche
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                batch = Batch(minibatch, self.dataset, self.device)
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                # should we have many batches or we should have one long batch with many windows
                seq_len, text, time = self.__series_2_bptt(batch)
                dataset = Dataset(examples=self.dataset.examples, fields=[
                    ('time', self.dataset.fields['time']), ('text', self.dataset.fields['text']),
                    ('target_time', Field(use_vocab=False)), ('target_text', Field(use_vocab=False))])
                yield (Batch.fromvars(
                        dataset, self.batch_size,
                        time=(ti[:, :, :2], l),
                        text=te[:, :, 0],
                        target_time=ti[:, :, -1],
                        target_text=te[:, :, 1]) for ti, te, l in zip(time, text, seq_len))

            if not self.repeat:
                return
예제 #2
0
파일: iterator.py 프로젝트: cesarali/Tyche
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                batch = Batch(minibatch, self.dataset, self.device)
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1

                seq_len, time = self.__series_2_bptt(batch)
                dataset = Dataset(examples=self.dataset.examples, fields=[
                    ('time', self.dataset.fields['time'])])

                yield (Batch.fromvars(
                        dataset, self.batch_size,
                        time=(ti, l)) for ti, l in zip(time, seq_len))

                if not self.repeat:
                    return
예제 #3
0
    def __iter__(self):
        text = self.dataset[0].src
        TEXT = self.dataset.fields['src']
        TEXT.eos_token = None
        text = text + ([TEXT.pad_token] * int(
            math.ceil(len(text) / self.batch_size) * self.batch_size -
            len(text)))
        data = TEXT.numericalize([text], device=self.device)
        data = data.view(self.batch_size, -1).t().contiguous()
        dataset = Dataset(examples=self.dataset.examples,
                          fields=[('src', TEXT), ('tgt', TEXT)])

        while True:
            for i in range(0, len(self) * self.bptt_len, self.bptt_len):
                self.iterations += 1
                seq_len = min(self.bptt_len, len(data) - i - 1)
                batch_text = data[i:i + seq_len]
                batch_target = data[i + 1:i + 1 + seq_len]
                if TEXT.batch_first:
                    batch_text = batch_text.t().contiguous()
                    batch_target = batch_target.t().contiguous()
                yield Batch.fromvars(dataset,
                                     self.batch_size,
                                     src=batch_text,
                                     tgt=batch_target)
            if not self.repeat:
                return
예제 #4
0
 def __text_bptt_2_minibatches(self, batch, batch_size, dataset):
     time, time_len, text, text_len = self.__text_series_2_bptt(batch, batch_size)
     yield (Batch.fromvars(dataset, batch_size,
                           time=(ti[:, :, :2], ti_l),
                           text=(te, te_l),
                           target_time=ti[:, :, -1])
            for ti, ti_l, te, te_l in zip(time, time_len, text, text_len))
예제 #5
0
 def __bow_bptt_2_minibatches(self, batch, batch_size, dataset):
     time, time_len, bow = self.__bow_series_2_bptt(batch, batch_size)
     yield (Batch.fromvars(dataset, batch_size,
                           time=(t[:, :, :2], l),
                           bow=b[:, :, 0],
                           target_time=t[:, :, -1],
                           target_bow=b[:, :, -1])
            for t, l, b in zip(time, time_len, bow))
예제 #6
0
 def get_single_rnd_batch(self):
     b = self.noised_win_iter.get_single_rnd_batch()
     middle_start, middle_end = self.__compute_middle_range()
     holed_noised, clean_middle = self.__hole_batch(b, middle_start,
                                                    middle_end)
     return Batch.fromvars(b.dataset,
                           b.batch_size,
                           noised=holed_noised,
                           clean=clean_middle)
예제 #7
0
 def __iter__(self):
     middle_start, middle_end = self.__compute_middle_range()
     for b in self.noised_win_iter:
         holed_noised, clean_middle = self.__hole_batch(
             b, middle_start, middle_end)
         yield Batch.fromvars(b.dataset,
                              b.batch_size,
                              noised=holed_noised,
                              clean=clean_middle)
예제 #8
0
 def get_single_rnd_batch(self):
     text_data, valid_noise, dataset = self.__generate_text_data()
     b = rnd.randint(0, len(self))
     start, end, batch_noised, batch_clean = self.__generate_batch(
         b, len(self.dataset[0].text), text_data, valid_noise)
     return Batch.fromvars(dataset,
                           end - start,
                           noised=batch_noised,
                           clean=batch_clean)
예제 #9
0
 def __iter__(self):
     text_data, valid_noise, dataset = self.__generate_text_data()
     for b in range(len(self)):
         start, end, batch_noised, batch_clean = self.__generate_batch(
             b, len(self.dataset[0].text), text_data, valid_noise)
         yield Batch.fromvars(dataset,
                              end - start,
                              noised=batch_noised,
                              clean=batch_clean)
예제 #10
0
    def __iter__(self):
        text = self.dataset[0].text
        TEXT = self.dataset.fields['text']
        CHARS = self.dataset.fields['chars']
        TEXT.eos_token = None
        text = text + ([TEXT.pad_token] * int(
            math.ceil(len(text) / self.batch_size) * self.batch_size -
            len(text)))
        data = TEXT.numericalize([text], device=self.device, train=self.train)

        char_data = [
            CHARS.numericalize([t], device=self.device, train=self.train)
            for t in text
        ]

        max_chars = max([c.shape[0] for c in char_data])

        pad_idx = CHARS.vocab.stoi['<pad>']

        for i, c in enumerate(char_data):
            if c.shape[0] < max_chars:
                padding = Variable(
                    torch.LongTensor(max_chars - c.shape[0],
                                     1).fill_(pad_idx).cuda())
                char_data[i] = torch.cat((c, padding), dim=0)

        char_data = torch.stack(char_data)

        print(char_data.shape)

        char_data = char_data.view(self.batch_size, -1,
                                   max_chars).permute(1, 0, 2).contiguous()

        print(char_data.shape)

        data = data.view(self.batch_size, -1).t().contiguous()

        dataset = Dataset(examples=self.dataset.examples,
                          fields=[('text', TEXT), ('chars', CHARS),
                                  ('target', TEXT)])

        while True:
            for i in range(0, len(self) * self.bptt_len, self.bptt_len):
                self.iterations += 1
                seq_len = min(self.bptt_len, len(data) - i - 1)
                yield Batch.fromvars(dataset,
                                     self.batch_size,
                                     train=self.train,
                                     text=data[i:i + seq_len],
                                     chars=char_data[i:i + seq_len],
                                     target=data[i + 1:i + 1 + seq_len])
            if not self.repeat:
                return
예제 #11
0
    def __iter__(self):
        TEXT = self.dataset.fields['text']
        TEXT.eos_token = None

        while True:
            for _ in range(
                    int(self.dataset.numberOfTokens / self.bsz /
                        self.bptt_len)):
                text = self.dataset.text.__next__()

                data = TEXT.numericalize([text], device=self.device)
                data = data.view(self.bsz, -1).t().contiguous()
                dataset = Dataset(examples=self.dataset.examples,
                                  fields=[('text', TEXT), ('target', TEXT)])
                yield Batch.fromvars(dataset,
                                     self.bsz,
                                     train=self.train,
                                     text=data[:, :-1],
                                     target=data[:, 1:])
            if not self.repeat:
                return
예제 #12
0
def corrupt(batch, config, inputs, eos_id=2, pad_id=1, noise_level=0.5):
    #print(batch.hypothesis.size())
    #Create random tensor of ints from inputs.vocab
    #Sample random tensor of 0-1, with probability given by corruption factor
    #mix batch.hypothesis with this random tensor [Except at EOS and Pad Positions]
    #Return an updated "batch" object with batch.hypothesis replaced by this new hypothesis. batch.label remains the same
    is_eos_or_pad_mask = (batch.hypothesis == eos_id).long() + (
        batch.hypothesis == pad_id).long()
    bernoulli_mask = (torch.rand(batch.hypothesis.size(),
                                 device=batch.hypothesis.device) <
                      noise_level).long()
    if config.noise_type == "uniform":
        replacement_indices = torch.randint(
            low=3,
            high=len(inputs.vocab),
            size=tuple(batch.hypothesis.size()),
            dtype=batch.hypothesis.dtype,
            device=batch.hypothesis.device).long()
    elif config.noise_type == "unigram" or config.noise_type == "uniroot":
        replacement_indices = torch.multinomial(
            config.freq_list_tensor.unsqueeze(dim=0).expand(
                (batch.hypothesis.size()[0], -1)),
            batch.hypothesis.size()[1],
            replacement=True) + 3
        device = torch.device('cuda:{}'.format(config.gpu))
        replacement_indices = replacement_indices.to(device)
    corrupted_hypothesis = (
        1 - bernoulli_mask
    ) * batch.hypothesis + bernoulli_mask * replacement_indices
    corrupted_hypothesis = is_eos_or_pad_mask * batch.hypothesis + (
        1 - is_eos_or_pad_mask) * corrupted_hypothesis
    #print(sum(sum(corrupted_hypothesis!=batch.hypothesis)))
    #print(type(batch))
    corrupted_batch = Batch.fromvars(batch.dataset, batch.batch_size)
    #Create and return corrupted batch using the corrupted hypothesis
    corrupted_batch.label = batch.label
    corrupted_batch.hypothesis = corrupted_hypothesis
    return corrupted_batch
예제 #13
0
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                batch = Batch(minibatch, self.dataset, self.device)
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                batch_size: int = len(batch)
                dataset = Dataset(examples=self.dataset.examples, fields=[
                    ('time', self.dataset.fields['time']), ('mark', self.dataset.fields['mark']),
                    ('target_time', Field(use_vocab=False)), ('target_mark', Field(use_vocab=False))])
                if self.train:
                    seq_len, time, mark = self.__series_2_bptt(batch, batch_size)
                    yield (Batch.fromvars(
                            dataset, batch_size,
                            time=(ti[:, :, :2], l),
                            mark=m[:, :, 0],
                            target_time=ti[:, :, -1],
                            target_mark=m[:, :, -1]) for ti, l, m in zip(time, seq_len, mark))
                else:
                    batch.target_time = batch.time[0][:, :, 2]
                    batch.time = (batch.time[0][:, :, :2], batch.time[1])
                    batch.target_mark = batch.mark[:, :, 1]
                    batch.mark = batch.mark[:, :, 0]
                    yield batch

            if not self.repeat:
                return
예제 #14
0
 def __iter__(self):
     valid_noise = torch.tensor(list(
         range(2, len(self.dataset.fields['text'].vocab))),
                                dtype=torch.long,
                                device=self.device)
     text_field = self.dataset.fields['text']
     dataset = Dataset(examples=self.dataset.examples,
                       fields=[('noised', text_field),
                               ('clean', text_field)])
     if self.shuffle:
         shuffled_examples = [
             self.dataset.examples[i]
             for i in torch.randperm(len(self.dataset))
         ]
     else:
         shuffled_examples = self.dataset.examples
     for b in range(len(self)):
         actual_batch_size = min(self.batch_size,
                                 len(self.dataset) - b * self.batch_size)
         batch_clean = torch.zeros([actual_batch_size, self.window_size],
                                   dtype=torch.long,
                                   device=self.device)
         for i in range(actual_batch_size):
             example_index = b * self.batch_size + i
             batch_clean[i] = text_field.numericalize(
                 shuffled_examples[example_index].text).squeeze(1)
         if self.noise_ratio > 0:
             batch_noised = binary_noise_char_input(batch_clean,
                                                    valid_noise,
                                                    self.noise_ratio)
         else:
             batch_noised = batch_clean
         yield Batch.fromvars(dataset,
                              actual_batch_size,
                              noised=batch_noised,
                              clean=batch_clean)
예제 #15
0
    def __iter__(self):
        with self._elastic.context():
            if self._elastic.skipdone():
                return

            self.batch_size = self._elastic._sync_local_bsz()

            text = self.dataset[0].text
            TEXT = self.dataset.fields['text']
            TEXT.eos_token = None
            text = text + ([TEXT.pad_token] * int(
                math.ceil(len(text) / self.batch_size) * self.batch_size -
                len(text)))
            data = TEXT.numericalize([text], device=self.device)
            data = data.view(self.batch_size, -1).t().contiguous()
            dataset = Dataset(examples=self.dataset.examples,
                              fields=[('text', TEXT), ('target', TEXT)])
            end = data.size(0)  # current length of dataset

            # Change in current batch size changes the dimensions of dataset
            # which changes the starting position in the reshaped dataset. The
            # local batch size is also a function of number of replicas, so
            # when we rescale we need to recalculate where to start the
            # iterations from for the new batch size.
            self._elastic.current_index = \
                self._recompute_start(self._elastic.current_index,
                                      self._elastic.end_index, end)
            self._elastic.end_index = end

            # Every replica reads data strided by bptt_len
            start = self._elastic.current_index + (self.bptt_len * self.rank)
            step = self.bptt_len * self.num_replicas

            # The starting index of the highest rank
            highest_start = self._elastic.current_index + \
                (self.bptt_len * (self.num_replicas - 1))

            # Number of steps we will take on the highest rank. We limit
            # iterations on all replicas by this number to prevent asymmetric
            # reduce ops which would result in a deadlock.
            min_steps_in_epoch = max(math.ceil((end - highest_start) / step),
                                     0)  # noqa: E501
            self.iterations = 0
            while True:
                for i in range(start, end, step):
                    self.iterations += 1
                    # Make sure that _elastic.profile is called equal number of
                    # times on all replicas
                    if self.iterations > min_steps_in_epoch:
                        break
                    with self._elastic.profile(self.training and i > 0):
                        seq_len = min(self.bptt_len, data.size(0) - i - 1)
                        assert seq_len > 0
                        batch_text = data[i:i + seq_len]
                        batch_target = data[i + 1:i + 1 + seq_len]
                        if TEXT.batch_first:
                            batch_text = batch_text.t().contiguous()
                            batch_target = batch_target.t().contiguous()
                        yield Batch.fromvars(dataset,
                                             self.batch_size,
                                             text=batch_text,
                                             target=batch_target)
                        self._elastic.current_index += step

                if not self.repeat:
                    break