def get_train_iterator(args, vocab: Vocabulary, label_encoder: preprocessing.LabelEncoder): threads = 6 buffer_size = 10000 fact = TextLine(args.train, buffer_size=buffer_size, num_threads=threads) label = TextLine(args.train_label, buffer_size=buffer_size, num_threads=threads) ds = Zip([fact, label], buffer_size=buffer_size, num_threads=threads) ds = ds.select( lambda sample: ( torch.as_tensor(text_to_indices(sample[0], vocab)), sample[1] ) ) ds = Shuffle(ds, buffer_size=-1, num_threads=threads) def collate_fn(samples): samples = list(itertools.zip_longest(*samples)) xs, ys = samples return to_cuda(pack_tensors(xs, vocab.pad_id)), to_cuda(torch.as_tensor(label_encoder.transform(ys)).long()) iterator = Iterator( ds, args.batch_size, cache_size=2048, collate_fn=collate_fn, sort_desc_by=lambda sample: sample[0].size(0) ) return iterator
def get_devtest_iterator(fact_path, vocab: Vocabulary, label_encoder: preprocessing.LabelEncoder, label_path=None): threads = 6 buffer_size = 10000 if label_path is not None: fact = TextLine(fact_path, buffer_size=buffer_size, num_threads=threads) label = TextLine(label_path, buffer_size=buffer_size, num_threads=threads) ds = Zip([fact, label], buffer_size=buffer_size, num_threads=threads) ds = ds.select( lambda sample: ( torch.as_tensor(text_to_indices(sample[0], vocab)), sample[1] ) ) def collate_fn(samples): samples = list(itertools.zip_longest(*samples)) xs, ys = samples return to_cuda(pack_tensors(xs, vocab.pad_id)), to_cuda(torch.as_tensor(label_encoder.transform(ys)).long()) iterator = Iterator( ds, 100, cache_size=100, collate_fn=collate_fn, ) else: ds = TextLine(fact_path, buffer_size=buffer_size, num_threads=threads) ds = ds.select( lambda sample: torch.as_tensor(text_to_indices(sample[0], vocab)) ) def collate_fn(samples): return to_cuda(pack_tensors(samples, vocab.pad_id)) iterator = Iterator( ds, 100, cache_size=100, collate_fn=collate_fn, ) return iterator
def get_dev_iterator(args, source_vocab: Vocabulary): threads = args.num_workers src = TextLine(args.dev[0], bufsize=args.buffer_size, num_threads=threads) refs = [ TextLine(ref, bufsize=args.buffer_size, num_threads=threads) for ref in args.dev[1:] ] src = src.select( lambda x: torch.as_tensor(text_to_indices(x, source_vocab))) refs = Zip(refs, bufsize=args.buffer_size, num_threads=threads) ds = Zip([src, refs], bufsize=args.buffer_size, num_threads=threads).select( lambda x, ys: { 'src': x, 'n_tok': x.size(0), 'refs': [y.split() for y in ys] # tokenize references }) def collate_fn(xs): return { 'src': cuda( pack_tensors(aggregate_value_by_key(xs, 'src'), source_vocab.pad_id)), 'n_tok': aggregate_value_by_key(xs, 'n_tok', sum), 'refs': aggregate_value_by_key(xs, 'refs') } batch_size = args.eval_batch_size iterator = Iterator(ds, batch_size, cache_size=batch_size, collate_fn=collate_fn, sort_cache_by=lambda sample: -sample['n_tok']) return iterator
def get_dev_iterator(args, vocs): svoc, tvoc = vocs n_input = 2 if args.arch in ['abdrnn2', 'mulsrc'] else 1 dev = args.dev[:n_input] refs = args.dev[n_input:] if n_input == 1: vocabs = [svoc] elif n_input == 2: vocabs = [svoc, tvoc] else: raise NotImplementedError('a maximum of 2 sources is allowed') bufsize = args.buffer_size batch_size = args.eval_batch_size threads = args.num_workers def fn(*xs): tensors = [ torch.as_tensor(text_to_indices(x, voc)) for x, voc in zip(xs, vocabs) ] return {'src': tensors, 'n_tok': tensors[0].size(0)} src = [TextLine(f, bufsize=bufsize, num_threads=6) for f in dev] src = Zip(src, bufsize=bufsize, num_threads=6).select(fn) refs = [ TextLine(ref, bufsize=args.buffer_size, num_threads=threads) for ref in refs ] refs = Zip(refs, bufsize=args.buffer_size, num_threads=threads) ds = Zip([src, refs], bufsize=args.buffer_size, num_threads=threads).select( lambda x, ys: { 'src': x['src'], 'n_tok': x['n_tok'], 'refs': [y.split() for y in ys] # tokenize references }) # for x in ds: # print(x) # exit(0) def collate_fn(xs): inputs = aggregate_value_by_key(xs, 'src') inputs = list(zip(*inputs)) inputs = [ cuda(pack_tensors(input, voc.pad_id)) for input, voc in zip(inputs, vocabs) ] return { 'src': inputs[0] if len(inputs) == 1 else inputs, 'n_tok': aggregate_value_by_key(xs, 'n_tok', sum), 'n_snt': inputs[0].size(0), 'refs': aggregate_value_by_key(xs, 'refs') } iterator = Iterator( ds, batch_size, cache_size=batch_size, collate_fn=collate_fn, sort_cache_by=lambda sample: -sample['n_tok'], ) return iterator
def get_train_iterator(args, source_vocab: Vocabulary, target_vocab: Vocabulary): threads = args.num_workers src = TextLine(args.train[0], bufsize=args.buffer_size, num_threads=threads) r2l = TextLine(args.train[1], bufsize=args.buffer_size, num_threads=threads) l2r = TextLine(args.train[2], bufsize=args.buffer_size, num_threads=threads) ds = Zip([src, r2l, l2r], bufsize=args.buffer_size, num_threads=threads) def fn(src, r2l, l2r): src = torch.as_tensor(text_to_indices(src, source_vocab)) r2l = torch.as_tensor(text_to_indices(r2l, target_vocab)) l2r = torch.as_tensor(text_to_indices(l2r, target_vocab)) n_src_tok = src.size(0) n_trg_tok_r2l = r2l.size(0) n_trg_tok_l2r = l2r.size(0) return { 'src': src, 'r2l': r2l, 'l2r': l2r, 'n_src_tok': n_src_tok, 'ntok_r2l': n_trg_tok_r2l, 'ntok_l2r': n_trg_tok_l2r, } ds = ds.select(fn) shuffle = args.shuffle if shuffle != 0: ds = Shuffle(ds, shufsize=shuffle, bufsize=args.buffer_size, num_threads=threads) # limit = args.length_limit # if limit is not None and len(limit) == 1: # limit *= len(args.train) # if limit is not None: # ds = ds.where( # lambda x: x['n_src_tok'] - 1 <= limit[0] and x['n_trg_tok'] - 1 <= limit[1] # ) def collate_fn(xs): return { 'src': cuda( pack_tensors(aggregate_value_by_key(xs, 'src'), source_vocab.pad_id)), 'r2l': cuda( pack_tensors(aggregate_value_by_key(xs, 'r2l'), target_vocab.pad_id)), 'l2r': cuda( pack_tensors(aggregate_value_by_key(xs, 'l2r'), target_vocab.pad_id)), 'ntok_src': aggregate_value_by_key(xs, 'n_src_tok', sum), 'ntok_r2l': aggregate_value_by_key(xs, 'ntok_r2l', sum), 'ntok_l2r': aggregate_value_by_key(xs, 'ntok_l2r', sum), } sample_size_fn = None if not args.batch_by_sentence: sample_size_fn = lambda x: x['ntok_l2r'] batch_size = args.batch_size[0] padded_size = None padded_size_fn = lambda xs: 0 if not xs else \ max(xs, key=lambda x: x['ntok_l2r'])['ntok_l2r'] * len(xs) if torch.cuda.is_available(): batch_size *= torch.cuda.device_count() if len(args.batch_size) > 1: padded_size = args.batch_size[1] padded_size *= torch.cuda.device_count() # padded_size = None iterator = Iterator( ds, batch_size, padded_size=padded_size, cache_size=max(args.sort_buffer_factor, 1) * batch_size, sample_size_fn=sample_size_fn, padded_size_fn=padded_size_fn, collate_fn=collate_fn, sort_cache_by=lambda sample: sample['ntok_l2r'], sort_batch_by=lambda sample: -sample['n_src_tok'], # sort_cache_by=lambda sample: -sample['n_src_tok'], strip_batch=True) return iterator