def test_batch(self):
     train_batch_sampler = SamplerHelper(self.train_ds)
     batch_size = 3
     batch_sampler = train_batch_sampler.batch(batch_size)
     for i, sample in enumerate(batch_sampler):
         for j, minibatch in enumerate(sample):
             self.check_output_equal(i * batch_size + j, minibatch)
 def test_list(self):
     train_batch_sampler = SamplerHelper(self.train_ds)
     list_sampler = train_batch_sampler.list()
     self.check_output_equal(
         type(iter(list_sampler)).__name__, "list_iterator")
     for i, sample in enumerate(list_sampler):
         self.check_output_equal(i, sample)
 def test_sort_no_buffer_size(self):
     train_ds_len = len(self.train_ds)
     ds_iter = iter(range(train_ds_len - 1, -1, -1))
     train_batch_sampler = SamplerHelper(self.train_ds, ds_iter)
     sort_sampler = train_batch_sampler.sort(
         cmp=lambda x, y, dataset: cmp(x, y))
     for i, sample in enumerate(sort_sampler):
         self.check_output_equal(i, sample)
 def test_apply(self):
     train_ds_len = len(self.train_ds)
     ds_iter = iter(range(train_ds_len - 1, -1, -1))
     train_batch_sampler = SamplerHelper(self.train_ds, ds_iter)
     fn = lambda sampler: SamplerHelper.sort(sampler, cmp=lambda x, y, dataset: cmp(x, y))
     apply_sampler = train_batch_sampler.apply(fn)
     for i, sample in enumerate(apply_sampler):
         self.check_output_equal(i, sample)
    def test_length(self):
        train_batch_sampler = SamplerHelper(self.train_ds)
        self.check_output_equal(len(train_batch_sampler), 25000)
        self.check_output_equal(
            len(train_batch_sampler), train_batch_sampler.length)

        train_batch_sampler.length = 20
        self.check_output_equal(len(train_batch_sampler), 20)
    def test_shard(self):
        train_batch_sampler = SamplerHelper(self.train_ds)
        shard_sampler1 = train_batch_sampler.shard(2, 0)
        shard_sampler2 = train_batch_sampler.shard(2, 1)
        for i, sample in enumerate(shard_sampler1):
            self.check_output_equal(i * 2, sample)

        for i, sample in enumerate(shard_sampler2):
            self.check_output_equal(i * 2 + 1, sample)
 def test_shuffle_buffer_size(self):
     train_batch_sampler = SamplerHelper(self.train_ds)
     shuffle_sampler = train_batch_sampler.shuffle(buffer_size=10, seed=102)
     expected_result = {
         0: 4,
         12000: 12003,
         24999: 24997,
     }
     for i, sample in enumerate(shuffle_sampler):
         if i in expected_result.keys():
             self.check_output_equal(sample, expected_result[i])
    def test_batch_oversize(self):
        train_batch_sampler = SamplerHelper(self.train_ds)
        batch_size = 3
        key = lambda size_so_far, minibatch_len: max(size_so_far, minibatch_len)
        batch_size_fn = lambda new, count, sofar, data_source: len(data_source)

        batch_sampler = train_batch_sampler.batch(
            batch_size, key=key, batch_size_fn=batch_size_fn)
        for i, sample in enumerate(batch_sampler):
            for j, minibatch in enumerate(sample):
                self.check_output_equal(i * batch_size + j, minibatch)
Exemple #9
0
def create_infer_loader(args):
    batch_size = args.batch_size
    max_len = args.max_len

    test_ds = load_dataset('iwslt15', splits='test')
    src_vocab = Vocab.load_vocabulary(**test_ds.vocab_info['en'])
    tgt_vocab = Vocab.load_vocabulary(**test_ds.vocab_info['vi'])
    bos_id = src_vocab[src_vocab.bos_token]
    eos_id = src_vocab[src_vocab.eos_token]
    pad_id = eos_id

    def convert_example(example):
        source = example['en'].split()
        target = example['vi'].split()

        source = src_vocab.to_indices(source)
        target = tgt_vocab.to_indices(target)

        return source, target

    test_ds.map(convert_example)
    test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size)

    test_loader = paddle.io.DataLoader(test_ds,
                                       batch_sampler=test_batch_sampler,
                                       collate_fn=partial(prepare_infer_input,
                                                          bos_id=bos_id,
                                                          eos_id=eos_id,
                                                          pad_id=pad_id))
    return test_loader, len(src_vocab), len(tgt_vocab), bos_id, eos_id
