/
preprocess.py
100 lines (73 loc) · 2.81 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Pre-process Data / features files and build vocabulary
"""
import argparse
import glob
import sys
import gc
import os
import codecs
import torch
from onmt.utils.logging import init_logger, logger
import onmt.inputters as inputters
import onmt.opts as opts
def parse_args():
""" Parsing arguments """
parser = argparse.ArgumentParser(
description='preprocess.py',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
opts.add_md_help_argument(parser)
opts.preprocess_opts(parser)
opt = parser.parse_args()
torch.manual_seed(opt.seed)
return opt
def build_save_dataset(corpus_type, fields, opt):
""" Building and saving the dataset """
assert corpus_type in ['train', 'valid']
if corpus_type == 'train':
corpus = opt.train_dir
else:
corpus = opt.valid_dir
dataset = inputters.build_dataset(
fields,
data_path=corpus,
data_type=opt.data_type,
total_token_length=opt.total_token_length,
src_seq_length=opt.src_seq_length,
src_sent_length=opt.src_sent_length,
seq_length_trunc=opt.seq_length_trunc)
# We save fields in vocab.pt seperately, so make it empty.
dataset.fields = []
pt_file = "{:s}.{:s}.pt".format(opt.save_data, corpus_type)
logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file))
torch.save(dataset, pt_file)
return pt_file
def build_save_vocab(train_dataset, data_type, fields, opt):
""" Building and saving the vocab """
fields = inputters.build_vocab(train_dataset, data_type, fields,
opt.share_vocab,
opt.src_vocab_size,
opt.src_words_min_frequency,
opt.tgt_vocab_size,
opt.tgt_words_min_frequency)
# Can't save fields, so remove/reconstruct at training time.
vocab_file = opt.save_data + '.vocab.pt'
torch.save(inputters.save_fields_to_vocab(fields), vocab_file)
def main():
opt = parse_args()
init_logger(opt.log_file)
logger.info("Extracting features...")
print(opt)
logger.info("Building `Fields` object...")
fields = inputters.get_fields(opt.data_type)
logger.info("Building & saving training data...")
train_dataset_files = build_save_dataset('train', fields, opt)
logger.info("Building & saving validation data...")
build_save_dataset('valid', fields, opt)
logger.info("Building & saving vocabulary...")
# train_dataset_files = 'data/processed.train.pt'
build_save_vocab(train_dataset_files, opt.data_type, fields, opt)
if __name__ == "__main__":
main()