コード例 #1
0
ファイル: iterator.py プロジェクト: Impavidity/pbase
 def __iter__(self):
     while True:
         self.init_epoch()
         for idx, pair in enumerate(self.pairs):
             yield Batch(pair.positive, self.dataset, self.device, True), \
                   Batch(pair.negative, self.dataset, self.device, True)
         break
コード例 #2
0
    def create_batch(self, src_examples, tgt_examples, src_lengths, indices,
                     dataset, fields):
        """ Creates a batch object from previously extracted parallel data.
        Args:
            src_examples(list): list of src sequence tensors
            tgt_examples(list): list of tgt sequence tensors
            src_lenths(list): list of the lengths of each src sequence
            indices(list): list of indices of example instances in dataset
            dataset(:obj:'onmt.io.TextDataset.TextDataset'): dataset object
            fields(list): list of keys of fields in dataset object

        Returns:
            batch(torchtext.data.batch.Batch): batch object
        """
        batch = Batch()
        src_examples, tgt_examples, src_lengths, indices = \
            PairBank.sort(src_examples, tgt_examples, src_lengths, indices)
        src = self.preprocess_side(src_examples)
        tgt_examples = [self.check_sos_eos(ex) for ex in tgt_examples]
        tgt = self.preprocess_side(tgt_examples)
        src_lengths = torch.cat([length for length in src_lengths])
        indices = None

        batch.batch_size = src.size(1)
        batch.dataset = dataset
        batch.fields = fields
        batch.train = True
        batch.src = (src, src_lengths)
        batch.tgt = tgt
        batch.indices = indices

        return batch
コード例 #3
0
ファイル: evaluate.py プロジェクト: logoutlzh/pt.seq2seq
def random_eval(dset, seq2seq, N=3):
    seq2seq.eval()
    src_vocab = dset.fields['src'].vocab
    trg_vocab = dset.fields['trg'].vocab

    examples = np.random.choice(dset.examples, replace=False, size=N).tolist()
    examples = sorted(examples, key=lambda ex: len(ex.src), reverse=True)
    x = Batch(examples, dset, 'cuda')

    # [B, T], [B]
    src, src_lens = x.src
    tgt, tgt_lens = x.trg

    dec_outs, attn_ws = seq2seq.generate(src, src_lens)
    topi = dec_outs.topk(1)[1].squeeze() # [B, max_len, 1]

    for src_idx, tgt_idx, out_idx in zip(src, tgt, topi):
        src_sentence = " ".join(idx2words(src_idx[1:], src_vocab))
        tgt_sentence = " ".join(idx2words(tgt_idx[1:], trg_vocab))
        out_sentence = " ".join(idx2words(out_idx, trg_vocab))

        logger.info("> {}".format(src_sentence))
        logger.info("= {}".format(tgt_sentence))
        logger.info("< {}".format(out_sentence))
        logger.info("")
コード例 #4
0
    def __iter__(self):
        text = self.dataset[0].text
        TEXT = self.dataset.fields['text']
        TEXT.eos_token = None

        num_batches = math.ceil(len(text) / self.batch_size)
        pad_amount = int(num_batches * self.batch_size - len(text))
        text += [TEXT.pad_token] * pad_amount

        data = TEXT.numericalize([text], device=self.device)
        data = data.stack(('seqlen', 'batch'), 'flat') \
                   .split('flat', ('batch', 'seqlen'), batch=self.batch_size) \
                   .transpose('seqlen', 'batch')

        fields = [('text', TEXT), ('target', TEXT)]
        dataset = Dataset(examples=self.dataset.examples, fields=fields)

        while True:
            for i in range(0, len(self) * self.bptt_len, self.bptt_len):
                self.iterations += 1
                seq_len = min(self.bptt_len, len(data) - i - 1)
                yield Batch.fromvars(dataset,
                                     self.batch_size,
                                     text=data.narrow('seqlen', i, seq_len),
                                     target=data.narrow(
                                         'seqlen', i + 1, seq_len))

            if not self.repeat:
                return
