def create_infer_loader(args): dataset = load_dataset('wmt14ende', splits=('test')) src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["bpe"]) trg_vocab = src_vocab padding_vocab = ( lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor) args.src_vocab_size = padding_vocab(len(src_vocab)) args.trg_vocab_size = padding_vocab(len(trg_vocab)) def convert_samples(sample): source = sample[args.src_lang].split() target = sample[args.trg_lang].split() source = src_vocab.to_indices(source) target = trg_vocab.to_indices(target) return source, target dataset = dataset.map(convert_samples, lazy=False) batch_sampler = SamplerHelper(dataset).batch( batch_size=args.infer_batch_size, drop_last=False) data_loader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=partial(prepare_infer_input, bos_idx=args.bos_idx, eos_idx=args.eos_idx, pad_idx=args.bos_idx), num_workers=0, return_list=True) return data_loader, trg_vocab.to_tokens
def create_infer_loader(args): root = None if args.root == "None" else args.root (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root) padding_vocab = ( lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor ) args.src_vocab_size = padding_vocab(len(src_vocab)) args.trg_vocab_size = padding_vocab(len(trg_vocab)) transform_func = WMT14ende.get_default_transform_func(root=root) dataset = WMT14ende.get_datasets( mode="test", root=root, transform_func=transform_func).filter( partial( min_max_filer, max_len=args.max_length)) batch_sampler = SamplerHelper(dataset).batch( batch_size=args.infer_batch_size, drop_last=False) data_loader = DataLoader( dataset=dataset, batch_sampler=batch_sampler, collate_fn=partial( prepare_infer_input, bos_idx=args.bos_idx, eos_idx=args.eos_idx, pad_idx=args.bos_idx), num_workers=0, return_list=True) return data_loader, trg_vocab.to_tokens
def create_infer_loader(args): if args.test_file is not None: dataset = load_dataset('wmt14ende', data_files=[args.test_file], splits=['test']) else: dataset = load_dataset('wmt14ende', splits=('test')) if args.vocab_file is not None: src_vocab = Vocab.load_vocabulary(filepath=args.vocab_file, unk_token=args.unk_token, bos_token=args.bos_token, eos_token=args.eos_token) elif not args.benchmark: src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["bpe"]) else: src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["benchmark"]) trg_vocab = src_vocab padding_vocab = ( lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor) args.src_vocab_size = padding_vocab(len(src_vocab)) args.trg_vocab_size = padding_vocab(len(trg_vocab)) def convert_samples(sample): source = sample[args.src_lang].split() target = sample[args.trg_lang].split() source = src_vocab.to_indices(source) target = trg_vocab.to_indices(target) return source, target dataset = dataset.map(convert_samples, lazy=False) batch_sampler = SamplerHelper(dataset).batch( batch_size=args.infer_batch_size, drop_last=False) data_loader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=partial(prepare_infer_input, bos_idx=args.bos_idx, eos_idx=args.eos_idx, pad_idx=args.bos_idx, pad_seq=args.pad_seq, dtype=args.input_dtype), num_workers=args.num_workers, return_list=True) return data_loader, trg_vocab.to_tokens
def create_infer_loader(args, use_all_vocab=False): data_files = None if args.root != "None" and os.path.exists(args.root): data_files = { 'test': (os.path.join(args.root, "newstest2014.tok.bpe.33708.en"), os.path.join(args.root, "newstest2014.tok.bpe.33708.de")) } dataset = load_dataset('wmt14ende', data_files=data_files, splits=('test')) if use_all_vocab: src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["bpe"]) else: src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["benchmark"]) trg_vocab = src_vocab padding_vocab = ( lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor ) args.src_vocab_size = padding_vocab(len(src_vocab)) args.trg_vocab_size = padding_vocab(len(trg_vocab)) def convert_samples(sample): source = sample[args.src_lang].split() target = sample[args.trg_lang].split() source = src_vocab.to_indices(source) target = trg_vocab.to_indices(target) return source, target dataset = dataset.map(convert_samples, lazy=False) batch_sampler = SamplerHelper(dataset).batch( batch_size=args.infer_batch_size, drop_last=False) data_loader = DataLoader( dataset=dataset, batch_sampler=batch_sampler, collate_fn=partial( prepare_infer_input, bos_idx=args.bos_idx, eos_idx=args.eos_idx, pad_idx=args.bos_idx, pad_seq=args.pad_seq), num_workers=0, return_list=True) return data_loader, trg_vocab.to_tokens
def create_infer_loader(args): dataset = load_dataset(read, src_path=args.predict_file, tgt_path=None, is_predict=True, lazy=False) src_vocab = Vocab.load_vocabulary(args.src_vocab_fpath, bos_token=args.special_token[0], eos_token=args.special_token[1], unk_token=args.special_token[2]) trg_vocab = Vocab.load_vocabulary(args.trg_vocab_fpath, bos_token=args.special_token[0], eos_token=args.special_token[1], unk_token=args.special_token[2]) padding_vocab = ( lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor) args.src_vocab_size = padding_vocab(len(src_vocab)) args.trg_vocab_size = padding_vocab(len(trg_vocab)) def convert_samples(sample): source = sample['src'].split() target = sample['tgt'].split() source = src_vocab.to_indices(source) target = trg_vocab.to_indices(target) return source, target dataset = dataset.map(convert_samples, lazy=False) batch_sampler = SamplerHelper(dataset).batch( batch_size=args.infer_batch_size, drop_last=False) data_loader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=partial(prepare_infer_input, bos_idx=args.bos_idx, eos_idx=args.eos_idx, pad_idx=args.bos_idx), num_workers=2, return_list=True) return data_loader, trg_vocab.to_tokens
def create_infer_loader(args, places=None): data_files = { 'test': args.predict_file, } dataset = load_dataset(read, src_tgt_file=data_files['test'], only_src=True, lazy=False) src_vocab = Vocab.load_vocabulary(args.src_vocab_fpath, bos_token=args.special_token[0], eos_token=args.special_token[1], unk_token=args.special_token[2]) trg_vocab = Vocab.load_vocabulary(args.trg_vocab_fpath, bos_token=args.special_token[0], eos_token=args.special_token[1], unk_token=args.special_token[2]) args.src_vocab_size = len(src_vocab) args.trg_vocab_size = len(trg_vocab) def convert_samples(sample): source = [item.strip() for item in sample['src'].split()] source = src_vocab.to_indices(source) + [args.eos_idx] target = [args.bos_idx] return source, target dataset = dataset.map(convert_samples, lazy=False) batch_sampler = SamplerHelper(dataset).batch(batch_size=args.batch_size, drop_last=False) data_loader = DataLoader(dataset=dataset, places=places, batch_sampler=batch_sampler, collate_fn=partial(prepare_infer_input, pad_idx=args.bos_idx), num_workers=0, return_list=True) return data_loader, trg_vocab.to_tokens
def create_infer_loader(args, use_all_vocab=False): root = None if args.root == "None" else args.root if not use_all_vocab: WMT14ende.VOCAB_INFO = (os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", "vocab_all.bpe.33712"), os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", "vocab_all.bpe.33712"), "de485e3c2e17e23acf4b4b70b54682dd", "de485e3c2e17e23acf4b4b70b54682dd") (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root) padding_vocab = ( lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor) args.src_vocab_size = padding_vocab(len(src_vocab)) args.trg_vocab_size = padding_vocab(len(trg_vocab)) transform_func = WMT14ende.get_default_transform_func(root=root) dataset = WMT14ende.get_datasets(mode="test", root=root, transform_func=transform_func).filter( partial(min_max_filer, max_len=args.max_length)) batch_sampler = SamplerHelper(dataset).batch( batch_size=args.infer_batch_size, drop_last=False) data_loader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=partial(prepare_infer_input, bos_idx=args.bos_idx, eos_idx=args.eos_idx, pad_idx=args.bos_idx, pad_seq=args.pad_seq), num_workers=0, return_list=True) return data_loader, trg_vocab.to_tokens
def create_data_loader(args, places=None, use_all_vocab=False): data_files = None if args.root != "None" and os.path.exists(args.root): data_files = { 'train': (os.path.join(args.root, "train.tok.clean.bpe.33708.en"), os.path.join(args.root, "train.tok.clean.bpe.33708.de")), 'dev': (os.path.join(args.root, "newstest2013.tok.bpe.33708.en"), os.path.join(args.root, "newstest2013.tok.bpe.33708.de")) } datasets = load_dataset('wmt14ende', data_files=data_files, splits=('train', 'dev')) if use_all_vocab: src_vocab = Vocab.load_vocabulary(**datasets[0].vocab_info["bpe"]) else: src_vocab = Vocab.load_vocabulary( **datasets[0].vocab_info["benchmark"]) trg_vocab = src_vocab padding_vocab = ( lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor) args.src_vocab_size = padding_vocab(len(src_vocab)) args.trg_vocab_size = padding_vocab(len(trg_vocab)) def convert_samples(sample): source = sample[args.src_lang].split() target = sample[args.trg_lang].split() source = src_vocab.to_indices(source) target = trg_vocab.to_indices(target) return source, target def _max_token_fn(current_idx, current_batch_size, tokens_sofar, data_source): return max(tokens_sofar, len(data_source[current_idx][0]) + 1, len(data_source[current_idx][1]) + 1) def _key(size_so_far, minibatch_len): return size_so_far * minibatch_len data_loaders = [(None)] * 2 for i, dataset in enumerate(datasets): dataset = dataset.map(convert_samples, lazy=False).filter( partial(min_max_filer, max_len=args.max_length)) sampler = SamplerHelper(dataset) if args.sort_type == SortType.GLOBAL: src_key = (lambda x, data_source: len(data_source[x][0]) + 1) trg_key = (lambda x, data_source: len(data_source[x][1]) + 1) # Sort twice sampler = sampler.sort(key=trg_key).sort(key=src_key) else: if args.shuffle: sampler = sampler.shuffle(seed=args.shuffle_seed) max_key = (lambda x, data_source: max(len(data_source[x][0]), len(data_source[x][1])) + 1) if args.sort_type == SortType.POOL: sampler = sampler.sort(key=max_key, buffer_size=args.pool_size) batch_sampler = sampler.batch(batch_size=args.batch_size, drop_last=False, batch_size_fn=_max_token_fn, key=_key) if args.shuffle_batch: batch_sampler = batch_sampler.shuffle(seed=args.shuffle_seed) if i == 0: batch_sampler = batch_sampler.shard() data_loader = DataLoader(dataset=dataset, places=places, batch_sampler=batch_sampler, collate_fn=partial(prepare_train_input, bos_idx=args.bos_idx, eos_idx=args.eos_idx, pad_idx=args.bos_idx, pad_seq=args.pad_seq), num_workers=0) data_loaders[i] = (data_loader) return data_loaders
def create_data_loader(args): root = None if args.root == "None" else args.root (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root) args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab) transform_func = WMT14ende.get_default_transform_func(root=root) datasets = [ WMT14ende.get_datasets(mode=m, transform_func=transform_func) for m in ["train", "dev"] ] if args.shuffle or args.shuffle_batch: if args.shuffle_seed == "None" or args.shuffle_seed is None: shuffle_seed = 0 else: shuffle_seed = args.shuffle_seed def _max_token_fn(current_idx, current_batch_size, tokens_sofar, data_source): return max(tokens_sofar, len(data_source[current_idx][0]) + 1, len(data_source[current_idx][1]) + 1) def _key(size_so_far, minibatch_len): return size_so_far * minibatch_len data_loaders = [(None)] * 2 for i, dataset in enumerate(datasets): m = dataset.mode dataset = dataset.filter( partial(min_max_filer, max_len=args.max_length)) sampler = SamplerHelper(dataset) src_key = (lambda x, data_source: len(data_source[x][0]) + 1) if args.sort_type == SortType.GLOBAL: buffer_size = -1 trg_key = (lambda x, data_source: len(data_source[x][1]) + 1) # Sort twice sampler = sampler.sort(key=trg_key, buffer_size=buffer_size).sort( key=src_key, buffer_size=buffer_size) else: if args.shuffle: sampler = sampler.shuffle(seed=shuffle_seed) if args.sort_type == SortType.POOL: buffer_size = args.pool_size sampler = sampler.sort(key=src_key, buffer_size=buffer_size) batch_sampler = sampler.batch(batch_size=args.batch_size, drop_last=False, batch_size_fn=_max_token_fn, key=_key) if m == "train": batch_sampler = batch_sampler.shard() if args.shuffle_batch: batch_sampler.shuffle(seed=shuffle_seed) data_loader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=partial(prepare_train_input, bos_idx=args.bos_idx, eos_idx=args.eos_idx, pad_idx=args.bos_idx), num_workers=0, return_list=True) data_loaders[i] = (data_loader) return data_loaders
def create_data_loader(args): train_dataset = load_dataset(read, src_path=args.training_file.split(',')[0], tgt_path=args.training_file.split(',')[1], lazy=False) dev_dataset = load_dataset(read, src_path=args.training_file.split(',')[0], tgt_path=args.training_file.split(',')[1], lazy=False) print('load src vocab') print(args.src_vocab_fpath) src_vocab = Vocab.load_vocabulary(args.src_vocab_fpath, bos_token=args.special_token[0], eos_token=args.special_token[1], unk_token=args.special_token[2]) print('load trg vocab') print(args.trg_vocab_fpath) trg_vocab = Vocab.load_vocabulary(args.trg_vocab_fpath, bos_token=args.special_token[0], eos_token=args.special_token[1], unk_token=args.special_token[2]) print('padding') padding_vocab = ( lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor) args.src_vocab_size = padding_vocab(len(src_vocab)) args.trg_vocab_size = padding_vocab(len(trg_vocab)) print('convert example') def convert_samples(sample): source = sample['src'].split() target = sample['tgt'].split() source = src_vocab.to_indices(source) target = trg_vocab.to_indices(target) return source, target data_loaders = [(None)] * 2 print('dataset loop') for i, dataset in enumerate([train_dataset, dev_dataset]): dataset = dataset.map(convert_samples, lazy=False).filter( partial(min_max_filer, max_len=args.max_length)) sampler = SamplerHelper(dataset) if args.sort_type == SortType.GLOBAL: src_key = (lambda x, data_source: len(data_source[x][0]) + 1) trg_key = (lambda x, data_source: len(data_source[x][1]) + 1) # Sort twice sampler = sampler.sort(key=trg_key).sort(key=src_key) else: if args.shuffle: sampler = sampler.shuffle(seed=args.shuffle_seed) max_key = (lambda x, data_source: max(len(data_source[x][0]), len(data_source[x][1])) + 1) if args.sort_type == SortType.POOL: sampler = sampler.sort(key=max_key, buffer_size=args.pool_size) batch_size_fn = lambda new, count, sofar, data_source: max( sofar, len(data_source[new][0]) + 1, len(data_source[new][1]) + 1) batch_sampler = sampler.batch( batch_size=args.batch_size, drop_last=False, batch_size_fn=batch_size_fn, key=lambda size_so_far, minibatch_len: size_so_far * minibatch_len) if args.shuffle_batch: batch_sampler = batch_sampler.shuffle(seed=args.shuffle_seed) if i == 0: batch_sampler = batch_sampler.shard() data_loader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, collate_fn=partial(prepare_train_input, bos_idx=args.bos_idx, eos_idx=args.eos_idx, pad_idx=args.bos_idx), num_workers=2, return_list=True) data_loaders[i] = (data_loader) return data_loaders
batch_size_fn = lambda idx, minibatch_len, size_so_far, data_source: max( size_so_far, len(data_source[idx][0])) batch_key = lambda size_so_far, minibatch_len: size_so_far * minibatch_len if __name__ == '__main__': batch_size = 4096 #32 pad_id = 2 transform_func = IWSLT15.get_default_transform_func() train_dataset = IWSLT15(transform_func=transform_func) key = (lambda x, data_source: len(data_source[x][0])) train_batch_sampler = SamplerHelper(train_dataset).shuffle().sort( key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size, drop_last=True, batch_size_fn=batch_size_fn, key=batch_key).shard() train_loader = paddle.io.DataLoader(train_dataset, batch_sampler=train_batch_sampler, collate_fn=partial(prepare_train_input, pad_id=pad_id)) for i, data in enumerate(train_loader): print(data[1]) print(paddle.max(data[1]) * len(data[1])) print(len(data[1]))
def create_data_loader(args, places=None): data_files = {'train': args.training_file, 'dev': args.validation_file} datasets = [ load_dataset(read, src_tgt_file=filename, lazy=False) for split, filename in data_files.items() ] src_vocab = Vocab.load_vocabulary(args.src_vocab_fpath, bos_token=args.special_token[0], eos_token=args.special_token[1], unk_token=args.special_token[2]) trg_vocab = Vocab.load_vocabulary(args.trg_vocab_fpath, bos_token=args.special_token[0], eos_token=args.special_token[1], unk_token=args.special_token[2]) args.src_vocab_size = len(src_vocab) args.trg_vocab_size = len(trg_vocab) def convert_samples(sample): source = [item.strip() for item in sample['src'].split()] target = [item.strip() for item in sample['trg'].split()] source = src_vocab.to_indices(source) + [args.eos_idx] target = [args.bos_idx] + \ trg_vocab.to_indices(target) + [args.eos_idx] return source, target data_loaders = [(None)] * 2 for i, dataset in enumerate(datasets): dataset = dataset.map(convert_samples, lazy=False).filter( partial(min_max_filer, max_len=args.max_length)) sampler = SamplerHelper(dataset) if args.sort_type == SortType.GLOBAL: src_key = (lambda x, data_source: len(data_source[x][0])) trg_key = (lambda x, data_source: len(data_source[x][1])) # Sort twice sampler = sampler.sort(key=trg_key).sort(key=src_key) else: if args.shuffle: sampler = sampler.shuffle(seed=args.random_seed) max_key = (lambda x, data_source: max(len(data_source[x][0]), len(data_source[x][1]))) if args.sort_type == SortType.POOL: sampler = sampler.sort(key=max_key, buffer_size=args.pool_size) batch_size_fn = lambda new, count, sofar, data_source: max( sofar, len(data_source[new][0]), len(data_source[new][1])) batch_sampler = sampler.batch( batch_size=args.batch_size, drop_last=False, batch_size_fn=batch_size_fn, key=lambda size_so_far, minibatch_len: size_so_far * minibatch_len) if args.shuffle_batch: batch_sampler = batch_sampler.shuffle(seed=args.random_seed) if i == 0: batch_sampler = batch_sampler.shard() data_loader = DataLoader(dataset=dataset, places=places, batch_sampler=batch_sampler, collate_fn=partial(prepare_train_input, pad_idx=args.bos_idx), num_workers=0) data_loaders[i] = (data_loader) return data_loaders