-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
160 lines (132 loc) · 5.75 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# -*- coding: utf-8 -*-
import os
import argparse
import time
import math
import pickle
import torch
from torch.utils.data import Dataset
from preprocess import Preprocessor
from model import CNNModel, LSTMModel
from train import train, evaluate
# =================================================
# Utility functions
# =================================================
def parse_args():
parser = argparse.ArgumentParser()
# paths
parser.add_argument('--data', type=str, default='./data/',
help='path to load data')
parser.add_argument('--save', type=str, default='./save/',
help='path to save model')
# set/model/embedding choice
parser.add_argument('--set', type=str, default='1',
help="essay set to use (1-8)")
parser.add_argument('--model', type=str, default='LSTM',
help="model to use (CNN, LSTM, bi-LSTM)")
parser.add_argument('--embed', type=str, default='glove',
help="embedding to use (none, glove)")
# parameters to tune
parser.add_argument('--e_dim', type=int, default=300,
help='embedding dimension')
parser.add_argument('--h_dim', type=int, default=200,
help='hidden dimension')
parser.add_argument('--batch', type=int, default=20,
help='batch size')
parser.add_argument('--epochs', type=int, default=40,
help='upper epoch limit')
parser.add_argument('--lr', type=float, default=20,
help='initial learning rate')
parser.add_argument('--dropout', type=float, default=0.5,
help='dropout applied to layers')
# boolean arguments
parser.add_argument('-n', '--noise', action='store_true',
help='use noise')
parser.add_argument('-c', '--cuda', action='store_true',
help='use CUDA')
args = parser.parse_args()
return args
def pprint(filepath, s):
print(s)
open(filepath, 'a').write(s + '\n')
# =================================================
# Corpus class
# =================================================
class Corpus(Dataset):
def __init__(self, filename):
_data = pickle.load(open(filename, 'rb'))
self.data = [_data[key] for key in _data]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# =================================================
# Main
# =================================================
if __name__=='__main__':
args = parse_args()
datapath = os.path.join(args.data, "essayset_{}".format(args.set))
savefile = "model_{}_{}--set{}_emb{}_hid{}_bat{}_epc{}.pt".format(args.model, args.embed, datapath.strip('/')[-1],
args.e_dim, args.h_dim, args.batch, args.epochs)
logfile = "log_{}_{}--set{}_emb{}_hid{}_bat{}_epc{}.pt".format(args.model, args.embed, datapath.strip('/')[-1],
args.e_dim, args.h_dim, args.batch, args.epochs)
savefile_path = os.path.join(args.save, savefile)
logfile_path = os.path.join(args.save, logfile)
torch.manual_seed(1)
if torch.cuda.is_available():
torch.cuda.manual_seed(1)
### Load preprocessed data
pprint(logfile_path, '=' * 89)
pprint(logfile_path, 'preprocessing data...')
pprint(logfile_path, '=' * 89)
glove_path = "./data/glove.840B.300d.txt"
preprocessor = Preprocessor(datapath)
train_data = Corpus(os.path.join(datapath, 'train.dat'))
valid_data = Corpus(os.path.join(datapath, 'valid.dat'))
test_data = Corpus(os.path.join(datapath, 'test.dat'))
### Build model
pprint(logfile_path, '=' * 89)
pprint(logfile_path, 'building model...')
pprint(logfile_path, '=' * 89)
n = len(preprocessor.vocab)
if args.model == 'LSTM':
model = LSTMModel(n, args.e_dim, args.h_dim, args.dropout, False)
elif args.model == 'bi-LSTM':
model = LSTMModel(n, args.e_dim, args.h_dim, args.dropout, True)
elif args.model == 'CNN':
model = CNNModel(n, args.e_dim, args.h_dim, args.dropout)
if args.embed == 'glove':
model.init_weights_glove(glove_path, preprocessor.word2idx)
if args.cuda:
model.cuda()
### Train model
pprint(logfile_path, '=' * 89)
pprint(logfile_path, 'training model...')
pprint(logfile_path, '=' * 89)
lr = args.lr
best_val_loss = None
for epoch in range(1, args.epochs+1):
epoch_start_time = time.time()
train(model, train_data, args.batch, logfile_path, args.noise,)
val_loss = evaluate(model, valid_data, args.batch, logfile_path)
pprint(logfile_path, '-' * 89)
pprint(logfile_path, '| end of epoch {:3d} | time: {:5.2f}s | '
'valid loss {:5.2f} | '.format(epoch, (time.time() - epoch_start_time), val_loss))
pprint(logfile_path, '-' * 89)
if epoch % 20 == 0:
with open (savefile_path + '--e{}'.format(epoch), 'wb') as f:
torch.save(model, f)
if not best_val_loss or val_loss < best_val_loss:
with open(savefile_path, 'wb') as f:
torch.save(model, f)
best_val_loss = val_loss
pprint(logfile_path, '-' * 89)
pprint(logfile_path, 'best_val_loss: {}'.format(best_val_loss))
pprint(logfile_path, '-' * 89)
with open(savefile_path, 'rb') as f:
model = torch.load(f)
### Run on test data
test_loss = evaluate(model, test_data, args.batch, logfile_path)
pprint(logfile_path, '=' * 89)
pprint(logfile_path, '| End of training | test loss {:5.2f}'.format(test_loss))
pprint(logfile_path, '=' * 89)