コード例 #5
0
ファイル: iterators.py プロジェクト: yushu-liu/quick-nlp
    def __iter__(self) -> Iterator[Batch]:
        """Same iterator almost as bucket iterator"""
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                # fast-forward if loaded from state
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                if self.sort_within_batch:
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)

                context, response, targets = self.process_minibatch(minibatch)
                for index in range(context.shape[0]):
                    # do not yield if the target is just padding (does not provide anything to training)
                    if (targets[index] == self.text_field.vocab.stoi[self.text_field.pad_token]).all():
                        continue
                    # skip examples with contexts that won't fit in gpu memory
                    if np.prod(context[:index + 1].shape) > self.max_context_size:
                        continue
                    yield Batch.fromvars(dataset=self.dataset, batch_size=len(minibatch),
                                         train=self.train,
                                         context=context[:index + 1],
                                         response=response[index],
                                         targets=targets[index]
                                         )
            if not self.repeat:
                raise StopIteration
コード例 #6
0
ファイル: iterators.py プロジェクト: zhangjiekui/quick-nlp
    def __iter__(self) -> Iter[Batch]:
        """Same iterator almost as bucket iterator"""
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                # fast-forward if loaded from state
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                if self.sort_within_batch:
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)

                context, response, targets = self.process_minibatch(minibatch)
                yield Batch.fromvars(dataset=self.dataset,
                                     batch_size=len(minibatch),
                                     train=self.train,
                                     context=context,
                                     response=response,
                                     targets=targets)
            if not self.repeat:
                raise StopIteration
コード例 #7
0
 def __iter__(self):
     text = self.dataset[0].text
     TEXT = self.dataset.fields['text']
     TEXT.eos_token = None
     text = text + ([TEXT.pad_token] * int(
         math.ceil(len(text) / self.batch_size) * self.batch_size -
         len(text)))
     data = TEXT.numericalize([text], device=self.device)
     data = data.view(self.batch_size, -1).t().contiguous()
     dataset = Dataset(examples=self.dataset.examples,
                       fields=[('text', TEXT), ('target', TEXT)])
     while True:
         for i in range(0, len(self) * self.bptt_len, self.bptt_len):
             self.iterations += 1
             seq_len = min(self.bptt_len, len(data) - i - 1)
             batch_text = data[i:i + seq_len]
             batch_target = data[i + 1:i + 1 + seq_len]
             if TEXT.batch_first:
                 batch_text = batch_text.t().contiguous()
                 batch_target = batch_target.t().contiguous()
             yield Batch.fromvars(dataset,
                                  self.batch_size,
                                  text=batch_text,
                                  target=batch_target)
         if not self.repeat:
             return
コード例 #8
0
    def __iter__(self):
        text = self.dataset[0].text
        TEXT = self.dataset.fields["text"]
        TEXT.eos_token = None
        text = text + ([TEXT.pad_token] * int(
            math.ceil(len(text) / self.batch_size) * self.batch_size -
            len(text)))
        data = TEXT.numericalize([text], device=self.device)
        data = (data.stack(
            ("seqlen", "batch"),
            "flat").split("flat", ("batch", "seqlen"),
                          batch=self.batch_size).transpose("seqlen", "batch"))

        dataset = Dataset(examples=self.dataset.examples,
                          fields=[("text", TEXT), ("target", TEXT)])
        while True:
            for i in range(0, len(self) * self.bptt_len, self.bptt_len):
                self.iterations += 1
                seq_len = min(self.bptt_len, len(data) - i - 1)
                yield Batch.fromvars(
                    dataset,
                    self.batch_size,
                    text=data.narrow("seqlen", i, seq_len),
                    target=data.narrow("seqlen", i + 1, seq_len),
                )

            if not self.repeat:
                return
コード例 #9
0
 def consume_buffer(self):
     cur_text_buffer = self.get_contiguous_buffer()
     data, dataset = self.prepare_text_buffer(cur_text_buffer)
     t_len = self.get_len(cur_text_buffer)
     for batch_text, batch_target in self.consume_data(data, t_len):
         kwargs = {self.field_name: batch_text, 'target': batch_target}
         yield Batch.fromvars(
             dataset,
             self.batch_size,
             **kwargs
         )
コード例 #10
0
 def consume_buffer(self):
     self.prepare_buffer()
     self.create_batches()
     for minibatch in self.batches:
         self.iterations += 1
         self._iterations_this_epoch += 1
         if self.sort_within_batch:
             if self.sort:
                 minibatch.reverse()
             else:
                 minibatch.sort(key=self.sort_key, reverse=True)
         yield Batch(minibatch, self.dataset, self.device)
