Esempio n. 1
0
def create_infer_loader(args):
    root = None if args.root == "None" else args.root
    (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
    padding_vocab = (
        lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor
    )
    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))
    transform_func = WMT14ende.get_default_transform_func(root=root)
    dataset = WMT14ende.get_datasets(
        mode="test", root=root, transform_func=transform_func).filter(
            partial(
                min_max_filer, max_len=args.max_length))

    batch_sampler = SamplerHelper(dataset).batch(
        batch_size=args.infer_batch_size, drop_last=False)

    data_loader = DataLoader(
        dataset=dataset,
        batch_sampler=batch_sampler,
        collate_fn=partial(
            prepare_infer_input,
            bos_idx=args.bos_idx,
            eos_idx=args.eos_idx,
            pad_idx=args.bos_idx),
        num_workers=0,
        return_list=True)
    return data_loader, trg_vocab.to_tokens
Esempio n. 2
0
def create_data_loader(args, places=None, use_all_vocab=False):
    root = None if args.root == "None" else args.root
    if not use_all_vocab:
        WMT14ende.VOCAB_INFO = (os.path.join("WMT14.en-de",
                                             "wmt14_ende_data_bpe",
                                             "vocab_all.bpe.33712"),
                                os.path.join("WMT14.en-de",
                                             "wmt14_ende_data_bpe",
                                             "vocab_all.bpe.33712"),
                                "de485e3c2e17e23acf4b4b70b54682dd",
                                "de485e3c2e17e23acf4b4b70b54682dd")
    (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
    padding_vocab = (
        lambda x:
        (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor)
    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))
    transform_func = WMT14ende.get_default_transform_func(root=root)
    datasets = [
        WMT14ende.get_datasets(mode=m,
                               root=root,
                               transform_func=transform_func)
        for m in ["train", "dev"]
    ]

    data_loaders = [(None)] * 2
    for i, dataset in enumerate(datasets):
        dataset = dataset.filter(
            partial(min_max_filer, max_len=args.max_length))
        batch_sampler = TransformerBatchSampler(
            dataset=dataset,
            batch_size=args.batch_size,
            pool_size=args.pool_size,
            sort_type=args.sort_type,
            shuffle=args.shuffle,
            shuffle_batch=args.shuffle_batch,
            use_token_batch=True,
            max_length=args.max_length,
            distribute_mode=True if i == 0 else False,
            world_size=dist.get_world_size(),
            rank=dist.get_rank(),
            pad_seq=args.pad_seq,
            bsz_multi=args.bsz_multi)

        data_loader = DataLoader(dataset=dataset,
                                 places=places,
                                 batch_sampler=batch_sampler,
                                 collate_fn=partial(prepare_train_input,
                                                    bos_idx=args.bos_idx,
                                                    eos_idx=args.eos_idx,
                                                    pad_idx=args.bos_idx,
                                                    pad_seq=args.pad_seq),
                                 num_workers=0)
        data_loaders[i] = (data_loader)
    return data_loaders
Esempio n. 3
0
def create_data_loader(args):
    root = None if args.root == "None" else args.root
    (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
    padding_vocab = (
        lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor
    )
    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))
    transform_func = WMT14ende.get_default_transform_func(root=root)
    datasets = [
        WMT14ende.get_datasets(
            mode=m, root=root, transform_func=transform_func)
        for m in ["train", "dev"]
    ]

    data_loaders = [(None)] * 2
    for i, dataset in enumerate(datasets):
        dataset = dataset.filter(
            partial(
                min_max_filer, max_len=args.max_length))
        batch_sampler = TransformerBatchSampler(
            dataset=dataset,
            batch_size=args.batch_size,
            pool_size=args.pool_size,
            sort_type=args.sort_type,
            shuffle=args.shuffle,
            shuffle_batch=args.shuffle_batch,
            use_token_batch=True,
            max_length=args.max_length,
            distribute_mode=True if i == 0 else False,
            world_size=dist.get_world_size(),
            rank=dist.get_rank())

        data_loader = DataLoader(
            dataset=dataset,
            batch_sampler=batch_sampler,
            collate_fn=partial(
                prepare_train_input,
                bos_idx=args.bos_idx,
                eos_idx=args.eos_idx,
                pad_idx=args.bos_idx),
            num_workers=0,
            return_list=True)
        data_loaders[i] = (data_loader)
    return data_loaders
