예제 #1
0
    def fn(src, trg):
        src = torch.as_tensor(text_to_indices(src, source_vocab))
        trg = torch.as_tensor(text_to_indices(trg, target_vocab))
        n_src_tok = src.size(0)
        n_trg_tok = trg.size(0)

        return {
            'src': src,
            'trg': trg,
            'n_src_tok': n_src_tok,
            'n_trg_tok': n_trg_tok,
        }
예제 #2
0
    def fn(src, r2l, l2r):
        src = torch.as_tensor(text_to_indices(src, source_vocab))
        r2l = torch.as_tensor(text_to_indices(r2l, target_vocab))
        l2r = torch.as_tensor(text_to_indices(l2r, target_vocab))
        n_src_tok = src.size(0)
        n_trg_tok_r2l = r2l.size(0)
        n_trg_tok_l2r = l2r.size(0)

        return {
            'src': src,
            'r2l': r2l,
            'l2r': l2r,
            'n_src_tok': n_src_tok,
            'ntok_r2l': n_trg_tok_r2l,
            'ntok_l2r': n_trg_tok_l2r,
        }
예제 #3
0
def get_dev_iterator(args, source_vocab: Vocabulary):
    threads = args.num_workers

    src = TextLine(args.dev[0], bufsize=args.buffer_size, num_threads=threads)
    refs = [
        TextLine(ref, bufsize=args.buffer_size, num_threads=threads)
        for ref in args.dev[1:]
    ]
    src = src.select(
        lambda x: torch.as_tensor(text_to_indices(x, source_vocab)))
    refs = Zip(refs, bufsize=args.buffer_size, num_threads=threads)
    ds = Zip([src, refs], bufsize=args.buffer_size,
             num_threads=threads).select(
                 lambda x, ys: {
                     'src': x,
                     'n_tok': x.size(0),
                     'refs': [y.split() for y in ys]  # tokenize references
                 })

    def collate_fn(xs):
        return {
            'src':
            cuda(
                pack_tensors(aggregate_value_by_key(xs, 'src'),
                             source_vocab.pad_id)),
            'n_tok':
            aggregate_value_by_key(xs, 'n_tok', sum),
            'refs':
            aggregate_value_by_key(xs, 'refs')
        }

    batch_size = args.eval_batch_size
    iterator = Iterator(ds,
                        batch_size,
                        cache_size=batch_size,
                        collate_fn=collate_fn,
                        sort_cache_by=lambda sample: -sample['n_tok'])

    return iterator
예제 #4
0
 def fn(*xs):
     tensors = [torch.as_tensor(text_to_indices(x, voc)) for x, voc in zip(xs, vocabs)]
     return {
         'src': tensors,
         'n_tok': tensors[0].size(0)
     }