コード例 #11
0
    def __iter__(self):
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                # fast-forward if loaded from state
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)
                source_batch = [m.source for m in minibatch]
                source_mask = sequence_mask(source_batch)
                if not self.mode == "infer":
                    target_batch = [m.target for m in minibatch]
                    label_batch = [m.label for m in minibatch]
                    target_mask = sequence_mask(target_batch)
                    yield Batch.fromvars(
                        self.dataset,
                        self.batch_size,
                        source=postprocessing(source_batch, self.params),
                        source_mask=source_mask,
                        target=postprocessing(target_batch, self.params),
                        target_mask=target_mask,
                        label=postprocessing(label_batch, self.params))
                else:
                    yield Batch.fromvars(self.dataset,
                                         self.batch_size,
                                         source=postprocessing(
                                             source_batch, self.params),
                                         source_mask=source_mask)

            if not self.repeat:
                return
コード例 #12
0
ファイル: trainer.py プロジェクト: clay-lab/transductions
    def batchify(self, args):
        """
    Turn the REPL input into a batch for the model to process.
    """

        transf, source = args.split(' ', 1)

        source = source.split(' ')
        source.append('<eos>')
        source.insert(0, '<sos>')
        source = [self._model._encoder.to_ids([s]) for s in source]
        source = torch.LongTensor(source)

        transf = ['<sos>', transf, '<eos>']
        transf = [[self._dataset.transform_field.vocab.stoi[t]]
                  for t in transf]
        transf = torch.LongTensor(transf)

        batch = Batch()
        batch.source = source
        batch.annotation = transf

        return batch
コード例 #13
0
 def __iter__(self):
     text = getattr(self.dataset[0], self.field_name)
     data, dataset = self.prepare_text(text)
     while True:
         for i in range(0, len(self) * self.cur_bptt_len, self.cur_bptt_len):
             self.iterations += 1
             seq_len = min(self.cur_bptt_len, len(data) - i - 1)
             batch_text = data[i:i + seq_len]
             batch_target = data[i + 1:i + 1 + seq_len]
             if self.batch_first:
                 batch_text = batch_text.t().contiguous()
                 batch_target = batch_target.t().contiguous()
             yield Batch.fromvars(
                 dataset,
                 self.batch_size,
                 text=batch_text,
                 target=batch_target
             )
         if not self.repeat:
             return
コード例 #14
0
 def __iter__(self):
     text = self.dataset[0].text
     TEXT = self.dataset.fields['text']
     TEXT.eos_token = None
     pad_num = int(math.ceil(len(text) / self.batch_size) * self.batch_size \
                   - len(text))
     text = text + ([TEXT.pad_token] * pad_num)
     data = TEXT.numericalize([text], device=self.device)
     data = data.view(self.batch_size, -1).contiguous()
     dataset = Dataset(examples=self.dataset.examples,
                       fields=[('text', TEXT), ('target', TEXT)])
     while True:
         for i in range(0, len(self) * self.bptt_len, self.bptt_len):
             self.iterations += 1
             seq_len = self.bptt_len
             yield Batch.fromvars(dataset,
                                  self.batch_size,
                                  text=data[:, i:i + seq_len],
                                  target=data[:, i + 1:i + 1 + seq_len])
         if not self.repeat:
             return
コード例 #15
0
 def __iter__(self):
     while True:
         self.init_epoch()
         for idx, minibatch in enumerate(self.batches):
             if len(minibatch) == 0:
                 continue
             # fast-forward if loaded from state
             if self._iterations_this_epoch > idx:
                 continue
             self.iterations += 1
             self._iterations_this_epoch += 1
             if self.sort_within_batch:
                 # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                 # be sorted by decreasing order, which requires reversing
                 # relative to typical sort keys
                 if self.sort:
                     minibatch.reverse()
                 else:
                     minibatch.sort(key=self.sort_key, reverse=True)
             yield Batch(minibatch, self.dataset, self.device)
         if not self.repeat:
             return
コード例 #16
0
    def _transform(self, batch):

        src, src_lens = batch.src
        src_size = src.size()
        src = torch.LongTensor([
            self.src_b2s[i] for i in src.data.view(-1).tolist()
        ]).view(src_size)

        trg, trg_lens = batch.trg
        trg_size = trg.size()
        trg = torch.LongTensor([
            self.trg_b2s[i] for i in trg.data.view(-1).tolist()
        ]).view(trg_size)

        if self.use_cuda:
            src = src.cuda()
            trg = trg.cuda()

        return Batch.fromvars(batch.dataset,
                              batch.batch_size,
                              batch.train,
                              src=(src, src_lens),
                              trg=(trg, trg_lens))
