Beispiel #1
0
def create_data_loader(args, places=None, use_all_vocab=False):
    data_files = None
    if args.root != "None" and os.path.exists(args.root):
        data_files = {
            'train': (os.path.join(args.root, "train.tok.clean.bpe.33708.en"),
                      os.path.join(args.root, "train.tok.clean.bpe.33708.de")),
            'dev': (os.path.join(args.root, "newstest2013.tok.bpe.33708.en"),
                    os.path.join(args.root, "newstest2013.tok.bpe.33708.de"))
        }

    datasets = load_dataset('wmt14ende',
                            data_files=data_files,
                            splits=('train', 'dev'))
    if use_all_vocab:
        src_vocab = Vocab.load_vocabulary(**datasets[0].vocab_info["bpe"])
    else:
        src_vocab = Vocab.load_vocabulary(
            **datasets[0].vocab_info["benchmark"])
    trg_vocab = src_vocab

    padding_vocab = (
        lambda x:
        (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor)
    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))

    def convert_samples(sample):
        source = sample[args.src_lang].split()
        target = sample[args.trg_lang].split()

        source = src_vocab.to_indices(source)
        target = trg_vocab.to_indices(target)

        return source, target

    def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
                      data_source):
        return max(tokens_sofar,
                   len(data_source[current_idx][0]) + 1,
                   len(data_source[current_idx][1]) + 1)

    def _key(size_so_far, minibatch_len):
        return size_so_far * minibatch_len

    data_loaders = [(None)] * 2
    for i, dataset in enumerate(datasets):
        dataset = dataset.map(convert_samples, lazy=False).filter(
            partial(min_max_filer, max_len=args.max_length))

        sampler = SamplerHelper(dataset)

        if args.sort_type == SortType.GLOBAL:
            src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
            trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
            # Sort twice
            sampler = sampler.sort(key=trg_key).sort(key=src_key)
        else:
            if args.shuffle:
                sampler = sampler.shuffle(seed=args.shuffle_seed)
            max_key = (lambda x, data_source: max(len(data_source[x][0]),
                                                  len(data_source[x][1])) + 1)
            if args.sort_type == SortType.POOL:
                sampler = sampler.sort(key=max_key, buffer_size=args.pool_size)

        batch_sampler = sampler.batch(batch_size=args.batch_size,
                                      drop_last=False,
                                      batch_size_fn=_max_token_fn,
                                      key=_key)

        if args.shuffle_batch:
            batch_sampler = batch_sampler.shuffle(seed=args.shuffle_seed)

        if i == 0:
            batch_sampler = batch_sampler.shard()

        data_loader = DataLoader(dataset=dataset,
                                 places=places,
                                 batch_sampler=batch_sampler,
                                 collate_fn=partial(prepare_train_input,
                                                    bos_idx=args.bos_idx,
                                                    eos_idx=args.eos_idx,
                                                    pad_idx=args.bos_idx,
                                                    pad_seq=args.pad_seq),
                                 num_workers=0)
        data_loaders[i] = (data_loader)
    return data_loaders
Beispiel #2
0
def create_data_loader(args):
    train_dataset = load_dataset(read,
                                 src_path=args.training_file.split(',')[0],
                                 tgt_path=args.training_file.split(',')[1],
                                 lazy=False)
    dev_dataset = load_dataset(read,
                               src_path=args.training_file.split(',')[0],
                               tgt_path=args.training_file.split(',')[1],
                               lazy=False)
    print('load src vocab')
    print(args.src_vocab_fpath)
    src_vocab = Vocab.load_vocabulary(args.src_vocab_fpath,
                                      bos_token=args.special_token[0],
                                      eos_token=args.special_token[1],
                                      unk_token=args.special_token[2])
    print('load trg vocab')
    print(args.trg_vocab_fpath)
    trg_vocab = Vocab.load_vocabulary(args.trg_vocab_fpath,
                                      bos_token=args.special_token[0],
                                      eos_token=args.special_token[1],
                                      unk_token=args.special_token[2])
    print('padding')
    padding_vocab = (
        lambda x:
        (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor)
    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))
    print('convert example')

    def convert_samples(sample):
        source = sample['src'].split()
        target = sample['tgt'].split()

        source = src_vocab.to_indices(source)
        target = trg_vocab.to_indices(target)

        return source, target

    data_loaders = [(None)] * 2
    print('dataset loop')
    for i, dataset in enumerate([train_dataset, dev_dataset]):
        dataset = dataset.map(convert_samples, lazy=False).filter(
            partial(min_max_filer, max_len=args.max_length))

        sampler = SamplerHelper(dataset)

        if args.sort_type == SortType.GLOBAL:
            src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
            trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
            # Sort twice
            sampler = sampler.sort(key=trg_key).sort(key=src_key)
        else:
            if args.shuffle:
                sampler = sampler.shuffle(seed=args.shuffle_seed)
            max_key = (lambda x, data_source: max(len(data_source[x][0]),
                                                  len(data_source[x][1])) + 1)
            if args.sort_type == SortType.POOL:
                sampler = sampler.sort(key=max_key, buffer_size=args.pool_size)

        batch_size_fn = lambda new, count, sofar, data_source: max(
            sofar,
            len(data_source[new][0]) + 1,
            len(data_source[new][1]) + 1)
        batch_sampler = sampler.batch(
            batch_size=args.batch_size,
            drop_last=False,
            batch_size_fn=batch_size_fn,
            key=lambda size_so_far, minibatch_len: size_so_far * minibatch_len)

        if args.shuffle_batch:
            batch_sampler = batch_sampler.shuffle(seed=args.shuffle_seed)

        if i == 0:
            batch_sampler = batch_sampler.shard()

        data_loader = DataLoader(dataset=dataset,
                                 batch_sampler=batch_sampler,
                                 collate_fn=partial(prepare_train_input,
                                                    bos_idx=args.bos_idx,
                                                    eos_idx=args.eos_idx,
                                                    pad_idx=args.bos_idx),
                                 num_workers=2,
                                 return_list=True)
        data_loaders[i] = (data_loader)
    return data_loaders
