예제 #1
0
def create_infer_loader(args):
    dataset = load_dataset('wmt14ende', splits=('test'))
    src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["bpe"])
    trg_vocab = src_vocab

    padding_vocab = (
        lambda x:
        (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor)
    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))

    def convert_samples(sample):
        source = sample[args.src_lang].split()
        target = sample[args.trg_lang].split()

        source = src_vocab.to_indices(source)
        target = trg_vocab.to_indices(target)

        return source, target

    dataset = dataset.map(convert_samples, lazy=False)

    batch_sampler = SamplerHelper(dataset).batch(
        batch_size=args.infer_batch_size, drop_last=False)

    data_loader = DataLoader(dataset=dataset,
                             batch_sampler=batch_sampler,
                             collate_fn=partial(prepare_infer_input,
                                                bos_idx=args.bos_idx,
                                                eos_idx=args.eos_idx,
                                                pad_idx=args.bos_idx),
                             num_workers=0,
                             return_list=True)
    return data_loader, trg_vocab.to_tokens
예제 #2
0
파일: reader.py 프로젝트: wbj0110/models
def create_infer_loader(args):
    root = None if args.root == "None" else args.root
    (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
    padding_vocab = (
        lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor
    )
    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))
    transform_func = WMT14ende.get_default_transform_func(root=root)
    dataset = WMT14ende.get_datasets(
        mode="test", root=root, transform_func=transform_func).filter(
            partial(
                min_max_filer, max_len=args.max_length))

    batch_sampler = SamplerHelper(dataset).batch(
        batch_size=args.infer_batch_size, drop_last=False)

    data_loader = DataLoader(
        dataset=dataset,
        batch_sampler=batch_sampler,
        collate_fn=partial(
            prepare_infer_input,
            bos_idx=args.bos_idx,
            eos_idx=args.eos_idx,
            pad_idx=args.bos_idx),
        num_workers=0,
        return_list=True)
    return data_loader, trg_vocab.to_tokens
예제 #3
0
def create_infer_loader(args):
    if args.test_file is not None:
        dataset = load_dataset('wmt14ende',
                               data_files=[args.test_file],
                               splits=['test'])
    else:
        dataset = load_dataset('wmt14ende', splits=('test'))

    if args.vocab_file is not None:
        src_vocab = Vocab.load_vocabulary(filepath=args.vocab_file,
                                          unk_token=args.unk_token,
                                          bos_token=args.bos_token,
                                          eos_token=args.eos_token)
    elif not args.benchmark:
        src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["bpe"])
    else:
        src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["benchmark"])
    trg_vocab = src_vocab

    padding_vocab = (
        lambda x:
        (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor)
    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))

    def convert_samples(sample):
        source = sample[args.src_lang].split()
        target = sample[args.trg_lang].split()

        source = src_vocab.to_indices(source)
        target = trg_vocab.to_indices(target)

        return source, target

    dataset = dataset.map(convert_samples, lazy=False)

    batch_sampler = SamplerHelper(dataset).batch(
        batch_size=args.infer_batch_size, drop_last=False)

    data_loader = DataLoader(dataset=dataset,
                             batch_sampler=batch_sampler,
                             collate_fn=partial(prepare_infer_input,
                                                bos_idx=args.bos_idx,
                                                eos_idx=args.eos_idx,
                                                pad_idx=args.bos_idx,
                                                pad_seq=args.pad_seq,
                                                dtype=args.input_dtype),
                             num_workers=args.num_workers,
                             return_list=True)
    return data_loader, trg_vocab.to_tokens
예제 #4
0
def create_infer_loader(args, use_all_vocab=False):
    data_files = None
    if args.root != "None" and os.path.exists(args.root):
        data_files = {
            'test': (os.path.join(args.root, "newstest2014.tok.bpe.33708.en"),
                     os.path.join(args.root, "newstest2014.tok.bpe.33708.de"))
        }

    dataset = load_dataset('wmt14ende', data_files=data_files, splits=('test'))
    if use_all_vocab:
        src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["bpe"])
    else:
        src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["benchmark"])
    trg_vocab = src_vocab

    padding_vocab = (
        lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor
    )
    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))

    def convert_samples(sample):
        source = sample[args.src_lang].split()
        target = sample[args.trg_lang].split()

        source = src_vocab.to_indices(source)
        target = trg_vocab.to_indices(target)

        return source, target

    dataset = dataset.map(convert_samples, lazy=False)

    batch_sampler = SamplerHelper(dataset).batch(
        batch_size=args.infer_batch_size, drop_last=False)

    data_loader = DataLoader(
        dataset=dataset,
        batch_sampler=batch_sampler,
        collate_fn=partial(
            prepare_infer_input,
            bos_idx=args.bos_idx,
            eos_idx=args.eos_idx,
            pad_idx=args.bos_idx,
            pad_seq=args.pad_seq),
        num_workers=0,
        return_list=True)
    return data_loader, trg_vocab.to_tokens