コード例 #17
0
    zip(["<bos>"] + t2 + ["<eos>"], ss0[0].tolist(), ss1[0, :, :-1].tolist(),
        ss3[0].tolist(), ss2[0].tolist(),
        psi_ys[0, :, 1:, 1:].transpose(-1, -2).tolist()))
print("## model comparison ##")
for w, x, y, a, z, p in ok:
    print(
        f"{w:<10}:\t[{' '.join(map(f, x))}]\t[{' '.join(map(f, y))}]\t[{' '.join(map(f, a))}]\t[{' '.join(map(f, z))}]"
    )
print()

import pdb
pdb.set_trace()
print("===== Comparisons on valid between models =====")
dataset = valid
for i in range(0, len(dataset)):
    example = Batch([dataset[i]], dataset, device)
    if example.sentiments.item() != 0:
        s0 = lstmfinal.observe(example.text[0], example.text[1],
                               example.locations, example.aspects,
                               example.sentiments)
        s1 = crflstmdiag.observe(example.text[0], example.text[1],
                                 example.locations, example.aspects,
                                 example.sentiments)
        s2, psi_ys = crfemblstm.observe(example.text[0], example.text[1],
                                        example.locations, example.aspects,
                                        example.sentiments)
        s3, psi_ys1 = crflstmlstm.observe(example.text[0], example.text[1],
                                          example.locations, example.aspects,
                                          example.sentiments)
        #import pdb; pdb.set_trace()
コード例 #18
0
def test_single_gpu_batch_parse():
    trainer = Trainer(gpus=1)

    # non-transferrable types
    primitive_objects = [None, {}, [], 1.0, "x", [None, 2], {"x": (1, 2), "y": None}]
    for batch in primitive_objects:
        data = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
        assert data == batch

    # batch is just a tensor
    batch = torch.rand(2, 3)
    batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
    assert batch.device.index == 0 and batch.type() == 'torch.cuda.FloatTensor'

    # tensor list
    batch = [torch.rand(2, 3), torch.rand(2, 3)]
    batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
    assert batch[0].device.index == 0 and batch[0].type() == 'torch.cuda.FloatTensor'
    assert batch[1].device.index == 0 and batch[1].type() == 'torch.cuda.FloatTensor'

    # tensor list of lists
    batch = [[torch.rand(2, 3), torch.rand(2, 3)]]
    batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
    assert batch[0][0].device.index == 0 and batch[0][0].type() == 'torch.cuda.FloatTensor'
    assert batch[0][1].device.index == 0 and batch[0][1].type() == 'torch.cuda.FloatTensor'

    # tensor dict
    batch = [{'a': torch.rand(2, 3), 'b': torch.rand(2, 3)}]
    batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
    assert batch[0]['a'].device.index == 0 and batch[0]['a'].type() == 'torch.cuda.FloatTensor'
    assert batch[0]['b'].device.index == 0 and batch[0]['b'].type() == 'torch.cuda.FloatTensor'

    # tuple of tensor list and list of tensor dict
    batch = ([torch.rand(2, 3) for _ in range(2)], [{'a': torch.rand(2, 3), 'b': torch.rand(2, 3)} for _ in range(2)])
    batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
    assert batch[0][0].device.index == 0 and batch[0][0].type() == 'torch.cuda.FloatTensor'

    assert batch[1][0]['a'].device.index == 0
    assert batch[1][0]['a'].type() == 'torch.cuda.FloatTensor'

    assert batch[1][0]['b'].device.index == 0
    assert batch[1][0]['b'].type() == 'torch.cuda.FloatTensor'

    # namedtuple of tensor
    BatchType = namedtuple('BatchType', ['a', 'b'])
    batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)]
    batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))
    assert batch[0].a.device.index == 0
    assert batch[0].a.type() == 'torch.cuda.FloatTensor'

    # non-Tensor that has `.to()` defined
    class CustomBatchType:

        def __init__(self):
            self.a = torch.rand(2, 2)

        def to(self, *args, **kwargs):
            self.a = self.a.to(*args, **kwargs)
            return self

    batch = trainer.accelerator.batch_to_device(CustomBatchType(), torch.device('cuda:0'))
    assert batch.a.type() == 'torch.cuda.FloatTensor'

    # torchtext.data.Batch
    samples = [{
        'text': 'PyTorch Lightning is awesome!',
        'label': 0
    }, {
        'text': 'Please make it work with torchtext',
        'label': 1
    }]

    text_field = Field()
    label_field = LabelField()
    fields = {'text': ('text', text_field), 'label': ('label', label_field)}

    examples = [Example.fromdict(sample, fields) for sample in samples]
    dataset = Dataset(examples=examples, fields=fields.values())

    # Batch runs field.process() that numericalizes tokens, but it requires to build dictionary first
    text_field.build_vocab(dataset)
    label_field.build_vocab(dataset)

    batch = Batch(data=examples, dataset=dataset)
    batch = trainer.accelerator.batch_to_device(batch, torch.device('cuda:0'))

    assert batch.text.type() == 'torch.cuda.LongTensor'
    assert batch.label.type() == 'torch.cuda.LongTensor'
