コード例 #1
0
ファイル: data.py プロジェクト: justinchiu/hmmlm-jax
 def __iter__(self):
     text = self.dataset[0].text
     TEXT = self.dataset.fields['text']
     TEXT.eos_token = None
     text = text + ([TEXT.pad_token] * int(
         math.ceil(len(text) / self.batch_size) * self.batch_size -
         len(text)))
     data = TEXT.numericalize([text], device=self.device)
     data = data.view(self.batch_size, -1).t().contiguous()
     dataset = Dataset(examples=self.dataset.examples,
                       fields=[('text', TEXT), ('target', TEXT)])
     while True:
         for i in range(0, len(self) * self.bptt_len, self.bptt_len):
             self.iterations += 1
             seq_len = min(self.bptt_len, len(data) - i - 1)
             batch_text = data[i:i + seq_len]
             batch_target = data[i + 1:i + 1 + seq_len]
             if TEXT.batch_first:
                 batch_text = batch_text.t().contiguous()
                 batch_target = batch_target.t().contiguous()
             yield Batch.fromvars(dataset,
                                  self.batch_size,
                                  text=batch_text,
                                  target=batch_target)
         if not self.repeat:
             return
コード例 #2
0
    def create_boosted_dataset(self, new_training_data: list) -> DataIterator:
        """
        Create a new Dataset and DataIterator with the new hard training instances

        Arguments:
            new_training_data: list of new training data to create a new `Dataset` object and a new `DataIterator`

        Returns:
            A new DataIterator
        """

        # create
        dataset = Dataset(new_training_data,
                          fields=[("src", self.params.SRC),
                                  ("trg", self.params.TRG)])

        data_iterator = DataIterator(dataset,
                                     batch_size=self.params.train_batch_size,
                                     device=self.params.device,
                                     repeat=False,
                                     sort_key=lambda x:
                                     (len(x.src), len(x.trg)),
                                     batch_size_fn=batch_size_fn,
                                     train=True,
                                     sort_within_batch=True,
                                     shuffle=True)
        return data_iterator
コード例 #3
0
ファイル: data.py プロジェクト: justinchiu/hmmlm-jax
    def __iter__(self):
        TEXT = self.dataset.fields['text']
        TEXT.eos_token = None
        # ^ not sure what this is for

        dataset = Dataset(
            examples=self.dataset.examples,
            fields=[
                ('text', TEXT),  #('target', TEXT)
            ],
        )

        # double check this is deterministic during eval
        raw_data = self.data()
        # concatenate everything
        text = [word for ex in raw_data for word in ex.text]
        text = text + ([TEXT.pad_token] * int(
            math.ceil(len(text) / self.batch_size) * self.batch_size -
            len(text)))
        data = TEXT.numericalize([text], device=self.device)
        data = data.view(self.batch_size, -1).t()
        eos_token = TEXT.vocab["<eos>"]
        data = th.cat([
            th.empty(
                self.batch_size,
                device=data.device,
                dtype=data.dtype,
            )[None].fill_(eos_token),
            data,
        ], 0)

        while True:
            for i in range(1, len(self) * self.bptt_len, self.bptt_len):
                self.iterations += 1
                """
                seq_len = min(self.bptt_len, len(data) - i - 1)
                batch_text = data[i:i + seq_len]
                batch_target = data[i + 1:i + 1 + seq_len]
                if TEXT.batch_first:
                    batch_text = batch_text.t().contiguous()
                    batch_target = batch_target.t().contiguous()
                yield Batch.fromvars(
                    dataset, self.batch_size,
                    text=batch_text,
                    target=batch_target)
                """
                seq_len = min(self.bptt_len, len(data) - i)
                batch_text = data[i:i + seq_len]
                batch_textp1 = data[i - 1:i + seq_len]
                if TEXT.batch_first:
                    batch_text = batch_text.t().contiguous()
                    batch_textp1 = batch_textp1.t().contiguous()
                yield Batch.fromvars(
                    dataset,
                    self.batch_size,
                    text=batch_text,
                    textp1=batch_textp1,
                )
            if not self.repeat:
                return
