Пример #1
0
def get_train_iterator(args, vocab: Vocabulary, label_encoder: preprocessing.LabelEncoder):
    threads = 6
    buffer_size = 10000

    fact = TextLine(args.train, buffer_size=buffer_size, num_threads=threads)
    label = TextLine(args.train_label, buffer_size=buffer_size, num_threads=threads)
    ds = Zip([fact, label], buffer_size=buffer_size, num_threads=threads)

    ds = ds.select(
        lambda sample: (
            torch.as_tensor(text_to_indices(sample[0], vocab)),
            sample[1]
        )
    )
    ds = Shuffle(ds, buffer_size=-1, num_threads=threads)

    def collate_fn(samples):
        samples = list(itertools.zip_longest(*samples))
        xs, ys = samples
        return to_cuda(pack_tensors(xs, vocab.pad_id)), to_cuda(torch.as_tensor(label_encoder.transform(ys)).long())

    iterator = Iterator(
        ds, args.batch_size,
        cache_size=2048,
        collate_fn=collate_fn,
        sort_desc_by=lambda sample: sample[0].size(0)
    )

    return iterator
Пример #2
0
def get_devtest_iterator(fact_path, vocab: Vocabulary, label_encoder: preprocessing.LabelEncoder, label_path=None):
    threads = 6
    buffer_size = 10000
    if label_path is not None:
        fact = TextLine(fact_path, buffer_size=buffer_size, num_threads=threads)
        label = TextLine(label_path, buffer_size=buffer_size, num_threads=threads)
        ds = Zip([fact, label], buffer_size=buffer_size, num_threads=threads)

        ds = ds.select(
            lambda sample: (
                torch.as_tensor(text_to_indices(sample[0], vocab)),
                sample[1]
            )
        )

        def collate_fn(samples):
            samples = list(itertools.zip_longest(*samples))
            xs, ys = samples
            return to_cuda(pack_tensors(xs, vocab.pad_id)), to_cuda(torch.as_tensor(label_encoder.transform(ys)).long())

        iterator = Iterator(
            ds, 100,
            cache_size=100,
            collate_fn=collate_fn,
        )
    else:
        ds = TextLine(fact_path, buffer_size=buffer_size, num_threads=threads)
        ds = ds.select(
            lambda sample: torch.as_tensor(text_to_indices(sample[0], vocab))
        )

        def collate_fn(samples):
            return to_cuda(pack_tensors(samples, vocab.pad_id))

        iterator = Iterator(
            ds, 100,
            cache_size=100,
            collate_fn=collate_fn,
        )

    return iterator
Пример #3
0
def get_dev_iterator(args, source_vocab: Vocabulary):
    threads = args.num_workers

    src = TextLine(args.dev[0], bufsize=args.buffer_size, num_threads=threads)
    refs = [
        TextLine(ref, bufsize=args.buffer_size, num_threads=threads)
        for ref in args.dev[1:]
    ]
    src = src.select(
        lambda x: torch.as_tensor(text_to_indices(x, source_vocab)))
    refs = Zip(refs, bufsize=args.buffer_size, num_threads=threads)
    ds = Zip([src, refs], bufsize=args.buffer_size,
             num_threads=threads).select(
                 lambda x, ys: {
                     'src': x,
                     'n_tok': x.size(0),
                     'refs': [y.split() for y in ys]  # tokenize references
                 })

    def collate_fn(xs):
        return {
            'src':
            cuda(
                pack_tensors(aggregate_value_by_key(xs, 'src'),
                             source_vocab.pad_id)),
            'n_tok':
            aggregate_value_by_key(xs, 'n_tok', sum),
            'refs':
            aggregate_value_by_key(xs, 'refs')
        }

    batch_size = args.eval_batch_size
    iterator = Iterator(ds,
                        batch_size,
                        cache_size=batch_size,
                        collate_fn=collate_fn,
                        sort_cache_by=lambda sample: -sample['n_tok'])

    return iterator