예제 #5
0
def create_infer_loader(args):
    dataset = load_dataset(read,
                           src_path=args.predict_file,
                           tgt_path=None,
                           is_predict=True,
                           lazy=False)

    src_vocab = Vocab.load_vocabulary(args.src_vocab_fpath,
                                      bos_token=args.special_token[0],
                                      eos_token=args.special_token[1],
                                      unk_token=args.special_token[2])
    trg_vocab = Vocab.load_vocabulary(args.trg_vocab_fpath,
                                      bos_token=args.special_token[0],
                                      eos_token=args.special_token[1],
                                      unk_token=args.special_token[2])

    padding_vocab = (
        lambda x:
        (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor)
    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))

    def convert_samples(sample):
        source = sample['src'].split()
        target = sample['tgt'].split()

        source = src_vocab.to_indices(source)
        target = trg_vocab.to_indices(target)

        return source, target

    dataset = dataset.map(convert_samples, lazy=False)

    batch_sampler = SamplerHelper(dataset).batch(
        batch_size=args.infer_batch_size, drop_last=False)

    data_loader = DataLoader(dataset=dataset,
                             batch_sampler=batch_sampler,
                             collate_fn=partial(prepare_infer_input,
                                                bos_idx=args.bos_idx,
                                                eos_idx=args.eos_idx,
                                                pad_idx=args.bos_idx),
                             num_workers=2,
                             return_list=True)
    return data_loader, trg_vocab.to_tokens
예제 #6
0
def create_infer_loader(args, places=None):
    data_files = {
        'test': args.predict_file,
    }
    dataset = load_dataset(read,
                           src_tgt_file=data_files['test'],
                           only_src=True,
                           lazy=False)

    src_vocab = Vocab.load_vocabulary(args.src_vocab_fpath,
                                      bos_token=args.special_token[0],
                                      eos_token=args.special_token[1],
                                      unk_token=args.special_token[2])

    trg_vocab = Vocab.load_vocabulary(args.trg_vocab_fpath,
                                      bos_token=args.special_token[0],
                                      eos_token=args.special_token[1],
                                      unk_token=args.special_token[2])

    args.src_vocab_size = len(src_vocab)
    args.trg_vocab_size = len(trg_vocab)

    def convert_samples(sample):
        source = [item.strip() for item in sample['src'].split()]
        source = src_vocab.to_indices(source) + [args.eos_idx]
        target = [args.bos_idx]
        return source, target

    dataset = dataset.map(convert_samples, lazy=False)

    batch_sampler = SamplerHelper(dataset).batch(batch_size=args.batch_size,
                                                 drop_last=False)

    data_loader = DataLoader(dataset=dataset,
                             places=places,
                             batch_sampler=batch_sampler,
                             collate_fn=partial(prepare_infer_input,
                                                pad_idx=args.bos_idx),
                             num_workers=0,
                             return_list=True)

    return data_loader, trg_vocab.to_tokens
예제 #7
0
파일: reader.py 프로젝트: wbj0110/models
def create_infer_loader(args, use_all_vocab=False):
    root = None if args.root == "None" else args.root
    if not use_all_vocab:
        WMT14ende.VOCAB_INFO = (os.path.join("WMT14.en-de",
                                             "wmt14_ende_data_bpe",
                                             "vocab_all.bpe.33712"),
                                os.path.join("WMT14.en-de",
                                             "wmt14_ende_data_bpe",
                                             "vocab_all.bpe.33712"),
                                "de485e3c2e17e23acf4b4b70b54682dd",
                                "de485e3c2e17e23acf4b4b70b54682dd")
    (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
    padding_vocab = (
        lambda x:
        (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor)
    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))
    transform_func = WMT14ende.get_default_transform_func(root=root)
    dataset = WMT14ende.get_datasets(mode="test",
                                     root=root,
                                     transform_func=transform_func).filter(
                                         partial(min_max_filer,
                                                 max_len=args.max_length))

    batch_sampler = SamplerHelper(dataset).batch(
        batch_size=args.infer_batch_size, drop_last=False)

    data_loader = DataLoader(dataset=dataset,
                             batch_sampler=batch_sampler,
                             collate_fn=partial(prepare_infer_input,
                                                bos_idx=args.bos_idx,
                                                eos_idx=args.eos_idx,
                                                pad_idx=args.bos_idx,
                                                pad_seq=args.pad_seq),
                             num_workers=0,
                             return_list=True)
    return data_loader, trg_vocab.to_tokens
