def Apply(self, lr, var_grad): """Applies the gradient to the variable. Args: lr: A scalar. The base learning rate. var_grad: A `.NestedMap` of (var, grad) pairs. Returns: The variable update op. """ optimizer = self.GetOptimizer(lr) def _Apply(): return optimizer.apply_gradients( [(g, v) for (v, g) in var_grad.Flatten()], name='meta_backprop') if not py_utils.use_resource_variables(): var_update_op = _Apply() else: # Many optimizers, e.g., Adam, Adagrad, etc., create # variables. We need to ensure name scope and variable scope are # cleared. Otherwise, tpu.batch_parallel does not work. var_reuse = False if py_utils.GetOpportunisticVariableReuse(): var_reuse = tf.AUTO_REUSE with tf.name_scope(None): with tf.variable_scope( tf.VariableScope(use_resource=True, reuse=var_reuse)): var_update_op = _Apply() self.AddSummary(lr, optimizer, var_grad) return var_update_op
def Apply(self, lr, var_grad): """Applies the gradient to the variable. Args: lr: A scalar. The base learning rate. var_grad: A `.NestedMap` of (var, grad) pairs. Returns: The variable update op. """ optimizer = self.GetOptimizer(lr) def _Apply(): if self.params.use_bf16_gradients_ar: return optimizer.apply_gradients( [(tf.cast(g, tf.float32), v) for (v, g) in var_grad.Flatten()], name='meta_backprop') else: return optimizer.apply_gradients( [(g, v) for (v, g) in var_grad.Flatten()], name='meta_backprop') if not py_utils.use_resource_variables(): var_update_op = _Apply() else: # Many optimizers, e.g., Adam, Adagrad, etc., create # variables. We need to ensure name scope and variable scope are # cleared. Otherwise, tpu.batch_parallel does not work. with tf.name_scope(None): with tf.variable_scope( tf.VariableScope( use_resource=True, reuse=self.VarReuseForSlotVars())): var_update_op = _Apply() if self.params.add_summary_in_apply: self.AddSummary(lr, optimizer, var_grad) return var_update_op
def Apply(self, lr, var_grad): """For each optimizer, apply the gradient to the variable. Args: lr: A scalar. The base learning rate. var_grad: A `.NestedMap` of (var, grad) pairs. Returns: The variable update op. Raises: Exception: When the regex overlaps with or does not cover all variables. """ # Override inherited GetOptimizer even though learning rate is unused. tf_optimizer_map = self.GetOptimizer(0) var_grad_map = {regex: [] for regex in self._optimizer_map} for (v, g) in var_grad.Flatten(): regex_match = 0 for regex in self._optimizer_map: if re.match(regex, v.name): var_grad_map[regex].append((g, v)) regex_match += 1 if regex_match == 0: var_grad_map['default_optimizer'].append((g, v)) if regex_match > 1: raise Exception( 'Variable {} is matched {} times by regex {}'.format( v.name, regex_match, list(self._optimizer_map.keys()))) def _Apply(): """Use the matched optimizer to apply the gradients.""" train_ops = [] non_default_regex = [ regex for regex in self._optimizer_map if regex != 'default_optimizer' ] for regex in self._optimizer_map: if var_grad_map[regex]: opt = tf_optimizer_map[regex] train_ops.append(opt.apply_gradients(var_grad_map[regex])) # pylint: disable=cell-var-from-loop, g-long-lambda if regex == 'default_optimizer': filtered_var_grad = var_grad.FilterKeyVal( lambda k, v: any([ re.match(i, v.var.name) for i in non_default_regex ])) else: filtered_var_grad = var_grad.FilterKeyVal( lambda k, v: (re.match(regex, v.var.name))) # pylint: enable=cell-var-from-loop, g-long-lambda self._optimizer_map[regex].AddSummary( self._lr_map[regex], opt, filtered_var_grad) return tf.group(*train_ops, name='composite_optimizer_train_op') if not py_utils.use_resource_variables(): var_update_op = _Apply() else: # Many optimizers, e.g., Adam, Adagrad, etc., create # variables. We need to ensure name scope and variable scope are # cleared. Otherwise, tpu.batch_parallel does not work. var_reuse = False if py_utils.GetOpportunisticVariableReuse(): var_reuse = tf.AUTO_REUSE with tf.name_scope(None): with tf.variable_scope( tf.VariableScope(use_resource=True, reuse=var_reuse)): var_update_op = _Apply() return var_update_op