コード例 #4
0
def _bunch_to_ds(bunch: Bunch, text: Field, label: LabelField) -> Dataset:
    r""" Converts the \p bunch to a classification dataset """
    fields = [('text', text), ('label', label)]
    examples = [
        Example.fromlist(x, fields)
        for x in zip(bunch[DATA_COL], bunch[LABEL_COL])
    ]
    return Dataset(examples, fields)
コード例 #5
0
    def __iter__(self):
        text = self.dataset[0].text
        TEXT = self.dataset.fields['text']
        CHARS = self.dataset.fields['chars']
        TEXT.eos_token = None
        text = text + ([TEXT.pad_token] * int(
            math.ceil(len(text) / self.batch_size) * self.batch_size -
            len(text)))
        data = TEXT.numericalize([text], device=self.device, train=self.train)

        char_data = [
            CHARS.numericalize([t], device=self.device, train=self.train)
            for t in text
        ]

        max_chars = max([c.shape[0] for c in char_data])

        pad_idx = CHARS.vocab.stoi['<pad>']

        for i, c in enumerate(char_data):
            if c.shape[0] < max_chars:
                padding = Variable(
                    torch.LongTensor(max_chars - c.shape[0],
                                     1).fill_(pad_idx).cuda())
                char_data[i] = torch.cat((c, padding), dim=0)

        char_data = torch.stack(char_data)

        print(char_data.shape)

        char_data = char_data.view(self.batch_size, -1,
                                   max_chars).permute(1, 0, 2).contiguous()

        print(char_data.shape)

        data = data.view(self.batch_size, -1).t().contiguous()

        dataset = Dataset(examples=self.dataset.examples,
                          fields=[('text', TEXT), ('chars', CHARS),
                                  ('target', TEXT)])

        while True:
            for i in range(0, len(self) * self.bptt_len, self.bptt_len):
                self.iterations += 1
                seq_len = min(self.bptt_len, len(data) - i - 1)
                yield Batch.fromvars(dataset,
                                     self.batch_size,
                                     train=self.train,
                                     text=data[i:i + seq_len],
                                     chars=char_data[i:i + seq_len],
                                     target=data[i + 1:i + 1 + seq_len])
            if not self.repeat:
                return
コード例 #6
0
 def __generate_text_data(self):
     text_field = self.dataset.fields['text']
     raw_text = self.dataset[0].text
     text_data = text_field.numericalize([raw_text],
                                         device=self.device).squeeze(0)
     valid_noise = torch.tensor(list(
         range(2, len(self.dataset.fields['text'].vocab))),
                                dtype=torch.long,
                                device=self.device)
     dataset = Dataset(examples=self.dataset.examples,
                       fields=[('noised', text_field),
                               ('clean', text_field)])
     return text_data, valid_noise, dataset
コード例 #7
0
    def load(cls, args: Namespace):
        r""" Load the serialized newsgroups dataset """
        path = cls._pickle_filename(args)

        with open(str(path), "rb") as f_in:
            flds = pk.load(f_in)
        newsgroup = cls(text=flds["text"], label=flds["label"])

        for key in vars(newsgroup).keys():
            if newsgroup.__getattribute__(key) is not None:
                continue
            newsgroup.__setattr__(key,
                                  Dataset(flds[key], newsgroup.build_fields()))
        return newsgroup
コード例 #8
0
def get_data_set(path,
                 tokenizer,
                 text_field,
                 label_field,
                 max_len,
                 max_num=0,
                 include_neutual=True):
    fields = [('text', text_field), ('label', label_field)]
    examples = get_examples(path,
                            tokenizer,
                            max_len=max_len,
                            max_num=max_num,
                            include_neutual=include_neutual)
    data_set = Dataset(examples=examples, fields=fields)
    return data_set
コード例 #9
0
    def __iter__(self):
        TEXT = self.dataset.fields['text']
        TEXT.eos_token = None

        while True:
            for _ in range(
                    int(self.dataset.numberOfTokens / self.bsz /
                        self.bptt_len)):
                text = self.dataset.text.__next__()

                data = TEXT.numericalize([text], device=self.device)
                data = data.view(self.bsz, -1).t().contiguous()
                dataset = Dataset(examples=self.dataset.examples,
                                  fields=[('text', TEXT), ('target', TEXT)])
                yield Batch.fromvars(dataset,
                                     self.bsz,
                                     train=self.train,
                                     text=data[:, :-1],
                                     target=data[:, 1:])
            if not self.repeat:
                return
