def _compute_rnn_state_projections(self, state, new_param_state,
                                       grads_scaled):
        """Computes the RNN state-based updates to parameters and update steps."""
        # Compute the update direction (a linear projection of the RNN output).
        update_weights = self.update_weights

        update_delta = utils.project(new_param_state, update_weights)
        if self.use_gradient_shortcut:
            # Include an affine projection of just the direction of the gradient
            # so that RNN hidden states are freed up to store more complex
            # functions of the gradient and other parameters.
            grads_scaled_tensor = tf.concat([g for g in grads_scaled], 1)
            update_delta += utils.affine(grads_scaled_tensor,
                                         1,
                                         scope="GradsToDelta",
                                         include_bias=False,
                                         vec_mean=1. / len(grads_scaled),
                                         random_seed=self.random_seed)
        if self.dynamic_output_scale:
            denom = tf.sqrt(tf.reduce_mean(update_delta**2) + 1e-16)

            update_delta /= denom

        if self.use_attention:
            attention_weights = self.attention_weights
            attention_delta = utils.project(new_param_state, attention_weights)
            if self.use_gradient_shortcut:
                attention_delta += utils.affine(grads_scaled_tensor,
                                                1,
                                                scope="GradsToAttnDelta",
                                                include_bias=False,
                                                vec_mean=1. /
                                                len(grads_scaled),
                                                random_seed=self.random_seed)
            if self.dynamic_output_scale:
                attention_delta /= tf.sqrt(
                    tf.reduce_mean(attention_delta**2) + 1e-16)
        else:
            attention_delta = None

        # The updated decay is an affine projection of the hidden state.
        scl_decay = utils.project(new_param_state,
                                  self.scl_decay_weights,
                                  bias=self.scl_decay_bias,
                                  activation=tf.nn.sigmoid)
        # This is only used if learnable_decay and num_gradient_scales > 1
        inp_decay = utils.project(new_param_state,
                                  self.inp_decay_weights,
                                  bias=self.inp_decay_bias,
                                  activation=tf.nn.sigmoid)

        # Also update the learning rate.
        lr_param, lr_attend, new_log_lr = self._compute_new_learning_rate(
            state, new_param_state)

        update_step = tf.reshape(lr_param * update_delta,
                                 state["true_param"].get_shape())

        return (scl_decay, inp_decay, new_log_lr, update_step, lr_attend,
                attention_delta)
예제 #2
0
  def _compute_rnn_state_projections(self, state, new_param_state,
                                     grads_scaled):
    """Computes the RNN state-based updates to parameters and update steps."""
    # Compute the update direction (a linear projection of the RNN output).
    update_weights = self.update_weights

    update_delta = utils.project(new_param_state, update_weights)
    if self.use_gradient_shortcut:
      # Include an affine projection of just the direction of the gradient
      # so that RNN hidden states are freed up to store more complex
      # functions of the gradient and other parameters.
      grads_scaled_tensor = tf.concat([g for g in grads_scaled], 1)
      update_delta += utils.affine(grads_scaled_tensor, 1,
                                   scope="GradsToDelta",
                                   include_bias=False,
                                   vec_mean=1. / len(grads_scaled),
                                   random_seed=self.random_seed)
    if self.dynamic_output_scale:
      denom = tf.sqrt(tf.reduce_mean(update_delta ** 2) + 1e-16)

      update_delta /= denom

    if self.use_attention:
      attention_weights = self.attention_weights
      attention_delta = utils.project(new_param_state,
                                      attention_weights)
      if self.use_gradient_shortcut:
        attention_delta += utils.affine(grads_scaled_tensor, 1,
                                        scope="GradsToAttnDelta",
                                        include_bias=False,
                                        vec_mean=1. / len(grads_scaled),
                                        random_seed=self.random_seed)
      if self.dynamic_output_scale:
        attention_delta /= tf.sqrt(
            tf.reduce_mean(attention_delta ** 2) + 1e-16)
    else:
      attention_delta = None

    # The updated decay is an affine projection of the hidden state.
    scl_decay = utils.project(new_param_state, self.scl_decay_weights,
                              bias=self.scl_decay_bias,
                              activation=tf.nn.sigmoid)
    # This is only used if learnable_decay and num_gradient_scales > 1
    inp_decay = utils.project(new_param_state, self.inp_decay_weights,
                              bias=self.inp_decay_bias,
                              activation=tf.nn.sigmoid)

    # Also update the learning rate.
    lr_param, lr_attend, new_log_lr = self._compute_new_learning_rate(
        state, new_param_state)

    update_step = tf.reshape(lr_param * update_delta,
                             state["true_param"].get_shape())

    return (scl_decay, inp_decay, new_log_lr, update_step, lr_attend,
            attention_delta)
    def _compute_new_learning_rate(self, state, new_param_state):
        if self.dynamic_output_scale:
            # Compute the change in learning rate (an affine projection of the
            # RNN state, passed through a sigmoid or log depending on flags).
            # Update the learning rate, w/ momentum.
            lr_change = utils.project(new_param_state,
                                      self.lr_weights,
                                      bias=self.lr_bias)
            step_log_lr = state["log_learning_rate"] + lr_change

            # Clip the log learning rate to the flag at the top end, and to
            # (log(min int32) - 1) at the bottom

            # Check out this hack: we want to be able to compute the gradient
            # of the downstream result w.r.t lr weights and bias, even if the
            # value of step_log_lr is outside the clip range. So we clip,
            # subtract off step_log_lr, and wrap all that in a stop_gradient so
            # TF never tries to take the gradient of the clip... or the
            # subtraction. Then we add BACK step_log_lr so that downstream still
            # receives the clipped value. But the GRADIENT of step_log_lr will
            # be the gradient of the unclipped value, which we added back in
            # after stop_gradients.
            step_log_lr += tf.stop_gradient(
                tf.clip_by_value(step_log_lr, -33, self.max_log_lr) -
                step_log_lr)

            lr_momentum_logit = tf.get_variable(
                "learning_rate_momentum_logit",
                initializer=FLAGS.learning_rate_momentum_logit_init)
            lrm = tf.nn.sigmoid(lr_momentum_logit)
            new_log_lr = (lrm * state["log_learning_rate"] +
                          (1. - lrm) * step_log_lr)
            param_stepsize_offset = tf.get_variable("param_stepsize_offset",
                                                    initializer=-1.)
            lr_param = tf.exp(step_log_lr + param_stepsize_offset)
            lr_attend = tf.exp(step_log_lr) if self.use_attention else lr_param
        else:
            # Dynamic output scale is off, LR param is always 1.
            lr_param = 2. * utils.project(new_param_state,
                                          self.lr_weights,
                                          bias=self.lr_bias,
                                          activation=tf.nn.sigmoid)
            new_log_lr = None
            lr_attend = lr_param

        return lr_param, lr_attend, new_log_lr