コード例 #19
0
    def to_batch(self,
                 device,
                 fields=None,
                 parent_label=None,
                 x2y='d2e',
                 index=0):
        if fields is None:
            fields = self.fields

        assert self.fields is not None, 'you need to call after setting fields with Doc.set_fields()'

        if x2y in ['d2e', 'd2p', 'd2s']:
            starts_sentence = self.starts_sentence
            starts_paragraph = self.starts_paragraph
            tokens = self.tokens
            if x2y == 'd2e':
                starts_xxx = self.starts_edu
            elif x2y == 'd2s':
                starts_xxx = self.starts_sentence
            elif x2y == 'd2p':
                starts_xxx = self.starts_paragraph

        start_offset = 0
        if x2y in ['p2s', 's2e']:

            def get_span(hierarchical_type, word_starts_sentence,
                         word_starts_paragraph, index):
                if hierarchical_type.startswith('p'):
                    flags = word_starts_paragraph
                elif hierarchical_type.startswith('s'):
                    flags = word_starts_sentence

                # flags: [True, False, False, False, True, False, True, ...]
                # index: 0 -> start, end = 0, 4
                # index: 1 -> start, end = 4, 6
                count = 0
                start, end = 0, len(flags)
                for i, f in enumerate(flags):
                    if not f:
                        continue
                    if count == index:
                        start = i
                    if count == index + 1:
                        end = i
                        break
                    count += 1
                return start, end

            start, end = get_span(x2y, self.starts_sentence,
                                  self.starts_paragraph, index)
            start_offset = start  # to make text span
            starts_sentence = self.starts_sentence[start:end]
            starts_paragraph = self.starts_paragraph[start:end]
            tokens = self.tokens[start:end]

            if x2y == 'p2s':
                starts_xxx = self.starts_sentence[start:end]
            elif x2y == 's2e':
                starts_xxx = self.starts_edu[start:end]

        # tokenized_strings: List of edus
        # raw_tokenized_strings: List of edus splitted by white-space
        # starts_*: List of bool value representing start of *
        tokenized_strings = self.make_edu(tokens, starts_xxx)
        raw_tokenized_strings = [edu.split() for edu in tokenized_strings]
        starts_sentence = self.make_starts(starts_xxx, starts_sentence)
        starts_paragraph = self.make_starts(starts_xxx, starts_paragraph)
        parent_label = self.parent_label if parent_label is None else parent_label
        spans, _ = self.make_text_span(tokenized_strings, start_offset,
                                       self.doc_id)

        assert len(tokenized_strings) == len(starts_sentence) == len(starts_paragraph), \
            'num input seqs not same'

        example = Example.fromdict(
            {
                'doc_id': self.doc_id,
                'labelled_attachment_tree':
                '(nucleus-nucleus:Elaboration (text 1) (text 2))',  # DummyTree
                'tokenized_strings': tokenized_strings,
                'raw_tokenized_strings': raw_tokenized_strings,
                'spans': spans,
                'starts_sentence': starts_sentence,
                'starts_paragraph': starts_paragraph,
                'parent_label': parent_label,
            },
            fields)

        if isinstance(fields, dict):  # copy from torchtext.data.TabularDataset
            fields, field_dict = [], fields
            for field in field_dict.values():
                if isinstance(field, list):
                    fields.extend(field)
                else:
                    fields.append(field)

        dataset = Dataset([example], fields)
        batch = Batch([example], dataset, device=device)
        batch.tree = None
        return batch
