def load_train_data(data_path, batch_size, max_src_len, max_trg_len, use_cuda=False): # Note: sequential=False, use_vocab=False, since we use preprocessed inputs. src_field = Field(sequential=True, use_vocab=False, include_lengths=True, batch_first=True, pad_token=PAD, unk_token=UNK, init_token=None, eos_token=None,) trg_field = Field(sequential=True, use_vocab=False, include_lengths=True, batch_first=True, pad_token=PAD, unk_token=UNK, init_token=BOS, eos_token=EOS,) fields = (src_field, trg_field) device = None if use_cuda else -1 def filter_pred(example): if len(example.src) <= max_src_len and len(example.trg) <= max_trg_len: return True return False dataset = torch.load(data_path) train_src, train_tgt = dataset['train_src'], dataset['train_tgt'] dev_src, dev_tgt = dataset['dev_src'], dataset['dev_tgt'] train_data = ParallelDataset(train_src, train_tgt, fields=fields, filter_pred=filter_pred,) train_iter = Iterator(dataset=train_data, batch_size=batch_size, train=True, # Variable(volatile=False) sort_key=lambda x: data.interleave_keys(len(x.src), len(x.trg)), repeat=False, shuffle=True, device=device) dev_data = ParallelDataset(dev_src, dev_tgt, fields=fields,) dev_iter = Iterator(dataset=dev_data, batch_size=batch_size, train=False, # Variable(volatile=True) repeat=False, device=device, shuffle=False, sort=False,) return src_field, trg_field, train_iter, dev_iter
def load_test_data(data_path, vocab_path, batch_size, use_cuda=False): # Note: sequential=False, use_vocab=False, since we use preprocessed inputs. src_field = Field(sequential=True, use_vocab=False, include_lengths=True, batch_first=True, pad_token=PAD, unk_token=UNK, init_token=None, eos_token=None,) fields = (src_field, None) device = None if use_cuda else -1 vocab = torch.load(vocab_path) _, src_word2idx, _ = vocab['src_dict'] lower_case = vocab['lower_case'] test_src = convert_text2idx(read_corpus(data_path, None, lower_case), src_word2idx) test_data = ParallelDataset(test_src, None, fields=fields,) test_iter = Iterator(dataset=test_data, batch_size=batch_size, train=False, # Variable(volatile=True) repeat=False, device=device, shuffle=False, sort=False) return src_field, test_iter