예제 #4
0
  def _compute_new_learning_rate(self, state, new_param_state):
    if self.dynamic_output_scale:
      # Compute the change in learning rate (an affine projection of the
      # RNN state, passed through a sigmoid or log depending on flags).
      # Update the learning rate, w/ momentum.
      lr_change = utils.project(new_param_state, self.lr_weights,
                                bias=self.lr_bias)
      step_log_lr = state["log_learning_rate"] + lr_change

      # Clip the log learning rate to the flag at the top end, and to
      # (log(min int32) - 1) at the bottom

      # Check out this hack: we want to be able to compute the gradient
      # of the downstream result w.r.t lr weights and bias, even if the
      # value of step_log_lr is outside the clip range. So we clip,
      # subtract off step_log_lr, and wrap all that in a stop_gradient so
      # TF never tries to take the gradient of the clip... or the
      # subtraction. Then we add BACK step_log_lr so that downstream still
      # receives the clipped value. But the GRADIENT of step_log_lr will
      # be the gradient of the unclipped value, which we added back in
      # after stop_gradients.
      step_log_lr += tf.stop_gradient(
          tf.clip_by_value(step_log_lr, -33, self.max_log_lr)
          - step_log_lr)

      lr_momentum_logit = tf.get_variable(
          "learning_rate_momentum_logit",
          initializer=FLAGS.learning_rate_momentum_logit_init)
      lrm = tf.nn.sigmoid(lr_momentum_logit)
      new_log_lr = (lrm * state["log_learning_rate"] +
                    (1. - lrm) * step_log_lr)
      param_stepsize_offset = tf.get_variable("param_stepsize_offset",
                                              initializer=-1.)
      lr_param = tf.exp(step_log_lr + param_stepsize_offset)
      lr_attend = tf.exp(step_log_lr) if self.use_attention else lr_param
    else:
      # Dynamic output scale is off, LR param is always 1.
      lr_param = 2. * utils.project(new_param_state, self.lr_weights,
                                    bias=self.lr_bias,
                                    activation=tf.nn.sigmoid)
      new_log_lr = None
      lr_attend = lr_param

    return lr_param, lr_attend, new_log_lr