コード例 #10
0
 def __iter__(self):
     valid_noise = torch.tensor(list(
         range(2, len(self.dataset.fields['text'].vocab))),
                                dtype=torch.long,
                                device=self.device)
     text_field = self.dataset.fields['text']
     dataset = Dataset(examples=self.dataset.examples,
                       fields=[('noised', text_field),
                               ('clean', text_field)])
     if self.shuffle:
         shuffled_examples = [
             self.dataset.examples[i]
             for i in torch.randperm(len(self.dataset))
         ]
     else:
         shuffled_examples = self.dataset.examples
     for b in range(len(self)):
         actual_batch_size = min(self.batch_size,
                                 len(self.dataset) - b * self.batch_size)
         batch_clean = torch.zeros([actual_batch_size, self.window_size],
                                   dtype=torch.long,
                                   device=self.device)
         for i in range(actual_batch_size):
             example_index = b * self.batch_size + i
             batch_clean[i] = text_field.numericalize(
                 shuffled_examples[example_index].text).squeeze(1)
         if self.noise_ratio > 0:
             batch_noised = binary_noise_char_input(batch_clean,
                                                    valid_noise,
                                                    self.noise_ratio)
         else:
             batch_noised = batch_clean
         yield Batch.fromvars(dataset,
                              actual_batch_size,
                              noised=batch_noised,
                              clean=batch_clean)
コード例 #11
0
    def __iter__(self):
        with self._elastic.context():
            if self._elastic.skipdone():
                return

            self.batch_size = self._elastic._sync_local_bsz()

            text = self.dataset[0].text
            TEXT = self.dataset.fields['text']
            TEXT.eos_token = None
            text = text + ([TEXT.pad_token] * int(
                math.ceil(len(text) / self.batch_size) * self.batch_size -
                len(text)))
            data = TEXT.numericalize([text], device=self.device)
            data = data.view(self.batch_size, -1).t().contiguous()
            dataset = Dataset(examples=self.dataset.examples,
                              fields=[('text', TEXT), ('target', TEXT)])
            end = data.size(0)  # current length of dataset

            # Change in current batch size changes the dimensions of dataset
            # which changes the starting position in the reshaped dataset. The
            # local batch size is also a function of number of replicas, so
            # when we rescale we need to recalculate where to start the
            # iterations from for the new batch size.
            self._elastic.current_index = \
                self._recompute_start(self._elastic.current_index,
                                      self._elastic.end_index, end)
            self._elastic.end_index = end

            # Every replica reads data strided by bptt_len
            start = self._elastic.current_index + (self.bptt_len * self.rank)
            step = self.bptt_len * self.num_replicas

            # The starting index of the highest rank
            highest_start = self._elastic.current_index + \
                (self.bptt_len * (self.num_replicas - 1))

            # Number of steps we will take on the highest rank. We limit
            # iterations on all replicas by this number to prevent asymmetric
            # reduce ops which would result in a deadlock.
            min_steps_in_epoch = max(math.ceil((end - highest_start) / step),
                                     0)  # noqa: E501
            self.iterations = 0
            while True:
                for i in range(start, end, step):
                    self.iterations += 1
                    # Make sure that _elastic.profile is called equal number of
                    # times on all replicas
                    if self.iterations > min_steps_in_epoch:
                        break
                    with self._elastic.profile(self.training and i > 0):
                        seq_len = min(self.bptt_len, data.size(0) - i - 1)
                        assert seq_len > 0
                        batch_text = data[i:i + seq_len]
                        batch_target = data[i + 1:i + 1 + seq_len]
                        if TEXT.batch_first:
                            batch_text = batch_text.t().contiguous()
                            batch_target = batch_target.t().contiguous()
                        yield Batch.fromvars(dataset,
                                             self.batch_size,
                                             text=batch_text,
                                             target=batch_target)
                        self._elastic.current_index += step

                if not self.repeat:
                    break