def _compute_rnn_state_projections(self, state, new_param_state, grads_scaled): """Computes the RNN state-based updates to parameters and update steps.""" # Compute the update direction (a linear projection of the RNN output). update_weights = self.update_weights update_delta = utils.project(new_param_state, update_weights) if self.use_gradient_shortcut: # Include an affine projection of just the direction of the gradient # so that RNN hidden states are freed up to store more complex # functions of the gradient and other parameters. grads_scaled_tensor = tf.concat([g for g in grads_scaled], 1) update_delta += utils.affine(grads_scaled_tensor, 1, scope="GradsToDelta", include_bias=False, vec_mean=1. / len(grads_scaled), random_seed=self.random_seed) if self.dynamic_output_scale: denom = tf.sqrt(tf.reduce_mean(update_delta**2) + 1e-16) update_delta /= denom if self.use_attention: attention_weights = self.attention_weights attention_delta = utils.project(new_param_state, attention_weights) if self.use_gradient_shortcut: attention_delta += utils.affine(grads_scaled_tensor, 1, scope="GradsToAttnDelta", include_bias=False, vec_mean=1. / len(grads_scaled), random_seed=self.random_seed) if self.dynamic_output_scale: attention_delta /= tf.sqrt( tf.reduce_mean(attention_delta**2) + 1e-16) else: attention_delta = None # The updated decay is an affine projection of the hidden state. scl_decay = utils.project(new_param_state, self.scl_decay_weights, bias=self.scl_decay_bias, activation=tf.nn.sigmoid) # This is only used if learnable_decay and num_gradient_scales > 1 inp_decay = utils.project(new_param_state, self.inp_decay_weights, bias=self.inp_decay_bias, activation=tf.nn.sigmoid) # Also update the learning rate. lr_param, lr_attend, new_log_lr = self._compute_new_learning_rate( state, new_param_state) update_step = tf.reshape(lr_param * update_delta, state["true_param"].get_shape()) return (scl_decay, inp_decay, new_log_lr, update_step, lr_attend, attention_delta)
def _compute_rnn_state_projections(self, state, new_param_state, grads_scaled): """Computes the RNN state-based updates to parameters and update steps.""" # Compute the update direction (a linear projection of the RNN output). update_weights = self.update_weights update_delta = utils.project(new_param_state, update_weights) if self.use_gradient_shortcut: # Include an affine projection of just the direction of the gradient # so that RNN hidden states are freed up to store more complex # functions of the gradient and other parameters. grads_scaled_tensor = tf.concat([g for g in grads_scaled], 1) update_delta += utils.affine(grads_scaled_tensor, 1, scope="GradsToDelta", include_bias=False, vec_mean=1. / len(grads_scaled), random_seed=self.random_seed) if self.dynamic_output_scale: denom = tf.sqrt(tf.reduce_mean(update_delta ** 2) + 1e-16) update_delta /= denom if self.use_attention: attention_weights = self.attention_weights attention_delta = utils.project(new_param_state, attention_weights) if self.use_gradient_shortcut: attention_delta += utils.affine(grads_scaled_tensor, 1, scope="GradsToAttnDelta", include_bias=False, vec_mean=1. / len(grads_scaled), random_seed=self.random_seed) if self.dynamic_output_scale: attention_delta /= tf.sqrt( tf.reduce_mean(attention_delta ** 2) + 1e-16) else: attention_delta = None # The updated decay is an affine projection of the hidden state. scl_decay = utils.project(new_param_state, self.scl_decay_weights, bias=self.scl_decay_bias, activation=tf.nn.sigmoid) # This is only used if learnable_decay and num_gradient_scales > 1 inp_decay = utils.project(new_param_state, self.inp_decay_weights, bias=self.inp_decay_bias, activation=tf.nn.sigmoid) # Also update the learning rate. lr_param, lr_attend, new_log_lr = self._compute_new_learning_rate( state, new_param_state) update_step = tf.reshape(lr_param * update_delta, state["true_param"].get_shape()) return (scl_decay, inp_decay, new_log_lr, update_step, lr_attend, attention_delta)
def _compute_new_learning_rate(self, state, new_param_state): if self.dynamic_output_scale: # Compute the change in learning rate (an affine projection of the # RNN state, passed through a sigmoid or log depending on flags). # Update the learning rate, w/ momentum. lr_change = utils.project(new_param_state, self.lr_weights, bias=self.lr_bias) step_log_lr = state["log_learning_rate"] + lr_change # Clip the log learning rate to the flag at the top end, and to # (log(min int32) - 1) at the bottom # Check out this hack: we want to be able to compute the gradient # of the downstream result w.r.t lr weights and bias, even if the # value of step_log_lr is outside the clip range. So we clip, # subtract off step_log_lr, and wrap all that in a stop_gradient so # TF never tries to take the gradient of the clip... or the # subtraction. Then we add BACK step_log_lr so that downstream still # receives the clipped value. But the GRADIENT of step_log_lr will # be the gradient of the unclipped value, which we added back in # after stop_gradients. step_log_lr += tf.stop_gradient( tf.clip_by_value(step_log_lr, -33, self.max_log_lr) - step_log_lr) lr_momentum_logit = tf.get_variable( "learning_rate_momentum_logit", initializer=FLAGS.learning_rate_momentum_logit_init) lrm = tf.nn.sigmoid(lr_momentum_logit) new_log_lr = (lrm * state["log_learning_rate"] + (1. - lrm) * step_log_lr) param_stepsize_offset = tf.get_variable("param_stepsize_offset", initializer=-1.) lr_param = tf.exp(step_log_lr + param_stepsize_offset) lr_attend = tf.exp(step_log_lr) if self.use_attention else lr_param else: # Dynamic output scale is off, LR param is always 1. lr_param = 2. * utils.project(new_param_state, self.lr_weights, bias=self.lr_bias, activation=tf.nn.sigmoid) new_log_lr = None lr_attend = lr_param return lr_param, lr_attend, new_log_lr
def _compute_update(self, param, grad, state): """Update parameters given the gradient and state. Args: param: tensor of parameters grad: tensor of gradients with the same shape as param state: a dictionary containing any state for the optimizer Returns: updated_param: updated parameters updated_state: updated state variables in a dictionary """ with tf.variable_scope(opt.OPTIMIZER_SCOPE) as scope: if self.reuse_vars: scope.reuse_variables() else: self.reuse_vars = True param_shape = tf.shape(param) (grad_values, decay_state, rms_state, rnn_state, learning_rate_state, grad_indices) = self._extract_gradients_and_internal_state( grad, state, param_shape) # Vectorize and scale the gradients. grad_scaled, rms = utils.rms_scaling(grad_values, decay_state, rms_state) # Apply the RNN update. rnn_state_tuples = self._unpack_rnn_state_into_tuples(rnn_state) rnn_output, rnn_state_tuples = self.cell(grad_scaled, rnn_state_tuples) rnn_state = self._pack_tuples_into_rnn_state(rnn_state_tuples) # Compute the update direction (a linear projection of the RNN output). delta = utils.project(rnn_output, self.update_weights) # The updated decay is an affine projection of the hidden state decay = utils.project(rnn_output, self.decay_weights, bias=self.decay_bias, activation=tf.nn.sigmoid) # Compute the change in learning rate (an affine projection of the RNN # state, passed through a 2x sigmoid, so the change is bounded). learning_rate_change = 2. * utils.project(rnn_output, self.lr_weights, bias=self.lr_bias, activation=tf.nn.sigmoid) # Update the learning rate. new_learning_rate = learning_rate_change * learning_rate_state # Apply the update to the parameters. update = tf.reshape(new_learning_rate * delta, tf.shape(grad_values)) if isinstance(grad, tf.IndexedSlices): update = utils.stack_tensor(update, grad_indices, param, param_shape[:1]) rms = utils.update_slices(rms, grad_indices, state["rms"], param_shape) new_learning_rate = utils.update_slices(new_learning_rate, grad_indices, state["learning_rate"], param_shape) rnn_state = utils.update_slices(rnn_state, grad_indices, state["rnn"], param_shape) decay = utils.update_slices(decay, grad_indices, state["decay"], param_shape) new_param = param - update # Collect the update and new state. new_state = { "rms": rms, "learning_rate": new_learning_rate, "rnn": rnn_state, "decay": decay, } return new_param, new_state
def _compute_update(self, param, grad, state): """Update parameters given the gradient and state. Args: param: tensor of parameters grad: tensor of gradients with the same shape as param state: a dictionary containing any state for the optimizer Returns: updated_param: updated parameters updated_state: updated state variables in a dictionary """ with tf.variable_scope(opt.OPTIMIZER_SCOPE) as scope: if self.reuse_vars: scope.reuse_variables() else: self.reuse_vars = True param_shape = tf.shape(param) (grad_values, decay_state, rms_state, rnn_state, learning_rate_state, grad_indices) = self._extract_gradients_and_internal_state( grad, state, param_shape) # Vectorize and scale the gradients. grad_scaled, rms = utils.rms_scaling(grad_values, decay_state, rms_state) # Apply the RNN update. rnn_state_tuples = self._unpack_rnn_state_into_tuples(rnn_state) rnn_output, rnn_state_tuples = self.cell(grad_scaled, rnn_state_tuples) rnn_state = self._pack_tuples_into_rnn_state(rnn_state_tuples) # Compute the update direction (a linear projection of the RNN output). delta = utils.project(rnn_output, self.update_weights) # The updated decay is an affine projection of the hidden state decay = utils.project(rnn_output, self.decay_weights, bias=self.decay_bias, activation=tf.nn.sigmoid) # Compute the change in learning rate (an affine projection of the RNN # state, passed through a 2x sigmoid, so the change is bounded). learning_rate_change = 2. * utils.project(rnn_output, self.lr_weights, bias=self.lr_bias, activation=tf.nn.sigmoid) # Update the learning rate. new_learning_rate = learning_rate_change * learning_rate_state # Apply the update to the parameters. update = tf.reshape(new_learning_rate * delta, tf.shape(grad_values)) if isinstance(grad, tf.IndexedSlices): update = utils.stack_tensor(update, grad_indices, param, param_shape[:1]) rms = utils.update_slices(rms, grad_indices, state["rms"], param_shape) new_learning_rate = utils.update_slices( new_learning_rate, grad_indices, state["learning_rate"], param_shape) rnn_state = utils.update_slices(rnn_state, grad_indices, state["rnn"], param_shape) decay = utils.update_slices(decay, grad_indices, state["decay"], param_shape) new_param = param - update # Collect the update and new state. new_state = { "rms": rms, "learning_rate": new_learning_rate, "rnn": rnn_state, "decay": decay, } return new_param, new_state