def __init__(self, is_training, input_tensor, label_ids, sample_weight=None, scope='mrc', name='', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) seq_length = input_tensor.shape.as_list()[-2] hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[2, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[2], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) output_layer = tf.reshape(output_layer, [-1, hidden_size]) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [-1, seq_length, 2]) logits = tf.transpose(logits, [0, 2, 1]) probs = tf.nn.softmax(logits, axis=-1, name='probs') self.probs[name] = probs start_one_hot_labels = tf.one_hot(label_ids[:, 0], depth=seq_length, dtype=tf.float32) end_one_hot_labels = tf.one_hot(label_ids[:, 1], depth=seq_length, dtype=tf.float32) start_log_probs = tf.nn.log_softmax(logits[:, 0, :], axis=-1) end_log_probs = tf.nn.log_softmax(logits[:, 1, :], axis=-1) per_example_loss = ( -0.5 * tf.reduce_sum(start_one_hot_labels * start_log_probs, axis=-1) - 0.5 * tf.reduce_sum(end_one_hot_labels * end_log_probs, axis=-1)) if sample_weight is not None: per_example_loss *= sample_weight self.total_loss = tf.reduce_mean(per_example_loss) self.losses[name] = per_example_loss start_preds = tf.expand_dims(tf.argmax(logits[:, 0, :], axis=-1), axis=-1) end_preds = tf.expand_dims(tf.argmax(logits[:, 1, :], axis=-1), axis=-1) self.preds[name] = tf.concat([start_preds, end_preds], axis=-1)
def get_timing_signal_1d_given_position(channels, position, min_timescale=1.0, max_timescale=1.0e4): """Get sinusoids of diff frequencies, with timing position given. Adapted from add_timing_signal_1d_given_position in //third_party/py/tensor2tensor/layers/common_attention.py Args: channels: scalar, size of timing embeddings to create. The number of different timescales is equal to channels / 2. position: a Tensor with shape [batch, seq_len] min_timescale: a float max_timescale: a float Returns: a Tensor of timing signals [batch, seq_len, channels] """ num_timescales = channels // 2 log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (tf.to_float(num_timescales) - 1)) inv_timescales = min_timescale * tf.exp( tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) scaled_time = (tf.expand_dims(tf.to_float(position), 2) * tf.expand_dims(tf.expand_dims(inv_timescales, 0), 0)) signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=2) signal = tf.pad(signal, [[0, 0], [0, 0], [0, tf.mod(channels, 2)]]) return signal
def gather_positions(sequence, positions): '''Gathers the vectors at the specific positions over a minibatch. Args: sequence: A [batch_size, seq_length] or [batch_size, seq_length, depth] tensor of values positions: A [batch_size, n_positions] tensor of indices Returns: A [batch_size, n_positions] or [batch_size, n_positions, depth] tensor of the values at the indices ''' shape = util.get_shape_list(sequence, expected_rank=[2, 3]) depth_dimension = (len(shape) == 3) if depth_dimension: B, L, D = shape else: B, L = shape D = 1 sequence = tf.expand_dims(sequence, -1) position_shift = tf.expand_dims(L * tf.range(B), -1) flat_positions = tf.reshape(positions + position_shift, [-1]) flat_sequence = tf.reshape(sequence, [B * L, D]) gathered = tf.gather(flat_sequence, flat_positions) if depth_dimension: return tf.reshape(gathered, [B, -1, D]) else: return tf.reshape(gathered, [B, -1])
def crf_unary_score(tag_indices, sequence_lengths, inputs): ''' Computes the unary scores of tag sequences. Args: tag_indices: A [batch_size, max_seq_len] matrix of tag indices. sequence_lengths: A [batch_size] vector of true sequence lengths. inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. Returns: unary_scores: A [batch_size] vector of unary scores. ''' batch_size = tf.shape(inputs)[0] max_seq_len = tf.shape(inputs)[1] num_tags = tf.shape(inputs)[2] flattened_inputs = tf.reshape(inputs, [-1]) offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1) offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0) # Use int32 or int64 based on tag_indices' dtype. if tag_indices.dtype == tf.int64: offsets = tf.cast(offsets, tf.int64) flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1]) unary_scores = tf.reshape( tf.gather(flattened_inputs, flattened_tag_indices), [batch_size, max_seq_len]) masks = tf.sequence_mask(sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=tf.float32) unary_scores = tf.reduce_sum(unary_scores * masks, 1) return unary_scores
def scatter_update(sequence, updates, positions): '''Scatter-update a sequence. Args: sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor updates: A tensor of size batch_size*seq_len(*depth) positions: A [batch_size, n_positions] tensor Returns: A tuple of two tensors. First is a [batch_size, seq_len] or [batch_size, seq_len, depth] tensor of 'sequence' with elements at 'positions' replaced by the values at 'updates.' Updates to index 0 are ignored. If there are duplicated positions the update is only applied once. Second is a [batch_size, seq_len] mask tensor of which inputs were updated. ''' shape = util.get_shape_list(sequence, expected_rank=[2, 3]) depth_dimension = (len(shape) == 3) if depth_dimension: B, L, D = shape else: B, L = shape D = 1 sequence = tf.expand_dims(sequence, -1) N = util.get_shape_list(positions)[1] shift = tf.expand_dims(L * tf.range(B), -1) flat_positions = tf.reshape(positions + shift, [-1, 1]) flat_updates = tf.reshape(updates, [-1, D]) updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D]) updates = tf.reshape(updates, [B, L, D]) flat_updates_mask = tf.ones([B * N], tf.int32) updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L]) updates_mask = tf.reshape(updates_mask, [B, L]) not_first_token = tf.concat( [tf.zeros((B, 1), tf.int32), tf.ones((B, L - 1), tf.int32)], -1) updates_mask *= not_first_token updates_mask_3d = tf.expand_dims(updates_mask, -1) # account for duplicate positions if sequence.dtype == tf.float32: updates_mask_3d = tf.cast(updates_mask_3d, tf.float32) updates /= tf.maximum(1.0, updates_mask_3d) else: assert sequence.dtype == tf.int32 updates = tf.divide(updates, tf.maximum(1, updates_mask_3d)) updates = tf.cast(updates, tf.int32) updates_mask = tf.minimum(updates_mask, 1) updates_mask_3d = tf.minimum(updates_mask_3d, 1) updated_sequence = (((1 - updates_mask_3d) * sequence) + (updates_mask_3d * updates)) if not depth_dimension: updated_sequence = tf.squeeze(updated_sequence, -1) return updated_sequence, updates_mask
def mask(inputs, key_masks=None, type=None): '''Masks paddings on keys or queries to inputs inputs: 3d tensor. (h*N, T_q, T_k) key_masks: 3d tensor. (N, 1, T_k) type: string. 'key' | 'future' e.g., >> inputs = tf.zeros([2, 2, 3], dtype=tf.float32) >> key_masks = tf.constant([[0., 0., 1.], [0., 1., 1.]]) >> mask(inputs, key_masks=key_masks, type='key') array([[[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09], [ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09]], [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09], [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]], [[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09], [ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09]], [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09], [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]]], dtype=float32) ''' padding_num = -2 ** 32 + 1 if type in ('k', 'key', 'keys'): key_masks = tf.to_float(key_masks) key_masks = tf.tile( key_masks, [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen) key_masks = tf.expand_dims(key_masks, 1) # (h*N, 1, seqlen) outputs = inputs + key_masks * padding_num # elif type in ('q', 'query', 'queries'): # # Generate masks # masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) # (N, T_q) # masks = tf.expand_dims(masks, -1) # (N, T_q, 1) # masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) # (N, T_q, T_k) # # # Apply masks to inputs # outputs = inputs*masks elif type in ('f', 'future', 'right'): diag_vals = tf.ones_like(inputs[0, :, :]) # (T_q, T_k) tril = tf.linalg.LinearOperatorLowerTriangular( diag_vals).to_dense() # (T_q, T_k) future_masks = tf.tile( tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) # (N, T_q, T_k) paddings = tf.ones_like(future_masks) * padding_num outputs = tf.where(tf.equal(future_masks, 0), paddings, inputs) else: print('Check if you entered type correctly!') return outputs
def __init__(self, xlnet_config, is_training, input_ids, seg_ids, input_mask, mems, perm_mask, target, target_mask, target_mapping, inp_q, sample_weight=None, **kwargs): super().__init__() run_config = XLNetRunConfig(is_training=is_training, bi_data=True, use_tpu=False, use_bfloat16=False, dropout=(0.1 if is_training else 0.0), dropatt=(0.1 if is_training else 0.0), init='normal', init_range=0.1, init_std=0.02, clamp_len=-1) model = XLNetEncoder(xlnet_config=xlnet_config, is_training=is_training, input_ids=input_ids, seg_ids=seg_ids, input_mask=input_mask, mems=mems, perm_mask=perm_mask, target_mapping=target_mapping, inp_q=inp_q, **kwargs) with tf.variable_scope('model', reuse=tf.AUTO_REUSE): per_example_loss, preds = lm_loss( hidden=model.get_sequence_output(), target=target, n_token=xlnet_config.n_token, d_model=xlnet_config.d_model, initializer=model.get_initializer(), lookup_table=model.get_embedding_table(), tie_weight=True, bi_data=run_config.bi_data, use_tpu=run_config.use_tpu) if sample_weight is not None: sample_weight = tf.expand_dims(tf.cast(sample_weight, dtype=tf.float32), axis=-1) per_example_loss *= sample_weight self.total_loss = tf.reduce_sum( per_example_loss * target_mask) / tf.reduce_sum(target_mask) self.losses['losses'] = per_example_loss * target_mask self.preds['preds'] = preds self.preds['mask'] = target_mask
def embedding_lookup(self, input_ids, vocab_size, batch_size, max_seq_length, embedding_size=128, initializer_range=0.02, word_embedding_name='word_embeddings', dtype=tf.float32, trainable=True, tilda_embeddings=None): if input_ids.shape.ndims == 2: input_ids = tf.expand_dims(input_ids, axis=[-1]) if tilda_embeddings is not None: embedding_table = tilda_embeddings else: embedding_table = tf.get_variable( name=word_embedding_name, shape=[vocab_size, embedding_size], initializer=util.create_initializer(initializer_range), dtype=dtype, trainable=trainable) flat_input_ids = tf.reshape(input_ids, [-1]) output = tf.gather(embedding_table, flat_input_ids, name='embedding_look_up') output = tf.reshape(output, [batch_size, max_seq_length, embedding_size]) return (output, embedding_table)
def __call__(self, inputs, state, scope=None): ''' Build the CrfForwardRnnCell. Args: inputs: A [batch_size, num_tags] matrix of unary potentials. state: A [batch_size, num_tags] matrix containing the previous alpha values. scope: Unused variable scope of this cell. Returns: new_alphas, new_alphas: A pair of [batch_size, num_tags] matrices values containing the new alpha values. ''' state = tf.expand_dims(state, 2) # This addition op broadcasts self._transitions_params along the zeroth # dimension and state along the second dimension. This performs the # multiplication of previous alpha values and the current binary # potentials in log space. transition_scores = state + self._transition_params new_alphas = inputs + tf.reduce_logsumexp(transition_scores, [1]) # Both the state and the output of this RNN cell contain the alphas # values. The output value is currently unused and simply satisfies the # RNN API. This could be useful in the future if we need to compute # marginal probabilities, which would require the accumulated alpha # values at every time step. return new_alphas, new_alphas
def __init__(self, transition_params): '''Initialize the CrfForwardRnnCell. Args: transition_params: A [num_tags, num_tags] matrix of binary potentials. This matrix is expanded into a [1, num_tags, num_tags] in preparation for the broadcast summation occurring within the cell. ''' self._transition_params = tf.expand_dims(transition_params, 0) self._num_tags = util.get_shape_list(transition_params)[0]
def embedding_lookup(input_ids, vocab_size, embedding_size=128, initializer_range=0.02, word_embedding_name='word_embeddings', use_one_hot_embeddings=False): '''Looks up words embeddings for id tensor. Args: input_ids: int32 Tensor of shape [batch_size, seq_length] containing word ids. vocab_size: int. Size of the embedding vocabulary. embedding_size: int. Width of the word embeddings. initializer_range: float. Embedding initialization range. word_embedding_name: string. Name of the embedding table. use_one_hot_embeddings: bool. If True, use one-hot method for word embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better for TPUs. Returns: float Tensor of shape [batch_size, seq_length, embedding_size]. ''' # This function assumes that the input is of shape [batch_size, seq_length, # num_inputs]. # # If the input is a 2D tensor of shape [batch_size, seq_length], we # reshape to [batch_size, seq_length, 1]. original_dims = input_ids.shape.ndims if original_dims == 2: input_ids = tf.expand_dims(input_ids, axis=[-1]) embedding_table = tf.get_variable( name=word_embedding_name, shape=[vocab_size, embedding_size], initializer=util.create_initializer(initializer_range)) if original_dims == 3: input_shape = util.get_shape_list(input_ids) tf.reshape(input_ids, [-1, input_shape[-1]]) output = tf.matmul(input_ids, embedding_table) output = tf.reshape(output, [input_shape[0], input_shape[1], embedding_size]) else: if use_one_hot_embeddings: flat_input_ids = tf.reshape(input_ids, [-1]) one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) output = tf.matmul(one_hot_input_ids, embedding_table) else: output = tf.nn.embedding_lookup(embedding_table, input_ids) input_shape = util.get_shape_list(input_ids) output = tf.reshape( output, input_shape[0:-1] + [input_shape[-1] * embedding_size]) return output, embedding_table
def scaled_dot_product_attention(Q, K, V, key_masks, causality=False, dropout_rate=0., training=True, scope='scaled_dot_product_attention'): '''See 3.2.1. Q: Packed queries. 3d tensor. [N, T_q, d_k]. K: Packed keys. 3d tensor. [N, T_k, d_k]. V: Packed values. 3d tensor. [N, T_k, d_v]. key_masks: A 2d tensor with shape of [N, key_seqlen] causality: If True, applies masking for future blinding dropout_rate: A floating point number of [0, 1]. training: boolean for controlling droput scope: Optional scope for `variable_scope`. ''' with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): d_k = Q.get_shape().as_list()[-1] # dot product outputs = tf.matmul(Q, tf.transpose(K, [0, 2, 1])) # (N, T_q, T_k) # scale outputs /= d_k ** 0.5 # key masking outputs = mask(outputs, key_masks=key_masks, type='key') # causality or future blinding masking if causality: outputs = mask(outputs, type='future') # softmax outputs = tf.nn.softmax(outputs) attention = tf.transpose(outputs, [0, 2, 1]) tf.summary.image('attention', tf.expand_dims(attention[:1], -1)) # # query masking # outputs = mask(outputs, Q, K, type='query') # dropout outputs = tf.layers.dropout( outputs, rate=dropout_rate, training=training) # weighted sum (context vectors) outputs = tf.matmul(outputs, V) # (N, T_q, d_v) return outputs
def softmax_kernel_transformation(data, is_query, projection_matrix=None, numerical_stabilizer=0.000001): '''Computes random features for the softmax kernel using FAVOR+ mechanism. Computes random features for the softmax kernel using FAVOR+ mechanism from https://arxiv.org/pdf/2009.14794.pdf. Args: data: input data tensor of the shape [B, L, H, D], where: B - batch dimension, L - attention dimensions, H - heads, D - features. is_query: indicates whether input data is a query oor key tensor. projection_matrix: random Gaussian matrix of shape [M, D], where M stands for the number of random features and each D x D sub-block has pairwise orthogonal rows. numerical_stabilizer: small positive constant for numerical stability. Returns: Corresponding kernel feature map. ''' data_normalizer = \ tf.math.rsqrt(1 / tf.math.rsqrt(tf.cast(data.shape[-1], tf.float32))) ratio = tf.math.rsqrt( tf.cast( projection_matrix.shape[0] if projection_matrix is not None else 1.0, tf.float32)) data_dash = tf.einsum('blhd,md->blhm', data, projection_matrix) diag_data = tf.math.square(data) diag_data = tf.math.reduce_sum(diag_data, axis=tf.keras.backend.ndim(data) - 1) diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1) if is_query: last_dims_t = (len(data_dash.shape) - 1, ) data_dash = ratio * (tf.math.exp( data_dash - diag_data - tf.math.reduce_max(data_dash, axis=last_dims_t, keepdims=True)) + numerical_stabilizer) else: data_dash = ratio * (tf.math.exp(data_dash - diag_data - tf.math.reduce_max(data_dash)) + numerical_stabilizer) return data_dash
def _cls_forward(self, is_training, input_tensor, input_mask, label_ids, bert_config, batch_size, max_seq_length, prob, scope, name, sample_weight=None, hidden_dropout_prob=0.1, initializer_range=0.02): with tf.variable_scope(scope): logits = tf.layers.dense( input_tensor, 2, kernel_initializer=util.create_initializer( bert_config.initializer_range), trainable=True) # loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=2) per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) input_mask = tf.cast(input_mask, tf.float32) per_token_loss *= input_mask / tf.reduce_sum( input_mask, keepdims=True, axis=-1) per_example_loss = tf.reduce_sum(per_token_loss, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) if prob != 0: self.total_loss += tf.reduce_mean(per_example_loss) self.losses[name + '_loss'] = per_example_loss self.preds[name + '_preds'] = tf.argmax(logits, axis=-1)
def positional_encoding(inputs, maxlen, masking=True, scope='positional_encoding'): '''Sinusoidal Positional_Encoding. See 3.5 inputs: 3d tensor. (N, T, E) maxlen: scalar. Must be >= T masking: Boolean. If True, padding positions are set to zeros. scope: Optional scope for `variable_scope`. returns 3d tensor that has the same shape as inputs. ''' E = inputs.get_shape().as_list()[-1] # static N, T = tf.shape(inputs)[0], tf.shape(inputs)[1] # dynamic with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): # position indices position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1]) # (N, T) # First part of the PE function: sin and cos argument position_enc = np.array([ [pos / np.power(10000, (i-i%2)/E) for i in range(E)] for pos in range(maxlen)]) # Second part, apply the cosine to even columns and sin to odds. position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1 position_enc = tf.convert_to_tensor( position_enc, tf.float32) # (maxlen, E) # lookup outputs = tf.nn.embedding_lookup(position_enc, position_ind) # masks if masking: outputs = tf.where(tf.equal(inputs, 0), inputs, outputs) return tf.to_float(outputs)
def favor_attention(query, key, value, kernel_transformation, causal, projection_matrix=None): '''Computes FAVOR normalized attention. Args: query: query tensor. key: key tensor. value: value tensor. kernel_transformation: transformation used to get finite kernel features. causal: whether attention is causal or not. projection_matrix: projection matrix to be used. Returns: FAVOR normalized attention. ''' query_prime = kernel_transformation(query, True, projection_matrix) # [B,L,H,M] key_prime = kernel_transformation(key, False, projection_matrix) # [B,L,H,M] query_prime = tf.transpose(query_prime, [1, 0, 2, 3]) # [L,B,H,M] key_prime = tf.transpose(key_prime, [1, 0, 2, 3]) # [L,B,H,M] value = tf.transpose(value, [1, 0, 2, 3]) # [L,B,H,D] if causal: av_attention = causal_numerator(query_prime, key_prime, value) attention_normalizer = causal_denominator(query_prime, key_prime) else: av_attention = noncausal_numerator(query_prime, key_prime, value) attention_normalizer = noncausal_denominator(query_prime, key_prime) av_attention = tf.transpose(av_attention, [1, 0, 2, 3]) attention_normalizer = tf.transpose(attention_normalizer, [1, 0, 2]) attention_normalizer = tf.expand_dims(attention_normalizer, len(attention_normalizer.shape)) return av_attention / attention_normalizer
def expand_tile(value, size): '''Add a new axis of given size.''' value = tf.convert_to_tensor(value, name='value') ndims = value.shape.ndims return tf.tile(tf.expand_dims(value, axis=0), [size] + [1] * ndims)
def __init__(self, hparams, is_training, input_ids, sample_weight=None, scope='model', given=1, use_tilda_embedding=False, **kwargs): super().__init__() batch_size = util.get_shape_list(input_ids, expected_rank=2)[0] max_seq_length = hparams.n_predict # Tilda embeddings for SMART algorithm tilda_embeddings = None if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): def _forward(input_ids, past=None): batch, sequence = shape_list(input_ids) if tilda_embeddings is None: wte = tf.get_variable( 'word_embeddings', [hparams.n_vocab, hparams.n_embed], initializer=tf.random_normal_initializer(stddev=0.02)) else: wte = tilda_embeddings wpe = tf.get_variable( 'wpe', [hparams.n_ctx, hparams.n_embed], initializer=tf.random_normal_initializer(stddev=0.01)) past_length = 0 if past is None else tf.shape(past)[-2] h = (tf.gather(wte, input_ids) + tf.gather(wpe, positions_for(input_ids, past_length))) # stacked transformer layers presents = [] pasts = tf.unstack(past, axis=1) if past is not None else \ [None] * hparams.n_layer assert len(pasts) == hparams.n_layer for layer, past in enumerate(pasts): h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) presents.append(present) present = tf.stack(presents, axis=1) h = norm(h, 'ln_f') # Language model loss. Do tokens <n predict token n? h_flat = tf.reshape(h, [batch * sequence, hparams.n_embed]) logits = tf.matmul(h_flat, wte, transpose_b=True) logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab]) return logits, present # convert to labels label_ids = tf.concat( [input_ids[:, 1:], tf.zeros([batch_size, 1], dtype=tf.int32)], axis=-1) # forward once if is_training: (logits, _) = _forward(input_ids) self.preds['LM'] = tf.argmax(logits, axis=-1) # forward loop else: input_ids = input_ids[:, 0:given] for cur_length in range(given, max_seq_length + 1): (logits, _) = _forward(input_ids) pred_ids = tf.argmax(logits[:, cur_length - 1:cur_length, :], axis=-1) pred_ids = tf.cast(pred_ids, tf.int32) input_ids = tf.concat([input_ids, pred_ids], axis=-1) self.preds['LM'] = input_ids # loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=hparams.n_vocab) per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) label_mask = tf.cast(tf.not_equal(label_ids, 0), tf.float32) per_example_loss = \ tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \ tf.reduce_sum(label_mask, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) self.total_loss = tf.reduce_mean(per_example_loss) self.losses['LM'] = per_example_loss
def __init__(self, vocab_size, is_training, input_ids, input_mask, segment_ids, sample_weight=None, reduced_size=64, topic_size=1024, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, bias=0, scope='vae', trainable=True, **kwargs): super().__init__() # freeze parameters config = Config(vocab_size, hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads) if not is_training: config.hidden_dropout_prob = 0.0 config.attention_probs_dropout_prob = 0.0 input_shape = util.get_shape_list(input_ids, expected_rank=2) batch_size = input_shape[0] seq_length = input_shape[1] # Tilda embeddings for SMART algorithm tilda_embeddings = None use_tilda_embedding = kwargs.get('use_tilda_embedding') if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): with tf.variable_scope('embeddings'): (self.embedding_output, self.embedding_table) = \ self.embedding_lookup( input_ids=input_ids, vocab_size=config.vocab_size, batch_size=batch_size, max_seq_length=seq_length, embedding_size=config.hidden_size, initializer_range=config.initializer_range, word_embedding_name='word_embeddings', tilda_embeddings=tilda_embeddings, trainable=trainable) self.embedding_output = self.embedding_postprocessor( input_tensor=self.embedding_output, batch_size=batch_size, max_seq_length=seq_length, hidden_size=config.hidden_size, use_token_type=True, segment_ids=segment_ids, token_type_vocab_size=config.type_vocab_size, token_type_embedding_name='token_type_embeddings', use_position_embeddings=True, position_embedding_name='position_embeddings', initializer_range=config.initializer_range, max_position_embeddings=config.max_position_embeddings, dropout_prob=config.hidden_dropout_prob, trainable=trainable) with tf.variable_scope('encoder'): # stacked transformer attention_mask = self.create_attention_mask_from_input_mask( input_mask, batch_size, seq_length) self.all_encoder_layers = self.transformer_model( input_tensor=self.embedding_output, batch_size=batch_size, max_seq_length=seq_length, attention_mask=attention_mask, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, intermediate_act_fn=util.get_activation(config.hidden_act), hidden_dropout_prob=config.hidden_dropout_prob, attention_probs_dropout_prob=\ config.attention_probs_dropout_prob, initializer_range=config.initializer_range, trainable=trainable) # projection with tf.variable_scope('projection'): transformer_output = tf.layers.dense( self.all_encoder_layers[-1], reduced_size, activation=util.gelu, kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), trainable=trainable) transformer_output = tf.reshape(transformer_output, [batch_size, -1]) input_length = tf.reduce_sum(input_mask, axis=-1) input_length = tf.cast(input_length, tf.float32) input_length_1d = tf.reshape(input_length, [batch_size]) input_length_2d = tf.reshape(input_length, [batch_size, 1]) broadcast_mask = tf.sequence_mask( tf.multiply(input_length_1d, reduced_size), seq_length * reduced_size, dtype=tf.float32) broadcast_mask = tf.multiply(broadcast_mask, seq_length / input_length_2d) transformer_output *= broadcast_mask # latent space miu = tf.layers.dense( transformer_output, topic_size, activation='tanh', kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), name='miu', trainable=trainable) sigma = tf.layers.dense( transformer_output, topic_size, kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), name='sigma', trainable=trainable) self.probs['miu'] = miu self.probs['sigma'] = sigma with tf.variable_scope('decoder'): with tf.variable_scope('projection'): # reparametarization if is_training: noise = tf.random_normal([batch_size, topic_size]) else: noise = tf.random_uniform([batch_size, topic_size], minval=-bias, maxval=bias) decoder_input = miu + tf.exp(sigma) * noise # projection decoder_input = tf.layers.dense( decoder_input, seq_length * reduced_size, activation=util.gelu, kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), trainable=trainable) intermediate_input = tf.reshape( decoder_input, [-1, seq_length, reduced_size]) intermediate_input = util.layer_norm(intermediate_input, trainable=trainable) intermediate_input = util.dropout( intermediate_input, config.hidden_dropout_prob) # MLP with tf.variable_scope('intermediate'): intermediate_output = tf.layers.dense( intermediate_input, 4 * reduced_size, activation=util.gelu, kernel_initializer=util.create_initializer( config.initializer_range), trainable=trainable) with tf.variable_scope('output'): decoder_output = tf.layers.dense( intermediate_output, config.hidden_size, kernel_initializer=util.create_initializer( config.initializer_range), trainable=trainable) decoder_output = util.layer_norm(decoder_output, trainable=trainable) decoder_output = util.dropout(decoder_output, config.hidden_dropout_prob) self.all_decoder_layers = [intermediate_output, decoder_output] self.all_decoder_layers = [decoder_output] # reconstruction with tf.variable_scope('cls/predictions'): with tf.variable_scope('transform'): input_tensor = tf.layers.dense( decoder_output, units=config.hidden_size, activation=util.get_activation(config.hidden_act), kernel_initializer=util.create_initializer( config.initializer_range), trainable=trainable) input_tensor = util.layer_norm(input_tensor, trainable=trainable) output_weights = self.embedding_table output_bias = tf.get_variable('output_bias', shape=[config.vocab_size], initializer=tf.zeros_initializer(), trainable=trainable) flatten_input_tensor = tf.reshape(input_tensor, [-1, config.hidden_size]) logits = tf.matmul(flatten_input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [batch_size, seq_length, config.vocab_size]) probs = tf.nn.softmax(logits, axis=-1, name='probs') lm_log_probs = tf.nn.log_softmax(logits, axis=-1) self.preds['preds'] = tf.argmax(probs, axis=-1) one_hot_labels = tf.one_hot(input_ids, depth=config.vocab_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(lm_log_probs * one_hot_labels, axis=[-1]) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) self.total_loss = (tf.reduce_mean(per_example_loss) + tf.reduce_mean(tf.square(miu)) + tf.reduce_mean(tf.exp(sigma) - sigma - 1)) self.losses['losses'] = per_example_loss
def attention_layer(self, from_tensor, to_tensor, attention_mask=None, num_attention_heads=12, size_per_head=512, query_act=None, key_act=None, value_act=None, attention_probs_dropout_prob=0.0, initializer_range=0.02, do_return_2d_tensor=False, batch_size=None, from_max_seq_length=None, to_max_seq_length=None, dtype=tf.float32, trainable=True): def transpose_for_scores(input_tensor, batch_size, num_attention_heads, max_seq_length, width): output_tensor = tf.reshape( input_tensor, [batch_size, max_seq_length, num_attention_heads, width]) output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) return output_tensor # Scalar dimensions referenced here: # B = batch size (number of sequences) # F = from_tensor sequence length # T = to_tensor sequence length # N = num_attention_heads # H = size_per_head from_tensor_2d = util.reshape_to_matrix(from_tensor) to_tensor_2d = util.reshape_to_matrix(to_tensor) # query_layer = [B*F, N*H] query_layer = tf.layers.dense( from_tensor_2d, num_attention_heads * size_per_head, activation=query_act, name='query', kernel_initializer=util.create_initializer(initializer_range), trainable=trainable) # key_layer = [B*T, N*H] key_layer = tf.layers.dense( to_tensor_2d, num_attention_heads * size_per_head, activation=key_act, name='key', kernel_initializer=util.create_initializer(initializer_range), trainable=trainable) # value_layer = [B*T, N*H] value_layer = tf.layers.dense( to_tensor_2d, num_attention_heads * size_per_head, activation=value_act, name='value', kernel_initializer=util.create_initializer(initializer_range), trainable=trainable) # query_layer = [B, N, F, H] query_layer = transpose_for_scores(query_layer, batch_size, num_attention_heads, from_max_seq_length, size_per_head) # key_layer = [B, N, T, H] key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, to_max_seq_length, size_per_head) # Take the dot product between 'query' and 'key' to get the raw # attention scores. # attention_scores = [B, N, F, T] attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(size_per_head))) if attention_mask is not None: # attention_mask = [B, 1, F, T] attention_mask = tf.expand_dims(attention_mask, axis=[1]) adder = (1.0 - tf.cast(attention_mask, dtype)) * -10000.0 attention_scores += adder # Normalize the attention scores to probabilities. # attention_probs = [B, N, F, T] attention_probs = tf.nn.softmax(attention_scores, axis=-1) # This is actually dropping out entire tokens to attend to, # which might seem a bit unusual, but is taken from the original # Transformer paper. attention_probs = util.dropout(attention_probs, attention_probs_dropout_prob) # value_layer = [B, T, N, H] value_layer = tf.reshape(value_layer, [ batch_size, to_max_seq_length, num_attention_heads, size_per_head ]) # value_layer = [B, N, T, H] value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) # context_layer = [B, N, F, H] context_layer = tf.matmul(attention_probs, value_layer) # context_layer = [B, F, N, H] context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) if do_return_2d_tensor: # context_layer = [B*F, N*H] context_layer = tf.reshape(context_layer, [ batch_size * from_max_seq_length, num_attention_heads * size_per_head ]) else: # context_layer = [B, F, N*H] context_layer = tf.reshape(context_layer, [ batch_size, from_max_seq_length, num_attention_heads * size_per_head ]) return (context_layer, attention_scores)
def __init__(self, bert_config, is_training, sketchy_encoder, intensive_encoder, query_mask, label_ids, has_answer, sample_weight=None, scope='retro_reader', matching_mechanism='cross-attention', beta_1=0.5, beta_2=0.5, threshold=1.0, trainable=True, **kwargs): super().__init__(**kwargs) # verifier with tf.variable_scope(scope): # sketchy reading module with tf.variable_scope('sketchy/prediction'): sketchy_output = sketchy_encoder.get_pooled_output() hidden_size = sketchy_output.shape.as_list()[-1] output_weights = tf.get_variable( 'output_weights', shape=[2, hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[2], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( sketchy_output, bert_config.hidden_dropout_prob \ if is_training else 0.0) logits = tf.matmul( output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot( has_answer, depth=2, dtype=tf.float32) per_example_loss = - tf.reduce_sum( one_hot_labels * log_probs, axis=-1) if sample_weight is not None: per_example_loss = tf.cast( sample_weight, dtype=tf.float32) * per_example_loss self.losses['sketchy_losses'] = per_example_loss sketchy_loss = tf.reduce_mean(per_example_loss) score_ext = logits[:, 1] - logits[:, 0] # intensive reading module with tf.variable_scope('intensive'): H = intensive_encoder.get_sequence_output() H_Q = H * tf.cast( tf.expand_dims(query_mask, axis=-1), tf.float32) (batch_size, max_seq_length, hidden_size) = \ util.get_shape_list(H) # cross-attention if matching_mechanism == 'cross-attention': with tf.variable_scope('cross_attention'): attention_mask = \ self.create_attention_mask_from_input_mask( query_mask, batch_size, max_seq_length) (H_prime, _) = self.attention_layer( from_tensor=H, to_tensor=H_Q, attention_mask=attention_mask, num_attention_heads=\ bert_config.num_attention_heads, size_per_head=\ hidden_size // bert_config.num_attention_heads, attention_probs_dropout_prob=\ bert_config.hidden_dropout_prob, initializer_range=bert_config.initializer_range, do_return_2d_tensor=False, batch_size=batch_size, from_max_seq_length=max_seq_length, to_max_seq_length=max_seq_length, trainable=trainable) # matching-attention elif matching_mechanism == 'matching-attention': with tf.variable_scope('matching_attention'): output_weights = tf.get_variable( 'output_weights', shape=[hidden_size, hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[hidden_size], initializer=tf.zeros_initializer(), trainable=trainable) trans = tf.matmul( H_Q, tf.tile( tf.expand_dims(output_weights, axis=0), [batch_size, 1, 1]), transpose_b=True) trans = tf.nn.bias_add(trans, output_bias) M = tf.nn.softmax( tf.matmul(H, trans, transpose_b=True), axis=-1) H_prime = tf.matmul(M, H_Q) with tf.variable_scope('prediction'): output_weights = tf.get_variable( 'output_weights', shape=[2, hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[2], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( H_prime, bert_config.hidden_dropout_prob \ if is_training else 0.0) output_layer = tf.reshape( output_layer, [batch_size * max_seq_length, hidden_size]) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape( logits, [batch_size, max_seq_length, 2]) logits = tf.transpose(logits, [0, 2, 1]) probs = tf.nn.softmax(logits, axis=-1, name='probs') self.probs['mrc_probs'] = probs self.preds['mrc_preds'] = tf.argmax(logits, axis=-1) start_one_hot_labels = tf.one_hot( label_ids[:, 0], depth=max_seq_length, dtype=tf.float32) end_one_hot_labels = tf.one_hot( label_ids[:, 1], depth=max_seq_length, dtype=tf.float32) start_log_probs = tf.nn.log_softmax(logits[:, 0, :], axis=-1) end_log_probs = tf.nn.log_softmax(logits[:, 1, :], axis=-1) per_example_loss = ( - 0.5 * tf.reduce_sum( start_one_hot_labels * start_log_probs, axis=-1) - 0.5 * tf.reduce_sum( end_one_hot_labels * end_log_probs, axis=-1)) if sample_weight is not None: per_example_loss *= sample_weight intensive_loss = tf.reduce_mean(per_example_loss) self.losses['intensive_losses'] = per_example_loss score_has = tf.norm( probs[:, 0, 1:] + probs[:, 1, 1:], np.inf, axis=-1) score_null = probs[:, 0, 0] + probs[:, 1, 0] score_diff = score_has - score_null # rear verification v = beta_1 * score_diff + beta_2 * score_ext self.preds['verifier_preds'] = \ tf.cast(tf.greater(v, threshold), tf.int32) self.probs['verifier_probs'] = v self.total_loss = sketchy_loss + intensive_loss
def _local_perm(inputs, targets, is_masked, perm_size, seq_len): ''' Sample a permutation of the factorization order, and create an attention mask accordingly. Args: inputs: int64 Tensor in shape [seq_len], input ids. targets: int64 Tensor in shape [seq_len], target ids. is_masked: bool Tensor in shape [seq_len]. True means being selected for partial prediction. perm_size: the length of longest permutation. Could be set to be reuse_len. Should not be larger than reuse_len or there will be data leaks. seq_len: int, sequence length. ''' batch_size = tf.shape(inputs)[0] # Generate permutation indices index = tf.range(seq_len, dtype=tf.int64) index = tf.reshape(index, [-1, perm_size]) index = tf.transpose(index) index = tf.random_shuffle(index) index = tf.transpose(index) index = tf.reshape(index, [1, -1]) index = tf.tile(index, [batch_size, 1]) # `perm_mask` and `target_mask` # non-functional tokens non_func_tokens = tf.logical_not( tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID))) non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) masked_or_func_tokens = tf.logical_not(non_mask_tokens) # Set the permutation indices of non-masked (& non-funcional) tokens to the # smallest index (-1): # (1) they can be seen by all other positions # (2) they cannot see masked positions, so there won't be information leak smallest_index = -tf.ones([batch_size, seq_len], dtype=tf.int64) rev_index = tf.where(non_mask_tokens, smallest_index, index) # Create `target_mask`: non-funcional and maksed tokens # 1: use mask as input and have loss # 0: use token (or [SEP], [CLS]) as input and do not have loss target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) target_mask = tf.cast(target_tokens, tf.float32) # Create `perm_mask` # `target_tokens` cannot see themselves self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) # 0: can attend if i > j or j is non-masked perm_mask = tf.logical_and( self_rev_index[:, :, None] <= rev_index[:, None, :], tf.expand_dims(masked_or_func_tokens, axis=-1)) # new target: [next token] for LM and [curr token] (self) for PLM new_targets = tf.concat([inputs[:, 0:1], targets[:, :-1]], axis=1) # construct inputs_k inputs_k = inputs # construct inputs_q inputs_q = target_mask return perm_mask, new_targets, target_mask, inputs_k, inputs_q
def __init__(self, bert_config, is_training, encoder, masked_lm_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels, sample_weight=None, scope_lm='cls/predictions', scope_cls='cls/seq_relationship', trainable=True, use_nsp_loss=True, **kwargs): super(BERTDecoder, self).__init__(**kwargs) def gather_indexes(sequence_tensor, positions): sequence_shape = util.get_shape_list(sequence_tensor, 3) batch_size = sequence_shape[0] seq_length = sequence_shape[1] width = sequence_shape[2] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor scalar_losses = [] # masked language modeling input_tensor = gather_indexes(encoder.get_sequence_output(), masked_lm_positions) with tf.variable_scope(scope_lm): with tf.variable_scope('transform'): input_tensor = tf.layers.dense( input_tensor, units=bert_config.hidden_size, activation=util.get_activation(bert_config.hidden_act), kernel_initializer=util.create_initializer( bert_config.initializer_range)) input_tensor = util.layer_norm(input_tensor) output_bias = tf.get_variable('output_bias', shape=[bert_config.vocab_size], initializer=tf.zeros_initializer(), trainable=trainable) logits = tf.matmul(input_tensor, encoder.get_embedding_table(), transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probs = tf.nn.softmax(logits, axis=-1, name='MLM_probs') log_probs = tf.nn.log_softmax(logits, axis=-1) label_ids = tf.reshape(masked_lm_ids, [-1]) if sample_weight is not None: sample_weight = tf.expand_dims(tf.cast(sample_weight, dtype=tf.float32), axis=-1) masked_lm_weights *= sample_weight label_weights = tf.reshape(masked_lm_weights, [-1]) one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) per_example_loss = label_weights * per_example_loss numerator = tf.reduce_sum(per_example_loss) denominator = tf.reduce_sum(label_weights) + 1e-5 loss = numerator / denominator scalar_losses.append(loss) self.losses['MLM_losses'] = per_example_loss self.preds['MLM_preds'] = tf.argmax(probs, axis=-1) # next sentence prediction with tf.variable_scope(scope_cls): output_weights = tf.get_variable( 'output_weights', shape=[2, bert_config.hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[2], initializer=tf.zeros_initializer(), trainable=trainable) logits = tf.matmul(encoder.get_pooled_output(), output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probs = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) labels = tf.reshape(next_sentence_labels, [-1]) one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) if sample_weight is not None: per_example_loss = (tf.cast(sample_weight, dtype=tf.float32) * per_example_loss) loss = tf.reduce_mean(per_example_loss) if use_nsp_loss: scalar_losses.append(loss) self.losses['NSP_losses'] = per_example_loss self.probs['NSP_probs'] = probs self.preds['NSP_preds'] = tf.argmax(probs, axis=-1) self.total_loss = tf.add_n(scalar_losses)
def __init__(self, vocab_size, is_training, source_ids, target_ids, sos_id, sample_weight=None, hidden_size=768, num_blocks=6, num_attention_heads=12, scope='transformer', use_label_smoothing=False, use_tilda_embedding=False, trainable=True, **kwargs): super().__init__() dropout_rate = 0.0 if is_training: dropout_rate = 0.1 source_shape = util.get_shape_list(source_ids, expected_rank=2) target_shape = util.get_shape_list(target_ids, expected_rank=2) batch_size = source_shape[0] source_max_seq_length = source_shape[1] target_max_seq_length = target_shape[1] # Tilda embeddings for SMART algorithm tilda_embeddings = None if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): source_mask = tf.math.equal(source_ids, 0) # embedding with tf.variable_scope('embeddings'): (enc, embedding_table) = embedding_lookup( input_ids=source_ids, vocab_size=vocab_size, batch_size=batch_size, max_seq_length=source_max_seq_length, embedding_size=hidden_size, word_embedding_name='word_embeddings', tilda_embeddings=tilda_embeddings) enc *= hidden_size ** 0.5 # scale enc += positional_encoding(enc, source_max_seq_length) enc = util.dropout(enc, dropout_rate) with tf.variable_scope('encoder'): # stacked multi-attention layers for i in range(num_blocks): with tf.variable_scope('block_%s' % i): # self-attention enc = multihead_attention( queries=enc, keys=enc, values=enc, key_masks=source_mask, num_heads=num_attention_heads, dropout_rate=dropout_rate, training=is_training, causality=False, scope='self_attention') # feed forward enc = ff(enc, num_units=[hidden_size * 4, hidden_size]) memory = enc def _forward(target_ids, target_mask, target_max_seq_length): with tf.variable_scope('decoder'): # shared embedding dec = tf.nn.embedding_lookup(embedding_table, target_ids) dec *= hidden_size ** 0.5 # scale dec += positional_encoding(dec, target_max_seq_length) dec = util.dropout(dec, dropout_rate) # blocks for i in range(num_blocks): with tf.variable_scope('block_%s' % i): # masked self-attention dec = multihead_attention( queries=dec, keys=dec, values=dec, key_masks=target_mask, num_heads=num_attention_heads, dropout_rate=dropout_rate, training=is_training, causality=True, scope='masked_self_attention') # vanilla attention dec = multihead_attention( queries=dec, keys=memory, values=memory, key_masks=source_mask, num_heads=num_attention_heads, dropout_rate=dropout_rate, training=is_training, causality=False, scope='vanilla_attention') # feed forward dec = ff( dec, num_units=[4 * hidden_size, hidden_size]) # final linear projection (embedding weights are shared) with tf.variable_scope('cls'): output_bias = tf.get_variable( 'output_bias', shape=[vocab_size], initializer=tf.zeros_initializer()) dec = tf.reshape(dec, [-1, hidden_size]) logits = tf.matmul(dec, embedding_table, transpose_b=True) logits = tf.reshape( logits, [-1, target_max_seq_length, vocab_size]) logits = tf.nn.bias_add(logits, output_bias) return logits # convert to labels label_ids = tf.concat( [target_ids[:, 1:], tf.zeros([batch_size, 1], dtype=tf.int32)], axis=-1) # forward once if is_training: target_mask = tf.math.equal(target_ids, 0) # (N, T2) logits = _forward( target_ids, target_mask, target_max_seq_length) self.preds['MT'] = tf.argmax(logits, axis=-1) # forward loop else: target_mask_base = tf.zeros([batch_size, 1], dtype=tf.int32) target_ids = tf.ones([batch_size, 1], dtype=tf.int32) * sos_id for cur_length in range(1, target_max_seq_length + 1): target_mask = tf.tile(target_mask_base, [1, cur_length]) logits = _forward(target_ids, target_mask, cur_length) pred_ids = tf.argmax( logits[:, cur_length-1:cur_length, :], axis=-1) pred_ids = tf.cast(pred_ids, tf.int32) target_ids = tf.concat([target_ids, pred_ids], axis=-1) self.preds['MT'] = target_ids[:, 1:] # loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=vocab_size) if use_label_smoothing: one_hot_labels = label_smoothing(one_hot_labels) per_token_loss = -tf.reduce_sum( one_hot_labels * log_probs, axis=-1) label_mask = tf.cast(tf.not_equal(label_ids, 0), tf.float32) per_example_loss = \ tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \ tf.reduce_sum(label_mask, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) self.total_loss = tf.reduce_mean(per_example_loss) self.losses['MT'] = per_example_loss
def __init__(self, is_training, input_tensor, n_wide_features, wide_features, label_ids, label_size=2, sample_weight=None, scope='cls/seq_relationship', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) hidden_size = input_tensor.shape.as_list()[-1] feature_size = wide_features.shape.as_list()[-1] with tf.variable_scope('wide'): feature_embeddings = tf.get_variable( name='feature_embeddings', shape=[feature_size + 1, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) wide_output = tf.gather(feature_embeddings, wide_features) # [B, N, H] with tf.variable_scope('wide_and_deep'): deep_output = tf.expand_dims(input_tensor, -1) # [B, H, 1] attention_scores = tf.matmul(wide_output, deep_output) # [B, N, 1] attention_scores = tf.transpose(attention_scores, [0, 2, 1]) # [B, 1, N] attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(hidden_size)) feature_mask = tf.cast( tf.sequence_mask(n_wide_features, feature_size), tf.float32) # [B, N] feature_mask = tf.expand_dims(feature_mask, 1) # [B, 1, N] attention_scores += (1.0 - feature_mask) * -10000.0 attention_matrix = tf.nn.softmax(attention_scores, axis=-1) attention_output = tf.matmul(attention_matrix, wide_output) # [B, 1, H] attention_output = attention_output[:, 0, :] # [B, H] # attention_output = util.dropout( # attention_output, hidden_dropout_prob) input_tensor = util.layer_norm(attention_output + input_tensor, trainable=trainable) with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) self.preds['preds'] = tf.argmax(logits, axis=-1) self.probs['probs'] = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=label_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) if sample_weight is not None: per_example_loss = tf.cast(sample_weight, dtype=tf.float32) * per_example_loss thresh = kwargs.get('tsa_thresh') if thresh is not None: assert isinstance( thresh, float), ('`tsa_thresh` must be a float between 0 and 1.') uncertainty = tf.reduce_sum(self.probs['probs'] * tf.log(self.probs['probs']), axis=-1) uncertainty /= tf.log(1 / label_size) per_example_loss = tf.cast( tf.greater(uncertainty, thresh), dtype=tf.float32) * \ per_example_loss self.losses['losses'] = per_example_loss self.total_loss = tf.reduce_mean(per_example_loss)
def _get_generator_output(self, inputs, sample_weight, generator): '''Masked language modeling softmax layer.''' def gather_indexes(sequence_tensor, positions): sequence_shape = util.get_shape_list(sequence_tensor, 3) batch_size = sequence_shape[0] seq_length = sequence_shape[1] width = sequence_shape[2] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor input_tensor = gather_indexes(generator.get_sequence_output(), inputs.masked_lm_positions) with tf.variable_scope('generator_predictions'): input_tensor = tf.layers.dense( input_tensor, units=self.config.embedding_size, activation=util.get_activation(self.bert_config.hidden_act), kernel_initializer=util.create_initializer( self.bert_config.initializer_range)) input_tensor = util.layer_norm(input_tensor) output_bias = tf.get_variable('output_bias', shape=[self.bert_config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, generator.get_embedding_table(), transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probs = tf.nn.softmax(logits, axis=-1, name='MLM_probs') preds = tf.argmax(logits, axis=-1) log_probs = tf.nn.log_softmax(logits, axis=-1) label_ids = tf.reshape(inputs.masked_lm_ids, [-1]) masked_lm_weights = inputs.masked_lm_weights if sample_weight is not None: sample_weight = tf.expand_dims(tf.cast(sample_weight, dtype=tf.float32), axis=-1) masked_lm_weights *= sample_weight label_weights = tf.reshape(masked_lm_weights, [-1]) one_hot_labels = tf.one_hot(label_ids, depth=self.bert_config.vocab_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) per_example_loss = label_weights * per_example_loss numerator = tf.reduce_sum(per_example_loss) denominator = tf.reduce_sum(label_weights) + 1e-6 loss = numerator / denominator MLMOutput = collections.namedtuple( 'MLMOutput', ['logits', 'probs', 'loss', 'per_example_loss', 'preds']) return MLMOutput(logits=logits, probs=probs, per_example_loss=per_example_loss, loss=loss, preds=preds)
def attention_layer(from_tensor, to_tensor, attention_mask=None, num_attention_heads=1, size_per_head=512, query_act=None, key_act=None, value_act=None, attention_probs_dropout_prob=0.0, initializer_range=0.02, do_return_2d_tensor=False, batch_size=None, from_seq_length=None, to_seq_length=None): '''Performs multi-headed attention from `from_tensor` to `to_tensor`. This is an implementation of multi-headed attention based on 'Attention is all you Need'. If `from_tensor` and `to_tensor` are the same, then this is self-attention. Each timestep in `from_tensor` attends to the corresponding sequence in `to_tensor`, and returns a fixed-with vector. This function first projects `from_tensor` into a 'query' tensor and `to_tensor` into 'key' and 'value' tensors. These are (effectively) a list of tensors of length `num_attention_heads`, where each tensor is of shape [batch_size, seq_length, size_per_head]. Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor and returned. In practice, the multi-headed attention are done with transposes and reshapes rather than actual separate tensors. Args: from_tensor: float Tensor of shape [batch_size, from_seq_length, from_width]. to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length, to_seq_length]. The values should be 1 or 0. The attention scores will effectively be set to -infinity for any positions in the mask that are 0, and will be unchanged for positions that are 1. num_attention_heads: int. Number of attention heads. size_per_head: int. Size of each attention head. query_act: (optional) Activation function for the query transform. key_act: (optional) Activation function for the key transform. value_act: (optional) Activation function for the value transform. attention_probs_dropout_prob: (optional) float. Dropout probability of the attention probabilities. initializer_range: float. Range of the weight initializer. do_return_2d_tensor: bool. If True, the output will be of shape [batch_size * from_seq_length, num_attention_heads * size_per_head]. If False, the output will be of shape [batch_size, from_seq_length, num_attention_heads * size_per_head]. batch_size: (Optional) int. If the input is 2D, this might be the batch size of the 3D version of the `from_tensor` and `to_tensor`. from_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `from_tensor`. to_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `to_tensor`. Returns: float Tensor of shape [batch_size, from_seq_length, num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is true, this will be of shape [batch_size * from_seq_length, num_attention_heads * size_per_head]). Raises: ValueError: Any of the arguments or tensor shapes are invalid. ''' def transpose_for_scores(input_tensor, batch_size, num_attention_heads, seq_length, width): output_tensor = tf.reshape( input_tensor, [batch_size, seq_length, num_attention_heads, width]) output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) return output_tensor from_shape = util.get_shape_list(from_tensor, expected_rank=[2, 3]) to_shape = util.get_shape_list(to_tensor, expected_rank=[2, 3]) if len(from_shape) != len(to_shape): raise ValueError( 'The rank of `from_tensor` must match the rank of `to_tensor`.') if len(from_shape) == 3: batch_size = from_shape[0] from_seq_length = from_shape[1] to_seq_length = to_shape[1] elif len(from_shape) == 2: if batch_size is None or from_seq_length is None or to_seq_length is None: raise ValueError( 'When passing in rank 2 tensors to attention_layer, the values ' 'for `batch_size`, `from_seq_length`, and `to_seq_length` ' 'must all be specified.') # Scalar dimensions referenced here: # B = batch size (number of sequences) # F = `from_tensor` sequence length # T = `to_tensor` sequence length # N = `num_attention_heads` # H = `size_per_head` from_tensor_2d = util.reshape_to_matrix(from_tensor) to_tensor_2d = util.reshape_to_matrix(to_tensor) # `query_layer` = [B*F, N*H] query_layer = tf.layers.dense( from_tensor_2d, num_attention_heads * size_per_head, activation=query_act, name='query', kernel_initializer=util.create_initializer(initializer_range)) # `key_layer` = [B*T, N*H] key_layer = tf.layers.dense( to_tensor_2d, num_attention_heads * size_per_head, activation=key_act, name='key', kernel_initializer=util.create_initializer(initializer_range)) # `value_layer` = [B*T, N*H] value_layer = tf.layers.dense( to_tensor_2d, num_attention_heads * size_per_head, activation=value_act, name='value', kernel_initializer=util.create_initializer(initializer_range)) # `query_layer` = [B, N, F, H] query_layer = transpose_for_scores(query_layer, batch_size, num_attention_heads, from_seq_length, size_per_head) # `key_layer` = [B, N, T, H] key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, to_seq_length, size_per_head) # Take the dot product between 'query' and 'key' to get the raw # attention scores. # `attention_scores` = [B, N, F, T] attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(size_per_head))) if attention_mask is not None: # `attention_mask` = [B, 1, F, T] attention_mask = tf.expand_dims(attention_mask, axis=[1]) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_scores += adder # Normalize the attention scores to probabilities. # `attention_probs` = [B, N, F, T] attention_probs = tf.nn.softmax(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = util.dropout(attention_probs, attention_probs_dropout_prob) # `value_layer` = [B, T, N, H] value_layer = tf.reshape( value_layer, [batch_size, to_seq_length, num_attention_heads, size_per_head]) # `value_layer` = [B, N, T, H] value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) # `context_layer` = [B, N, F, H] context_layer = tf.matmul(attention_probs, value_layer) # `context_layer` = [B, F, N, H] context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) if do_return_2d_tensor: # `context_layer` = [B*F, N*H] context_layer = tf.reshape(context_layer, [ batch_size * from_seq_length, num_attention_heads * size_per_head ]) else: # `context_layer` = [B, F, N*H] context_layer = tf.reshape( context_layer, [batch_size, from_seq_length, num_attention_heads * size_per_head]) return context_layer, attention_probs
def _build_forward(layer_input): with tf.variable_scope('attention'): with tf.variable_scope('self'): layer_input *= tf.cast(tf.expand_dims(input_mask, axis=-1), dtype=tf.float32) attention_layer = Attention( hidden_size=hidden_size, num_heads=num_attention_heads, attention_dropout=attention_probs_dropout_prob, kernel_transformation=\ self.kernel_transformation, numerical_stabilizer=0.001, causal=False, projection_matrix_type=True \ if bool(self.nb_random_features) else None, nb_random_features=self.nb_random_features) attention_layer.build(layer_input.shape) attention_output = attention_layer.call( layer_input, layer_input, bias=None, training=is_training, cache=None, decode_loop_step=None) with tf.variable_scope('output'): attention_output = tf.layers.dense( attention_output, hidden_size, kernel_initializer=util.create_initializer( initializer_range), trainable=trainable) attention_output = util.dropout( attention_output, hidden_dropout_prob) attention_output = util.layer_norm( attention_output + layer_input, trainable=trainable) # The activation is only applied to the `intermediate` # hidden layer. with tf.variable_scope('intermediate'): intermediate_output = tf.layers.dense( attention_output, intermediate_size, activation=intermediate_act_fn, kernel_initializer=util.create_initializer( initializer_range), trainable=trainable) # Down-project back to hidden_size then add the residual. with tf.variable_scope('output'): layer_output = tf.layers.dense( intermediate_output, hidden_size, kernel_initializer=util.create_initializer( initializer_range), trainable=trainable) layer_output = util.dropout(layer_output, hidden_dropout_prob) layer_output = util.layer_norm(layer_output + attention_output, trainable=trainable) return layer_output
def __init__(self, bert_config, is_training, dilated_ids, label_ids, max_seq_length, spad_id=1, loop=3, sample_weight=None, scope='dilated', use_tilda_embedding=False, **kwargs): super().__init__() dilated_mask = tf.cast(tf.not_equal(dilated_ids, 0), tf.float32) shape = util.get_shape_list(dilated_ids, expected_rank=2) batch_size = shape[0] dilated_seq_length = shape[1] # Tilda embeddings for SMART algorithm tilda_embeddings = None if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): # forward once if is_training: logits = self._bert_forward(bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) self.preds['LM'] = tf.argmax(logits, axis=-1) # LM loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size) per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) input_length = tf.reduce_sum(dilated_mask, axis=-1) * 2 label_mask = tf.sequence_mask(input_length, max_seq_length * 2, dtype=tf.float32) per_example_loss = \ tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \ tf.reduce_sum(label_mask, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) self.total_loss = tf.reduce_mean(per_example_loss) self.losses['LM'] = per_example_loss # forward loop else: def _forward(dilated_ids, dilated_mask): logits = self._bert_forward( bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) output_ids = tf.argmax(logits, axis=-1) output_ids = tf.cast(output_ids, dtype=tf.int32) # special padding (using `spad` token) equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32) equal_zero = tf.reduce_sum(equal_zero, axis=-1) right_pad = spad_id * tf.sequence_mask( equal_zero, dilated_seq_length, dtype=tf.int32) paded = tf.concat([output_ids, right_pad], axis=-1) # extract ids of length `max_seq_length` flattened_padded = tf.reshape(paded, [-1]) is_valid = tf.cast(tf.greater(flattened_padded, 0), dtype=tf.int32) flattened_valid = tf.boolean_mask(flattened_padded, is_valid) valid = tf.reshape(flattened_valid, [batch_size, dilated_seq_length]) cutted_valid = valid[:, :max_seq_length] # replace `spad` token with `pad` non_spad_mask = tf.cast(tf.not_equal( cutted_valid, spad_id), dtype=tf.int32) output_ids = cutted_valid * non_spad_mask output_length = tf.reduce_sum(non_spad_mask, axis=-1) # dilate reshaped_ids = tf.reshape(output_ids, [batch_size, max_seq_length, 1]) reshaped_mask = tf.reshape( tf.sequence_mask(output_length, max_seq_length, dtype=tf.int32), [batch_size, max_seq_length, 1]) concat_ids = tf.concat( [reshaped_ids, tf.zeros_like(reshaped_ids)], axis=-1) concat_mask = tf.concat([ reshaped_mask, tf.zeros_like(reshaped_mask, dtype=tf.int32) ], axis=-1) dilated_ids = tf.reshape(concat_ids, [batch_size, max_seq_length * 2]) dilated_mask = tf.reshape(concat_mask, [batch_size, max_seq_length * 2]) return dilated_ids, dilated_mask for _ in range(loop): dilated_ids, dilated_mask = _forward( dilated_ids, dilated_mask) self.preds['LM'] = dilated_ids
def _lm_forward(self, is_training, input_tensor, input_mask, label_ids, bert_config, batch_size, max_seq_length, prob, scope, name, sample_weight=None, hidden_dropout_prob=0.1, initializer_range=0.02): with tf.variable_scope(scope): with tf.variable_scope('verifier'): logits = tf.layers.dense( input_tensor, 2, kernel_initializer=util.create_initializer( bert_config.initializer_range), trainable=True) verifier_label_ids = tf.cast(tf.greater(label_ids, 0), tf.int32) # loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(verifier_label_ids, depth=2) per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) input_mask = tf.cast(input_mask, tf.float32) per_token_loss *= input_mask / tf.reduce_sum( input_mask, keepdims=True, axis=-1) per_example_loss = tf.reduce_sum(per_token_loss, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) if prob != 0: self.total_loss += tf.reduce_mean(per_example_loss) verifier_loss = per_example_loss verifier_preds = tf.argmax(logits, axis=-1) with tf.variable_scope('prediction'): with tf.variable_scope('intermediate'): logits = tf.layers.dense( input_tensor, bert_config.hidden_size * 4, kernel_initializer=util.create_initializer( bert_config.initializer_range), activation=util.gelu, trainable=True) with tf.variable_scope('output'): logits = tf.layers.dense( logits, bert_config.hidden_size, kernel_initializer=util.create_initializer( bert_config.initializer_range), trainable=True) flattened = tf.reshape( logits, [batch_size * max_seq_length, bert_config.hidden_size]) logits = tf.matmul(flattened, self.embedding_table, transpose_b=True) logits = tf.reshape( logits, [-1, max_seq_length, bert_config.vocab_size]) # loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size) per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) input_mask *= tf.cast(verifier_preds, tf.float32) per_token_loss *= input_mask / ( tf.reduce_sum(input_mask, keepdims=True, axis=-1) + 1e-6) per_example_loss = tf.reduce_sum(per_token_loss, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) if prob != 0: self.total_loss += tf.reduce_mean(per_example_loss) self.losses[name + '_loss'] = verifier_loss self.preds[name + '_preds'] = \ tf.argmax(logits, axis=-1) * verifier_preds