Beispiel #3
0
def create_data_loader(args):
    root = None if args.root == "None" else args.root
    (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
    args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
    transform_func = WMT14ende.get_default_transform_func(root=root)
    datasets = [
        WMT14ende.get_datasets(mode=m, transform_func=transform_func)
        for m in ["train", "dev"]
    ]

    if args.shuffle or args.shuffle_batch:
        if args.shuffle_seed == "None" or args.shuffle_seed is None:
            shuffle_seed = 0
        else:
            shuffle_seed = args.shuffle_seed

    def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
                      data_source):
        return max(tokens_sofar,
                   len(data_source[current_idx][0]) + 1,
                   len(data_source[current_idx][1]) + 1)

    def _key(size_so_far, minibatch_len):
        return size_so_far * minibatch_len

    data_loaders = [(None)] * 2
    for i, dataset in enumerate(datasets):
        m = dataset.mode
        dataset = dataset.filter(
            partial(min_max_filer, max_len=args.max_length))
        sampler = SamplerHelper(dataset)

        src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
        if args.sort_type == SortType.GLOBAL:
            buffer_size = -1
            trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
            # Sort twice
            sampler = sampler.sort(key=trg_key, buffer_size=buffer_size).sort(
                key=src_key, buffer_size=buffer_size)
        else:
            if args.shuffle:
                sampler = sampler.shuffle(seed=shuffle_seed)
            if args.sort_type == SortType.POOL:
                buffer_size = args.pool_size
                sampler = sampler.sort(key=src_key, buffer_size=buffer_size)

        batch_sampler = sampler.batch(batch_size=args.batch_size,
                                      drop_last=False,
                                      batch_size_fn=_max_token_fn,
                                      key=_key)

        if m == "train":
            batch_sampler = batch_sampler.shard()

        if args.shuffle_batch:
            batch_sampler.shuffle(seed=shuffle_seed)

        data_loader = DataLoader(dataset=dataset,
                                 batch_sampler=batch_sampler,
                                 collate_fn=partial(prepare_train_input,
                                                    bos_idx=args.bos_idx,
                                                    eos_idx=args.eos_idx,
                                                    pad_idx=args.bos_idx),
                                 num_workers=0,
                                 return_list=True)
        data_loaders[i] = (data_loader)
    return data_loaders
Beispiel #4
0
def create_data_loader(args, places=None):
    data_files = {'train': args.training_file, 'dev': args.validation_file}

    datasets = [
        load_dataset(read, src_tgt_file=filename, lazy=False)
        for split, filename in data_files.items()
    ]

    src_vocab = Vocab.load_vocabulary(args.src_vocab_fpath,
                                      bos_token=args.special_token[0],
                                      eos_token=args.special_token[1],
                                      unk_token=args.special_token[2])
    trg_vocab = Vocab.load_vocabulary(args.trg_vocab_fpath,
                                      bos_token=args.special_token[0],
                                      eos_token=args.special_token[1],
                                      unk_token=args.special_token[2])

    args.src_vocab_size = len(src_vocab)
    args.trg_vocab_size = len(trg_vocab)

    def convert_samples(sample):
        source = [item.strip() for item in sample['src'].split()]
        target = [item.strip() for item in sample['trg'].split()]

        source = src_vocab.to_indices(source) + [args.eos_idx]
        target = [args.bos_idx] + \
                 trg_vocab.to_indices(target) + [args.eos_idx]

        return source, target

    data_loaders = [(None)] * 2
    for i, dataset in enumerate(datasets):
        dataset = dataset.map(convert_samples, lazy=False).filter(
            partial(min_max_filer, max_len=args.max_length))

        sampler = SamplerHelper(dataset)

        if args.sort_type == SortType.GLOBAL:
            src_key = (lambda x, data_source: len(data_source[x][0]))
            trg_key = (lambda x, data_source: len(data_source[x][1]))
            # Sort twice
            sampler = sampler.sort(key=trg_key).sort(key=src_key)
        else:
            if args.shuffle:
                sampler = sampler.shuffle(seed=args.random_seed)
            max_key = (lambda x, data_source: max(len(data_source[x][0]),
                                                  len(data_source[x][1])))
            if args.sort_type == SortType.POOL:
                sampler = sampler.sort(key=max_key, buffer_size=args.pool_size)

        batch_size_fn = lambda new, count, sofar, data_source: max(
            sofar, len(data_source[new][0]), len(data_source[new][1]))
        batch_sampler = sampler.batch(
            batch_size=args.batch_size,
            drop_last=False,
            batch_size_fn=batch_size_fn,
            key=lambda size_so_far, minibatch_len: size_so_far * minibatch_len)

        if args.shuffle_batch:
            batch_sampler = batch_sampler.shuffle(seed=args.random_seed)

        if i == 0:
            batch_sampler = batch_sampler.shard()

        data_loader = DataLoader(dataset=dataset,
                                 places=places,
                                 batch_sampler=batch_sampler,
                                 collate_fn=partial(prepare_train_input,
                                                    pad_idx=args.bos_idx),
                                 num_workers=0)

        data_loaders[i] = (data_loader)

    return data_loaders