def next(self, state, weights, gradients): optimizer.check_weights_gradients_match(weights, gradients) if self._momentum is None: updated_state = state updated_weights = tf.nest.map_structure( lambda w, g: w - self._lr * g, weights, gradients) else: _check_momentum_matches_weights(state, weights) updated_state = tf.nest.map_structure( lambda m, g: self._momentum * m + g, state, gradients) updated_weights = tf.nest.map_structure( lambda w, m: w - self._lr * m, weights, updated_state) return updated_state, updated_weights
def next(self, state, weights, gradients): gradients = optimizer.handle_indexed_slices_gradients(gradients) optimizer.check_weights_gradients_match(weights, gradients) lr = state[optimizer.LEARNING_RATE_KEY] beta_1 = state[_BETA_1_KEY] beta_2 = state[_BETA_2_KEY] epsilon = state[_EPSILON_KEY] step = state[_STEP_KEY] + 1 accumulator = state[_ACCUMULATOR_KEY] preconditioner = state[_PRECONDITIONER_KEY] optimizer.check_weights_state_match(weights, accumulator, 'accumulator') optimizer.check_weights_state_match(weights, preconditioner, 'preconditioner') updated_accumulator = tf.nest.map_structure( lambda a, g: a + (g - a) * (1 - beta_1), accumulator, gradients) def preconditioner_update(s, g): g2 = tf.math.square(g) sign = tf.sign(g2 - s) return s + (1 - beta_2) * sign * g2 updated_preconditioner = tf.nest.map_structure(preconditioner_update, preconditioner, gradients) normalized_lr = lr * tf.math.sqrt( (1 - tf.math.pow(beta_2, tf.cast(step, tf.float32)))) / ( 1 - tf.math.pow(beta_1, tf.cast(step, tf.float32))) updated_weights = tf.nest.map_structure( lambda w, g, a, s: w - normalized_lr * a / (tf.math.sqrt(s) + epsilon), weights, gradients, updated_accumulator, updated_preconditioner) updated_state = collections.OrderedDict([ (optimizer.LEARNING_RATE_KEY, lr), (_BETA_1_KEY, beta_1), (_BETA_2_KEY, beta_2), (_EPSILON_KEY, epsilon), (_STEP_KEY, step), (_ACCUMULATOR_KEY, updated_accumulator), (_PRECONDITIONER_KEY, updated_preconditioner), ]) return updated_state, updated_weights
def next(self, state, weights, gradients): gradients = optimizer.handle_indexed_slices_gradients(gradients) optimizer.check_weights_gradients_match(weights, gradients) lr = state[optimizer.LEARNING_RATE_KEY] epsilon = state[_EPSILON_KEY] preconditioner = state[_PRECONDITIONER_KEY] optimizer.check_weights_state_match(weights, preconditioner, 'preconditioner') updated_preconditioner = tf.nest.map_structure( lambda a, g: a + tf.math.square(g), preconditioner, gradients) updated_weights = tf.nest.map_structure( lambda w, g, a: w - lr * g / tf.math.sqrt(a + epsilon), weights, gradients, updated_preconditioner) updated_state = collections.OrderedDict([ (optimizer.LEARNING_RATE_KEY, lr), (_EPSILON_KEY, epsilon), (_PRECONDITIONER_KEY, updated_preconditioner), ]) return updated_state, updated_weights
def next(self, state, weights, gradients): gradients = optimizer.handle_indexed_slices_gradients(gradients) optimizer.check_weights_gradients_match(weights, gradients) lr = state[optimizer.LEARNING_RATE_KEY] if _MOMENTUM_KEY not in state: updated_weights = tf.nest.map_structure(lambda w, g: w - lr * g, weights, gradients) updated_state = collections.OrderedDict([(optimizer.LEARNING_RATE_KEY, lr) ]) else: momentum = state[_MOMENTUM_KEY] accumulator = state[_ACCUMULATOR_KEY] optimizer.check_weights_state_match(weights, accumulator, 'accumulator') updated_accumulator = tf.nest.map_structure(lambda a, g: momentum * a + g, accumulator, gradients) updated_weights = tf.nest.map_structure(lambda w, m: w - lr * m, weights, updated_accumulator) updated_state = collections.OrderedDict([ (optimizer.LEARNING_RATE_KEY, lr), (_MOMENTUM_KEY, momentum), (_ACCUMULATOR_KEY, updated_accumulator), ]) return updated_state, updated_weights
def test_check_weights_gradients_match(self, weights, gradients): with self.assertRaises(ValueError): optimizer.check_weights_gradients_match(weights, gradients)