Ejemplo n.º 1
0
    def next(self, state, weights, gradients):
        gradients = optimizer.handle_indexed_slices_gradients(gradients)
        optimizer.check_weights_gradients_match(weights, gradients)
        lr = state[optimizer.LEARNING_RATE_KEY]
        beta_1 = state[_BETA_1_KEY]
        beta_2 = state[_BETA_2_KEY]
        epsilon = state[_EPSILON_KEY]
        step = state[_STEP_KEY] + 1
        accumulator = state[_ACCUMULATOR_KEY]
        preconditioner = state[_PRECONDITIONER_KEY]
        optimizer.check_weights_state_match(weights, accumulator,
                                            'accumulator')
        optimizer.check_weights_state_match(weights, preconditioner,
                                            'preconditioner')

        updated_accumulator = tf.nest.map_structure(
            lambda a, g: a + (g - a) * (1 - beta_1), accumulator, gradients)

        def preconditioner_update(s, g):
            g2 = tf.math.square(g)
            sign = tf.sign(g2 - s)
            return s + (1 - beta_2) * sign * g2

        updated_preconditioner = tf.nest.map_structure(preconditioner_update,
                                                       preconditioner,
                                                       gradients)
        normalized_lr = lr * tf.math.sqrt(
            (1 - tf.math.pow(beta_2, tf.cast(step, tf.float32)))) / (
                1 - tf.math.pow(beta_1, tf.cast(step, tf.float32)))
        updated_weights = tf.nest.map_structure(
            lambda w, g, a, s: w - normalized_lr * a /
            (tf.math.sqrt(s) + epsilon), weights, gradients,
            updated_accumulator, updated_preconditioner)

        updated_state = collections.OrderedDict([
            (optimizer.LEARNING_RATE_KEY, lr),
            (_BETA_1_KEY, beta_1),
            (_BETA_2_KEY, beta_2),
            (_EPSILON_KEY, epsilon),
            (_STEP_KEY, step),
            (_ACCUMULATOR_KEY, updated_accumulator),
            (_PRECONDITIONER_KEY, updated_preconditioner),
        ])
        return updated_state, updated_weights
Ejemplo n.º 2
0
  def next(self, state, weights, gradients):
    gradients = optimizer.handle_indexed_slices_gradients(gradients)
    optimizer.check_weights_gradients_match(weights, gradients)
    lr = state[optimizer.LEARNING_RATE_KEY]
    epsilon = state[_EPSILON_KEY]
    preconditioner = state[_PRECONDITIONER_KEY]
    optimizer.check_weights_state_match(weights, preconditioner,
                                        'preconditioner')

    updated_preconditioner = tf.nest.map_structure(
        lambda a, g: a + tf.math.square(g), preconditioner, gradients)
    updated_weights = tf.nest.map_structure(
        lambda w, g, a: w - lr * g / tf.math.sqrt(a + epsilon), weights,
        gradients, updated_preconditioner)

    updated_state = collections.OrderedDict([
        (optimizer.LEARNING_RATE_KEY, lr),
        (_EPSILON_KEY, epsilon),
        (_PRECONDITIONER_KEY, updated_preconditioner),
    ])
    return updated_state, updated_weights
Ejemplo n.º 3
0
  def next(self, state, weights, gradients):
    gradients = optimizer.handle_indexed_slices_gradients(gradients)
    optimizer.check_weights_gradients_match(weights, gradients)
    lr = state[optimizer.LEARNING_RATE_KEY]

    if _MOMENTUM_KEY not in state:
      updated_weights = tf.nest.map_structure(lambda w, g: w - lr * g, weights,
                                              gradients)
      updated_state = collections.OrderedDict([(optimizer.LEARNING_RATE_KEY, lr)
                                              ])
    else:
      momentum = state[_MOMENTUM_KEY]
      accumulator = state[_ACCUMULATOR_KEY]
      optimizer.check_weights_state_match(weights, accumulator, 'accumulator')
      updated_accumulator = tf.nest.map_structure(lambda a, g: momentum * a + g,
                                                  accumulator, gradients)
      updated_weights = tf.nest.map_structure(lambda w, m: w - lr * m, weights,
                                              updated_accumulator)
      updated_state = collections.OrderedDict([
          (optimizer.LEARNING_RATE_KEY, lr),
          (_MOMENTUM_KEY, momentum),
          (_ACCUMULATOR_KEY, updated_accumulator),
      ])
    return updated_state, updated_weights
Ejemplo n.º 4
0
 def test_check_weights_state_match(self, weights, state):
     with self.assertRaisesRegex(ValueError, 'foo'):
         optimizer.check_weights_state_match(weights, state, name='foo')