Exemple #10
0
def create_train_loader(args):
    batch_size = args.batch_size
    max_len = args.max_len

    train_ds, dev_ds = load_dataset('iwslt15', splits=('train', 'dev'))
    src_vocab = Vocab.load_vocabulary(**train_ds.vocab_info['en'])
    tgt_vocab = Vocab.load_vocabulary(**train_ds.vocab_info['vi'])
    bos_id = src_vocab[src_vocab.bos_token]
    eos_id = src_vocab[src_vocab.eos_token]
    pad_id = eos_id

    def convert_example(example):
        source = example['en'].split()[:max_len]
        target = example['vi'].split()[:max_len]

        source = src_vocab.to_indices(source)
        target = tgt_vocab.to_indices(target)

        return source, target

    key = (lambda x, data_source: len(data_source[x][0]))

    # Truncate and convert example to ids
    train_ds = train_ds.map(convert_example, lazy=False)
    dev_ds = dev_ds.map(convert_example, lazy=False)

    train_batch_sampler = SamplerHelper(train_ds).shuffle().sort(
        key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)

    dev_batch_sampler = SamplerHelper(dev_ds).sort(
        key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)

    train_loader = paddle.io.DataLoader(train_ds,
                                        batch_sampler=train_batch_sampler,
                                        collate_fn=partial(prepare_train_input,
                                                           bos_id=bos_id,
                                                           eos_id=eos_id,
                                                           pad_id=pad_id))

    dev_loader = paddle.io.DataLoader(dev_ds,
                                      batch_sampler=dev_batch_sampler,
                                      collate_fn=partial(prepare_train_input,
                                                         bos_id=bos_id,
                                                         eos_id=eos_id,
                                                         pad_id=pad_id))

    return train_loader, dev_loader, len(src_vocab), len(tgt_vocab), pad_id
Exemple #11
0
def create_infer_loader(batch_size=128):
    test_ds = load_dataset('couplet', splits='test')
    vocab = Vocab.load_vocabulary(**test_ds.vocab_info)
    pad_id = vocab[vocab.eos_token]
    trans_func = partial(convert_example, vocab=vocab)
    test_ds = test_ds.map(trans_func, lazy=False)
    test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size)

    test_loader = paddle.io.DataLoader(test_ds,
                                       batch_sampler=test_batch_sampler,
                                       collate_fn=partial(prepare_input,
                                                          pad_id=pad_id))
    return test_loader, vocab
Exemple #12
0
def create_train_loader(args):
    batch_size = args.batch_size
    max_len = args.max_len
    src_vocab, tgt_vocab = IWSLT15.get_vocab()
    bos_id = src_vocab[src_vocab.bos_token]
    eos_id = src_vocab[src_vocab.eos_token]
    pad_id = eos_id

    train_ds, dev_ds = IWSLT15.get_datasets(
        mode=["train", "dev"],
        transform_func=[trans_func_tuple, trans_func_tuple])

    key = (lambda x, data_source: len(data_source[x][0]))
    cut_fn = lambda data: (data[0][:max_len], data[1][:max_len])

    train_ds = train_ds.filter(
        lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn)
    dev_ds = dev_ds.filter(
        lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn)
    train_batch_sampler = SamplerHelper(train_ds).shuffle().sort(
        key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)

    dev_batch_sampler = SamplerHelper(dev_ds).sort(
        key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)

    train_loader = paddle.io.DataLoader(
        train_ds,
        batch_sampler=train_batch_sampler,
        collate_fn=partial(
            prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))

    dev_loader = paddle.io.DataLoader(
        dev_ds,
        batch_sampler=dev_batch_sampler,
        collate_fn=partial(
            prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))

    return train_loader, dev_loader, len(src_vocab), len(tgt_vocab), pad_id
Exemple #13
0
def create_infer_loader(batch_size=128):
    test_ds = CoupletDataset.get_datasets(["test"])
    vocab, _ = CoupletDataset.get_vocab()
    pad_id = vocab[CoupletDataset.EOS_TOKEN]
    bos_id = vocab[CoupletDataset.BOS_TOKEN]
    eos_id = vocab[CoupletDataset.EOS_TOKEN]

    test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size)

    test_loader = paddle.io.DataLoader(test_ds,
                                       batch_sampler=test_batch_sampler,
                                       collate_fn=partial(prepare_input,
                                                          pad_id=pad_id))
    return test_loader, len(vocab), pad_id, bos_id, eos_id