Пример #4
0
def get_dev_iterator(args, vocs):
    svoc, tvoc = vocs

    n_input = 2 if args.arch in ['abdrnn2', 'mulsrc'] else 1
    dev = args.dev[:n_input]
    refs = args.dev[n_input:]

    if n_input == 1:
        vocabs = [svoc]
    elif n_input == 2:
        vocabs = [svoc, tvoc]
    else:
        raise NotImplementedError('a maximum of 2 sources is allowed')
    bufsize = args.buffer_size
    batch_size = args.eval_batch_size
    threads = args.num_workers

    def fn(*xs):
        tensors = [
            torch.as_tensor(text_to_indices(x, voc))
            for x, voc in zip(xs, vocabs)
        ]
        return {'src': tensors, 'n_tok': tensors[0].size(0)}

    src = [TextLine(f, bufsize=bufsize, num_threads=6) for f in dev]
    src = Zip(src, bufsize=bufsize, num_threads=6).select(fn)

    refs = [
        TextLine(ref, bufsize=args.buffer_size, num_threads=threads)
        for ref in refs
    ]
    refs = Zip(refs, bufsize=args.buffer_size, num_threads=threads)

    ds = Zip([src, refs], bufsize=args.buffer_size,
             num_threads=threads).select(
                 lambda x, ys: {
                     'src': x['src'],
                     'n_tok': x['n_tok'],
                     'refs': [y.split() for y in ys]  # tokenize references
                 })

    # for x in ds:
    #     print(x)
    #     exit(0)

    def collate_fn(xs):
        inputs = aggregate_value_by_key(xs, 'src')
        inputs = list(zip(*inputs))
        inputs = [
            cuda(pack_tensors(input, voc.pad_id))
            for input, voc in zip(inputs, vocabs)
        ]
        return {
            'src': inputs[0] if len(inputs) == 1 else inputs,
            'n_tok': aggregate_value_by_key(xs, 'n_tok', sum),
            'n_snt': inputs[0].size(0),
            'refs': aggregate_value_by_key(xs, 'refs')
        }

    iterator = Iterator(
        ds,
        batch_size,
        cache_size=batch_size,
        collate_fn=collate_fn,
        sort_cache_by=lambda sample: -sample['n_tok'],
    )

    return iterator
Пример #5
0
def get_train_iterator(args, source_vocab: Vocabulary,
                       target_vocab: Vocabulary):
    threads = args.num_workers

    src = TextLine(args.train[0],
                   bufsize=args.buffer_size,
                   num_threads=threads)
    r2l = TextLine(args.train[1],
                   bufsize=args.buffer_size,
                   num_threads=threads)
    l2r = TextLine(args.train[2],
                   bufsize=args.buffer_size,
                   num_threads=threads)
    ds = Zip([src, r2l, l2r], bufsize=args.buffer_size, num_threads=threads)

    def fn(src, r2l, l2r):
        src = torch.as_tensor(text_to_indices(src, source_vocab))
        r2l = torch.as_tensor(text_to_indices(r2l, target_vocab))
        l2r = torch.as_tensor(text_to_indices(l2r, target_vocab))
        n_src_tok = src.size(0)
        n_trg_tok_r2l = r2l.size(0)
        n_trg_tok_l2r = l2r.size(0)

        return {
            'src': src,
            'r2l': r2l,
            'l2r': l2r,
            'n_src_tok': n_src_tok,
            'ntok_r2l': n_trg_tok_r2l,
            'ntok_l2r': n_trg_tok_l2r,
        }

    ds = ds.select(fn)
    shuffle = args.shuffle
    if shuffle != 0:
        ds = Shuffle(ds,
                     shufsize=shuffle,
                     bufsize=args.buffer_size,
                     num_threads=threads)

    # limit = args.length_limit
    # if limit is not None and len(limit) == 1:
    #     limit *= len(args.train)
    # if limit is not None:
    #     ds = ds.where(
    #         lambda x: x['n_src_tok'] - 1 <= limit[0] and x['n_trg_tok'] - 1 <= limit[1]
    #     )

    def collate_fn(xs):
        return {
            'src':
            cuda(
                pack_tensors(aggregate_value_by_key(xs, 'src'),
                             source_vocab.pad_id)),
            'r2l':
            cuda(
                pack_tensors(aggregate_value_by_key(xs, 'r2l'),
                             target_vocab.pad_id)),
            'l2r':
            cuda(
                pack_tensors(aggregate_value_by_key(xs, 'l2r'),
                             target_vocab.pad_id)),
            'ntok_src':
            aggregate_value_by_key(xs, 'n_src_tok', sum),
            'ntok_r2l':
            aggregate_value_by_key(xs, 'ntok_r2l', sum),
            'ntok_l2r':
            aggregate_value_by_key(xs, 'ntok_l2r', sum),
        }

    sample_size_fn = None
    if not args.batch_by_sentence:
        sample_size_fn = lambda x: x['ntok_l2r']

    batch_size = args.batch_size[0]
    padded_size = None
    padded_size_fn = lambda xs: 0 if not xs else \
        max(xs, key=lambda x: x['ntok_l2r'])['ntok_l2r'] * len(xs)

    if torch.cuda.is_available():
        batch_size *= torch.cuda.device_count()
        if len(args.batch_size) > 1:
            padded_size = args.batch_size[1]
            padded_size *= torch.cuda.device_count()
    # padded_size = None

    iterator = Iterator(
        ds,
        batch_size,
        padded_size=padded_size,
        cache_size=max(args.sort_buffer_factor, 1) * batch_size,
        sample_size_fn=sample_size_fn,
        padded_size_fn=padded_size_fn,
        collate_fn=collate_fn,
        sort_cache_by=lambda sample: sample['ntok_l2r'],
        sort_batch_by=lambda sample: -sample['n_src_tok'],
        # sort_cache_by=lambda sample: -sample['n_src_tok'],
        strip_batch=True)

    return iterator