def _build_train(self): """Build training ops.""" print('-' * 80) print('Building train graph') reg_loss, loss = self._forward(self.x_train, self.y_train, self.train_params, self.batch_init_states, is_training=True) tf_vars = tf.trainable_variables() global_step = tf.train.get_or_create_global_step() lr_scale = (tf.cast(tf.shape(self.y_train)[-1], dtype=tf.float32) / tf.cast(self.params.bptt_steps, dtype=tf.float32)) learning_rate = utils.get_lr(global_step, self.params) * lr_scale # learning_rate = tf.Print( # learning_rate, # [learning_rate, lr_scale, self.base_bptt, tf.shape(self.y_train)], # message='lr: ', summarize=3) grads = tf.gradients(reg_loss, tf_vars) clipped_grads, grad_norm = tf.clip_by_global_norm( grads, self.params.grad_bound) (self.update_moving_avg_ops, self.use_moving_avg_vars, self.restore_normal_vars) = self._create_average_ops() optimizer = tf.train.GradientDescentOptimizer(learning_rate) train_op = optimizer.apply_gradients(zip(clipped_grads, tf_vars), global_step=global_step) self.train_loss = loss self.train_op = train_op self.grad_norm = grad_norm self.learning_rate = learning_rate
def _build_train(self): """Build training ops.""" print('-' * 80) print('Building train graph') reg_loss, loss = self._forward(self.x_train, self.y_train, self.train_params, self.batch_init_states, is_training=True) tf_vars = [ v for v in tf.trainable_variables() if v.name.startswith(self.name) ] global_step = tf.train.get_or_create_global_step() lr_scale = (tf.cast(tf.shape(self.y_train)[-1], dtype=tf.float32) / tf.cast(self.params.bptt_steps, dtype=tf.float32)) learning_rate = utils.get_lr(global_step, self.params) * lr_scale if self.params.grad_bound: grads = tf.gradients(reg_loss, tf_vars) clipped_grads, grad_norm = tf.clip_by_global_norm( grads, self.params.grad_bound) optimizer = tf.train.GradientDescentOptimizer(learning_rate) train_op = optimizer.apply_gradients(zip(clipped_grads, tf_vars), global_step=global_step) self.train_loss = loss self.train_op = train_op self.grad_norm = grad_norm self.learning_rate = learning_rate