Exemple #14
0
def create_train_loader(batch_size=128):
    train_ds = CoupletDataset.get_datasets(["train"])
    vocab, _ = CoupletDataset.get_vocab()
    pad_id = vocab[CoupletDataset.EOS_TOKEN]

    train_batch_sampler = SamplerHelper(train_ds).shuffle().batch(
        batch_size=batch_size)

    train_loader = paddle.io.DataLoader(train_ds,
                                        batch_sampler=train_batch_sampler,
                                        collate_fn=partial(prepare_input,
                                                           pad_id=pad_id))

    return train_loader, len(vocab), pad_id
Exemple #15
0
def create_infer_loader(args):
    batch_size = args.batch_size
    max_len = args.max_len
    trans_func_tuple = IWSLT15.get_default_transform_func()
    test_ds = IWSLT15.get_datasets(
        mode=["test"], transform_func=[trans_func_tuple])
    src_vocab, tgt_vocab = IWSLT15.get_vocab()
    bos_id = src_vocab[src_vocab.bos_token]
    eos_id = src_vocab[src_vocab.eos_token]
    pad_id = eos_id

    test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size)

    test_loader = paddle.io.DataLoader(
        test_ds,
        batch_sampler=test_batch_sampler,
        collate_fn=partial(
            prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))
    return test_loader, len(src_vocab), len(tgt_vocab), bos_id, eos_id
Exemple #16
0
 def test_iter2(self):
     train_batch_sampler = SamplerHelper(self.train_ds)
     for i, sample in enumerate(train_batch_sampler):
         self.check_output_equal(i, sample)
Exemple #17
0
 def test_iter1(self):
     train_ds_len = len(self.train_ds)
     ds_iter = iter(range(train_ds_len - 1, -1, -1))
     train_batch_sampler = SamplerHelper(self.train_ds, ds_iter)
     for i, sample in enumerate(train_batch_sampler):
         self.check_output_equal(i, train_ds_len - 1 - sample)
Exemple #18
0
def create_data_loader(args):
    batch_size = args.batch_size
    max_len = args.max_len
    if args.dataset == 'yahoo':
        train_ds, dev_ds, test_ds = load_dataset('yahoo_answer_100k',
                                                 splits=('train', 'valid',
                                                         'test'))
        vocab = Vocab.load_vocabulary(**train_ds.vocab_info)
    else:
        train_ds, dev_ds, test_ds = load_dataset('ptb',
                                                 splits=('train', 'valid',
                                                         'test'))
        examples = [
            train_ds[i]['sentence'].split() for i in range(len(train_ds))
        ]
        vocab = Vocab.build_vocab(examples)

    vocab_size = len(vocab)
    bos_id = vocab_size
    eos_id = vocab_size + 1
    pad_id = vocab_size + 1

    def convert_example(example):
        features = vocab.to_indices(example['sentence'].split()[:max_len])
        return features

    key = (lambda x, data_source: len(data_source[x]))
    # Truncate and convert example to ids
    train_ds = train_ds.map(convert_example, lazy=False)
    dev_ds = dev_ds.map(convert_example, lazy=False)
    test_ds = test_ds.map(convert_example, lazy=False)

    train_batch_sampler = SamplerHelper(train_ds).shuffle().sort(
        key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)

    dev_batch_sampler = SamplerHelper(dev_ds).sort(
        key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)

    test_batch_sampler = SamplerHelper(dev_ds).sort(
        key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)

    train_loader = paddle.io.DataLoader(train_ds,
                                        batch_sampler=train_batch_sampler,
                                        collate_fn=partial(prepare_train_input,
                                                           bos_id=bos_id,
                                                           eos_id=eos_id,
                                                           pad_id=pad_id))

    dev_loader = paddle.io.DataLoader(dev_ds,
                                      batch_sampler=dev_batch_sampler,
                                      collate_fn=partial(prepare_train_input,
                                                         bos_id=bos_id,
                                                         eos_id=eos_id,
                                                         pad_id=pad_id))

    test_loader = paddle.io.DataLoader(dev_ds,
                                       batch_sampler=dev_batch_sampler,
                                       collate_fn=partial(prepare_train_input,
                                                          bos_id=bos_id,
                                                          eos_id=eos_id,
                                                          pad_id=pad_id))

    return train_loader, dev_loader, test_loader, vocab, bos_id, pad_id, len(
        train_ds)
Exemple #19
0
 def test_shard_default(self):
     train_batch_sampler = SamplerHelper(self.train_ds)
     shard_sampler1 = train_batch_sampler.shard()
     for i, sample in enumerate(shard_sampler1):
         self.check_output_equal(i, sample)