예제 #8
0
def create_data_loader(args, places=None, use_all_vocab=False):
    data_files = None
    if args.root != "None" and os.path.exists(args.root):
        data_files = {
            'train': (os.path.join(args.root, "train.tok.clean.bpe.33708.en"),
                      os.path.join(args.root, "train.tok.clean.bpe.33708.de")),
            'dev': (os.path.join(args.root, "newstest2013.tok.bpe.33708.en"),
                    os.path.join(args.root, "newstest2013.tok.bpe.33708.de"))
        }

    datasets = load_dataset('wmt14ende',
                            data_files=data_files,
                            splits=('train', 'dev'))
    if use_all_vocab:
        src_vocab = Vocab.load_vocabulary(**datasets[0].vocab_info["bpe"])
    else:
        src_vocab = Vocab.load_vocabulary(
            **datasets[0].vocab_info["benchmark"])
    trg_vocab = src_vocab

    padding_vocab = (
        lambda x:
        (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor)
    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))

    def convert_samples(sample):
        source = sample[args.src_lang].split()
        target = sample[args.trg_lang].split()

        source = src_vocab.to_indices(source)
        target = trg_vocab.to_indices(target)

        return source, target

    def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
                      data_source):
        return max(tokens_sofar,
                   len(data_source[current_idx][0]) + 1,
                   len(data_source[current_idx][1]) + 1)

    def _key(size_so_far, minibatch_len):
        return size_so_far * minibatch_len

    data_loaders = [(None)] * 2
    for i, dataset in enumerate(datasets):
        dataset = dataset.map(convert_samples, lazy=False).filter(
            partial(min_max_filer, max_len=args.max_length))

        sampler = SamplerHelper(dataset)

        if args.sort_type == SortType.GLOBAL:
            src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
            trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
            # Sort twice
            sampler = sampler.sort(key=trg_key).sort(key=src_key)
        else:
            if args.shuffle:
                sampler = sampler.shuffle(seed=args.shuffle_seed)
            max_key = (lambda x, data_source: max(len(data_source[x][0]),
                                                  len(data_source[x][1])) + 1)
            if args.sort_type == SortType.POOL:
                sampler = sampler.sort(key=max_key, buffer_size=args.pool_size)

        batch_sampler = sampler.batch(batch_size=args.batch_size,
                                      drop_last=False,
                                      batch_size_fn=_max_token_fn,
                                      key=_key)

        if args.shuffle_batch:
            batch_sampler = batch_sampler.shuffle(seed=args.shuffle_seed)

        if i == 0:
            batch_sampler = batch_sampler.shard()

        data_loader = DataLoader(dataset=dataset,
                                 places=places,
                                 batch_sampler=batch_sampler,
                                 collate_fn=partial(prepare_train_input,
                                                    bos_idx=args.bos_idx,
                                                    eos_idx=args.eos_idx,
                                                    pad_idx=args.bos_idx,
                                                    pad_seq=args.pad_seq),
                                 num_workers=0)
        data_loaders[i] = (data_loader)
    return data_loaders
