def get_layerwise_gate(layer_id): steps_per_phase = num_train_steps // bert_config.num_hidden_layers layer_wise_gate = distill_util.layer_wise_learning_rate( layer_id=layer_id, steps_per_phase=steps_per_phase, binary=True) return layer_wise_gate
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """See base class.""" assignments = [] background_lr = distill_util.get_background_lr( global_step=global_step, steps_per_phase=self.steps_per_phase) for (grad, param) in grads_and_vars: if grad is None or param is None: continue param_name = self._get_variable_name(param.name) m = tf.get_variable(name=param_name + "/adam_m", shape=param.shape.as_list(), dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) v = tf.get_variable(name=param_name + "/adam_v", shape=param.shape.as_list(), dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) if self.use_layer_wise_warmup: # Use model-specific name spaces to get layer id. if param_name.startswith("bert/encoder/layer_"): layer_id = int( param_name[len("bert/encoder/layer_"):].split("/", 1)[0]) layer_wise_lr = distill_util.layer_wise_learning_rate( layer_id=layer_id, steps_per_phase=self.steps_per_phase, background_lr=background_lr) layer_wise_gate = tf.where( tf.math.greater(layer_wise_lr, 0.0), 1.0, 0.0) else: layer_wise_lr = 0.0 layer_wise_gate = 0.0 else: layer_wise_lr = 1.0 layer_wise_gate = 1.0 # Standard Adam update. next_m = layer_wise_gate * (tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) next_v = layer_wise_gate * ( tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, tf.square(grad))) update = next_m / (tf.sqrt(next_v) + self.epsilon) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want ot decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. if self._do_use_weight_decay(param_name): update += layer_wise_gate * self.weight_decay_rate * param ratio = 1.0 if self._do_layer_adaptation(param_name): w_norm = tf.linalg.norm(param, ord=2) g_norm = tf.linalg.norm(update, ord=2) ratio = tf.where( tf.math.greater(w_norm, 0), tf.where(tf.math.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0) update_with_lr = layer_wise_lr * ratio * self.learning_rate * update next_param = param - update_with_lr assignments.extend( [param.assign(next_param), m.assign(next_m), v.assign(next_v)]) return tf.group(*assignments, name=name)