def test_batch(self): train_batch_sampler = SamplerHelper(self.train_ds) batch_size = 3 batch_sampler = train_batch_sampler.batch(batch_size) for i, sample in enumerate(batch_sampler): for j, minibatch in enumerate(sample): self.check_output_equal(i * batch_size + j, minibatch)
def test_list(self): train_batch_sampler = SamplerHelper(self.train_ds) list_sampler = train_batch_sampler.list() self.check_output_equal( type(iter(list_sampler)).__name__, "list_iterator") for i, sample in enumerate(list_sampler): self.check_output_equal(i, sample)
def test_sort_no_buffer_size(self): train_ds_len = len(self.train_ds) ds_iter = iter(range(train_ds_len - 1, -1, -1)) train_batch_sampler = SamplerHelper(self.train_ds, ds_iter) sort_sampler = train_batch_sampler.sort( cmp=lambda x, y, dataset: cmp(x, y)) for i, sample in enumerate(sort_sampler): self.check_output_equal(i, sample)
def test_apply(self): train_ds_len = len(self.train_ds) ds_iter = iter(range(train_ds_len - 1, -1, -1)) train_batch_sampler = SamplerHelper(self.train_ds, ds_iter) fn = lambda sampler: SamplerHelper.sort(sampler, cmp=lambda x, y, dataset: cmp(x, y)) apply_sampler = train_batch_sampler.apply(fn) for i, sample in enumerate(apply_sampler): self.check_output_equal(i, sample)
def test_length(self): train_batch_sampler = SamplerHelper(self.train_ds) self.check_output_equal(len(train_batch_sampler), 25000) self.check_output_equal( len(train_batch_sampler), train_batch_sampler.length) train_batch_sampler.length = 20 self.check_output_equal(len(train_batch_sampler), 20)
def test_shard(self): train_batch_sampler = SamplerHelper(self.train_ds) shard_sampler1 = train_batch_sampler.shard(2, 0) shard_sampler2 = train_batch_sampler.shard(2, 1) for i, sample in enumerate(shard_sampler1): self.check_output_equal(i * 2, sample) for i, sample in enumerate(shard_sampler2): self.check_output_equal(i * 2 + 1, sample)
def test_shuffle_buffer_size(self): train_batch_sampler = SamplerHelper(self.train_ds) shuffle_sampler = train_batch_sampler.shuffle(buffer_size=10, seed=102) expected_result = { 0: 4, 12000: 12003, 24999: 24997, } for i, sample in enumerate(shuffle_sampler): if i in expected_result.keys(): self.check_output_equal(sample, expected_result[i])
def test_batch_oversize(self): train_batch_sampler = SamplerHelper(self.train_ds) batch_size = 3 key = lambda size_so_far, minibatch_len: max(size_so_far, minibatch_len) batch_size_fn = lambda new, count, sofar, data_source: len(data_source) batch_sampler = train_batch_sampler.batch( batch_size, key=key, batch_size_fn=batch_size_fn) for i, sample in enumerate(batch_sampler): for j, minibatch in enumerate(sample): self.check_output_equal(i * batch_size + j, minibatch)
def create_infer_loader(args): batch_size = args.batch_size max_len = args.max_len test_ds = load_dataset('iwslt15', splits='test') src_vocab = Vocab.load_vocabulary(**test_ds.vocab_info['en']) tgt_vocab = Vocab.load_vocabulary(**test_ds.vocab_info['vi']) bos_id = src_vocab[src_vocab.bos_token] eos_id = src_vocab[src_vocab.eos_token] pad_id = eos_id def convert_example(example): source = example['en'].split() target = example['vi'].split() source = src_vocab.to_indices(source) target = tgt_vocab.to_indices(target) return source, target test_ds.map(convert_example) test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size) test_loader = paddle.io.DataLoader(test_ds, batch_sampler=test_batch_sampler, collate_fn=partial(prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id)) return test_loader, len(src_vocab), len(tgt_vocab), bos_id, eos_id
def create_train_loader(args): batch_size = args.batch_size max_len = args.max_len train_ds, dev_ds = load_dataset('iwslt15', splits=('train', 'dev')) src_vocab = Vocab.load_vocabulary(**train_ds.vocab_info['en']) tgt_vocab = Vocab.load_vocabulary(**train_ds.vocab_info['vi']) bos_id = src_vocab[src_vocab.bos_token] eos_id = src_vocab[src_vocab.eos_token] pad_id = eos_id def convert_example(example): source = example['en'].split()[:max_len] target = example['vi'].split()[:max_len] source = src_vocab.to_indices(source) target = tgt_vocab.to_indices(target) return source, target key = (lambda x, data_source: len(data_source[x][0])) # Truncate and convert example to ids train_ds = train_ds.map(convert_example, lazy=False) dev_ds = dev_ds.map(convert_example, lazy=False) train_batch_sampler = SamplerHelper(train_ds).shuffle().sort( key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size) dev_batch_sampler = SamplerHelper(dev_ds).sort( key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size) train_loader = paddle.io.DataLoader(train_ds, batch_sampler=train_batch_sampler, collate_fn=partial(prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id)) dev_loader = paddle.io.DataLoader(dev_ds, batch_sampler=dev_batch_sampler, collate_fn=partial(prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id)) return train_loader, dev_loader, len(src_vocab), len(tgt_vocab), pad_id
def create_infer_loader(batch_size=128): test_ds = load_dataset('couplet', splits='test') vocab = Vocab.load_vocabulary(**test_ds.vocab_info) pad_id = vocab[vocab.eos_token] trans_func = partial(convert_example, vocab=vocab) test_ds = test_ds.map(trans_func, lazy=False) test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size) test_loader = paddle.io.DataLoader(test_ds, batch_sampler=test_batch_sampler, collate_fn=partial(prepare_input, pad_id=pad_id)) return test_loader, vocab
def create_train_loader(args): batch_size = args.batch_size max_len = args.max_len src_vocab, tgt_vocab = IWSLT15.get_vocab() bos_id = src_vocab[src_vocab.bos_token] eos_id = src_vocab[src_vocab.eos_token] pad_id = eos_id train_ds, dev_ds = IWSLT15.get_datasets( mode=["train", "dev"], transform_func=[trans_func_tuple, trans_func_tuple]) key = (lambda x, data_source: len(data_source[x][0])) cut_fn = lambda data: (data[0][:max_len], data[1][:max_len]) train_ds = train_ds.filter( lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn) dev_ds = dev_ds.filter( lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn) train_batch_sampler = SamplerHelper(train_ds).shuffle().sort( key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size) dev_batch_sampler = SamplerHelper(dev_ds).sort( key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size) train_loader = paddle.io.DataLoader( train_ds, batch_sampler=train_batch_sampler, collate_fn=partial( prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id)) dev_loader = paddle.io.DataLoader( dev_ds, batch_sampler=dev_batch_sampler, collate_fn=partial( prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id)) return train_loader, dev_loader, len(src_vocab), len(tgt_vocab), pad_id
def create_infer_loader(batch_size=128): test_ds = CoupletDataset.get_datasets(["test"]) vocab, _ = CoupletDataset.get_vocab() pad_id = vocab[CoupletDataset.EOS_TOKEN] bos_id = vocab[CoupletDataset.BOS_TOKEN] eos_id = vocab[CoupletDataset.EOS_TOKEN] test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size) test_loader = paddle.io.DataLoader(test_ds, batch_sampler=test_batch_sampler, collate_fn=partial(prepare_input, pad_id=pad_id)) return test_loader, len(vocab), pad_id, bos_id, eos_id
def create_train_loader(batch_size=128): train_ds = CoupletDataset.get_datasets(["train"]) vocab, _ = CoupletDataset.get_vocab() pad_id = vocab[CoupletDataset.EOS_TOKEN] train_batch_sampler = SamplerHelper(train_ds).shuffle().batch( batch_size=batch_size) train_loader = paddle.io.DataLoader(train_ds, batch_sampler=train_batch_sampler, collate_fn=partial(prepare_input, pad_id=pad_id)) return train_loader, len(vocab), pad_id
def create_infer_loader(args): batch_size = args.batch_size max_len = args.max_len trans_func_tuple = IWSLT15.get_default_transform_func() test_ds = IWSLT15.get_datasets( mode=["test"], transform_func=[trans_func_tuple]) src_vocab, tgt_vocab = IWSLT15.get_vocab() bos_id = src_vocab[src_vocab.bos_token] eos_id = src_vocab[src_vocab.eos_token] pad_id = eos_id test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size) test_loader = paddle.io.DataLoader( test_ds, batch_sampler=test_batch_sampler, collate_fn=partial( prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id)) return test_loader, len(src_vocab), len(tgt_vocab), bos_id, eos_id
def test_iter2(self): train_batch_sampler = SamplerHelper(self.train_ds) for i, sample in enumerate(train_batch_sampler): self.check_output_equal(i, sample)
def test_iter1(self): train_ds_len = len(self.train_ds) ds_iter = iter(range(train_ds_len - 1, -1, -1)) train_batch_sampler = SamplerHelper(self.train_ds, ds_iter) for i, sample in enumerate(train_batch_sampler): self.check_output_equal(i, train_ds_len - 1 - sample)
def create_data_loader(args): batch_size = args.batch_size max_len = args.max_len if args.dataset == 'yahoo': train_ds, dev_ds, test_ds = load_dataset('yahoo_answer_100k', splits=('train', 'valid', 'test')) vocab = Vocab.load_vocabulary(**train_ds.vocab_info) else: train_ds, dev_ds, test_ds = load_dataset('ptb', splits=('train', 'valid', 'test')) examples = [ train_ds[i]['sentence'].split() for i in range(len(train_ds)) ] vocab = Vocab.build_vocab(examples) vocab_size = len(vocab) bos_id = vocab_size eos_id = vocab_size + 1 pad_id = vocab_size + 1 def convert_example(example): features = vocab.to_indices(example['sentence'].split()[:max_len]) return features key = (lambda x, data_source: len(data_source[x])) # Truncate and convert example to ids train_ds = train_ds.map(convert_example, lazy=False) dev_ds = dev_ds.map(convert_example, lazy=False) test_ds = test_ds.map(convert_example, lazy=False) train_batch_sampler = SamplerHelper(train_ds).shuffle().sort( key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size) dev_batch_sampler = SamplerHelper(dev_ds).sort( key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size) test_batch_sampler = SamplerHelper(dev_ds).sort( key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size) train_loader = paddle.io.DataLoader(train_ds, batch_sampler=train_batch_sampler, collate_fn=partial(prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id)) dev_loader = paddle.io.DataLoader(dev_ds, batch_sampler=dev_batch_sampler, collate_fn=partial(prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id)) test_loader = paddle.io.DataLoader(dev_ds, batch_sampler=dev_batch_sampler, collate_fn=partial(prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id)) return train_loader, dev_loader, test_loader, vocab, bos_id, pad_id, len( train_ds)
def test_shard_default(self): train_batch_sampler = SamplerHelper(self.train_ds) shard_sampler1 = train_batch_sampler.shard() for i, sample in enumerate(shard_sampler1): self.check_output_equal(i, sample)