예제 #5
0
  def _compute_update(self, param, grad, state):
    """Update parameters given the gradient and state.

    Args:
      param: tensor of parameters
      grad: tensor of gradients with the same shape as param
      state: a dictionary containing any state for the optimizer

    Returns:
      updated_param: updated parameters
      updated_state: updated state variables in a dictionary
    """

    with tf.variable_scope(opt.OPTIMIZER_SCOPE) as scope:

      if self.reuse_vars:
        scope.reuse_variables()
      else:
        self.reuse_vars = True

      param_shape = tf.shape(param)

      (grad_values, decay_state, rms_state, rnn_state, learning_rate_state,
       grad_indices) = self._extract_gradients_and_internal_state(
           grad, state, param_shape)

      # Vectorize and scale the gradients.
      grad_scaled, rms = utils.rms_scaling(grad_values, decay_state, rms_state)

      # Apply the RNN update.
      rnn_state_tuples = self._unpack_rnn_state_into_tuples(rnn_state)
      rnn_output, rnn_state_tuples = self.cell(grad_scaled, rnn_state_tuples)
      rnn_state = self._pack_tuples_into_rnn_state(rnn_state_tuples)

      # Compute the update direction (a linear projection of the RNN output).
      delta = utils.project(rnn_output, self.update_weights)

      # The updated decay is an affine projection of the hidden state
      decay = utils.project(rnn_output, self.decay_weights,
                            bias=self.decay_bias, activation=tf.nn.sigmoid)

      # Compute the change in learning rate (an affine projection of the RNN
      # state, passed through a 2x sigmoid, so the change is bounded).
      learning_rate_change = 2. * utils.project(rnn_output, self.lr_weights,
                                                bias=self.lr_bias,
                                                activation=tf.nn.sigmoid)

      # Update the learning rate.
      new_learning_rate = learning_rate_change * learning_rate_state

      # Apply the update to the parameters.
      update = tf.reshape(new_learning_rate * delta, tf.shape(grad_values))

      if isinstance(grad, tf.IndexedSlices):
        update = utils.stack_tensor(update, grad_indices, param,
                                    param_shape[:1])
        rms = utils.update_slices(rms, grad_indices, state["rms"], param_shape)
        new_learning_rate = utils.update_slices(new_learning_rate, grad_indices,
                                                state["learning_rate"],
                                                param_shape)
        rnn_state = utils.update_slices(rnn_state, grad_indices, state["rnn"],
                                        param_shape)
        decay = utils.update_slices(decay, grad_indices, state["decay"],
                                    param_shape)

      new_param = param - update

      # Collect the update and new state.
      new_state = {
          "rms": rms,
          "learning_rate": new_learning_rate,
          "rnn": rnn_state,
          "decay": decay,
      }

    return new_param, new_state
예제 #6
0
    def _compute_update(self, param, grad, state):
        """Update parameters given the gradient and state.

    Args:
      param: tensor of parameters
      grad: tensor of gradients with the same shape as param
      state: a dictionary containing any state for the optimizer

    Returns:
      updated_param: updated parameters
      updated_state: updated state variables in a dictionary
    """

        with tf.variable_scope(opt.OPTIMIZER_SCOPE) as scope:

            if self.reuse_vars:
                scope.reuse_variables()
            else:
                self.reuse_vars = True

            param_shape = tf.shape(param)

            (grad_values, decay_state, rms_state, rnn_state,
             learning_rate_state,
             grad_indices) = self._extract_gradients_and_internal_state(
                 grad, state, param_shape)

            # Vectorize and scale the gradients.
            grad_scaled, rms = utils.rms_scaling(grad_values, decay_state,
                                                 rms_state)

            # Apply the RNN update.
            rnn_state_tuples = self._unpack_rnn_state_into_tuples(rnn_state)
            rnn_output, rnn_state_tuples = self.cell(grad_scaled,
                                                     rnn_state_tuples)
            rnn_state = self._pack_tuples_into_rnn_state(rnn_state_tuples)

            # Compute the update direction (a linear projection of the RNN output).
            delta = utils.project(rnn_output, self.update_weights)

            # The updated decay is an affine projection of the hidden state
            decay = utils.project(rnn_output,
                                  self.decay_weights,
                                  bias=self.decay_bias,
                                  activation=tf.nn.sigmoid)

            # Compute the change in learning rate (an affine projection of the RNN
            # state, passed through a 2x sigmoid, so the change is bounded).
            learning_rate_change = 2. * utils.project(rnn_output,
                                                      self.lr_weights,
                                                      bias=self.lr_bias,
                                                      activation=tf.nn.sigmoid)

            # Update the learning rate.
            new_learning_rate = learning_rate_change * learning_rate_state

            # Apply the update to the parameters.
            update = tf.reshape(new_learning_rate * delta,
                                tf.shape(grad_values))

            if isinstance(grad, tf.IndexedSlices):
                update = utils.stack_tensor(update, grad_indices, param,
                                            param_shape[:1])
                rms = utils.update_slices(rms, grad_indices, state["rms"],
                                          param_shape)
                new_learning_rate = utils.update_slices(
                    new_learning_rate, grad_indices, state["learning_rate"],
                    param_shape)
                rnn_state = utils.update_slices(rnn_state, grad_indices,
                                                state["rnn"], param_shape)
                decay = utils.update_slices(decay, grad_indices,
                                            state["decay"], param_shape)

            new_param = param - update

            # Collect the update and new state.
            new_state = {
                "rms": rms,
                "learning_rate": new_learning_rate,
                "rnn": rnn_state,
                "decay": decay,
            }

        return new_param, new_state