forked from mpezeshki/RNN_Experiments
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
143 lines (123 loc) · 5.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import logging
import os
import numpy as np
import theano
from blocks.algorithms import (Adam, CompositeRule, GradientDescent,
Momentum, RMSProp, StepClipping,
RemoveNotFinite)
from blocks.extensions import Printing, ProgressBar
from blocks.extensions.monitoring import (TrainingDataMonitoring)
from blocks.extensions.saveload import Load
from blocks.filter import VariableFilter
from blocks.graph import ComputationGraph, apply_noise
from blocks.main_loop import MainLoop
from blocks.model import Model
from blocks.roles import WEIGHT
from extensions import (EarlyStopping, TextGenerationExtension,
ResetStates, InteractiveMode, VisualizeGateSoft,
VisualizeGateLSTM)
from datastream_monitoring import DataStreamMonitoring
# from blocks.extensions.saveload import Checkpoint
floatX = theano.config.floatX
logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)
def learning_algorithm(args):
name = args.algorithm
learning_rate = float(args.learning_rate)
momentum = args.momentum
clipping_threshold = args.clipping
if name == 'adam':
clipping = StepClipping(threshold=np.cast[floatX](clipping_threshold))
adam = Adam(learning_rate=learning_rate)
# [adam, clipping] means 'step clipping'
# [clipping, adam] means 'gradient clipping'
step_rule = CompositeRule([adam, clipping])
elif name == 'rms_prop':
clipping = StepClipping(threshold=np.cast[floatX](clipping_threshold))
rms_prop = RMSProp(learning_rate=learning_rate)
rm_non_finite = RemoveNotFinite()
step_rule = CompositeRule([clipping, rms_prop, rm_non_finite])
else:
clipping = StepClipping(threshold=np.cast[floatX](clipping_threshold))
sgd_momentum = Momentum(learning_rate=learning_rate, momentum=momentum)
rm_non_finite = RemoveNotFinite()
step_rule = CompositeRule([clipping, sgd_momentum, rm_non_finite])
return step_rule
def train_model(cost, cross_entropy, updates,
train_stream, valid_stream, args, gate_values=None):
step_rule = learning_algorithm(args)
cg = ComputationGraph(cost)
# ADD REGULARIZATION
# WEIGHT NOISE
weight_noise = args.weight_noise
if weight_noise > 0:
weights = VariableFilter(roles=[WEIGHT])(cg.variables)
cg_train = apply_noise(cg, weights, weight_noise)
cost = cg_train.outputs[0]
cost.name = "cost_with_weight_noise"
cg = ComputationGraph(cost)
logger.info(cg.parameters)
algorithm = GradientDescent(cost=cost, step_rule=step_rule,
params=cg.parameters)
algorithm.add_updates(updates)
# extensions to be added
extensions = []
if args.load_path is not None:
extensions.append(Load(args.load_path))
outputs = [
variable for variable in cg.variables if variable.name == "presoft"]
if args.generate:
extensions.append(TextGenerationExtension(
outputs=outputs,
generation_length=args.generated_text_lenght,
initial_text_length=args.initial_text_length,
every_n_batches=args.monitoring_freq,
ploting_path=os.path.join(args.save_path, 'prob_plot.png'),
softmax_sampling=args.softmax_sampling,
dataset=args.dataset,
updates=updates,
interactive_mode=args.interactive_mode))
extensions.extend([
TrainingDataMonitoring([cost], prefix='train',
every_n_batches=args.monitoring_freq,
after_epoch=True),
DataStreamMonitoring([cost, cross_entropy],
valid_stream, args.mini_batch_size_valid,
state_updates=updates,
prefix='valid',
before_first_epoch=not(args.visualize_gates),
every_n_batches=args.monitoring_freq),
ResetStates([v for v, _ in updates], every_n_batches=100),
ProgressBar()])
# Creating directory for saving model.
if not args.interactive_mode:
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
else:
raise Exception('Directory already exists')
early_stopping = EarlyStopping('valid_cross_entropy',
args.patience, args.save_path,
every_n_batches=args.monitoring_freq)
# Visualizing extensions
if args.interactive_mode:
extensions.append(InteractiveMode())
if args.visualize_gates and (gate_values is not None):
if args.rnn_type == "lstm":
extensions.append(VisualizeGateLSTM(gate_values, updates,
args.dataset,
ploting_path=None))
elif args.rnn_type == "soft":
extensions.append(VisualizeGateSoft(gate_values, updates,
args.dataset,
ploting_path=None))
else:
assert(False)
extensions.append(early_stopping)
extensions.append(Printing(every_n_batches=args.monitoring_freq))
main_loop = MainLoop(
model=Model(cost),
data_stream=train_stream,
algorithm=algorithm,
extensions=extensions
)
main_loop.run()