예제 #9
0
def create_data_loader(args):
    root = None if args.root == "None" else args.root
    (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
    args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
    transform_func = WMT14ende.get_default_transform_func(root=root)
    datasets = [
        WMT14ende.get_datasets(mode=m, transform_func=transform_func)
        for m in ["train", "dev"]
    ]

    if args.shuffle or args.shuffle_batch:
        if args.shuffle_seed == "None" or args.shuffle_seed is None:
            shuffle_seed = 0
        else:
            shuffle_seed = args.shuffle_seed

    def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
                      data_source):
        return max(tokens_sofar,
                   len(data_source[current_idx][0]) + 1,
                   len(data_source[current_idx][1]) + 1)

    def _key(size_so_far, minibatch_len):
        return size_so_far * minibatch_len

    data_loaders = [(None)] * 2
    for i, dataset in enumerate(datasets):
        m = dataset.mode
        dataset = dataset.filter(
            partial(min_max_filer, max_len=args.max_length))
        sampler = SamplerHelper(dataset)

        src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
        if args.sort_type == SortType.GLOBAL:
            buffer_size = -1
            trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
            # Sort twice
            sampler = sampler.sort(key=trg_key, buffer_size=buffer_size).sort(
                key=src_key, buffer_size=buffer_size)
        else:
            if args.shuffle:
                sampler = sampler.shuffle(seed=shuffle_seed)
            if args.sort_type == SortType.POOL:
                buffer_size = args.pool_size
                sampler = sampler.sort(key=src_key, buffer_size=buffer_size)

        batch_sampler = sampler.batch(batch_size=args.batch_size,
                                      drop_last=False,
                                      batch_size_fn=_max_token_fn,
                                      key=_key)

        if m == "train":
            batch_sampler = batch_sampler.shard()

        if args.shuffle_batch:
            batch_sampler.shuffle(seed=shuffle_seed)

        data_loader = DataLoader(dataset=dataset,
                                 batch_sampler=batch_sampler,
                                 collate_fn=partial(prepare_train_input,
                                                    bos_idx=args.bos_idx,
                                                    eos_idx=args.eos_idx,
                                                    pad_idx=args.bos_idx),
                                 num_workers=0,
                                 return_list=True)
        data_loaders[i] = (data_loader)
    return data_loaders
예제 #10
0
def create_data_loader(args):
    train_dataset = load_dataset(read,
                                 src_path=args.training_file.split(',')[0],
                                 tgt_path=args.training_file.split(',')[1],
                                 lazy=False)
    dev_dataset = load_dataset(read,
                               src_path=args.training_file.split(',')[0],
                               tgt_path=args.training_file.split(',')[1],
                               lazy=False)
    print('load src vocab')
    print(args.src_vocab_fpath)
    src_vocab = Vocab.load_vocabulary(args.src_vocab_fpath,
                                      bos_token=args.special_token[0],
                                      eos_token=args.special_token[1],
                                      unk_token=args.special_token[2])
    print('load trg vocab')
    print(args.trg_vocab_fpath)
    trg_vocab = Vocab.load_vocabulary(args.trg_vocab_fpath,
                                      bos_token=args.special_token[0],
                                      eos_token=args.special_token[1],
                                      unk_token=args.special_token[2])
    print('padding')
    padding_vocab = (
        lambda x:
        (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor)
    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))
    print('convert example')

    def convert_samples(sample):
        source = sample['src'].split()
        target = sample['tgt'].split()

        source = src_vocab.to_indices(source)
        target = trg_vocab.to_indices(target)

        return source, target

    data_loaders = [(None)] * 2
    print('dataset loop')
    for i, dataset in enumerate([train_dataset, dev_dataset]):
        dataset = dataset.map(convert_samples, lazy=False).filter(
            partial(min_max_filer, max_len=args.max_length))

        sampler = SamplerHelper(dataset)

        if args.sort_type == SortType.GLOBAL:
            src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
            trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
            # Sort twice
            sampler = sampler.sort(key=trg_key).sort(key=src_key)
        else:
            if args.shuffle:
                sampler = sampler.shuffle(seed=args.shuffle_seed)
            max_key = (lambda x, data_source: max(len(data_source[x][0]),
                                                  len(data_source[x][1])) + 1)
            if args.sort_type == SortType.POOL:
                sampler = sampler.sort(key=max_key, buffer_size=args.pool_size)

        batch_size_fn = lambda new, count, sofar, data_source: max(
            sofar,
            len(data_source[new][0]) + 1,
            len(data_source[new][1]) + 1)
        batch_sampler = sampler.batch(
            batch_size=args.batch_size,
            drop_last=False,
            batch_size_fn=batch_size_fn,
            key=lambda size_so_far, minibatch_len: size_so_far * minibatch_len)

        if args.shuffle_batch:
            batch_sampler = batch_sampler.shuffle(seed=args.shuffle_seed)

        if i == 0:
            batch_sampler = batch_sampler.shard()

        data_loader = DataLoader(dataset=dataset,
                                 batch_sampler=batch_sampler,
                                 collate_fn=partial(prepare_train_input,
                                                    bos_idx=args.bos_idx,
                                                    eos_idx=args.eos_idx,
                                                    pad_idx=args.bos_idx),
                                 num_workers=2,
                                 return_list=True)
        data_loaders[i] = (data_loader)
    return data_loaders
