def __init__(self, is_training, input_tensor, label_ids, sample_weight=None, scope='mrc', name='', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) seq_length = input_tensor.shape.as_list()[-2] hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[2, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[2], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) output_layer = tf.reshape(output_layer, [-1, hidden_size]) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [-1, seq_length, 2]) logits = tf.transpose(logits, [0, 2, 1]) probs = tf.nn.softmax(logits, axis=-1, name='probs') self.probs[name] = probs start_one_hot_labels = tf.one_hot(label_ids[:, 0], depth=seq_length, dtype=tf.float32) end_one_hot_labels = tf.one_hot(label_ids[:, 1], depth=seq_length, dtype=tf.float32) start_log_probs = tf.nn.log_softmax(logits[:, 0, :], axis=-1) end_log_probs = tf.nn.log_softmax(logits[:, 1, :], axis=-1) per_example_loss = ( -0.5 * tf.reduce_sum(start_one_hot_labels * start_log_probs, axis=-1) - 0.5 * tf.reduce_sum(end_one_hot_labels * end_log_probs, axis=-1)) if sample_weight is not None: per_example_loss *= sample_weight self.total_loss = tf.reduce_mean(per_example_loss) self.losses[name] = per_example_loss start_preds = tf.expand_dims(tf.argmax(logits[:, 0, :], axis=-1), axis=-1) end_preds = tf.expand_dims(tf.argmax(logits[:, 1, :], axis=-1), axis=-1) self.preds[name] = tf.concat([start_preds, end_preds], axis=-1)
def classification_loss(hidden, labels, n_class, initializer, scope, reuse=None, return_logits=False): ''' Different classification tasks should use different scope names to ensure different dense layers (parameters) are used to produce the logits. An exception will be in transfer learning, where one hopes to transfer the classification weights. ''' with tf.variable_scope(scope, reuse=reuse): logits = tf.layers.dense(hidden, n_class, kernel_initializer=initializer, name='logit') one_hot_target = tf.one_hot(labels, n_class, dtype=hidden.dtype) loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1) if return_logits: return loss, logits return loss
def lm_loss(hidden, target, n_token, d_model, initializer, lookup_table=None, tie_weight=False, bi_data=True, use_tpu=False): '''doc.''' with tf.variable_scope('lm_loss'): if tie_weight: assert lookup_table is not None, \ 'lookup_table cannot be None for tie_weight' softmax_w = lookup_table else: softmax_w = tf.get_variable( 'weight', [n_token, d_model], dtype=hidden.dtype, initializer=initializer) softmax_b = tf.get_variable( 'bias', [n_token], dtype=hidden.dtype, initializer=tf.zeros_initializer()) logits = tf.einsum('ibd,nd->ibn', hidden, softmax_w) + softmax_b preds = tf.argmax(logits, axis=-1) if use_tpu: one_hot_target = tf.one_hot(target, n_token, dtype=logits.dtype) loss = -tf.reduce_sum( tf.nn.log_softmax(logits) * one_hot_target, -1) else: loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=target, logits=logits) return loss, preds
def embedding_lookup(x, n_token, d_embed, initializer, use_tpu=True, scope='embedding', tilda_embeddings=None, reuse=None, dtype=tf.float32): '''TPU and GPU embedding_lookup function.''' if tilda_embeddings is not None: lookup_table = tilda_embeddings else: with tf.variable_scope(scope, reuse=reuse): lookup_table = tf.get_variable('lookup_table', [n_token, d_embed], dtype=dtype, initializer=initializer) if use_tpu: one_hot_idx = tf.one_hot(x, n_token, dtype=dtype) if one_hot_idx.shape.ndims == 2: return (tf.einsum('in,nd->id', one_hot_idx, lookup_table), lookup_table) else: return (tf.einsum('ibn,nd->ibd', one_hot_idx, lookup_table), lookup_table) else: return tf.nn.embedding_lookup(lookup_table, x), lookup_table
def __init__(self, is_training, input_tensor, input_mask, label_ids, label_size=2, sample_weight=None, scope='cls/sequence', name='', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) batch_size = tf.shape(input_tensor)[0] seq_length = input_tensor.shape.as_list()[-2] hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) output_layer = tf.reshape(output_layer, [-1, hidden_size]) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [-1, seq_length, label_size]) self.preds[name] = tf.argmax(logits, axis=-1) self.probs[name] = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=label_size, dtype=tf.float32) per_token_losses = -tf.reduce_mean(one_hot_labels * log_probs, axis=-1) input_mask = tf.concat([ tf.zeros((batch_size, 1), dtype=tf.float32), tf.cast(input_mask[:, 2:], dtype=tf.float32), tf.zeros((batch_size, 1), dtype=tf.float32) ], axis=-1) per_token_losses *= input_mask per_example_loss = tf.reduce_mean(per_token_losses, axis=-1) if sample_weight is not None: per_example_loss *= tf.cast(sample_weight, dtype=tf.float32) self.losses[name] = per_example_loss self.total_loss = tf.reduce_mean(per_example_loss)
def sample_from_softmax(logits, disallow=None): if disallow is not None: logits -= 1000.0 * tf.reshape(disallow, [-1, logits.shape[-1]]) uniform_noise = tf.random_uniform( util.get_shape_list(logits), minval=0, maxval=1) gumbel_noise = -tf.log(-tf.log(uniform_noise + 1e-9) + 1e-9) return tf.one_hot(tf.argmax(logits + gumbel_noise, -1, output_type=tf.int32), logits.shape[-1])
def embedding_postprocessor(self, input_tensor, position_ids, batch_size, max_seq_length, hidden_size, use_token_type=False, segment_ids=None, token_type_vocab_size=16, token_type_embedding_name=\ 'token_type_embeddings', use_position_embeddings=True, position_embedding_name='position_embeddings', initializer_range=0.02, max_position_embeddings=512, dropout_prob=0.1, dtype=tf.float32, trainable=True): output = input_tensor if use_token_type: if segment_ids is None: raise ValueError( 'segment_ids must be specified if use_token_type is True.') token_type_table = tf.get_variable( name=token_type_embedding_name, shape=[token_type_vocab_size, hidden_size], initializer=util.create_initializer(initializer_range), dtype=dtype, trainable=trainable) # This vocab will be small so we always do one-hot here, # since it is always faster for a small vocabulary. flat_segment_ids = tf.reshape(segment_ids, [-1]) one_hot_ids = tf.one_hot(flat_segment_ids, depth=token_type_vocab_size, dtype=dtype) token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) token_type_embeddings = tf.reshape( token_type_embeddings, [batch_size, max_seq_length, hidden_size]) output += token_type_embeddings if use_position_embeddings: full_position_embeddings = tf.get_variable( name=position_embedding_name, shape=[max_position_embeddings, hidden_size], initializer=util.create_initializer(initializer_range), dtype=dtype, trainable=trainable) output += tf.gather(full_position_embeddings, position_ids) output = util.layer_norm_and_dropout(output, dropout_prob, trainable=trainable) return output
def embedding_lookup(input_ids, vocab_size, embedding_size=128, initializer_range=0.02, word_embedding_name='word_embeddings', use_one_hot_embeddings=False): '''Looks up words embeddings for id tensor. Args: input_ids: int32 Tensor of shape [batch_size, seq_length] containing word ids. vocab_size: int. Size of the embedding vocabulary. embedding_size: int. Width of the word embeddings. initializer_range: float. Embedding initialization range. word_embedding_name: string. Name of the embedding table. use_one_hot_embeddings: bool. If True, use one-hot method for word embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better for TPUs. Returns: float Tensor of shape [batch_size, seq_length, embedding_size]. ''' # This function assumes that the input is of shape [batch_size, seq_length, # num_inputs]. # # If the input is a 2D tensor of shape [batch_size, seq_length], we # reshape to [batch_size, seq_length, 1]. original_dims = input_ids.shape.ndims if original_dims == 2: input_ids = tf.expand_dims(input_ids, axis=[-1]) embedding_table = tf.get_variable( name=word_embedding_name, shape=[vocab_size, embedding_size], initializer=util.create_initializer(initializer_range)) if original_dims == 3: input_shape = util.get_shape_list(input_ids) tf.reshape(input_ids, [-1, input_shape[-1]]) output = tf.matmul(input_ids, embedding_table) output = tf.reshape(output, [input_shape[0], input_shape[1], embedding_size]) else: if use_one_hot_embeddings: flat_input_ids = tf.reshape(input_ids, [-1]) one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) output = tf.matmul(one_hot_input_ids, embedding_table) else: output = tf.nn.embedding_lookup(embedding_table, input_ids) input_shape = util.get_shape_list(input_ids) output = tf.reshape( output, input_shape[0:-1] + [input_shape[-1] * embedding_size]) return output, embedding_table
def __init__(self, is_training, input_tensor, label_ids, label_size=2, sample_weight=None, scope='cls/seq_relationship', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) self.preds['preds'] = tf.argmax(logits, axis=-1) self.probs['probs'] = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot( label_ids, depth=label_size, dtype=tf.float32) per_example_loss = - tf.reduce_sum( one_hot_labels * log_probs, axis=-1) if sample_weight is not None: per_example_loss = tf.cast( sample_weight, dtype=tf.float32) * per_example_loss thresh = kwargs.get('tsa_thresh') if thresh is not None: assert isinstance(thresh, float), ( '`tsa_thresh` must be a float between 0 and 1.') uncertainty = tf.reduce_sum(self.probs['probs'] * tf.log( self.probs['probs']), axis=-1) uncertainty /= tf.log(1 / label_size) per_example_loss = tf.cast( tf.greater(uncertainty, thresh), dtype=tf.float32) * \ per_example_loss self.losses['losses'] = per_example_loss self.total_loss = tf.reduce_mean(per_example_loss)
def __init__(self, is_training, input_tensor, label_ids, label_size=2, sample_weight=None, scope='cls/seq_relationship', name='', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) self.preds[name] = tf.argmax(logits, axis=-1) self.probs[name] = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=label_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) if sample_weight is not None: per_example_loss = tf.cast(sample_weight, dtype=tf.float32) * per_example_loss self.losses[name] = per_example_loss self.total_loss = tf.reduce_mean(per_example_loss)
def _get_fake_data(self, inputs, mlm_logits): '''Sample from the generator to create corrupted input.''' inputs = unmask(inputs) disallow = tf.one_hot( inputs.masked_lm_ids, depth=self.bert_config.vocab_size, dtype=tf.float32) if self.config.disallow_correct else None sampled_tokens = tf.stop_gradient(sample_from_softmax( mlm_logits / self.config.temperature, disallow=disallow)) sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) updated_input_ids, masked = scatter_update( inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) labels = masked * (1 - tf.cast( tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) updated_inputs = get_updated_inputs( inputs, input_ids=updated_input_ids) FakedData = collections.namedtuple('FakedData', [ 'inputs', 'is_fake_tokens', 'sampled_tokens']) return FakedData(inputs=updated_inputs, is_fake_tokens=labels, sampled_tokens=sampled_tokens)
def _cls_forward(self, is_training, input_tensor, input_mask, label_ids, bert_config, batch_size, max_seq_length, prob, scope, name, sample_weight=None, hidden_dropout_prob=0.1, initializer_range=0.02): with tf.variable_scope(scope): logits = tf.layers.dense( input_tensor, 2, kernel_initializer=util.create_initializer( bert_config.initializer_range), trainable=True) # loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=2) per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) input_mask = tf.cast(input_mask, tf.float32) per_token_loss *= input_mask / tf.reduce_sum( input_mask, keepdims=True, axis=-1) per_example_loss = tf.reduce_sum(per_token_loss, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) if prob != 0: self.total_loss += tf.reduce_mean(per_example_loss) self.losses[name + '_loss'] = per_example_loss self.preds[name + '_preds'] = tf.argmax(logits, axis=-1)
def __init__(self, is_training, input_tensor, n_wide_features, wide_features, label_ids, label_size=2, sample_weight=None, scope='cls/seq_relationship', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) hidden_size = input_tensor.shape.as_list()[-1] feature_size = wide_features.shape.as_list()[-1] with tf.variable_scope('wide'): feature_embeddings = tf.get_variable( name='feature_embeddings', shape=[feature_size + 1, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) wide_output = tf.gather(feature_embeddings, wide_features) # [B, N, H] with tf.variable_scope('wide_and_deep'): deep_output = tf.expand_dims(input_tensor, -1) # [B, H, 1] attention_scores = tf.matmul(wide_output, deep_output) # [B, N, 1] attention_scores = tf.transpose(attention_scores, [0, 2, 1]) # [B, 1, N] attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(hidden_size)) feature_mask = tf.cast( tf.sequence_mask(n_wide_features, feature_size), tf.float32) # [B, N] feature_mask = tf.expand_dims(feature_mask, 1) # [B, 1, N] attention_scores += (1.0 - feature_mask) * -10000.0 attention_matrix = tf.nn.softmax(attention_scores, axis=-1) attention_output = tf.matmul(attention_matrix, wide_output) # [B, 1, H] attention_output = attention_output[:, 0, :] # [B, H] # attention_output = util.dropout( # attention_output, hidden_dropout_prob) input_tensor = util.layer_norm(attention_output + input_tensor, trainable=trainable) with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) self.preds['preds'] = tf.argmax(logits, axis=-1) self.probs['probs'] = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=label_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) if sample_weight is not None: per_example_loss = tf.cast(sample_weight, dtype=tf.float32) * per_example_loss thresh = kwargs.get('tsa_thresh') if thresh is not None: assert isinstance( thresh, float), ('`tsa_thresh` must be a float between 0 and 1.') uncertainty = tf.reduce_sum(self.probs['probs'] * tf.log(self.probs['probs']), axis=-1) uncertainty /= tf.log(1 / label_size) per_example_loss = tf.cast( tf.greater(uncertainty, thresh), dtype=tf.float32) * \ per_example_loss self.losses['losses'] = per_example_loss self.total_loss = tf.reduce_mean(per_example_loss)
def _get_generator_output(self, inputs, sample_weight, generator): '''Masked language modeling softmax layer.''' def gather_indexes(sequence_tensor, positions): sequence_shape = util.get_shape_list(sequence_tensor, 3) batch_size = sequence_shape[0] seq_length = sequence_shape[1] width = sequence_shape[2] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor input_tensor = gather_indexes(generator.get_sequence_output(), inputs.masked_lm_positions) with tf.variable_scope('generator_predictions'): input_tensor = tf.layers.dense( input_tensor, units=self.config.embedding_size, activation=util.get_activation(self.bert_config.hidden_act), kernel_initializer=util.create_initializer( self.bert_config.initializer_range)) input_tensor = util.layer_norm(input_tensor) output_bias = tf.get_variable('output_bias', shape=[self.bert_config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, generator.get_embedding_table(), transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probs = tf.nn.softmax(logits, axis=-1, name='MLM_probs') preds = tf.argmax(logits, axis=-1) log_probs = tf.nn.log_softmax(logits, axis=-1) label_ids = tf.reshape(inputs.masked_lm_ids, [-1]) masked_lm_weights = inputs.masked_lm_weights if sample_weight is not None: sample_weight = tf.expand_dims(tf.cast(sample_weight, dtype=tf.float32), axis=-1) masked_lm_weights *= sample_weight label_weights = tf.reshape(masked_lm_weights, [-1]) one_hot_labels = tf.one_hot(label_ids, depth=self.bert_config.vocab_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) per_example_loss = label_weights * per_example_loss numerator = tf.reduce_sum(per_example_loss) denominator = tf.reduce_sum(label_weights) + 1e-6 loss = numerator / denominator MLMOutput = collections.namedtuple( 'MLMOutput', ['logits', 'probs', 'loss', 'per_example_loss', 'preds']) return MLMOutput(logits=logits, probs=probs, per_example_loss=per_example_loss, loss=loss, preds=preds)
def embedding_postprocessor(input_tensor, use_token_type=False, token_type_ids=None, token_type_vocab_size=16, token_type_embedding_name='token_type_embeddings', use_position_embeddings=True, position_embedding_name='position_embeddings', initializer_range=0.02, max_position_embeddings=512, dropout_prob=0.1): '''Performs various post-processing on a word embedding tensor. Args: input_tensor: float Tensor of shape [batch_size, seq_length, embedding_size]. use_token_type: bool. Whether to add embeddings for `token_type_ids`. token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. Must be specified if `use_token_type` is True. token_type_vocab_size: int. The vocabulary size of `token_type_ids`. token_type_embedding_name: string. The name of the embedding table variable for token type ids. use_position_embeddings: bool. Whether to add position embeddings for the position of each token in the sequence. position_embedding_name: string. The name of the embedding table variable for positional embeddings. initializer_range: float. Range of the weight initialization. max_position_embeddings: int. Maximum sequence length that might ever be used with this model. This can be longer than the sequence length of input_tensor, but cannot be shorter. dropout_prob: float. Dropout probability applied to the final output tensor. Returns: float tensor with same shape as `input_tensor`. Raises: ValueError: One of the tensor shapes or input values is invalid. ''' input_shape = util.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape[0] seq_length = input_shape[1] width = input_shape[2] output = input_tensor if use_token_type: if token_type_ids is None: raise ValueError('`token_type_ids` must be specified if' '`use_token_type` is True.') token_type_table = tf.get_variable( name=token_type_embedding_name, shape=[token_type_vocab_size, width], initializer=util.create_initializer(initializer_range)) # This vocab will be small so we always do one-hot here, since it is always # faster for a small vocabulary. flat_token_type_ids = tf.reshape(token_type_ids, [-1]) one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) token_type_embeddings = tf.reshape(token_type_embeddings, [batch_size, seq_length, width]) output += token_type_embeddings if use_position_embeddings: assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) with tf.control_dependencies([assert_op]): full_position_embeddings = tf.get_variable( name=position_embedding_name, shape=[max_position_embeddings, width], initializer=util.create_initializer(initializer_range)) # Since the position embedding table is a learned variable, we create it # using a (long) sequence length `max_position_embeddings`. The actual # sequence length might be shorter than this, for faster training of # tasks that do not have long sequences. # # So `full_position_embeddings` is effectively an embedding table # for position [0, 1, 2, ..., max_position_embeddings-1], and the current # sequence has positions [0, 1, 2, ... seq_length-1], so we can just # perform a slice. position_embeddings = tf.slice(full_position_embeddings, [0, 0], [seq_length, -1]) num_dims = len(output.shape.as_list()) # Only the last two dimensions are relevant (`seq_length` and `width`), # so we broadcast among the first dimensions, which is typically just # the batch size. position_broadcast_shape = [] for _ in range(num_dims - 2): position_broadcast_shape.append(1) position_broadcast_shape.extend([seq_length, width]) position_embeddings = tf.reshape(position_embeddings, position_broadcast_shape) output += position_embeddings output = util.layer_norm_and_dropout(output, dropout_prob) return output
def __init__(self, vocab_size, is_training, input_ids, input_mask, segment_ids, sample_weight=None, reduced_size=64, topic_size=1024, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, bias=0, scope='vae', trainable=True, **kwargs): super().__init__() # freeze parameters config = Config(vocab_size, hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads) if not is_training: config.hidden_dropout_prob = 0.0 config.attention_probs_dropout_prob = 0.0 input_shape = util.get_shape_list(input_ids, expected_rank=2) batch_size = input_shape[0] seq_length = input_shape[1] # Tilda embeddings for SMART algorithm tilda_embeddings = None use_tilda_embedding = kwargs.get('use_tilda_embedding') if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): with tf.variable_scope('embeddings'): (self.embedding_output, self.embedding_table) = \ self.embedding_lookup( input_ids=input_ids, vocab_size=config.vocab_size, batch_size=batch_size, max_seq_length=seq_length, embedding_size=config.hidden_size, initializer_range=config.initializer_range, word_embedding_name='word_embeddings', tilda_embeddings=tilda_embeddings, trainable=trainable) self.embedding_output = self.embedding_postprocessor( input_tensor=self.embedding_output, batch_size=batch_size, max_seq_length=seq_length, hidden_size=config.hidden_size, use_token_type=True, segment_ids=segment_ids, token_type_vocab_size=config.type_vocab_size, token_type_embedding_name='token_type_embeddings', use_position_embeddings=True, position_embedding_name='position_embeddings', initializer_range=config.initializer_range, max_position_embeddings=config.max_position_embeddings, dropout_prob=config.hidden_dropout_prob, trainable=trainable) with tf.variable_scope('encoder'): # stacked transformer attention_mask = self.create_attention_mask_from_input_mask( input_mask, batch_size, seq_length) self.all_encoder_layers = self.transformer_model( input_tensor=self.embedding_output, batch_size=batch_size, max_seq_length=seq_length, attention_mask=attention_mask, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, intermediate_act_fn=util.get_activation(config.hidden_act), hidden_dropout_prob=config.hidden_dropout_prob, attention_probs_dropout_prob=\ config.attention_probs_dropout_prob, initializer_range=config.initializer_range, trainable=trainable) # projection with tf.variable_scope('projection'): transformer_output = tf.layers.dense( self.all_encoder_layers[-1], reduced_size, activation=util.gelu, kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), trainable=trainable) transformer_output = tf.reshape(transformer_output, [batch_size, -1]) input_length = tf.reduce_sum(input_mask, axis=-1) input_length = tf.cast(input_length, tf.float32) input_length_1d = tf.reshape(input_length, [batch_size]) input_length_2d = tf.reshape(input_length, [batch_size, 1]) broadcast_mask = tf.sequence_mask( tf.multiply(input_length_1d, reduced_size), seq_length * reduced_size, dtype=tf.float32) broadcast_mask = tf.multiply(broadcast_mask, seq_length / input_length_2d) transformer_output *= broadcast_mask # latent space miu = tf.layers.dense( transformer_output, topic_size, activation='tanh', kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), name='miu', trainable=trainable) sigma = tf.layers.dense( transformer_output, topic_size, kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), name='sigma', trainable=trainable) self.probs['miu'] = miu self.probs['sigma'] = sigma with tf.variable_scope('decoder'): with tf.variable_scope('projection'): # reparametarization if is_training: noise = tf.random_normal([batch_size, topic_size]) else: noise = tf.random_uniform([batch_size, topic_size], minval=-bias, maxval=bias) decoder_input = miu + tf.exp(sigma) * noise # projection decoder_input = tf.layers.dense( decoder_input, seq_length * reduced_size, activation=util.gelu, kernel_initializer=tf.truncated_normal_initializer( stddev=config.initializer_range), trainable=trainable) intermediate_input = tf.reshape( decoder_input, [-1, seq_length, reduced_size]) intermediate_input = util.layer_norm(intermediate_input, trainable=trainable) intermediate_input = util.dropout( intermediate_input, config.hidden_dropout_prob) # MLP with tf.variable_scope('intermediate'): intermediate_output = tf.layers.dense( intermediate_input, 4 * reduced_size, activation=util.gelu, kernel_initializer=util.create_initializer( config.initializer_range), trainable=trainable) with tf.variable_scope('output'): decoder_output = tf.layers.dense( intermediate_output, config.hidden_size, kernel_initializer=util.create_initializer( config.initializer_range), trainable=trainable) decoder_output = util.layer_norm(decoder_output, trainable=trainable) decoder_output = util.dropout(decoder_output, config.hidden_dropout_prob) self.all_decoder_layers = [intermediate_output, decoder_output] self.all_decoder_layers = [decoder_output] # reconstruction with tf.variable_scope('cls/predictions'): with tf.variable_scope('transform'): input_tensor = tf.layers.dense( decoder_output, units=config.hidden_size, activation=util.get_activation(config.hidden_act), kernel_initializer=util.create_initializer( config.initializer_range), trainable=trainable) input_tensor = util.layer_norm(input_tensor, trainable=trainable) output_weights = self.embedding_table output_bias = tf.get_variable('output_bias', shape=[config.vocab_size], initializer=tf.zeros_initializer(), trainable=trainable) flatten_input_tensor = tf.reshape(input_tensor, [-1, config.hidden_size]) logits = tf.matmul(flatten_input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [batch_size, seq_length, config.vocab_size]) probs = tf.nn.softmax(logits, axis=-1, name='probs') lm_log_probs = tf.nn.log_softmax(logits, axis=-1) self.preds['preds'] = tf.argmax(probs, axis=-1) one_hot_labels = tf.one_hot(input_ids, depth=config.vocab_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(lm_log_probs * one_hot_labels, axis=[-1]) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) self.total_loss = (tf.reduce_mean(per_example_loss) + tf.reduce_mean(tf.square(miu)) + tf.reduce_mean(tf.exp(sigma) - sigma - 1)) self.losses['losses'] = per_example_loss
def __init__(self, bert_config, is_training, sketchy_encoder, intensive_encoder, query_mask, label_ids, has_answer, sample_weight=None, scope='retro_reader', matching_mechanism='cross-attention', beta_1=0.5, beta_2=0.5, threshold=1.0, trainable=True, **kwargs): super().__init__(**kwargs) # verifier with tf.variable_scope(scope): # sketchy reading module with tf.variable_scope('sketchy/prediction'): sketchy_output = sketchy_encoder.get_pooled_output() hidden_size = sketchy_output.shape.as_list()[-1] output_weights = tf.get_variable( 'output_weights', shape=[2, hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[2], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( sketchy_output, bert_config.hidden_dropout_prob \ if is_training else 0.0) logits = tf.matmul( output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot( has_answer, depth=2, dtype=tf.float32) per_example_loss = - tf.reduce_sum( one_hot_labels * log_probs, axis=-1) if sample_weight is not None: per_example_loss = tf.cast( sample_weight, dtype=tf.float32) * per_example_loss self.losses['sketchy_losses'] = per_example_loss sketchy_loss = tf.reduce_mean(per_example_loss) score_ext = logits[:, 1] - logits[:, 0] # intensive reading module with tf.variable_scope('intensive'): H = intensive_encoder.get_sequence_output() H_Q = H * tf.cast( tf.expand_dims(query_mask, axis=-1), tf.float32) (batch_size, max_seq_length, hidden_size) = \ util.get_shape_list(H) # cross-attention if matching_mechanism == 'cross-attention': with tf.variable_scope('cross_attention'): attention_mask = \ self.create_attention_mask_from_input_mask( query_mask, batch_size, max_seq_length) (H_prime, _) = self.attention_layer( from_tensor=H, to_tensor=H_Q, attention_mask=attention_mask, num_attention_heads=\ bert_config.num_attention_heads, size_per_head=\ hidden_size // bert_config.num_attention_heads, attention_probs_dropout_prob=\ bert_config.hidden_dropout_prob, initializer_range=bert_config.initializer_range, do_return_2d_tensor=False, batch_size=batch_size, from_max_seq_length=max_seq_length, to_max_seq_length=max_seq_length, trainable=trainable) # matching-attention elif matching_mechanism == 'matching-attention': with tf.variable_scope('matching_attention'): output_weights = tf.get_variable( 'output_weights', shape=[hidden_size, hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[hidden_size], initializer=tf.zeros_initializer(), trainable=trainable) trans = tf.matmul( H_Q, tf.tile( tf.expand_dims(output_weights, axis=0), [batch_size, 1, 1]), transpose_b=True) trans = tf.nn.bias_add(trans, output_bias) M = tf.nn.softmax( tf.matmul(H, trans, transpose_b=True), axis=-1) H_prime = tf.matmul(M, H_Q) with tf.variable_scope('prediction'): output_weights = tf.get_variable( 'output_weights', shape=[2, hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[2], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( H_prime, bert_config.hidden_dropout_prob \ if is_training else 0.0) output_layer = tf.reshape( output_layer, [batch_size * max_seq_length, hidden_size]) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape( logits, [batch_size, max_seq_length, 2]) logits = tf.transpose(logits, [0, 2, 1]) probs = tf.nn.softmax(logits, axis=-1, name='probs') self.probs['mrc_probs'] = probs self.preds['mrc_preds'] = tf.argmax(logits, axis=-1) start_one_hot_labels = tf.one_hot( label_ids[:, 0], depth=max_seq_length, dtype=tf.float32) end_one_hot_labels = tf.one_hot( label_ids[:, 1], depth=max_seq_length, dtype=tf.float32) start_log_probs = tf.nn.log_softmax(logits[:, 0, :], axis=-1) end_log_probs = tf.nn.log_softmax(logits[:, 1, :], axis=-1) per_example_loss = ( - 0.5 * tf.reduce_sum( start_one_hot_labels * start_log_probs, axis=-1) - 0.5 * tf.reduce_sum( end_one_hot_labels * end_log_probs, axis=-1)) if sample_weight is not None: per_example_loss *= sample_weight intensive_loss = tf.reduce_mean(per_example_loss) self.losses['intensive_losses'] = per_example_loss score_has = tf.norm( probs[:, 0, 1:] + probs[:, 1, 1:], np.inf, axis=-1) score_null = probs[:, 0, 0] + probs[:, 1, 0] score_diff = score_has - score_null # rear verification v = beta_1 * score_diff + beta_2 * score_ext self.preds['verifier_preds'] = \ tf.cast(tf.greater(v, threshold), tf.int32) self.probs['verifier_probs'] = v self.total_loss = sketchy_loss + intensive_loss
def __init__(self, bert_config, is_training, encoder, masked_lm_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels, sample_weight=None, scope_lm='cls/predictions', scope_cls='cls/seq_relationship', trainable=True, use_nsp_loss=True, **kwargs): super(BERTDecoder, self).__init__(**kwargs) def gather_indexes(sequence_tensor, positions): sequence_shape = util.get_shape_list(sequence_tensor, 3) batch_size = sequence_shape[0] seq_length = sequence_shape[1] width = sequence_shape[2] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor scalar_losses = [] # masked language modeling input_tensor = gather_indexes(encoder.get_sequence_output(), masked_lm_positions) with tf.variable_scope(scope_lm): with tf.variable_scope('transform'): input_tensor = tf.layers.dense( input_tensor, units=bert_config.hidden_size, activation=util.get_activation(bert_config.hidden_act), kernel_initializer=util.create_initializer( bert_config.initializer_range)) input_tensor = util.layer_norm(input_tensor) output_bias = tf.get_variable('output_bias', shape=[bert_config.vocab_size], initializer=tf.zeros_initializer(), trainable=trainable) logits = tf.matmul(input_tensor, encoder.get_embedding_table(), transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probs = tf.nn.softmax(logits, axis=-1, name='MLM_probs') log_probs = tf.nn.log_softmax(logits, axis=-1) label_ids = tf.reshape(masked_lm_ids, [-1]) if sample_weight is not None: sample_weight = tf.expand_dims(tf.cast(sample_weight, dtype=tf.float32), axis=-1) masked_lm_weights *= sample_weight label_weights = tf.reshape(masked_lm_weights, [-1]) one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) per_example_loss = label_weights * per_example_loss numerator = tf.reduce_sum(per_example_loss) denominator = tf.reduce_sum(label_weights) + 1e-5 loss = numerator / denominator scalar_losses.append(loss) self.losses['MLM_losses'] = per_example_loss self.preds['MLM_preds'] = tf.argmax(probs, axis=-1) # next sentence prediction with tf.variable_scope(scope_cls): output_weights = tf.get_variable( 'output_weights', shape=[2, bert_config.hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[2], initializer=tf.zeros_initializer(), trainable=trainable) logits = tf.matmul(encoder.get_pooled_output(), output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probs = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) labels = tf.reshape(next_sentence_labels, [-1]) one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) if sample_weight is not None: per_example_loss = (tf.cast(sample_weight, dtype=tf.float32) * per_example_loss) loss = tf.reduce_mean(per_example_loss) if use_nsp_loss: scalar_losses.append(loss) self.losses['NSP_losses'] = per_example_loss self.probs['NSP_probs'] = probs self.preds['NSP_preds'] = tf.argmax(probs, axis=-1) self.total_loss = tf.add_n(scalar_losses)
def __init__(self, is_training, input_tensor, is_supervised, is_expanded, label_ids, label_size=2, sample_weight=None, scope='cls/seq_relationship', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, global_step=None, num_train_steps=None, uda_softmax_temp=-1, uda_confidence_thresh=-1, tsa_schedule='linear', **kwargs): super().__init__(**kwargs) is_supervised = tf.cast(is_supervised, tf.float32) is_expanded = tf.cast(is_expanded, tf.float32) hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1) with tf.variable_scope('sup_loss'): # reshape sup_ori_log_probs = tf.boolean_mask(log_probs, mask=(1.0 - is_expanded), axis=0) sup_log_probs = tf.boolean_mask(sup_ori_log_probs, mask=is_supervised, axis=0) sup_label_ids = tf.boolean_mask(label_ids, mask=is_supervised, axis=0) self.preds['preds'] = tf.argmax(sup_ori_log_probs, axis=-1) one_hot_labels = tf.one_hot(sup_label_ids, depth=label_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum( one_hot_labels * sup_log_probs, axis=-1) loss_mask = tf.ones_like(per_example_loss, dtype=tf.float32) correct_label_probs = tf.reduce_sum(one_hot_labels * tf.exp(sup_log_probs), axis=-1) if is_training and tsa_schedule: tsa_start = 1.0 / label_size tsa_threshold = get_tsa_threshold(tsa_schedule, global_step, num_train_steps, tsa_start, end=1) larger_than_threshold = tf.greater(correct_label_probs, tsa_threshold) loss_mask = loss_mask * ( 1 - tf.cast(larger_than_threshold, tf.float32)) loss_mask = tf.stop_gradient(loss_mask) per_example_loss = per_example_loss * loss_mask if sample_weight is not None: sup_sample_weight = tf.boolean_mask(sample_weight, mask=is_supervised, axis=0) per_example_loss *= tf.cast(sup_sample_weight, dtype=tf.float32) sup_loss = (tf.reduce_sum(per_example_loss) / tf.maximum(tf.reduce_sum(loss_mask), 1)) self.losses['supervised'] = per_example_loss with tf.variable_scope('unsup_loss'): # reshape ori_log_probs = tf.boolean_mask(sup_ori_log_probs, mask=(1.0 - is_supervised), axis=0) aug_log_probs = tf.boolean_mask(log_probs, mask=is_expanded, axis=0) sup_ori_logits = tf.boolean_mask(logits, mask=(1.0 - is_expanded), axis=0) ori_logits = tf.boolean_mask(sup_ori_logits, mask=(1.0 - is_supervised), axis=0) unsup_loss_mask = 1 if uda_softmax_temp != -1: tgt_ori_log_probs = tf.nn.log_softmax(ori_logits / uda_softmax_temp, axis=-1) tgt_ori_log_probs = tf.stop_gradient(tgt_ori_log_probs) else: tgt_ori_log_probs = tf.stop_gradient(ori_log_probs) if uda_confidence_thresh != -1: largest_prob = tf.reduce_max(tf.exp(ori_log_probs), axis=-1) unsup_loss_mask = tf.cast( tf.greater(largest_prob, uda_confidence_thresh), tf.float32) unsup_loss_mask = tf.stop_gradient(unsup_loss_mask) per_example_loss = kl_for_log_probs( tgt_ori_log_probs, aug_log_probs) * unsup_loss_mask if sample_weight is not None: unsup_sample_weight = tf.boolean_mask(sample_weight, mask=(1.0 - is_supervised), axis=0) per_example_loss *= tf.cast(unsup_sample_weight, dtype=tf.float32) unsup_loss = tf.reduce_mean(per_example_loss) self.losses['unsupervised'] = per_example_loss self.total_loss = sup_loss + unsup_loss
def __init__(self, vocab_size, is_training, source_ids, target_ids, sos_id, sample_weight=None, hidden_size=768, num_blocks=6, num_attention_heads=12, scope='transformer', use_label_smoothing=False, use_tilda_embedding=False, trainable=True, **kwargs): super().__init__() dropout_rate = 0.0 if is_training: dropout_rate = 0.1 source_shape = util.get_shape_list(source_ids, expected_rank=2) target_shape = util.get_shape_list(target_ids, expected_rank=2) batch_size = source_shape[0] source_max_seq_length = source_shape[1] target_max_seq_length = target_shape[1] # Tilda embeddings for SMART algorithm tilda_embeddings = None if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): source_mask = tf.math.equal(source_ids, 0) # embedding with tf.variable_scope('embeddings'): (enc, embedding_table) = embedding_lookup( input_ids=source_ids, vocab_size=vocab_size, batch_size=batch_size, max_seq_length=source_max_seq_length, embedding_size=hidden_size, word_embedding_name='word_embeddings', tilda_embeddings=tilda_embeddings) enc *= hidden_size ** 0.5 # scale enc += positional_encoding(enc, source_max_seq_length) enc = util.dropout(enc, dropout_rate) with tf.variable_scope('encoder'): # stacked multi-attention layers for i in range(num_blocks): with tf.variable_scope('block_%s' % i): # self-attention enc = multihead_attention( queries=enc, keys=enc, values=enc, key_masks=source_mask, num_heads=num_attention_heads, dropout_rate=dropout_rate, training=is_training, causality=False, scope='self_attention') # feed forward enc = ff(enc, num_units=[hidden_size * 4, hidden_size]) memory = enc def _forward(target_ids, target_mask, target_max_seq_length): with tf.variable_scope('decoder'): # shared embedding dec = tf.nn.embedding_lookup(embedding_table, target_ids) dec *= hidden_size ** 0.5 # scale dec += positional_encoding(dec, target_max_seq_length) dec = util.dropout(dec, dropout_rate) # blocks for i in range(num_blocks): with tf.variable_scope('block_%s' % i): # masked self-attention dec = multihead_attention( queries=dec, keys=dec, values=dec, key_masks=target_mask, num_heads=num_attention_heads, dropout_rate=dropout_rate, training=is_training, causality=True, scope='masked_self_attention') # vanilla attention dec = multihead_attention( queries=dec, keys=memory, values=memory, key_masks=source_mask, num_heads=num_attention_heads, dropout_rate=dropout_rate, training=is_training, causality=False, scope='vanilla_attention') # feed forward dec = ff( dec, num_units=[4 * hidden_size, hidden_size]) # final linear projection (embedding weights are shared) with tf.variable_scope('cls'): output_bias = tf.get_variable( 'output_bias', shape=[vocab_size], initializer=tf.zeros_initializer()) dec = tf.reshape(dec, [-1, hidden_size]) logits = tf.matmul(dec, embedding_table, transpose_b=True) logits = tf.reshape( logits, [-1, target_max_seq_length, vocab_size]) logits = tf.nn.bias_add(logits, output_bias) return logits # convert to labels label_ids = tf.concat( [target_ids[:, 1:], tf.zeros([batch_size, 1], dtype=tf.int32)], axis=-1) # forward once if is_training: target_mask = tf.math.equal(target_ids, 0) # (N, T2) logits = _forward( target_ids, target_mask, target_max_seq_length) self.preds['MT'] = tf.argmax(logits, axis=-1) # forward loop else: target_mask_base = tf.zeros([batch_size, 1], dtype=tf.int32) target_ids = tf.ones([batch_size, 1], dtype=tf.int32) * sos_id for cur_length in range(1, target_max_seq_length + 1): target_mask = tf.tile(target_mask_base, [1, cur_length]) logits = _forward(target_ids, target_mask, cur_length) pred_ids = tf.argmax( logits[:, cur_length-1:cur_length, :], axis=-1) pred_ids = tf.cast(pred_ids, tf.int32) target_ids = tf.concat([target_ids, pred_ids], axis=-1) self.preds['MT'] = target_ids[:, 1:] # loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=vocab_size) if use_label_smoothing: one_hot_labels = label_smoothing(one_hot_labels) per_token_loss = -tf.reduce_sum( one_hot_labels * log_probs, axis=-1) label_mask = tf.cast(tf.not_equal(label_ids, 0), tf.float32) per_example_loss = \ tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \ tf.reduce_sum(label_mask, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) self.total_loss = tf.reduce_mean(per_example_loss) self.losses['MT'] = per_example_loss
def call(self, query_input, source_input, bias, training, cache=None, decode_loop_step=None): '''Apply attention mechanism to query_input and source_input. Args: query_input: A tensor with shape [batch_size, length_query, hidden_size]. source_input: A tensor with shape [batch_size, length_source, hidden_size]. bias: A tensor with shape [batch_size, 1, length_query, length_source], the attention bias that will be added to the result of the dot product. training: A bool, whether in training mode or not. cache: (Used during prediction) A dictionary with tensors containing results of previous attentions. The dictionary must have the items: {'k': tensor with shape [batch_size, i, heads, dim_per_head], 'v': tensor with shape [batch_size, i, heads, dim_per_head]} where i is the current decoded length for non-padded decode, or max sequence length for padded decode. decode_loop_step: An integer, step number of the decoding loop. Used only for autoregressive inference on TPU. Returns: Attention layer output with shape [batch_size, length_query, hidden_size] ''' # Linearly project the query, key and value using different learned # projections. Splitting heads is automatically done during the linear # projections --> [batch_size, length, num_heads, dim_per_head]. query = self.query_dense_layer(query_input) key = self.key_dense_layer(source_input) value = self.value_dense_layer(source_input) if self.projection_matrix_type is None: projection_matrix = None else: dim = query.shape.as_list()[-1] # seed = tf.math.ceil(tf.math.abs(tf.math.reduce_sum(query) * BIG_CONSTANT)) # seed = tf.cast(seed, tf.int32) seed = 0 projection_matrix = create_projection_matrix( self.nb_random_features, dim, seed=seed) if cache is not None: # Combine cached keys and values with new keys and values. if decode_loop_step is not None: cache_k_shape = cache['k'].shape.as_list() indices = tf.reshape( tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype), [1, cache_k_shape[1], 1, 1]) key = cache['k'] + key * indices cache_v_shape = cache['v'].shape.as_list() indices = tf.reshape( tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype), [1, cache_v_shape[1], 1, 1]) value = cache['v'] + value * indices else: key = tf.concat([tf.cast(cache['k'], key.dtype), key], axis=1) value = tf.concat([tf.cast(cache['v'], value.dtype), value], axis=1) # Update cache cache['k'] = key cache['v'] = value attention_output = favor_attention(query, key, value, self.kernel_transformation, self.causal, projection_matrix) attention_output = self.output_dense_layer(attention_output) return attention_output
def embedding_postprocessor(self, input_tensor, batch_size, max_seq_length, hidden_size, use_token_type=False, segment_ids=None, token_type_vocab_size=16, token_type_embedding_name=\ 'token_type_embeddings', use_position_embeddings=True, position_embedding_name='position_embeddings', initializer_range=0.02, max_position_embeddings=512, dropout_prob=0.1, dtype=tf.float32, trainable=True): output = input_tensor if use_token_type: if segment_ids is None: raise ValueError( 'segment_ids must be specified if use_token_type is True.') token_type_table = tf.get_variable( name=token_type_embedding_name, shape=[token_type_vocab_size, hidden_size], initializer=util.create_initializer(initializer_range), dtype=dtype, trainable=trainable) # This vocab will be small so we always do one-hot here, # since it is always faster for a small vocabulary. flat_segment_ids = tf.reshape(segment_ids, [-1]) one_hot_ids = tf.one_hot(flat_segment_ids, depth=token_type_vocab_size, dtype=dtype) token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) token_type_embeddings = tf.reshape( token_type_embeddings, [batch_size, max_seq_length, hidden_size]) output += token_type_embeddings if use_position_embeddings: full_position_embeddings = tf.get_variable( name=position_embedding_name, shape=[max_position_embeddings, hidden_size], initializer=util.create_initializer(initializer_range), dtype=dtype, trainable=trainable) position_embeddings = tf.slice(full_position_embeddings, [0, 0], [max_seq_length, -1]) num_dims = len(output.shape.as_list()) # Only the last two dimensions are relevant # (max_seq_length and hidden_size), so we broadcast # among the first dimensions, which is typically # just the batch size. position_broadcast_shape = [] for _ in range(num_dims - 2): position_broadcast_shape.append(1) position_broadcast_shape.extend([max_seq_length, hidden_size]) position_embeddings = tf.reshape(position_embeddings, position_broadcast_shape) output += position_embeddings output = util.layer_norm_and_dropout(output, dropout_prob, trainable=trainable) return output
def transformer_xl(inp_k, n_token, n_layer, d_model, n_head, d_head, d_inner, dropout, dropatt, attn_type, bi_data, initializer, is_training, mem_len=None, inp_q=None, mems=None, same_length=False, clamp_len=-1, untie_r=False, use_tpu=True, input_mask=None, perm_mask=None, seg_id=None, reuse_len=None, ff_activation='relu', target_mapping=None, use_bfloat16=False, scope='transformer', tilda_embeddings=None, **kwargs): ''' Defines a Transformer-XL computation graph with additional support for XLNet. Args: inp_k: int32 Tensor in shape [len, bsz], the input token IDs. seg_id: int32 Tensor in shape [len, bsz], the input segment IDs. input_mask: float32 Tensor in shape [len, bsz], the input mask. 0 for real tokens and 1 for padding. mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory from previous batches. The length of the list equals n_layer. If None, no memory is used. perm_mask: float32 Tensor in shape [len, len, bsz]. If perm_mask[i, j, k] = 0, i attend to j in batch k; if perm_mask[i, j, k] = 1, i does not attend to j in batch k. If None, each position attends to all the others. target_mapping: float32 Tensor in shape [num_predict, len, bsz]. If target_mapping[i, j, k] = 1, the i-th predict in batch k is on the j-th token. Only used during pretraining for partial prediction. Set to None during finetuning. inp_q: float32 Tensor in shape [len, bsz]. 1 for tokens with losses and 0 for tokens without losses. Only used during pretraining for two-stream attention. Set to None during finetuning. n_layer: int, the number of layers. d_model: int, the hidden size. n_head: int, the number of attention heads. d_head: int, the dimension size of each attention head. d_inner: int, the hidden size in feed-forward layers. ff_activation: str, 'relu' or 'gelu'. untie_r: bool, whether to untie the biases in attention. n_token: int, the vocab size. is_training: bool, whether in training mode. use_tpu: bool, whether TPUs are used. use_bfloat16: bool, use bfloat16 instead of float32. dropout: float, dropout rate. dropatt: float, dropout rate on attention probabilities. init: str, the initialization scheme, either 'normal' or 'uniform'. init_range: float, initialize the parameters with a uniform distribution in [-init_range, init_range]. Only effective when init='uniform'. init_std: float, initialize the parameters with a normal distribution with mean 0 and stddev init_std. Only effective when init='normal'. mem_len: int, the number of tokens to cache. reuse_len: int, the number of tokens in the currect batch to be cached and reused in the future. bi_data: bool, whether to use bidirectional input pipeline. Usually set to True during pretraining and False during finetuning. clamp_len: int, clamp all relative distances larger than clamp_len. -1 means no clamping. same_length: bool, whether to use the same attention length for each token. summary_type: str, 'last', 'first', 'mean', or 'attn'. The method to pool the input to get a vector representation. initializer: A tf initializer. scope: scope name for the computation graph. ''' tf_float = tf.bfloat16 if use_bfloat16 else tf.float32 new_mems = [] with tf.variable_scope(scope): if untie_r: r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) bsz = tf.shape(inp_k)[1] qlen = tf.shape(inp_k)[0] mlen = tf.shape(mems[0])[0] if mems is not None else 0 klen = mlen + qlen ##### Attention mask # causal attention mask if attn_type == 'uni': attn_mask = _create_mask(qlen, mlen, tf_float, same_length) attn_mask = attn_mask[:, :, None, None] elif attn_type == 'bi': attn_mask = None else: raise ValueError('Unsupported attention type: %s' % attn_type) # data mask: input mask & perm mask if input_mask is not None and perm_mask is not None: data_mask = input_mask[None] + perm_mask elif input_mask is not None and perm_mask is None: data_mask = input_mask[None] elif input_mask is None and perm_mask is not None: data_mask = perm_mask else: data_mask = None if data_mask is not None: # all mems can be attended to mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz], dtype=tf_float) data_mask = tf.cast(data_mask, dtype=tf.float32) data_mask = tf.concat([mems_mask, data_mask], 1) if attn_mask is None: attn_mask = data_mask[:, :, :, None] else: attn_mask += data_mask[:, :, :, None] if attn_mask is not None: attn_mask = tf.cast(attn_mask > 0, dtype=tf_float) if attn_mask is not None: non_tgt_mask = -tf.eye(qlen, dtype=tf_float) non_tgt_mask = tf.concat( [tf.zeros([qlen, mlen], dtype=tf_float), non_tgt_mask], axis=-1) non_tgt_mask = tf.cast( (attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=tf_float) else: non_tgt_mask = None ##### Word embedding word_emb_k, lookup_table = embedding_lookup( x=inp_k, n_token=n_token, d_embed=d_model, initializer=initializer, use_tpu=use_tpu, dtype=tf_float, scope='word_embedding', tilda_embeddings=tilda_embeddings) if inp_q is not None: with tf.variable_scope('mask_emb'): mask_emb = tf.get_variable('mask_emb', [1, 1, d_model], dtype=tf_float) if target_mapping is not None: word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0], bsz, 1]) else: inp_q_ext = inp_q[:, :, None] word_emb_q = \ inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training) if inp_q is not None: output_g = tf.layers.dropout(word_emb_q, dropout, training=is_training) ##### Segment embedding if seg_id is not None: if untie_r: r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head], dtype=tf_float, initializer=initializer) else: # default case (tie) r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head], dtype=tf_float, initializer=initializer) seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head], dtype=tf_float, initializer=initializer) # Convert `seg_id` to one-hot `seg_mat` mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) cat_ids = tf.concat([mem_pad, seg_id], 0) # `1` indicates not in the same segment [qlen x klen x bsz] seg_mat = tf.cast( tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])), tf.int32) seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float) else: seg_mat = None ##### Positional encoding pos_emb = relative_positional_encoding(qlen, klen, d_model, clamp_len, attn_type, bi_data, bsz=bsz, dtype=tf_float) pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training) ##### Attention layers if mems is None: mems = [None] * n_layer for i in range(n_layer): # cache new mems new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len)) # segment bias if seg_id is None: r_s_bias_i = None seg_embed_i = None else: r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i] seg_embed_i = seg_embed[i] with tf.variable_scope('layer_{}'.format(i)): if inp_q is not None: output_h, output_g = two_stream_rel_attn( h=output_h, g=output_g, r=pos_emb, r_w_bias=r_w_bias if not untie_r else r_w_bias[i], r_r_bias=r_r_bias if not untie_r else r_r_bias[i], seg_mat=seg_mat, r_s_bias=r_s_bias_i, seg_embed=seg_embed_i, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask, mems=mems[i], target_mapping=target_mapping, d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer) reuse = True else: reuse = False output_h = rel_multihead_attn( h=output_h, r=pos_emb, r_w_bias=r_w_bias if not untie_r else r_w_bias[i], r_r_bias=r_r_bias if not untie_r else r_r_bias[i], seg_mat=seg_mat, r_s_bias=r_s_bias_i, seg_embed=seg_embed_i, attn_mask=non_tgt_mask, mems=mems[i], d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, reuse=reuse) if inp_q is not None: output_g = positionwise_ffn(inp=output_g, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training) output_h = positionwise_ffn(inp=output_h, d_model=d_model, d_inner=d_inner, dropout=dropout, kernel_initializer=initializer, activation_type=ff_activation, is_training=is_training, reuse=reuse) if inp_q is not None: output = tf.layers.dropout(output_g, dropout, training=is_training) else: output = tf.layers.dropout(output_h, dropout, training=is_training) return output, new_mems, lookup_table
def _expand_features(module, split_placeholders): inputs = split_placeholders['input'] target = split_placeholders['target'] is_masked = tf.cast(split_placeholders['is_masked'], tf.bool) batch_size = tf.shape(inputs)[0] non_reuse_len = module.max_seq_length - module.reuse_seq_length assert (module.perm_size <= module.reuse_seq_length and module.perm_size <= non_reuse_len) (perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0) = \ _local_perm( inputs[:, :module.reuse_seq_length], target[:, :module.reuse_seq_length], is_masked[:, :module.reuse_seq_length], module.perm_size, module.reuse_seq_length) (perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1) = \ _local_perm( inputs[:, module.reuse_seq_length:], target[:, module.reuse_seq_length:], is_masked[:, module.reuse_seq_length:], module.perm_size, non_reuse_len) perm_mask_0 = tf.concat([ tf.cast(perm_mask_0, dtype=tf.float32), tf.ones([batch_size, module.reuse_seq_length, non_reuse_len]) ], axis=2) perm_mask_1 = tf.concat([ tf.zeros([batch_size, non_reuse_len, module.reuse_seq_length]), tf.cast(perm_mask_1, dtype=tf.float32) ], axis=2) perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=1) target = tf.concat([target_0, target_1], axis=1) target_mask = tf.concat([target_mask_0, target_mask_1], axis=1) input_k = tf.concat([input_k_0, input_k_1], axis=1) input_q = tf.concat([input_q_0, input_q_1], axis=1) if module._num_predict is not None: #TODO(geying): convert tensors from 1-D to 2-D indices = tf.range(module.max_seq_length, dtype=tf.int64) indices = tf.reshape(indices, [-1, module.max_seq_length]) indices = tf.tile(indices, [batch_size, 1]) bool_target_mask = tf.cast(target_mask, tf.bool) indices = tf.boolean_mask(indices, bool_target_mask) ##### extra padding due to CLS/SEP introduced after prepro actual_num_predict = tf.shape(indices)[1] pad_len = module._num_predict - actual_num_predict ##### target_mapping target_mapping = tf.one_hot(indices, module.max_seq_length, dtype=tf.float32) paddings = tf.zeros([pad_len, module.max_seq_length], dtype=target_mapping.dtype) target_mapping = tf.concat([target_mapping, paddings], axis=0) split_placeholders['target_mapping'] = tf.reshape( target_mapping, [-1, module._num_predict, module.max_seq_length]) ##### target target = tf.boolean_mask(target, bool_target_mask) paddings = tf.zeros([pad_len], dtype=target.dtype) target = tf.concat([target, paddings], axis=0) split_placeholders['target'] = tf.reshape(target, [-1, module._num_predict]) ##### target mask target_mask = tf.concat([ tf.ones([batch_size, actual_num_predict], dtype=tf.float32), tf.zeros([batch_size, pad_len], dtype=tf.float32) ], axis=1) split_placeholders['target_mask'] = tf.reshape( target_mask, [-1, module._num_predict]) else: split_placeholders['target'] = tf.reshape(target, [-1, module.max_seq_length]) split_placeholders['target_mask'] = tf.reshape( target_mask, [-1, module.max_seq_length]) # reshape back to fixed shape split_placeholders['perm_mask'] = tf.reshape( perm_mask, [-1, module.max_seq_length, module.max_seq_length]) split_placeholders['input_k'] = tf.reshape(input_k, [-1, module.max_seq_length]) split_placeholders['input_q'] = tf.reshape(input_q, [-1, module.max_seq_length]) return split_placeholders
def __init__(self, bert_config, is_training, dilated_ids, label_ids, max_seq_length, spad_id=1, loop=3, sample_weight=None, scope='dilated', use_tilda_embedding=False, **kwargs): super().__init__() dilated_mask = tf.cast(tf.not_equal(dilated_ids, 0), tf.float32) shape = util.get_shape_list(dilated_ids, expected_rank=2) batch_size = shape[0] dilated_seq_length = shape[1] # Tilda embeddings for SMART algorithm tilda_embeddings = None if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): # forward once if is_training: logits = self._bert_forward(bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) self.preds['LM'] = tf.argmax(logits, axis=-1) # LM loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size) per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) input_length = tf.reduce_sum(dilated_mask, axis=-1) * 2 label_mask = tf.sequence_mask(input_length, max_seq_length * 2, dtype=tf.float32) per_example_loss = \ tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \ tf.reduce_sum(label_mask, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) self.total_loss = tf.reduce_mean(per_example_loss) self.losses['LM'] = per_example_loss # forward loop else: def _forward(dilated_ids, dilated_mask): logits = self._bert_forward( bert_config, dilated_ids, dilated_mask, batch_size, dilated_seq_length, tilda_embeddings=tilda_embeddings) output_ids = tf.argmax(logits, axis=-1) output_ids = tf.cast(output_ids, dtype=tf.int32) # special padding (using `spad` token) equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32) equal_zero = tf.reduce_sum(equal_zero, axis=-1) right_pad = spad_id * tf.sequence_mask( equal_zero, dilated_seq_length, dtype=tf.int32) paded = tf.concat([output_ids, right_pad], axis=-1) # extract ids of length `max_seq_length` flattened_padded = tf.reshape(paded, [-1]) is_valid = tf.cast(tf.greater(flattened_padded, 0), dtype=tf.int32) flattened_valid = tf.boolean_mask(flattened_padded, is_valid) valid = tf.reshape(flattened_valid, [batch_size, dilated_seq_length]) cutted_valid = valid[:, :max_seq_length] # replace `spad` token with `pad` non_spad_mask = tf.cast(tf.not_equal( cutted_valid, spad_id), dtype=tf.int32) output_ids = cutted_valid * non_spad_mask output_length = tf.reduce_sum(non_spad_mask, axis=-1) # dilate reshaped_ids = tf.reshape(output_ids, [batch_size, max_seq_length, 1]) reshaped_mask = tf.reshape( tf.sequence_mask(output_length, max_seq_length, dtype=tf.int32), [batch_size, max_seq_length, 1]) concat_ids = tf.concat( [reshaped_ids, tf.zeros_like(reshaped_ids)], axis=-1) concat_mask = tf.concat([ reshaped_mask, tf.zeros_like(reshaped_mask, dtype=tf.int32) ], axis=-1) dilated_ids = tf.reshape(concat_ids, [batch_size, max_seq_length * 2]) dilated_mask = tf.reshape(concat_mask, [batch_size, max_seq_length * 2]) return dilated_ids, dilated_mask for _ in range(loop): dilated_ids, dilated_mask = _forward( dilated_ids, dilated_mask) self.preds['LM'] = dilated_ids
def __init__(self, bert_config, is_training, input_tensor, sa_mask, label_ids, sample_weight=None, scope='sanet', alpha=0.5, hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) shape = util.get_shape_list(input_tensor) batch_size = shape[0] seq_length = shape[1] hidden_size = shape[2] sa_mask = tf.reshape(sa_mask, [batch_size, seq_length, seq_length]) with tf.variable_scope(scope): with tf.variable_scope('sentence_attention'): (sa_output, _) = self.attention_layer( from_tensor=input_tensor, to_tensor=input_tensor, attention_mask=sa_mask, num_attention_heads=bert_config.num_attention_heads, size_per_head=\ hidden_size // bert_config.num_attention_heads, attention_probs_dropout_prob=\ bert_config.hidden_dropout_prob, initializer_range=bert_config.initializer_range, do_return_2d_tensor=False, batch_size=batch_size, from_max_seq_length=seq_length, to_max_seq_length=seq_length, trainable=trainable) with tf.variable_scope('cls/mrc'): output_weights = tf.get_variable( 'output_weights', shape=[2, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[2], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = alpha * sa_output + (1 - alpha) * input_tensor output_layer = tf.reshape(output_layer, [-1, hidden_size]) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [-1, seq_length, 2]) logits = tf.transpose(logits, [0, 2, 1]) probs = tf.nn.softmax(logits, axis=-1, name='probs') self.probs['probs'] = probs self.preds['preds'] = tf.argmax(logits, axis=-1) start_one_hot_labels = tf.one_hot(label_ids[:, 0], depth=seq_length, dtype=tf.float32) end_one_hot_labels = tf.one_hot(label_ids[:, 1], depth=seq_length, dtype=tf.float32) start_log_probs = tf.nn.log_softmax(logits[:, 0, :], axis=-1) end_log_probs = tf.nn.log_softmax(logits[:, 1, :], axis=-1) per_example_loss = ( -0.5 * tf.reduce_sum(start_one_hot_labels * start_log_probs, axis=-1) - 0.5 * tf.reduce_sum(end_one_hot_labels * end_log_probs, axis=-1)) if sample_weight is not None: per_example_loss *= sample_weight self.total_loss = tf.reduce_mean(per_example_loss) self.losses['losses'] = per_example_loss
def _lm_forward(self, is_training, input_tensor, input_mask, label_ids, bert_config, batch_size, max_seq_length, prob, scope, name, sample_weight=None, hidden_dropout_prob=0.1, initializer_range=0.02): with tf.variable_scope(scope): with tf.variable_scope('verifier'): logits = tf.layers.dense( input_tensor, 2, kernel_initializer=util.create_initializer( bert_config.initializer_range), trainable=True) verifier_label_ids = tf.cast(tf.greater(label_ids, 0), tf.int32) # loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(verifier_label_ids, depth=2) per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) input_mask = tf.cast(input_mask, tf.float32) per_token_loss *= input_mask / tf.reduce_sum( input_mask, keepdims=True, axis=-1) per_example_loss = tf.reduce_sum(per_token_loss, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) if prob != 0: self.total_loss += tf.reduce_mean(per_example_loss) verifier_loss = per_example_loss verifier_preds = tf.argmax(logits, axis=-1) with tf.variable_scope('prediction'): with tf.variable_scope('intermediate'): logits = tf.layers.dense( input_tensor, bert_config.hidden_size * 4, kernel_initializer=util.create_initializer( bert_config.initializer_range), activation=util.gelu, trainable=True) with tf.variable_scope('output'): logits = tf.layers.dense( logits, bert_config.hidden_size, kernel_initializer=util.create_initializer( bert_config.initializer_range), trainable=True) flattened = tf.reshape( logits, [batch_size * max_seq_length, bert_config.hidden_size]) logits = tf.matmul(flattened, self.embedding_table, transpose_b=True) logits = tf.reshape( logits, [-1, max_seq_length, bert_config.vocab_size]) # loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size) per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) input_mask *= tf.cast(verifier_preds, tf.float32) per_token_loss *= input_mask / ( tf.reduce_sum(input_mask, keepdims=True, axis=-1) + 1e-6) per_example_loss = tf.reduce_sum(per_token_loss, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) if prob != 0: self.total_loss += tf.reduce_mean(per_example_loss) self.losses[name + '_loss'] = verifier_loss self.preds[name + '_preds'] = \ tf.argmax(logits, axis=-1) * verifier_preds
def __init__(self, hparams, is_training, input_ids, sample_weight=None, scope='model', given=1, use_tilda_embedding=False, **kwargs): super().__init__() batch_size = util.get_shape_list(input_ids, expected_rank=2)[0] max_seq_length = hparams.n_predict # Tilda embeddings for SMART algorithm tilda_embeddings = None if use_tilda_embedding: with tf.variable_scope('', reuse=True): tilda_embeddings = tf.get_variable('tilda_embeddings') with tf.variable_scope(scope): def _forward(input_ids, past=None): batch, sequence = shape_list(input_ids) if tilda_embeddings is None: wte = tf.get_variable( 'word_embeddings', [hparams.n_vocab, hparams.n_embed], initializer=tf.random_normal_initializer(stddev=0.02)) else: wte = tilda_embeddings wpe = tf.get_variable( 'wpe', [hparams.n_ctx, hparams.n_embed], initializer=tf.random_normal_initializer(stddev=0.01)) past_length = 0 if past is None else tf.shape(past)[-2] h = (tf.gather(wte, input_ids) + tf.gather(wpe, positions_for(input_ids, past_length))) # stacked transformer layers presents = [] pasts = tf.unstack(past, axis=1) if past is not None else \ [None] * hparams.n_layer assert len(pasts) == hparams.n_layer for layer, past in enumerate(pasts): h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) presents.append(present) present = tf.stack(presents, axis=1) h = norm(h, 'ln_f') # Language model loss. Do tokens <n predict token n? h_flat = tf.reshape(h, [batch * sequence, hparams.n_embed]) logits = tf.matmul(h_flat, wte, transpose_b=True) logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab]) return logits, present # convert to labels label_ids = tf.concat( [input_ids[:, 1:], tf.zeros([batch_size, 1], dtype=tf.int32)], axis=-1) # forward once if is_training: (logits, _) = _forward(input_ids) self.preds['LM'] = tf.argmax(logits, axis=-1) # forward loop else: input_ids = input_ids[:, 0:given] for cur_length in range(given, max_seq_length + 1): (logits, _) = _forward(input_ids) pred_ids = tf.argmax(logits[:, cur_length - 1:cur_length, :], axis=-1) pred_ids = tf.cast(pred_ids, tf.int32) input_ids = tf.concat([input_ids, pred_ids], axis=-1) self.preds['LM'] = input_ids # loss log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=hparams.n_vocab) per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) label_mask = tf.cast(tf.not_equal(label_ids, 0), tf.float32) per_example_loss = \ tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \ tf.reduce_sum(label_mask, axis=-1) if sample_weight is not None: per_example_loss *= tf.expand_dims(sample_weight, axis=-1) self.total_loss = tf.reduce_mean(per_example_loss) self.losses['LM'] = per_example_loss
def __init__(self, bert_config, is_training, input_tensor, input_mask, sem_features, label_ids, max_seq_length, feature_size, label_size=2, sample_weight=None, scope='cls/seq_relationship', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, **kwargs): super().__init__(**kwargs) input_shape = util.get_shape_list(input_tensor) batch_size = input_shape[0] hidden_size = input_shape[-1] with tf.variable_scope('sem'): feature_embeddings = tf.get_variable( name='feature_embeddings', shape=[feature_size + 3, hidden_size], # for [PAD], [CLS], [SEP] initializer=util.create_initializer(initializer_range), trainable=trainable) sem_output = tf.gather(feature_embeddings, sem_features) # [B, N, H] attention_heads = [] with tf.variable_scope('self'): attention_mask = BERTEncoder.create_attention_mask_from_input_mask( input_mask, batch_size, max_seq_length) (attention_head, _) = BERTEncoder.attention_layer( from_tensor=sem_output, to_tensor=sem_output, attention_mask=attention_mask, num_attention_heads=bert_config.num_attention_heads, size_per_head=(hidden_size // bert_config.num_attention_heads), attention_probs_dropout_prob=hidden_dropout_prob if is_training else 0.0, initializer_range=initializer_range, do_return_2d_tensor=False, batch_size=batch_size, from_max_seq_length=max_seq_length, to_max_seq_length=max_seq_length, trainable=trainable) attention_heads.append(attention_head) if len(attention_heads) == 1: attention_output = attention_heads[0] else: attention_output = tf.concat(attention_heads, axis=-1) attention_output = attention_output[:, 0, :] # [B, H] input_tensor = util.layer_norm(attention_output + input_tensor, trainable=trainable) with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) self.preds['preds'] = tf.argmax(logits, axis=-1) self.probs['probs'] = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(label_ids, depth=label_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) if sample_weight is not None: per_example_loss = tf.cast(sample_weight, dtype=tf.float32) * per_example_loss thresh = kwargs.get('tsa_thresh') if thresh is not None: assert isinstance( thresh, float), ('`tsa_thresh` must be a float between 0 and 1.') uncertainty = tf.reduce_sum(self.probs['probs'] * tf.log(self.probs['probs']), axis=-1) uncertainty /= tf.log(1 / label_size) per_example_loss = tf.cast( tf.greater(uncertainty, thresh), dtype=tf.float32) * \ per_example_loss self.losses['losses'] = per_example_loss self.total_loss = tf.reduce_mean(per_example_loss)