def norm(x, scope, *, axis=-1, epsilon=1e-5): '''Normalize to mean = 0, std = 1, then do a diagonal affine transform.''' with tf.variable_scope(scope): n_state = x.shape[-1].value g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1)) b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0)) u = tf.reduce_mean(x, axis=axis, keepdims=True) s = tf.reduce_mean(tf.square(x - u), axis=axis, keepdims=True) x = (x - u) * tf.rsqrt(s + epsilon) x = x * g + b return x
def conv1d(x, scope, nf, *, w_init_stdev=0.02): with tf.variable_scope(scope): *start, nx = shape_list(x) w = tf.get_variable( 'w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev)) b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0)) c = tf.reshape( tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf])) + b, start + [nf]) return c
def __init__(self, student_config, bert_config, is_training, input_ids, input_mask, segment_ids, label_ids=None, sample_weight=None, scope='bert', dtype=tf.float32, drop_pooler=False, label_size=2, pred_temporature=1.0, emd_temporature=1.0, beta=0.01, **kwargs): super(TinyBERTCLSDistillor, self).__init__() def _get_logits(pooled_output, hidden_size, scope, trainable): with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) logits = tf.matmul(pooled_output, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) return logits use_tilda_embedding = kwargs.get('use_tilda_embedding') student = BERTEncoder(bert_config=student_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, scope='tiny/bert', use_tilda_embedding=use_tilda_embedding, drop_pooler=drop_pooler, trainable=True, **kwargs) student_logits = _get_logits(student.get_pooled_output(), student_config.hidden_size, 'tiny/cls/seq_relationship', True) if is_training: teacher = BERTEncoder(bert_config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, scope=scope, use_tilda_embedding=False, drop_pooler=drop_pooler, trainable=False, **kwargs) teacher_logits = _get_logits(teacher.get_pooled_output(), bert_config.hidden_size, 'cls/seq_relationship', False) weights = 1.0 if sample_weight is not None: weights = tf.cast(sample_weight, dtype=tf.float32) # embedding loss embedding_loss = self._get_embedding_loss(teacher, student, bert_config, weights) # emd M = bert_config.num_hidden_layers N = student_config.num_hidden_layers with tf.variable_scope('emd'): teacher_weight = tf.get_variable( 'teacher_weight', shape=[M], initializer=tf.constant_initializer(1 / M), trainable=False) student_weight = tf.get_variable( 'student_weight', shape=[N], initializer=tf.constant_initializer(1 / N), trainable=False) self.teacher_weight = teacher_weight self.student_weight = student_weight # attention emd (attention_emd, new_attention_teacher_weight, new_attention_student_weight) = \ self._get_attention_emd( teacher, student, teacher_weight, student_weight, weights, emd_temporature) # hidden emd (hidden_emd, new_hidden_teacher_weight, new_hidden_student_weight) = \ self._get_hidden_emd( teacher, student, teacher_weight, student_weight, bert_config, weights, emd_temporature) # update weights new_teacher_weight = \ (new_attention_teacher_weight + new_hidden_teacher_weight) / 2 new_student_weight = \ (new_attention_student_weight + new_hidden_student_weight) / 2 update_teacher_weight_op = tf.assign(teacher_weight, new_teacher_weight) update_student_weight_op = tf.assign(student_weight, new_student_weight) # prediction loss pred_loss = self._get_pred_loss(teacher_logits, student_logits, weights, pred_temporature) # sum up with tf.control_dependencies( [update_teacher_weight_op, update_student_weight_op]): distill_loss = \ beta * (embedding_loss + attention_emd + hidden_emd) + \ pred_loss self.total_loss = distill_loss self.losses['losses'] = distill_loss else: student_probs = tf.nn.softmax(student_logits, axis=-1, name='probs') self.probs['probs'] = student_probs self.preds['preds'] = tf.argmax(student_probs, axis=-1)
def _get_hidden_emd(self, teacher, student, teacher_weight, student_weight, bert_config, sample_weight, emd_temporature): teacher_hidden_layers = teacher.all_encoder_layers teacher_hidden_layers = [ tf.stop_gradient(value) for value in teacher_hidden_layers ] student_hidden_layers = student.all_encoder_layers M = len(teacher_hidden_layers) N = len(student_hidden_layers) with tf.variable_scope('hidden_emd'): flow = tf.get_variable('flow', shape=[M, N], initializer=tf.constant_initializer(1 / M / N), trainable=False) # MSE rows = [] for m in range(M): cols = [] for n in range(N): linear_trans = tf.layers.dense( student_hidden_layers[n], bert_config.hidden_size, kernel_initializer=util.create_initializer( bert_config.initializer_range)) mse = tf.losses.mean_squared_error( teacher_hidden_layers[m], linear_trans, weights=tf.reshape(sample_weight, [-1, 1, 1])) col = tf.reshape(mse, [1, 1]) cols.append(col) row = tf.concat(cols, axis=1) rows.append(row) distance = tf.concat(rows, axis=0) # cost attention mechanism teacher_cost = (tf.reduce_sum(flow, axis=1) * tf.reduce_sum(distance, axis=1) / (teacher_weight + 1e-6)) student_cost = (tf.reduce_sum(flow, axis=0) * tf.reduce_sum(distance, axis=0) / (student_weight + 1e-6)) # new weights new_teacher_weight = tf.where( teacher_cost > 1e-12, tf.reduce_sum(teacher_cost) / (teacher_cost + 1e-6), teacher_weight) new_student_weight = tf.where( student_cost > 1e-12, tf.reduce_sum(student_cost) / (student_cost + 1e-6), student_weight) new_teacher_weight = tf.nn.softmax(new_teacher_weight / emd_temporature) new_student_weight = tf.nn.softmax(new_student_weight / emd_temporature) self.hidden_flow = flow self.hidden_distance = distance hidden_emd = tf.reduce_sum(flow * distance) return hidden_emd, new_teacher_weight, new_student_weight
def _get_attention_emd(self, teacher, student, teacher_weight, student_weight, sample_weight, emd_temporature): teacher_attention_scores = teacher.get_attention_scores() teacher_attention_scores = [ tf.stop_gradient(value) for value in teacher_attention_scores ] student_attention_scores = student.get_attention_scores() M = len(teacher_attention_scores) N = len(student_attention_scores) with tf.variable_scope('attention_emd'): flow = tf.get_variable('flow', shape=[M, N], initializer=tf.constant_initializer(1 / M / N), trainable=False) # MSE rows = [] for m in range(M): cols = [] for n in range(N): teacher_matrix = tf.where( teacher_attention_scores[m] < -1e2, tf.zeros_like(teacher_attention_scores[m]), teacher_attention_scores[m]) student_matrix = tf.where( student_attention_scores[n] < -1e2, tf.zeros_like(student_attention_scores[n]), student_attention_scores[n]) mse = tf.losses.mean_squared_error(teacher_matrix, student_matrix, weights=tf.reshape( sample_weight, [-1, 1, 1, 1])) col = tf.reshape(mse, [1, 1]) cols.append(col) row = tf.concat(cols, axis=1) rows.append(row) distance = tf.concat(rows, axis=0) # cost attention mechanism teacher_cost = (tf.reduce_sum(flow, axis=1) * tf.reduce_sum(distance, axis=1) / (teacher_weight + 1e-6)) student_cost = (tf.reduce_sum(flow, axis=0) * tf.reduce_sum(distance, axis=0) / (student_weight + 1e-6)) # new weights new_teacher_weight = tf.where( teacher_cost > 1e-12, tf.reduce_sum(teacher_cost) / (teacher_cost + 1e-6), teacher_weight) new_student_weight = tf.where( student_cost > 1e-12, tf.reduce_sum(student_cost) / (student_cost + 1e-6), student_weight) new_teacher_weight = tf.nn.softmax(new_teacher_weight / emd_temporature) new_student_weight = tf.nn.softmax(new_student_weight / emd_temporature) self.attention_flow = flow self.attention_distance = distance attention_emd = tf.reduce_sum(flow * distance) return attention_emd, new_teacher_weight, new_student_weight
def __init__(self, vocab_size, filter_sizes, num_channels, is_training, input_ids, scope='text_cnn', embedding_size=256, dropout_prob=0.1, trainable=True, **kwargs): input_shape = util.get_shape_list(input_ids, expected_rank=2) batch_size = input_shape[0] max_seq_length = input_shape[1] if isinstance(filter_sizes, str): filter_sizes = filter_sizes.split(',') assert isinstance(filter_sizes, list), ( '`filter_sizes` should be a list of integers or a string ' 'seperated with commas.') # Tilda embeddings for SMART algorithm tilda_embeddings = None use_tilda_embedding=kwargs.get('use_tilda_embedding') if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): with tf.variable_scope('embeddings'): if tilda_embeddings is not None: embedding_table = tilda_embeddings else: embedding_table = tf.get_variable( name='word_embeddings', shape=[vocab_size, embedding_size], initializer=util.create_initializer(0.02), dtype=tf.float32, trainable=trainable) flat_input_ids = tf.reshape(input_ids, [-1]) output = tf.gather( embedding_table, flat_input_ids, name='embedding_look_up') output = tf.reshape( output, [batch_size, max_seq_length, embedding_size]) output_expanded = tf.expand_dims(output, -1) # Create a convolution + maxpool layer for each filter size pooled_outputs = [] for i, filter_size in enumerate(filter_sizes): with tf.variable_scope('conv_%s' % filter_size): # Convolution Layer filter_shape = [filter_size, embedding_size, 1, num_channels] W = tf.get_variable( name='W', shape=filter_shape, initializer=\ tf.truncated_normal_initializer(0.1), dtype=tf.float32, trainable=trainable) b = tf.get_variable( name='b', shape=[num_channels], initializer=\ tf.constant_initializer(0.1), dtype=tf.float32, trainable=trainable) conv = tf.nn.conv2d( output_expanded, W, strides=[1, 1, 1, 1], padding='VALID', name='conv') # Apply nonlinearity h = tf.nn.relu(tf.nn.bias_add(conv, b), name='relu') # Maxpooling over the outputs pooled = tf.nn.max_pool( h, ksize=[1, max_seq_length - int(filter_size) + 1, 1, 1], strides=[1, 1, 1, 1], padding='VALID', name='pool') pooled_outputs.append(pooled) num_channels_total = num_channels * len(filter_sizes) h_pool = tf.concat(pooled_outputs, 3) h_pool_flat = tf.reshape(h_pool, [batch_size, num_channels_total]) with tf.name_scope('dropout'): self.pooled_output = util.dropout(h_pool_flat, dropout_prob)