def get_layerwise_gate(layer_id):
     steps_per_phase = num_train_steps // bert_config.num_hidden_layers
     layer_wise_gate = distill_util.layer_wise_learning_rate(
         layer_id=layer_id,
         steps_per_phase=steps_per_phase,
         binary=True)
     return layer_wise_gate
 def apply_gradients(self, grads_and_vars, global_step=None, name=None):
     """See base class."""
     assignments = []
     background_lr = distill_util.get_background_lr(
         global_step=global_step, steps_per_phase=self.steps_per_phase)
     for (grad, param) in grads_and_vars:
         if grad is None or param is None:
             continue
         param_name = self._get_variable_name(param.name)
         m = tf.get_variable(name=param_name + "/adam_m",
                             shape=param.shape.as_list(),
                             dtype=tf.float32,
                             trainable=False,
                             initializer=tf.zeros_initializer())
         v = tf.get_variable(name=param_name + "/adam_v",
                             shape=param.shape.as_list(),
                             dtype=tf.float32,
                             trainable=False,
                             initializer=tf.zeros_initializer())
         if self.use_layer_wise_warmup:
             # Use model-specific name spaces to get layer id.
             if param_name.startswith("bert/encoder/layer_"):
                 layer_id = int(
                     param_name[len("bert/encoder/layer_"):].split("/",
                                                                   1)[0])
                 layer_wise_lr = distill_util.layer_wise_learning_rate(
                     layer_id=layer_id,
                     steps_per_phase=self.steps_per_phase,
                     background_lr=background_lr)
                 layer_wise_gate = tf.where(
                     tf.math.greater(layer_wise_lr, 0.0), 1.0, 0.0)
             else:
                 layer_wise_lr = 0.0
                 layer_wise_gate = 0.0
         else:
             layer_wise_lr = 1.0
             layer_wise_gate = 1.0
         # Standard Adam update.
         next_m = layer_wise_gate * (tf.multiply(self.beta_1, m) +
                                     tf.multiply(1.0 - self.beta_1, grad))
         next_v = layer_wise_gate * (
             tf.multiply(self.beta_2, v) +
             tf.multiply(1.0 - self.beta_2, tf.square(grad)))
         update = next_m / (tf.sqrt(next_v) + self.epsilon)
         # Just adding the square of the weights to the loss function is *not*
         # the correct way of using L2 regularization/weight decay with Adam,
         # since that will interact with the m and v parameters in strange ways.
         #
         # Instead we want ot decay the weights in a manner that doesn't interact
         # with the m/v parameters. This is equivalent to adding the square
         # of the weights to the loss with plain (non-momentum) SGD.
         if self._do_use_weight_decay(param_name):
             update += layer_wise_gate * self.weight_decay_rate * param
         ratio = 1.0
         if self._do_layer_adaptation(param_name):
             w_norm = tf.linalg.norm(param, ord=2)
             g_norm = tf.linalg.norm(update, ord=2)
             ratio = tf.where(
                 tf.math.greater(w_norm, 0),
                 tf.where(tf.math.greater(g_norm, 0), (w_norm / g_norm),
                          1.0), 1.0)
         update_with_lr = layer_wise_lr * ratio * self.learning_rate * update
         next_param = param - update_with_lr
         assignments.extend(
             [param.assign(next_param),
              m.assign(next_m),
              v.assign(next_v)])
     return tf.group(*assignments, name=name)