def grad(res_grad): grads = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0])) gr_sums = sums q_grads = [] k_grads = [] v_grads = [] for index in range(qs.shape[0] - 1, -1, -1): q_grads.append( tf.einsum('ijkl,ijl->ijk', gr_sums, res_grad[index])[None, Ellipsis]) grads = grads + tf.einsum('ijk,ijl->ijkl', qs[index], res_grad[index]) k_grads.append( tf.einsum('ijkl,ijl->ijk', grads, vs[index])[None, Ellipsis]) v_grads.append( tf.einsum('ijkl,ijk->ijl', grads, ks[index])[None, Ellipsis]) gr_sums = gr_sums - tf.einsum('ijk,ijl->ijkl', ks[index], vs[index]) q_grads = tf.concat(q_grads[::-1], axis=0) k_grads = tf.concat(k_grads[::-1], axis=0) v_grads = tf.concat(v_grads[::-1], axis=0) return q_grads, k_grads, v_grads
def _forward(dilated_ids, dilated_mask): logits = self._bert_forward( bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) output_ids = tf.argmax(logits, axis=-1) output_ids = tf.cast(output_ids, dtype=tf.int32) # special padding (using `spad` token) equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32) equal_zero = tf.reduce_sum(equal_zero, axis=-1) right_pad = spad_id * tf.sequence_mask( equal_zero, dilated_seq_length, dtype=tf.int32) paded = tf.concat([output_ids, right_pad], axis=-1) # extract ids of length `max_seq_length` flattened_padded = tf.reshape(paded, [-1]) is_valid = tf.cast(tf.greater(flattened_padded, 0), dtype=tf.int32) flattened_valid = tf.boolean_mask(flattened_padded, is_valid) valid = tf.reshape(flattened_valid, [batch_size, dilated_seq_length]) cutted_valid = valid[:, :max_seq_length] # replace `spad` token with `pad` non_spad_mask = tf.cast(tf.not_equal( cutted_valid, spad_id), dtype=tf.int32) output_ids = cutted_valid * non_spad_mask output_length = tf.reduce_sum(non_spad_mask, axis=-1) # dilate reshaped_ids = tf.reshape(output_ids, [batch_size, max_seq_length, 1]) reshaped_mask = tf.reshape( tf.sequence_mask(output_length, max_seq_length, dtype=tf.int32), [batch_size, max_seq_length, 1]) concat_ids = tf.concat( [reshaped_ids, tf.zeros_like(reshaped_ids)], axis=-1) concat_mask = tf.concat([ reshaped_mask, tf.zeros_like(reshaped_mask, dtype=tf.int32) ], axis=-1) dilated_ids = tf.reshape(concat_ids, [batch_size, max_seq_length * 2]) dilated_mask = tf.reshape(concat_mask, [batch_size, max_seq_length * 2]) return dilated_ids, dilated_mask
def _forward(self, is_training, split_placeholders, **kwargs): if not is_training: return super()._forward(is_training, split_placeholders, **kwargs) aug_input_ids = tf.boolean_mask( split_placeholders['aug_input_ids'], mask=(1.0 - split_placeholders['is_supervised']), axis=0) aug_input_mask = tf.boolean_mask( split_placeholders['aug_input_mask'], mask=(1.0 - split_placeholders['is_supervised']), axis=0) aug_segment_ids = tf.boolean_mask( split_placeholders['aug_segment_ids'], mask=(1.0 - split_placeholders['is_supervised']), axis=0) input_ids = tf.concat([split_placeholders['input_ids'], aug_input_ids], axis=0) input_mask = tf.concat( [split_placeholders['input_mask'], aug_input_mask], axis=0) segment_ids = tf.concat( [split_placeholders['segment_ids'], aug_segment_ids], axis=0) encoder = BERTEncoder(bert_config=self.bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, scope='bert', drop_pooler=self._drop_pooler, **kwargs) encoder_output = encoder.get_pooled_output() label_ids = split_placeholders['label_ids'] is_expanded = tf.zeros_like(label_ids, dtype=tf.float32) batch_size = util.get_shape_list(aug_input_ids)[0] aug_is_expanded = tf.ones((batch_size), dtype=tf.float32) is_expanded = tf.concat([is_expanded, aug_is_expanded], axis=0) decoder = UDADecoder( is_training=is_training, input_tensor=encoder_output, is_supervised=split_placeholders['is_supervised'], is_expanded=is_expanded, label_ids=label_ids, label_size=self.label_size, sample_weight=split_placeholders.get('sample_weight'), scope='cls/seq_relationship', global_step=self._global_step, num_train_steps=self.total_steps, uda_softmax_temp=self._uda_softmax_temp, uda_confidence_thresh=self._uda_confidence_thresh, tsa_schedule=self._tsa_schedule, **kwargs) (total_loss, losses, probs, preds) = decoder.get_forward_outputs() return (total_loss, losses, probs, preds)
def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False): '''create causal attention mask.''' attn_mask = tf.ones([qlen, qlen], dtype=dtype) mask_u = tf.matrix_band_part(attn_mask, 0, -1) mask_dia = tf.matrix_band_part(attn_mask, 0, 0) attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype) ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1) if same_length: mask_l = tf.matrix_band_part(attn_mask, -1, 0) ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1) return ret
def multihead_attention(queries, keys, values, key_masks, num_heads=8, dropout_rate=0, training=True, causality=False, scope='multihead_attention'): '''Applies multihead attention. See 3.2.2 queries: A 3d tensor with shape of [N, T_q, d_model]. keys: A 3d tensor with shape of [N, T_k, d_model]. values: A 3d tensor with shape of [N, T_k, d_model]. key_masks: A 2d tensor with shape of [N, key_seqlen] num_heads: An int. Number of heads. dropout_rate: A floating point number. training: Boolean. Controller of mechanism for dropout. causality: Boolean. If true, units that reference the future are masked. scope: Optional scope for `variable_scope`. Returns A 3d tensor with shape of (N, T_q, C) ''' d_model = queries.get_shape().as_list()[-1] with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): # Linear projections Q = tf.layers.dense( queries, d_model, use_bias=True) # (N, T_q, d_model) K = tf.layers.dense( keys, d_model, use_bias=True) # (N, T_k, d_model) V = tf.layers.dense( values, d_model, use_bias=True) # (N, T_k, d_model) # Split and concat Q_ = tf.concat( tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, d_model/h) K_ = tf.concat( tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h) V_ = tf.concat( tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h) # Attention outputs = scaled_dot_product_attention( Q_, K_, V_, key_masks, causality, dropout_rate, training) # Restore shape outputs = tf.concat( tf.split(outputs, num_heads, axis=0), axis=2 ) # (N, T_q, d_model) # Residual connection outputs += queries # Normalize outputs = ln(outputs) return outputs
def __init__(self, is_training, input_tensor, label_ids, sample_weight=None, scope='mrc', name='', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) seq_length = input_tensor.shape.as_list()[-2] hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[2, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[2], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) output_layer = tf.reshape(output_layer, [-1, hidden_size]) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [-1, seq_length, 2]) logits = tf.transpose(logits, [0, 2, 1]) probs = tf.nn.softmax(logits, axis=-1, name='probs') self.probs[name] = probs start_one_hot_labels = tf.one_hot(label_ids[:, 0], depth=seq_length, dtype=tf.float32) end_one_hot_labels = tf.one_hot(label_ids[:, 1], depth=seq_length, dtype=tf.float32) start_log_probs = tf.nn.log_softmax(logits[:, 0, :], axis=-1) end_log_probs = tf.nn.log_softmax(logits[:, 1, :], axis=-1) per_example_loss = ( -0.5 * tf.reduce_sum(start_one_hot_labels * start_log_probs, axis=-1) - 0.5 * tf.reduce_sum(end_one_hot_labels * end_log_probs, axis=-1)) if sample_weight is not None: per_example_loss *= sample_weight self.total_loss = tf.reduce_mean(per_example_loss) self.losses[name] = per_example_loss start_preds = tf.expand_dims(tf.argmax(logits[:, 0, :], axis=-1), axis=-1) end_preds = tf.expand_dims(tf.argmax(logits[:, 1, :], axis=-1), axis=-1) self.preds[name] = tf.concat([start_preds, end_preds], axis=-1)
def embedding_lookup(input_ids, vocab_size, batch_size, max_seq_length, embedding_size=128, initializer_range=0.02, word_embedding_name='word_embeddings', zero_pad=True, dtype=tf.float32, trainable=True, tilda_embeddings=None): if input_ids.shape.ndims == 2: input_ids = tf.expand_dims(input_ids, axis=[-1]) if tilda_embeddings is not None: embedding_table = tilda_embeddings else: embedding_table = tf.get_variable( name=word_embedding_name, shape=[vocab_size, embedding_size], initializer=util.create_initializer(initializer_range), dtype=dtype, trainable=trainable) embedding_table = tf.concat( (tf.zeros(shape=[1, embedding_size]), embedding_table[1:, :]), axis=0) flat_input_ids = tf.reshape(input_ids, [-1]) output = tf.gather( embedding_table, flat_input_ids, name='embedding_look_up') output = tf.reshape( output, [batch_size, max_seq_length, embedding_size]) return (output, embedding_table)
def get_timing_signal_1d_given_position(channels, position, min_timescale=1.0, max_timescale=1.0e4): """Get sinusoids of diff frequencies, with timing position given. Adapted from add_timing_signal_1d_given_position in //third_party/py/tensor2tensor/layers/common_attention.py Args: channels: scalar, size of timing embeddings to create. The number of different timescales is equal to channels / 2. position: a Tensor with shape [batch, seq_len] min_timescale: a float max_timescale: a float Returns: a Tensor of timing signals [batch, seq_len, channels] """ num_timescales = channels // 2 log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (tf.to_float(num_timescales) - 1)) inv_timescales = min_timescale * tf.exp( tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) scaled_time = (tf.expand_dims(tf.to_float(position), 2) * tf.expand_dims(tf.expand_dims(inv_timescales, 0), 0)) signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=2) signal = tf.pad(signal, [[0, 0], [0, 0], [0, tf.mod(channels, 2)]]) return signal
def embedding_preprocessor(self, input_values, batch_size=None, embedding_size=128, initializer_range=0.02, name='cls_embedding', dtype=tf.float32, trainable=True): with tf.variable_scope(name): input_values = util.layer_norm(input_values, trainable=trainable) linear_output = tf.layers.dense( input_values, embedding_size, activation=None, name='dense', kernel_initializer=util.create_initializer(initializer_range), trainable=trainable) cls_embedding = tf.get_variable( name='cls', shape=[1, 1, embedding_size], initializer=util.create_initializer(initializer_range), dtype=dtype, trainable=trainable) cls_output = tf.tile(cls_embedding, [batch_size, 1, 1]) output = tf.concat([cls_output, linear_output], axis=1) return output
def attn(x, scope, n_state, *, past, hparams): assert x.shape.ndims == 3 # Should be [batch, sequence, features] assert n_state % hparams.n_head == 0 if past is not None: assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, # features], where 2 is [k, v] def split_heads(x): # From [batch, sequence, features] to [batch, heads, # sequence, features] return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3]) def merge_heads(x): # Reverse of split_heads return merge_states(tf.transpose(x, [0, 2, 1, 3])) def mask_attn_weights(w): # w has shape [batch, heads, dst_sequence, src_sequence], where # information flows from src to dst. _, _, nd, ns = shape_list(w) b = attention_mask(nd, ns, dtype=w.dtype) b = tf.reshape(b, [1, 1, nd, ns]) w = w * b - tf.cast(1e10, w.dtype) * (1 - b) return w def multihead_attn(q, k, v): # q, k, v have shape [batch, heads, sequence, features] w = tf.matmul(q, k, transpose_b=True) w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype)) w = mask_attn_weights(w) w = softmax(w) a = tf.matmul(w, v) return a with tf.variable_scope(scope): c = conv1d(x, 'c_attn', n_state * 3) q, k, v = map(split_heads, tf.split(c, 3, axis=2)) present = tf.stack([k, v], axis=1) if past is not None: pk, pv = tf.unstack(past, axis=1) k = tf.concat([pk, k], axis=-2) v = tf.concat([pv, v], axis=-2) a = multihead_attn(q, k, v) a = merge_heads(a) a = conv1d(a, 'c_proj', n_state) return a, present
def __init__(self, is_training, input_tensor, input_mask, label_ids, label_size=2, sample_weight=None, scope='cls/sequence', name='', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) batch_size = tf.shape(input_tensor)[0] seq_length = input_tensor.shape.as_list()[-2] hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) output_layer = tf.reshape(output_layer, [-1, hidden_size]) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [-1, seq_length, label_size]) self.preds[name] = tf.argmax(logits, axis=-1) self.probs[name] = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=label_size, dtype=tf.float32) per_token_losses = -tf.reduce_mean(one_hot_labels * log_probs, axis=-1) input_mask = tf.concat([ tf.zeros((batch_size, 1), dtype=tf.float32), tf.cast(input_mask[:, 2:], dtype=tf.float32), tf.zeros((batch_size, 1), dtype=tf.float32) ], axis=-1) per_token_losses *= input_mask per_example_loss = tf.reduce_mean(per_token_losses, axis=-1) if sample_weight is not None: per_example_loss *= tf.cast(sample_weight, dtype=tf.float32) self.losses[name] = per_example_loss self.total_loss = tf.reduce_mean(per_example_loss)
def scatter_update(sequence, updates, positions): '''Scatter-update a sequence. Args: sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor updates: A tensor of size batch_size*seq_len(*depth) positions: A [batch_size, n_positions] tensor Returns: A tuple of two tensors. First is a [batch_size, seq_len] or [batch_size, seq_len, depth] tensor of 'sequence' with elements at 'positions' replaced by the values at 'updates.' Updates to index 0 are ignored. If there are duplicated positions the update is only applied once. Second is a [batch_size, seq_len] mask tensor of which inputs were updated. ''' shape = util.get_shape_list(sequence, expected_rank=[2, 3]) depth_dimension = (len(shape) == 3) if depth_dimension: B, L, D = shape else: B, L = shape D = 1 sequence = tf.expand_dims(sequence, -1) N = util.get_shape_list(positions)[1] shift = tf.expand_dims(L * tf.range(B), -1) flat_positions = tf.reshape(positions + shift, [-1, 1]) flat_updates = tf.reshape(updates, [-1, D]) updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D]) updates = tf.reshape(updates, [B, L, D]) flat_updates_mask = tf.ones([B * N], tf.int32) updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L]) updates_mask = tf.reshape(updates_mask, [B, L]) not_first_token = tf.concat( [tf.zeros((B, 1), tf.int32), tf.ones((B, L - 1), tf.int32)], -1) updates_mask *= not_first_token updates_mask_3d = tf.expand_dims(updates_mask, -1) # account for duplicate positions if sequence.dtype == tf.float32: updates_mask_3d = tf.cast(updates_mask_3d, tf.float32) updates /= tf.maximum(1.0, updates_mask_3d) else: assert sequence.dtype == tf.int32 updates = tf.divide(updates, tf.maximum(1, updates_mask_3d)) updates = tf.cast(updates, tf.int32) updates_mask = tf.minimum(updates_mask, 1) updates_mask_3d = tf.minimum(updates_mask_3d, 1) updated_sequence = (((1 - updates_mask_3d) * sequence) + (updates_mask_3d * updates)) if not depth_dimension: updated_sequence = tf.squeeze(updated_sequence, -1) return updated_sequence, updates_mask
def positional_embedding(pos_seq, inv_freq, bsz=None): sinusoid_inp = tf.einsum('i,d->id', pos_seq, inv_freq) pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) pos_emb = pos_emb[:, None, :] if bsz is not None: pos_emb = tf.tile(pos_emb, [1, bsz, 1]) return pos_emb
def _forward(dilated_ids, dilated_mask): logits = self._bert_forward( bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) output_ids = tf.argmax(logits, axis=-1) output_ids = tf.cast(output_ids, dtype=tf.int32) equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32) equal_zero = tf.reduce_sum(equal_zero, axis=-1) right_pad = spad_id * tf.sequence_mask( equal_zero, dilated_seq_length, dtype=tf.int32) paded = tf.concat([output_ids, right_pad], axis=-1) flattened_padded = tf.reshape(paded, [-1]) is_valid = tf.cast(tf.greater(flattened_padded, 0), dtype=tf.int32) flattened_valid = tf.boolean_mask(flattened_padded, is_valid) valid = tf.reshape(flattened_valid, [batch_size, dilated_seq_length]) cutted_valid = valid[:, :max_seq_length] nonpad_mask = tf.cast(tf.not_equal(cutted_valid, spad_id), dtype=tf.int32) output_ids = cutted_valid * nonpad_mask reshaped = tf.reshape(output_ids, [batch_size, max_seq_length, 1]) concatenated = tf.concat( [reshaped, tf.zeros_like(reshaped)], axis=-1) dilated_ids = tf.reshape(concatenated, [batch_size, max_seq_length * 2]) input_mask = tf.reduce_sum(nonpad_mask, axis=-1) dilated_mask = tf.sequence_mask(input_mask, dilated_seq_length, dtype=tf.int32) return dilated_ids, dilated_mask
def _cls_self_attention(self, prev_output, batch_size, max_seq_length, label_size, attention_mask=None, cls_hidden_size=128, cls_num_attention_heads=2, attention_probs_dropout_prob=0.1, initializer_range=0.02, dtype=tf.float32, trainable=True): if cls_hidden_size % cls_num_attention_heads != 0: raise ValueError( '`cls_hidden_size` (%d) is not a multiple of the number of ' '`cls_num_attention_heads` (%d)' % (cls_hidden_size, cls_num_attention_heads)) cls_attention_head_size = int(cls_hidden_size / cls_num_attention_heads) with tf.variable_scope('attention'): attention_heads = [] with tf.variable_scope('self'): attention_head, _ = self.attention_layer( from_tensor=prev_output, to_tensor=prev_output, attention_mask=attention_mask, num_attention_heads=cls_num_attention_heads, size_per_head=cls_attention_head_size, attention_probs_dropout_prob=attention_probs_dropout_prob, initializer_range=initializer_range, do_return_2d_tensor=False, batch_size=batch_size, from_max_seq_length=max_seq_length, to_max_seq_length=max_seq_length, dtype=dtype, trainable=trainable) attention_heads.append(attention_head) attention_output = None if len(attention_heads) == 1: attention_output = attention_heads[0] else: attention_output = tf.concat(attention_heads, axis=-1) attention_output = util.layer_norm(attention_output[:, 0, :], trainable=trainable) with tf.variable_scope('output'): cls_output = tf.layers.dense( attention_output, label_size, kernel_initializer=util.create_initializer(initializer_range), trainable=trainable) return cls_output
def _single_seq_fn(): batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0] example_inds = tf.reshape( tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) sequence_scores = tf.gather_nd( tf.squeeze(inputs, [1]), tf.concat([example_inds, tag_indices], axis=1)) sequence_scores = tf.where(tf.less_equal(sequence_lengths, 0), tf.zeros_like(sequence_scores), sequence_scores) return sequence_scores
def relative_positional_encoding(qlen, klen, d_model, clamp_len, attn_type, bi_data, bsz=None, dtype=None): '''create relative positional encoding.''' freq_seq = tf.range(0, d_model, 2.0) if dtype is not None and dtype != tf.float32: freq_seq = tf.cast(freq_seq, dtype=dtype) inv_freq = 1 / (10000**(freq_seq / d_model)) if attn_type == 'bi': # beg, end = klen - 1, -qlen beg, end = klen, -qlen elif attn_type == 'uni': # beg, end = klen - 1, -1 beg, end = klen, -1 else: raise ValueError('Unknown `attn_type` {}.'.format(attn_type)) if bi_data: fwd_pos_seq = tf.range(beg, end, -1.0) bwd_pos_seq = tf.range(-beg, -end, 1.0) if dtype is not None and dtype != tf.float32: fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype) if clamp_len > 0: fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len) bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -clamp_len, clamp_len) if bsz is not None: # With bi_data, the batch size should be divisible by 2. assert bsz % 2 == 0 fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz // 2) bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz // 2) else: fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq) bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq) pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1) else: fwd_pos_seq = tf.range(beg, end, -1.0) if dtype is not None and dtype != tf.float32: fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) if clamp_len > 0: fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len) pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz) return pos_emb
def create_projection_matrix(m, d, seed=0, scaling=0, struct_mode=False): r'''Constructs the matrix of random projections. Constructs a matrix of random orthogonal projections. Each projection vector has direction chosen uniformly at random and either deterministic length \sqrt{d} or length taken from the \chi(d) distribution (in the latter case marginal distributions of the projections are d-dimensional Gaussian vectors with associated identity covariance matrix). Args: m: number of random projections. d: dimensionality of each random projection. seed: random seed used to construct projections. scaling: 1 if all the random projections need to be renormalized to have length \sqrt{d}, 0 if the lengths of random projections should follow \chi(d) distribution. struct_mode: if True then products of Givens rotations will be used to construct random orthogonal matrix. This bypasses Gram-Schmidt orthogonalization. Returns: The matrix of random projections of the shape [m, d]. ''' nb_full_blocks = int(m / d) block_list = [] current_seed = seed for _ in range(nb_full_blocks): if struct_mode: q = create_products_of_givens_rotations(d, seed) else: unstructured_block = tf.random_normal((d, d), seed=current_seed) q, _ = tf.linalg.qr(unstructured_block) q = tf.transpose(q) block_list.append(q) current_seed += 1 remaining_rows = m - nb_full_blocks * d if remaining_rows > 0: if struct_mode: q = create_products_of_givens_rotations(d, seed) else: unstructured_block = tf.random_normal((d, d), seed=current_seed) q, _ = tf.linalg.qr(unstructured_block) q = tf.transpose(q) block_list.append(q[0:remaining_rows]) final_matrix = tf.concat(block_list, axis=0) current_seed += 1 if scaling == 0: multiplier = tf.norm(tf.random_normal((m, d), seed=current_seed), axis=1) elif scaling == 1: multiplier = 1 / tf.math.rsqrt(float(d)) * tf.ones((m)) else: raise ValueError('Scaling must be one of {0, 1}. Was %s' % scaling) return tf.matmul(tf.linalg.diag(multiplier), final_matrix)
def causal_numerator(qs, ks, vs): '''Computes not-normalized FAVOR causal attention A_{masked}V. Args: qs: query_prime tensor of the shape [L,B,H,M]. ks: key_prime tensor of the shape [L,B,H,M]. vs: value tensor of the shape [L,B,H,D]. Returns: Not-normalized FAVOR causal attention A_{masked}V. ''' result = [] sums = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0])) for index in range(qs.shape[0]): sums = sums + tf.einsum('ijk,ijl->ijkl', ks[index], vs[index]) result.append( tf.einsum('ijkl,ijk->ijl', sums, qs[index])[None, Ellipsis]) result = tf.concat(result, axis=0) def grad(res_grad): grads = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0])) gr_sums = sums q_grads = [] k_grads = [] v_grads = [] for index in range(qs.shape[0] - 1, -1, -1): q_grads.append( tf.einsum('ijkl,ijl->ijk', gr_sums, res_grad[index])[None, Ellipsis]) grads = grads + tf.einsum('ijk,ijl->ijkl', qs[index], res_grad[index]) k_grads.append( tf.einsum('ijkl,ijl->ijk', grads, vs[index])[None, Ellipsis]) v_grads.append( tf.einsum('ijkl,ijk->ijl', grads, ks[index])[None, Ellipsis]) gr_sums = gr_sums - tf.einsum('ijk,ijl->ijkl', ks[index], vs[index]) q_grads = tf.concat(q_grads[::-1], axis=0) k_grads = tf.concat(k_grads[::-1], axis=0) v_grads = tf.concat(v_grads[::-1], axis=0) return q_grads, k_grads, v_grads return result, grad
def grad(res_grad): k_grad = tf.zeros_like(ks[0]) gr_sums = sums q_grads = [] k_grads = [] for index in range(qs.shape[0] - 1, -1, -1): q_grads.append( tf.einsum('ijk,ij->ijk', gr_sums, res_grad[index])[None, Ellipsis]) k_grad = k_grad + tf.einsum('ijk,ij->ijk', qs[index], res_grad[index]) k_grads.append(k_grad[None, Ellipsis]) gr_sums = gr_sums - ks[index] q_grads = tf.concat(q_grads[::-1], axis=0) k_grads = tf.concat(k_grads[::-1], axis=0) return q_grads, k_grads
def rel_multihead_attn(h, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask, mems, d_model, n_head, d_head, dropout, dropatt, is_training, kernel_initializer, scope='rel_attn', reuse=None): '''Multi-head attention with relative positional encoding.''' scale = 1 / (d_head**0.5) with tf.variable_scope(scope, reuse=reuse): if mems is not None and mems.shape.ndims > 1: cat = tf.concat([mems, h], 0) else: cat = h # content heads q_head_h = head_projection(h, d_model, n_head, d_head, kernel_initializer, 'q') k_head_h = head_projection(cat, d_model, n_head, d_head, kernel_initializer, 'k') v_head_h = head_projection(cat, d_model, n_head, d_head, kernel_initializer, 'v') # positional heads k_head_r = head_projection(r, d_model, n_head, d_head, kernel_initializer, 'r') # core attention ops attn_vec = rel_attn_core(q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, r_r_bias, r_s_bias, attn_mask, dropatt, is_training, scale) # post processing output = post_attention(h, attn_vec, d_model, n_head, d_head, dropout, is_training, kernel_initializer) return output
def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None): '''cache hidden states into memory.''' if mem_len is None or mem_len == 0: return None else: if reuse_len is not None and reuse_len > 0: curr_out = curr_out[:reuse_len] if prev_mem is None: new_mem = curr_out[-mem_len:] else: new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:] return tf.stop_gradient(new_mem)
def causal_denominator(qs, ks): '''Computes FAVOR normalizer in causal attention. Args: qs: query_prime tensor of the shape [L,B,H,M]. ks: key_prime tensor of the shape [L,B,H,M]. Returns: FAVOR normalizer in causal attention. ''' result = [] sums = tf.zeros_like(ks[0]) for index in range(qs.shape[0]): sums = sums + ks[index] result.append(tf.reduce_sum(qs[index] * sums, axis=2)[None, Ellipsis]) result = tf.concat(result, axis=0) def grad(res_grad): k_grad = tf.zeros_like(ks[0]) gr_sums = sums q_grads = [] k_grads = [] for index in range(qs.shape[0] - 1, -1, -1): q_grads.append( tf.einsum('ijk,ij->ijk', gr_sums, res_grad[index])[None, Ellipsis]) k_grad = k_grad + tf.einsum('ijk,ij->ijk', qs[index], res_grad[index]) k_grads.append(k_grad[None, Ellipsis]) gr_sums = gr_sums - ks[index] q_grads = tf.concat(q_grads[::-1], axis=0) k_grads = tf.concat(k_grads[::-1], axis=0) return q_grads, k_grads return result, grad
def get_token_embeddings(vocab_size, num_units, zero_pad=True): '''Constructs token embedding matrix. Note that the column of index 0's are set to zeros. vocab_size: scalar. V. num_units: embedding dimensionalty. E. zero_pad: Boolean. If True, all the values of the first row (id = 0) should be constant zero To apply query/key masks easily, zero pad is turned on. Returns weight variable: (V, E) ''' with tf.variable_scope('shared_weight_matrix'): embeddings = tf.get_variable('weight_mat', dtype=tf.float32, shape=(vocab_size, num_units), initializer=xavier_initializer()) if zero_pad: embeddings = tf.concat((tf.zeros(shape=[1, num_units]), embeddings[1:, :]), 0) return embeddings
def __init__(self, hparams, is_training, input_ids, sample_weight=None, scope='model', given=1, use_tilda_embedding=False, **kwargs): super().__init__() batch_size = util.get_shape_list(input_ids, expected_rank=2)[0] max_seq_length = hparams.n_predict # Tilda embeddings for SMART algorithm tilda_embeddings = None if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): def _forward(input_ids, past=None): batch, sequence = shape_list(input_ids) if tilda_embeddings is None: wte = tf.get_variable( 'word_embeddings', [hparams.n_vocab, hparams.n_embed], initializer=tf.random_normal_initializer(stddev=0.02)) else: wte = tilda_embeddings wpe = tf.get_variable( 'wpe', [hparams.n_ctx, hparams.n_embed], initializer=tf.random_normal_initializer(stddev=0.01)) past_length = 0 if past is None else tf.shape(past)[-2] h = (tf.gather(wte, input_ids) + tf.gather(wpe, positions_for(input_ids, past_length))) # stacked transformer layers presents = [] pasts = tf.unstack(past, axis=1) if past is not None else \ [None] * hparams.n_layer assert len(pasts) == hparams.n_layer for layer, past in enumerate(pasts): h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) presents.append(present) present = tf.stack(presents, axis=1) h = norm(h, 'ln_f') # Language model loss. Do tokens <n predict token n? h_flat = tf.reshape(h, [batch * sequence, hparams.n_embed]) logits = tf.matmul(h_flat, wte, transpose_b=True) logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab]) return logits, present # convert to labels label_ids = tf.concat( [input_ids[:, 1:], tf.zeros([batch_size, 1], dtype=tf.int32)], axis=-1) # forward once if is_training: (logits, _) = _forward(input_ids) self.preds['LM'] = tf.argmax(logits, axis=-1) # forward loop else: input_ids = input_ids[:, 0:given] for cur_length in range(given, max_seq_length + 1): (logits, _) = _forward(input_ids) pred_ids = tf.argmax(logits[:, cur_length - 1:cur_length, :], axis=-1) pred_ids = tf.cast(pred_ids, tf.int32) input_ids = tf.concat([input_ids, pred_ids], axis=-1) self.preds['LM'] = input_ids # loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=hparams.n_vocab) per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) label_mask = tf.cast(tf.not_equal(label_ids, 0), tf.float32) per_example_loss = \ tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \ tf.reduce_sum(label_mask, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) self.total_loss = tf.reduce_mean(per_example_loss) self.losses['LM'] = per_example_loss
def _local_perm(inputs, targets, is_masked, perm_size, seq_len): ''' Sample a permutation of the factorization order, and create an attention mask accordingly. Args: inputs: int64 Tensor in shape [seq_len], input ids. targets: int64 Tensor in shape [seq_len], target ids. is_masked: bool Tensor in shape [seq_len]. True means being selected for partial prediction. perm_size: the length of longest permutation. Could be set to be reuse_len. Should not be larger than reuse_len or there will be data leaks. seq_len: int, sequence length. ''' batch_size = tf.shape(inputs)[0] # Generate permutation indices index = tf.range(seq_len, dtype=tf.int64) index = tf.reshape(index, [-1, perm_size]) index = tf.transpose(index) index = tf.random_shuffle(index) index = tf.transpose(index) index = tf.reshape(index, [1, -1]) index = tf.tile(index, [batch_size, 1]) # `perm_mask` and `target_mask` # non-functional tokens non_func_tokens = tf.logical_not( tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID))) non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) masked_or_func_tokens = tf.logical_not(non_mask_tokens) # Set the permutation indices of non-masked (& non-funcional) tokens to the # smallest index (-1): # (1) they can be seen by all other positions # (2) they cannot see masked positions, so there won't be information leak smallest_index = -tf.ones([batch_size, seq_len], dtype=tf.int64) rev_index = tf.where(non_mask_tokens, smallest_index, index) # Create `target_mask`: non-funcional and maksed tokens # 1: use mask as input and have loss # 0: use token (or [SEP], [CLS]) as input and do not have loss target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) target_mask = tf.cast(target_tokens, tf.float32) # Create `perm_mask` # `target_tokens` cannot see themselves self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) # 0: can attend if i > j or j is non-masked perm_mask = tf.logical_and( self_rev_index[:, :, None] <= rev_index[:, None, :], tf.expand_dims(masked_or_func_tokens, axis=-1)) # new target: [next token] for LM and [curr token] (self) for PLM new_targets = tf.concat([inputs[:, 0:1], targets[:, :-1]], axis=1) # construct inputs_k inputs_k = inputs # construct inputs_q inputs_q = target_mask return perm_mask, new_targets, target_mask, inputs_k, inputs_q
def _expand_features(module, split_placeholders): inputs = split_placeholders['input'] target = split_placeholders['target'] is_masked = tf.cast(split_placeholders['is_masked'], tf.bool) batch_size = tf.shape(inputs)[0] non_reuse_len = module.max_seq_length - module.reuse_seq_length assert (module.perm_size <= module.reuse_seq_length and module.perm_size <= non_reuse_len) (perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0) = \ _local_perm( inputs[:, :module.reuse_seq_length], target[:, :module.reuse_seq_length], is_masked[:, :module.reuse_seq_length], module.perm_size, module.reuse_seq_length) (perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1) = \ _local_perm( inputs[:, module.reuse_seq_length:], target[:, module.reuse_seq_length:], is_masked[:, module.reuse_seq_length:], module.perm_size, non_reuse_len) perm_mask_0 = tf.concat([ tf.cast(perm_mask_0, dtype=tf.float32), tf.ones([batch_size, module.reuse_seq_length, non_reuse_len]) ], axis=2) perm_mask_1 = tf.concat([ tf.zeros([batch_size, non_reuse_len, module.reuse_seq_length]), tf.cast(perm_mask_1, dtype=tf.float32) ], axis=2) perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=1) target = tf.concat([target_0, target_1], axis=1) target_mask = tf.concat([target_mask_0, target_mask_1], axis=1) input_k = tf.concat([input_k_0, input_k_1], axis=1) input_q = tf.concat([input_q_0, input_q_1], axis=1) if module._num_predict is not None: #TODO(geying): convert tensors from 1-D to 2-D indices = tf.range(module.max_seq_length, dtype=tf.int64) indices = tf.reshape(indices, [-1, module.max_seq_length]) indices = tf.tile(indices, [batch_size, 1]) bool_target_mask = tf.cast(target_mask, tf.bool) indices = tf.boolean_mask(indices, bool_target_mask) ##### extra padding due to CLS/SEP introduced after prepro actual_num_predict = tf.shape(indices)[1] pad_len = module._num_predict - actual_num_predict ##### target_mapping target_mapping = tf.one_hot(indices, module.max_seq_length, dtype=tf.float32) paddings = tf.zeros([pad_len, module.max_seq_length], dtype=target_mapping.dtype) target_mapping = tf.concat([target_mapping, paddings], axis=0) split_placeholders['target_mapping'] = tf.reshape( target_mapping, [-1, module._num_predict, module.max_seq_length]) ##### target target = tf.boolean_mask(target, bool_target_mask) paddings = tf.zeros([pad_len], dtype=target.dtype) target = tf.concat([target, paddings], axis=0) split_placeholders['target'] = tf.reshape(target, [-1, module._num_predict]) ##### target mask target_mask = tf.concat([ tf.ones([batch_size, actual_num_predict], dtype=tf.float32), tf.zeros([batch_size, pad_len], dtype=tf.float32) ], axis=1) split_placeholders['target_mask'] = tf.reshape( target_mask, [-1, module._num_predict]) else: split_placeholders['target'] = tf.reshape(target, [-1, module.max_seq_length]) split_placeholders['target_mask'] = tf.reshape( target_mask, [-1, module.max_seq_length]) # reshape back to fixed shape split_placeholders['perm_mask'] = tf.reshape( perm_mask, [-1, module.max_seq_length, module.max_seq_length]) split_placeholders['input_k'] = tf.reshape(input_k, [-1, module.max_seq_length]) split_placeholders['input_q'] = tf.reshape(input_q, [-1, module.max_seq_length]) return split_placeholders
def _build_forward(layer_input): with tf.variable_scope('attention'): attention_heads = [] with tf.variable_scope('self'): (attention_head, attention_scores) = \ self.attention_layer( from_tensor=layer_input, to_tensor=layer_input, attention_mask=attention_mask, num_attention_heads=num_attention_heads, size_per_head=attention_head_size, attention_probs_dropout_prob=\ attention_probs_dropout_prob, initializer_range=initializer_range, do_return_2d_tensor=True, batch_size=batch_size, from_max_seq_length=max_seq_length, to_max_seq_length=max_seq_length, dtype=dtype, trainable=trainable) attention_heads.append(attention_head) self.attention_scores.append(attention_scores) attention_output = None if len(attention_heads) == 1: attention_output = attention_heads[0] else: attention_output = tf.concat(attention_heads, axis=-1) with tf.variable_scope('output'): attention_output = tf.layers.dense( attention_output, hidden_size, kernel_initializer=util.create_initializer( initializer_range), trainable=trainable) attention_output = util.dropout( attention_output, hidden_dropout_prob) attention_output = util.layer_norm( attention_output + layer_input, trainable=trainable) # The activation is only applied to the `intermediate` # hidden layer. with tf.variable_scope('intermediate'): intermediate_output = tf.layers.dense( attention_output, intermediate_size, activation=intermediate_act_fn, kernel_initializer=util.create_initializer( initializer_range), trainable=trainable) # Down-project back to hidden_size then add the residual. with tf.variable_scope('output'): layer_output = tf.layers.dense( intermediate_output, hidden_size, kernel_initializer=util.create_initializer( initializer_range), trainable=trainable) layer_output = util.dropout(layer_output, hidden_dropout_prob) layer_output = util.layer_norm(layer_output + attention_output, trainable=trainable) return layer_output
def transformer_model(input_tensor, attention_mask=None, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, intermediate_act_fn=util.gelu, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02, do_return_all_layers=False): '''Multi-headed, multi-layer Transformer from 'Attention is All You Need'. This is almost an exact implementation of the original Transformer encoder. See the original paper: https://arxiv.org/abs/1706.03762 Also see: https://github.com/tensorflow/tensor2tensor/blob/master/ tensor2tensor/models/transformer.py Args: input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, seq_length], with 1 for positions that can be attended to and 0 in positions that should not be. hidden_size: int. Hidden size of the Transformer. num_hidden_layers: int. Number of layers (blocks) in the Transformer. num_attention_heads: int. Number of attention heads in the Transformer. intermediate_size: int. The size of the 'intermediate' (a.k.a., feed forward) layer. intermediate_act_fn: function. The non-linear activation function to apply to the output of the intermediate/feed-forward layer. hidden_dropout_prob: float. Dropout probability for the hidden layers. attention_probs_dropout_prob: float. Dropout probability of the attention probabilities. initializer_range: float. Range of the initializer (stddev of truncated normal). do_return_all_layers: Whether to also return all layers or just the final layer. Returns: float Tensor of shape [batch_size, seq_length, hidden_size], the final hidden layer of the Transformer. Raises: ValueError: A Tensor shape or parameter is invalid. ''' if hidden_size % num_attention_heads != 0: raise ValueError( 'The hidden size (%d) is not a multiple of the number of attention ' 'heads (%d)' % (hidden_size, num_attention_heads)) attention_head_size = int(hidden_size / num_attention_heads) input_shape = util.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape[0] seq_length = input_shape[1] input_width = input_shape[2] # The Transformer performs sum residuals on all layers so the input needs # to be the same as the hidden size. if input_width != hidden_size: raise ValueError( 'The width of the input tensor (%d) != hidden size (%d)' % (input_width, hidden_size)) # We keep the representation as a 2D tensor to avoid re-shaping it back and # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on # the GPU/CPU but may not be free on the TPU, so we want to minimize them to # help the optimizer. prev_output = util.reshape_to_matrix(input_tensor) attn_maps = [] all_layer_outputs = [] for layer_idx in range(num_hidden_layers): with tf.variable_scope('layer_%d' % layer_idx): with tf.variable_scope('attention'): attention_heads = [] with tf.variable_scope('self'): attention_head, probs = attention_layer( from_tensor=prev_output, to_tensor=prev_output, attention_mask=attention_mask, num_attention_heads=num_attention_heads, size_per_head=attention_head_size, attention_probs_dropout_prob= attention_probs_dropout_prob, initializer_range=initializer_range, do_return_2d_tensor=True, batch_size=batch_size, from_seq_length=seq_length, to_seq_length=seq_length) attention_heads.append(attention_head) attn_maps.append(probs) attention_output = None if len(attention_heads) == 1: attention_output = attention_heads[0] else: # In the case where we have other sequences, we just concatenate # them to the self-attention head before the projection. attention_output = tf.concat(attention_heads, axis=-1) # Run a linear projection of `hidden_size` then add a residual # with `layer_input`. with tf.variable_scope('output'): attention_output = tf.layers.dense( attention_output, hidden_size, kernel_initializer=util.create_initializer( initializer_range)) attention_output = util.dropout(attention_output, hidden_dropout_prob) attention_output = util.layer_norm(attention_output + prev_output) # The activation is only applied to the 'intermediate' hidden layer. with tf.variable_scope('intermediate'): intermediate_output = tf.layers.dense( attention_output, intermediate_size, activation=intermediate_act_fn, kernel_initializer=util.create_initializer( initializer_range)) # Down-project back to `hidden_size` then add the residual. with tf.variable_scope('output'): prev_output = tf.layers.dense( intermediate_output, hidden_size, kernel_initializer=util.create_initializer( initializer_range)) prev_output = util.dropout(prev_output, hidden_dropout_prob) prev_output = util.layer_norm(prev_output + attention_output) all_layer_outputs.append(prev_output) attn_maps = tf.stack(attn_maps, 0) if do_return_all_layers: return tf.stack([ util.reshape_from_matrix(layer, input_shape) for layer in all_layer_outputs ], 0), attn_maps else: return util.reshape_from_matrix(prev_output, input_shape), attn_maps
def __init__(self, bert_config, is_training, input_ids, input_mask, segment_ids, sample_weight=None, scope='bert', dtype=tf.float32, drop_pooler=False, cls_model='self-attention', label_size=2, speed=0.1, ignore_cls='0', **kwargs): super(FastBERTCLSDistillor, self).__init__() if not ignore_cls: ignore_cls = [] if isinstance(ignore_cls, str): ignore_cls = ignore_cls.replace(' ', '').split(',') ignore_cls = list(map(int, ignore_cls)) elif isinstance(ignore_cls, list): ignore_cls = list(map(int, ignore_cls)) else: raise ValueError( '`ignore_cls` should be a list of child-classifier ids or ' 'a string seperated with commas.') if not speed: raise ValueError( '`speed` should be a float number between `0` and `1`.') bert_config = copy.deepcopy(bert_config) bert_config.hidden_dropout_prob = 0.0 bert_config.attention_probs_dropout_prob = 0.0 input_shape = util.get_shape_list(input_ids, expected_rank=2) batch_size = input_shape[0] max_seq_length = input_shape[1] with tf.variable_scope(scope): with tf.variable_scope('embeddings'): (self.embedding_output, self.embedding_table) = \ self.embedding_lookup( input_ids=input_ids, vocab_size=bert_config.vocab_size, batch_size=batch_size, max_seq_length=max_seq_length, embedding_size=bert_config.hidden_size, initializer_range=bert_config.initializer_range, word_embedding_name='word_embeddings', dtype=dtype, trainable=False, tilda_embeddings=None) # Add positional embeddings and token type embeddings # layer normalize and perform dropout. self.embedding_output = self.embedding_postprocessor( input_tensor=self.embedding_output, batch_size=batch_size, max_seq_length=max_seq_length, hidden_size=bert_config.hidden_size, use_token_type=True, segment_ids=segment_ids, token_type_vocab_size=bert_config.type_vocab_size, token_type_embedding_name='token_type_embeddings', use_position_embeddings=True, position_embedding_name='position_embeddings', initializer_range=bert_config.initializer_range, max_position_embeddings=\ bert_config.max_position_embeddings, dropout_prob=bert_config.hidden_dropout_prob, dtype=dtype, trainable=False) with tf.variable_scope('encoder'): attention_mask = self.create_attention_mask_from_input_mask( input_mask, batch_size, max_seq_length, dtype=dtype) # stacked transformers (self.all_encoder_layers, self.all_cls_layers) = \ self.dynamic_transformer_model( is_training, input_tensor=self.embedding_output, input_mask=input_mask, batch_size=batch_size, max_seq_length=max_seq_length, label_size=label_size, attention_mask=attention_mask, hidden_size=bert_config.hidden_size, num_hidden_layers=bert_config.num_hidden_layers, num_attention_heads=bert_config.num_attention_heads, intermediate_size=bert_config.intermediate_size, intermediate_act_fn=util.get_activation( bert_config.hidden_act), hidden_dropout_prob=bert_config.hidden_dropout_prob, attention_probs_dropout_prob=\ bert_config.attention_probs_dropout_prob, initializer_range=bert_config.initializer_range, dtype=dtype, cls_model=cls_model, speed=speed, ignore_cls=ignore_cls) self.sequence_output = self.all_encoder_layers[-1] with tf.variable_scope('pooler'): first_token_tensor = self.sequence_output[:, 0, :] # trick: ignore the fully connected layer if drop_pooler: self.pooled_output = first_token_tensor else: self.pooled_output = tf.layers.dense( first_token_tensor, bert_config.hidden_size, activation=tf.tanh, kernel_initializer=util.create_initializer( bert_config.initializer_range), trainable=False) # teacher classifier if bert_config.num_hidden_layers not in ignore_cls: with tf.variable_scope('cls/seq_relationship'): output_weights = tf.get_variable( 'output_weights', shape=[label_size, bert_config.hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=False) output_bias = tf.get_variable( 'output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=False) logits = tf.matmul(self.pooled_output, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probs = tf.nn.softmax(logits, axis=-1) # distillation if is_training: losses = [] for cls_probs in self.all_cls_layers.values(): # KL-Divergence per_example_loss = tf.reduce_sum( cls_probs * (tf.log(cls_probs) - tf.log(probs)), axis=-1) if sample_weight is not None: per_example_loss *= tf.cast(sample_weight, dtype=tf.float32) loss = tf.reduce_mean(per_example_loss) losses.append(loss) distill_loss = tf.add_n(losses) self.total_loss = distill_loss self.losses['losses'] = distill_loss else: if bert_config.num_hidden_layers not in ignore_cls: self.all_cls_layers[bert_config.num_hidden_layers] = probs self.probs['probs'] = tf.concat(list(self.all_cls_layers.values()), axis=0, name='probs')