-
Notifications
You must be signed in to change notification settings - Fork 1
/
train-pos.py
executable file
·94 lines (78 loc) · 4.17 KB
/
train-pos.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env python3
import sys, argparse, math
import chainer.functions as F
import chainn.util.functions as UF
from chainer import optimizers
from chainn.util import AlignmentVisualizer
from chainn.util.io import ModelSerializer, load_pos_train_data, batch_generator
from chainn.classifier import ParallelTextClassifier
from chainn.machine import ParallelTrainer
parser = argparse.ArgumentParser("Program to train POS Tagger model using LSTM")
positive = lambda x: UF.check_positive(x, int)
positive_decimal = lambda x: UF.check_positive(x, float)
# Required
parser.add_argument("--model_out", type=str, required=True)
# Options
parser.add_argument("--hidden", type=positive, default=128, help="Size of hidden layer.")
parser.add_argument("--embed", type=positive, default=128, help="Size of embedding vector.")
parser.add_argument("--batch", type=positive, default=64, help="Number of (src) sentences in batch.")
parser.add_argument("--epoch", type=positive, default=10, help="Number of epoch to train the model.")
parser.add_argument("--depth", type=positive, default=1, help="Depth of the network.")
parser.add_argument("--save_len", type=positive, default=1, help="Number of iteration being done for ")
parser.add_argument("--verbose", action="store_true", help="To output the training progress for every sentence in corpora.")
parser.add_argument("--use_cpu", action="store_true", help="Force to use CPU.")
parser.add_argument("--save_models", action="store_true", help="Save models for every iteration with auto enumeration.")
parser.add_argument("--gpu", type=int, default=-1, help="Specify GPU to be used, negative for using CPU.")
parser.add_argument("--init_model", type=str, help="Init the training weights with saved model.")
parser.add_argument("--model",type=str,choices=["lstm"], default="lstm", help="Type of model being trained.")
parser.add_argument("--unk_cut", type=int, default=1, help="Threshold for words in corpora to be treated as unknown.")
parser.add_argument("--dropout", type=positive_decimal, default=0.2, help="Dropout ratio for LSTM.")
parser.add_argument("--seed", type=int, default=0, help="Seed for RNG. 0 for totally random seed.")
args = parser.parse_args()
if args.use_cpu:
args.gpu = -1
""" Training """
trainer = ParallelTrainer(args.seed, args.gpu)
# data
UF.trace("Loading corpus + dictionary")
X, Y, data = load_pos_train_data(sys.stdin, cut_threshold=args.unk_cut)
data = list(batch_generator(data, (X, Y), args.batch))
UF.trace("INPUT size:", len(X))
UF.trace("LABEL size:", len(Y))
UF.trace("Data loaded.")
""" Setup model """
UF.trace("Setting up classifier")
opt = optimizers.Adam()
model = ParallelTextClassifier(args, X, Y, opt, args.gpu, activation=F.relu, collect_output=args.verbose)
""" Training Callback """
def onEpochStart(epoch):
UF.trace("Starting Epoch", epoch+1)
def report(output, src, trg, trained, epoch):
for index in range(len(src)):
source = SRC.str_rpr(src[index])
ref = TRG.str_rpr(trg[index])
out = TRG.str_rpr(output.y[index])
UF.trace("Epoch (%d/%d) sample %d:\n\tSRC: %s\n\tOUT: %s\n\tREF: %s" % (epoch+1, args.epoch, index+trained, source, out, ref))
def onBatchUpdate(output, src, trg, trained, epoch, accum_loss):
if args.verbose:
report(output, src, trg, trained, epoch)
UF.trace("Trained %d: %f, col_size=%d" % (trained, accum_loss, len(trg[0])))
def save_model(epoch):
out_file = args.model_out
if args.save_models:
out_file += "-" + str(epoch)
UF.trace("saving model to " + out_file + "...")
serializer = ModelSerializer(out_file)
serializer.save(model)
def onEpochUpdate(epoch_loss, prev_loss, epoch):
UF.trace("Train Loss:", float(prev_loss), "->", float(epoch_loss))
UF.trace("Train PPL:", math.exp(float(prev_loss)), "->", math.exp(float(epoch_loss)))
# saving model
if args.save_models and (epoch + 1) % args.save_len == 0:
save_model(epoch)
def onTrainingFinish(epoch):
if not args.save_models or epoch % args.save_len != 0:
save_model(epoch)
UF.trace("training complete!")
""" Execute Training loop """
trainer.train(data, model, args.epoch, onEpochStart, onBatchUpdate, onEpochUpdate, onTrainingFinish)