예제 #11
0
batch_size_fn = lambda idx, minibatch_len, size_so_far, data_source: max(
    size_so_far, len(data_source[idx][0]))

batch_key = lambda size_so_far, minibatch_len: size_so_far * minibatch_len

if __name__ == '__main__':
    batch_size = 4096  #32
    pad_id = 2

    transform_func = IWSLT15.get_default_transform_func()
    train_dataset = IWSLT15(transform_func=transform_func)

    key = (lambda x, data_source: len(data_source[x][0]))

    train_batch_sampler = SamplerHelper(train_dataset).shuffle().sort(
        key=key,
        buffer_size=batch_size * 20).batch(batch_size=batch_size,
                                           drop_last=True,
                                           batch_size_fn=batch_size_fn,
                                           key=batch_key).shard()

    train_loader = paddle.io.DataLoader(train_dataset,
                                        batch_sampler=train_batch_sampler,
                                        collate_fn=partial(prepare_train_input,
                                                           pad_id=pad_id))

    for i, data in enumerate(train_loader):
        print(data[1])
        print(paddle.max(data[1]) * len(data[1]))
        print(len(data[1]))
예제 #12
0
def create_data_loader(args, places=None):
    data_files = {'train': args.training_file, 'dev': args.validation_file}

    datasets = [
        load_dataset(read, src_tgt_file=filename, lazy=False)
        for split, filename in data_files.items()
    ]

    src_vocab = Vocab.load_vocabulary(args.src_vocab_fpath,
                                      bos_token=args.special_token[0],
                                      eos_token=args.special_token[1],
                                      unk_token=args.special_token[2])
    trg_vocab = Vocab.load_vocabulary(args.trg_vocab_fpath,
                                      bos_token=args.special_token[0],
                                      eos_token=args.special_token[1],
                                      unk_token=args.special_token[2])

    args.src_vocab_size = len(src_vocab)
    args.trg_vocab_size = len(trg_vocab)

    def convert_samples(sample):
        source = [item.strip() for item in sample['src'].split()]
        target = [item.strip() for item in sample['trg'].split()]

        source = src_vocab.to_indices(source) + [args.eos_idx]
        target = [args.bos_idx] + \
                 trg_vocab.to_indices(target) + [args.eos_idx]

        return source, target

    data_loaders = [(None)] * 2
    for i, dataset in enumerate(datasets):
        dataset = dataset.map(convert_samples, lazy=False).filter(
            partial(min_max_filer, max_len=args.max_length))

        sampler = SamplerHelper(dataset)

        if args.sort_type == SortType.GLOBAL:
            src_key = (lambda x, data_source: len(data_source[x][0]))
            trg_key = (lambda x, data_source: len(data_source[x][1]))
            # Sort twice
            sampler = sampler.sort(key=trg_key).sort(key=src_key)
        else:
            if args.shuffle:
                sampler = sampler.shuffle(seed=args.random_seed)
            max_key = (lambda x, data_source: max(len(data_source[x][0]),
                                                  len(data_source[x][1])))
            if args.sort_type == SortType.POOL:
                sampler = sampler.sort(key=max_key, buffer_size=args.pool_size)

        batch_size_fn = lambda new, count, sofar, data_source: max(
            sofar, len(data_source[new][0]), len(data_source[new][1]))
        batch_sampler = sampler.batch(
            batch_size=args.batch_size,
            drop_last=False,
            batch_size_fn=batch_size_fn,
            key=lambda size_so_far, minibatch_len: size_so_far * minibatch_len)

        if args.shuffle_batch:
            batch_sampler = batch_sampler.shuffle(seed=args.random_seed)

        if i == 0:
            batch_sampler = batch_sampler.shard()

        data_loader = DataLoader(dataset=dataset,
                                 places=places,
                                 batch_sampler=batch_sampler,
                                 collate_fn=partial(prepare_train_input,
                                                    pad_idx=args.bos_idx),
                                 num_workers=0)

        data_loaders[i] = (data_loader)

    return data_loaders