def _compute_rnn_state_projections(self, state, new_param_state, grads_scaled): """Computes the RNN state-based updates to parameters and update steps.""" # Compute the update direction (a linear projection of the RNN output). update_weights = self.update_weights update_delta = utils.project(new_param_state, update_weights) if self.use_gradient_shortcut: # Include an affine projection of just the direction of the gradient # so that RNN hidden states are freed up to store more complex # functions of the gradient and other parameters. grads_scaled_tensor = tf.concat([g for g in grads_scaled], 1) update_delta += utils.affine(grads_scaled_tensor, 1, scope="GradsToDelta", include_bias=False, vec_mean=1. / len(grads_scaled), random_seed=self.random_seed) if self.dynamic_output_scale: denom = tf.sqrt(tf.reduce_mean(update_delta**2) + 1e-16) update_delta /= denom if self.use_attention: attention_weights = self.attention_weights attention_delta = utils.project(new_param_state, attention_weights) if self.use_gradient_shortcut: attention_delta += utils.affine(grads_scaled_tensor, 1, scope="GradsToAttnDelta", include_bias=False, vec_mean=1. / len(grads_scaled), random_seed=self.random_seed) if self.dynamic_output_scale: attention_delta /= tf.sqrt( tf.reduce_mean(attention_delta**2) + 1e-16) else: attention_delta = None # The updated decay is an affine projection of the hidden state. scl_decay = utils.project(new_param_state, self.scl_decay_weights, bias=self.scl_decay_bias, activation=tf.nn.sigmoid) # This is only used if learnable_decay and num_gradient_scales > 1 inp_decay = utils.project(new_param_state, self.inp_decay_weights, bias=self.inp_decay_bias, activation=tf.nn.sigmoid) # Also update the learning rate. lr_param, lr_attend, new_log_lr = self._compute_new_learning_rate( state, new_param_state) update_step = tf.reshape(lr_param * update_delta, state["true_param"].get_shape()) return (scl_decay, inp_decay, new_log_lr, update_step, lr_attend, attention_delta)
def __call__(self, inputs, state, bias=None): # Split the injected bias vector into a bias for the r, u, and c updates. if bias is None: bias = tf.zeros((1, 3)) r_bias, u_bias, c_bias = tf.split(bias, 3, 1) with tf.variable_scope(type(self).__name__): # "BiasGRUCell" with tf.variable_scope("gates"): # Reset gate and update gate. proj = utils.affine([inputs, state], 2 * self._num_units, scale=self._scale, bias_init=self._gate_bias_init, random_seed=self._random_seed) r_lin, u_lin = tf.split(proj, 2, 1) r, u = tf.nn.sigmoid(r_lin + r_bias), tf.nn.sigmoid(u_lin + u_bias) with tf.variable_scope("candidate"): proj = utils.affine([inputs, r * state], self._num_units, scale=self._scale, random_seed=self._random_seed) c = self._activation(proj + c_bias) new_h = u * state + (1 - u) * c return new_h, new_h
def _compute_rnn_state_projections(self, state, new_param_state, grads_scaled): """Computes the RNN state-based updates to parameters and update steps.""" # Compute the update direction (a linear projection of the RNN output). update_weights = self.update_weights update_delta = utils.project(new_param_state, update_weights) if self.use_gradient_shortcut: # Include an affine projection of just the direction of the gradient # so that RNN hidden states are freed up to store more complex # functions of the gradient and other parameters. grads_scaled_tensor = tf.concat([g for g in grads_scaled], 1) update_delta += utils.affine(grads_scaled_tensor, 1, scope="GradsToDelta", include_bias=False, vec_mean=1. / len(grads_scaled), random_seed=self.random_seed) if self.dynamic_output_scale: denom = tf.sqrt(tf.reduce_mean(update_delta ** 2) + 1e-16) update_delta /= denom if self.use_attention: attention_weights = self.attention_weights attention_delta = utils.project(new_param_state, attention_weights) if self.use_gradient_shortcut: attention_delta += utils.affine(grads_scaled_tensor, 1, scope="GradsToAttnDelta", include_bias=False, vec_mean=1. / len(grads_scaled), random_seed=self.random_seed) if self.dynamic_output_scale: attention_delta /= tf.sqrt( tf.reduce_mean(attention_delta ** 2) + 1e-16) else: attention_delta = None # The updated decay is an affine projection of the hidden state. scl_decay = utils.project(new_param_state, self.scl_decay_weights, bias=self.scl_decay_bias, activation=tf.nn.sigmoid) # This is only used if learnable_decay and num_gradient_scales > 1 inp_decay = utils.project(new_param_state, self.inp_decay_weights, bias=self.inp_decay_bias, activation=tf.nn.sigmoid) # Also update the learning rate. lr_param, lr_attend, new_log_lr = self._compute_new_learning_rate( state, new_param_state) update_step = tf.reshape(lr_param * update_delta, state["true_param"].get_shape()) return (scl_decay, inp_decay, new_log_lr, update_step, lr_attend, attention_delta)
def _update_rnn_cells(self, state, global_state, rnn_input_tensor, use_additional_features): """Updates the component RNN cells with the given state and tensor. Args: state: The current state of the optimizer. global_state: The current global RNN state. rnn_input_tensor: The input tensor to the RNN. use_additional_features: Whether the rnn input tensor contains additional features beyond the scaled gradients (affects whether the rnn input tensor is used as input to the RNN.) Returns: layer_state: The new state of the per-tensor RNN. new_param_state: The new state of the per-parameter RNN. """ # lowest level (per parameter) # input -> gradient for this parameter # bias -> output from the layer RNN with tf.variable_scope("Layer0_RNN"): total_bias = None if self.num_layers > 1: sz = 3 * self.cells[ 0].state_size # size of the concatenated bias param_bias = utils.affine([state["layer"]], sz, scope="Param/Affine", scale=FLAGS.hrnn_affine_scale, random_seed=self.random_seed) total_bias = param_bias if self.num_layers == 3: global_bias = utils.affine(global_state, sz, scope="Global/Affine", scale=FLAGS.hrnn_affine_scale, random_seed=self.random_seed) total_bias += global_bias new_param_state, _ = self.cells[0](rnn_input_tensor, state["parameter"], bias=total_bias) if self.num_layers > 1: # middle level (per layer) # input -> average hidden state from each parameter in this layer # bias -> output from the RNN at the global level with tf.variable_scope("Layer1_RNN"): if not use_additional_features: # Restore old behavior and only add the mean of the new params. layer_input = tf.reduce_mean(new_param_state, 0, keep_dims=True) else: layer_input = tf.reduce_mean(tf.concat( (new_param_state, rnn_input_tensor), 1), 0, keep_dims=True) if self.num_layers == 3: sz = 3 * self.cells[1].state_size layer_bias = utils.affine(global_state, sz, scale=FLAGS.hrnn_affine_scale, random_seed=self.random_seed) layer_state, _ = self.cells[1](layer_input, state["layer"], bias=layer_bias) else: layer_state, _ = self.cells[1](layer_input, state["layer"]) else: layer_state = None return layer_state, new_param_state
def _update_rnn_cells(self, state, global_state, rnn_input_tensor, use_additional_features): """Updates the component RNN cells with the given state and tensor. Args: state: The current state of the optimizer. global_state: The current global RNN state. rnn_input_tensor: The input tensor to the RNN. use_additional_features: Whether the rnn input tensor contains additional features beyond the scaled gradients (affects whether the rnn input tensor is used as input to the RNN.) Returns: layer_state: The new state of the per-tensor RNN. new_param_state: The new state of the per-parameter RNN. """ # lowest level (per parameter) # input -> gradient for this parameter # bias -> output from the layer RNN with tf.variable_scope("Layer0_RNN"): total_bias = None if self.num_layers > 1: sz = 3 * self.cells[0].state_size # size of the concatenated bias param_bias = utils.affine([state["layer"]], sz, scope="Param/Affine", scale=FLAGS.hrnn_affine_scale, random_seed=self.random_seed) total_bias = param_bias if self.num_layers == 3: global_bias = utils.affine(global_state, sz, scope="Global/Affine", scale=FLAGS.hrnn_affine_scale, random_seed=self.random_seed) total_bias += global_bias new_param_state, _ = self.cells[0]( rnn_input_tensor, state["parameter"], bias=total_bias) if self.num_layers > 1: # middle level (per layer) # input -> average hidden state from each parameter in this layer # bias -> output from the RNN at the global level with tf.variable_scope("Layer1_RNN"): if not use_additional_features: # Restore old behavior and only add the mean of the new params. layer_input = tf.reduce_mean(new_param_state, 0, keep_dims=True) else: layer_input = tf.reduce_mean( tf.concat((new_param_state, rnn_input_tensor), 1), 0, keep_dims=True) if self.num_layers == 3: sz = 3 * self.cells[1].state_size layer_bias = utils.affine(global_state, sz, scale=FLAGS.hrnn_affine_scale, random_seed=self.random_seed) layer_state, _ = self.cells[1]( layer_input, state["layer"], bias=layer_bias) else: layer_state, _ = self.cells[1](layer_input, state["layer"]) else: layer_state = None return layer_state, new_param_state