Esempio n. 4
0
def adapt_vocab_size(args):
    root = None if args.root == "None" else args.root
    (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
    padding_vocab = (
        lambda x:
        (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor)

    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))
Esempio n. 5
0
def create_infer_loader(args, use_all_vocab=False):
    root = None if args.root == "None" else args.root
    if not use_all_vocab:
        WMT14ende.VOCAB_INFO = (os.path.join("WMT14.en-de",
                                             "wmt14_ende_data_bpe",
                                             "vocab_all.bpe.33712"),
                                os.path.join("WMT14.en-de",
                                             "wmt14_ende_data_bpe",
                                             "vocab_all.bpe.33712"),
                                "de485e3c2e17e23acf4b4b70b54682dd",
                                "de485e3c2e17e23acf4b4b70b54682dd")
    (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
    padding_vocab = (
        lambda x:
        (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor)
    args.src_vocab_size = padding_vocab(len(src_vocab))
    args.trg_vocab_size = padding_vocab(len(trg_vocab))
    transform_func = WMT14ende.get_default_transform_func(root=root)
    dataset = WMT14ende.get_datasets(mode="test",
                                     root=root,
                                     transform_func=transform_func).filter(
                                         partial(min_max_filer,
                                                 max_len=args.max_length))

    batch_sampler = SamplerHelper(dataset).batch(
        batch_size=args.infer_batch_size, drop_last=False)

    data_loader = DataLoader(dataset=dataset,
                             batch_sampler=batch_sampler,
                             collate_fn=partial(prepare_infer_input,
                                                bos_idx=args.bos_idx,
                                                eos_idx=args.eos_idx,
                                                pad_idx=args.bos_idx,
                                                pad_seq=args.pad_seq),
                             num_workers=0,
                             return_list=True)
    return data_loader, trg_vocab.to_tokens
Esempio n. 6
0
def create_data_loader(args):
    root = None if args.root == "None" else args.root
    (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
    args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
    transform_func = WMT14ende.get_default_transform_func(root=root)
    datasets = [
        WMT14ende.get_datasets(mode=m, transform_func=transform_func)
        for m in ["train", "dev"]
    ]

    if args.shuffle or args.shuffle_batch:
        if args.shuffle_seed == "None" or args.shuffle_seed is None:
            shuffle_seed = 0
        else:
            shuffle_seed = args.shuffle_seed

    def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
                      data_source):
        return max(tokens_sofar,
                   len(data_source[current_idx][0]) + 1,
                   len(data_source[current_idx][1]) + 1)

    def _key(size_so_far, minibatch_len):
        return size_so_far * minibatch_len

    data_loaders = [(None)] * 2
    for i, dataset in enumerate(datasets):
        m = dataset.mode
        dataset = dataset.filter(
            partial(min_max_filer, max_len=args.max_length))
        sampler = SamplerHelper(dataset)

        src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
        if args.sort_type == SortType.GLOBAL:
            buffer_size = -1
            trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
            # Sort twice
            sampler = sampler.sort(key=trg_key, buffer_size=buffer_size).sort(
                key=src_key, buffer_size=buffer_size)
        else:
            if args.shuffle:
                sampler = sampler.shuffle(seed=shuffle_seed)
            if args.sort_type == SortType.POOL:
                buffer_size = args.pool_size
                sampler = sampler.sort(key=src_key, buffer_size=buffer_size)

        batch_sampler = sampler.batch(batch_size=args.batch_size,
                                      drop_last=False,
                                      batch_size_fn=_max_token_fn,
                                      key=_key)

        if m == "train":
            batch_sampler = batch_sampler.shard()

        if args.shuffle_batch:
            batch_sampler.shuffle(seed=shuffle_seed)

        data_loader = DataLoader(dataset=dataset,
                                 batch_sampler=batch_sampler,
                                 collate_fn=partial(prepare_train_input,
                                                    bos_idx=args.bos_idx,
                                                    eos_idx=args.eos_idx,
                                                    pad_idx=args.bos_idx),
                                 num_workers=0,
                                 return_list=True)
        data_loaders[i] = (data_loader)
    return data_loaders