コード例 #20
0
 def torchtext_collate(data):
     b = Batch(data, train_data)
     return {'src': b.src, 'trg': b.trg}
コード例 #21
0
    def __iter__(self):

        # set to "0" to skip reporting of minibatch processing overhead
        report_every = 1000

        start_wall = time.time()
        total_convert = 0
        total_both = 0

        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self.batches):
                # fast-forward if loaded from state
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1

                start = time.time()

                # apply JIT processing to convert from index values to x,y tensors
                minibatch = self.convert_minibatch(minibatch)

                total_convert += (time.time() - start)

                if self.sort_within_batch:
                    # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
                    # be sorted by decreasing order, which requires reversing
                    # relative to typical sort keys
                    if self.sort:
                        minibatch.reverse()
                    else:
                        minibatch.sort(key=self.sort_key, reverse=True)

                # process() the minibatch now
                mini = Batch(minibatch, self.dataset, self.device)

                total_both += (time.time() - start)

                if idx and report_every and idx % report_every == 0:
                    wall_elapsed = time.time() - start_wall

                    percent = 100 * total_convert / wall_elapsed
                    print(
                        "{} minibatch CONVERT: total={:.2f}, wall={:.2f}, overhead={:.2f} %"
                        .format(report_every, total_convert, wall_elapsed,
                                percent))

                    percent = 100 * total_both / wall_elapsed
                    print(
                        "{} minibatch CONVERT+PROCESS: total={:.2f}, wall={:.2f}, overhead={:.2f} %"
                        .format(report_every, total_both, wall_elapsed,
                                percent))

                    total_both = 0
                    total_convert = 0
                    start_wall = time.time()

                yield mini
            if not self.repeat:
                return
コード例 #22
0
def process(split='train'):

    en_fname = os.path.join(args.data_dir, f'{split}', f'{split}.en')
    de_fname = os.path.join(args.data_dir, f'{split}', f'{split}.de')

    en_fields = ('src', TEXT_en)
    de_fields = ('tar', TEXT_de)

    examples = []
    with open(en_fname, 'r') as f_en, \
         open(de_fname, 'r') as f_de:

        for line_en, line_de in tqdm(zip(f_en, f_de),
                                     total=get_num_lines(en_fname)):
            ex = Example.fromlist([line_en, line_de], [en_fields, de_fields])
            examples.append(ex)  # Examples stores tokenized sequential data

    ds_train = Dataset(examples=examples, fields=[en_fields, de_fields])
    # Till now, no tokenization, no vocabulary

    start_time = time.time()
    TEXT_en.build_vocab(ds_train)
    print("EN Vocab Built. Time Taken:{}s".format(time.time() - start_time))

    start_time = time.time()
    TEXT_de.build_vocab(ds_train)
    print("DE Vocab Built. Time Taken:{}s".format(time.time() - start_time))

    sorted(
        examples,
        key=lambda x: len(x.src))  # x.src is a list of tokenized src sentence

    # [TODO] Handle Long sentences (some ignore field is present)

    idx, B = 0, args.batch_size
    out_base_dir = os.path.join(os.path.join(args.data_dir, "tokenzied"))
    out_dir = os.path.join(os.path.join(out_base_dir, f"{split}"))
    os.makedirs(out_dir, exist_ok=True)
    file_name_idx = 0

    while idx < len(examples):
        end = min(idx + B, len(examples))
        # [TODO] Can improve batching so as to reduce padding, save some space
        batch = Batch(
            data=examples[idx:end],
            dataset=ds_train,  # Most likely used to access the Fields 
            device=device)
        idx += B
        data_en, data_de = batch.src.numpy(), batch.tar.numpy()

        with open(os.path.join(out_dir, f"{split}_{file_name_idx:02}"),
                  "wb") as f:
            np.savez(f, src=data_en, tar=data_de)
            file_name_idx += 1

    if split == "train":
        vocab = {
            "en": {
                "stoi": TEXT_en.vocab.stoi,
                "itos": TEXT_en.vocab.itos
            },
            "de": {
                "stoi": TEXT_de.vocab.stoi,
                "itos": TEXT_de.vocab.itos
            }
        }

        with open(os.path.join(out_base_dir, "vocab.pth"), "wb") as f:
            pickle.dump(vocab, f)