def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, label_ids, label_weights, num_partitions=1): """Get loss and log probs for the masked LM.""" input_tensor = gather_indexes(input_tensor, positions) with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): # We apply one more non-linear transformation before the output layer. # This matrix is not used after pre-training. with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=bert_config.hidden_size, activation=modeling.get_activation(bert_config.hidden_act), kernel_initializer=modeling.create_initializer( bert_config.initializer_range)) input_tensor = modeling.layer_norm(input_tensor) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. output_bias = tf.get_variable("output_bias", shape=[bert_config.vocab_size], initializer=tf.zeros_initializer()) if num_partitions > 1: output_weights = xla_sharding.split(output_weights, 0, num_partitions, use_sharding_op=True) output_bias = xla_sharding.split(output_bias, 0, num_partitions, use_sharding_op=True) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1) label_ids = tf.reshape(label_ids, [-1]) label_weights = tf.reshape(label_weights, [-1]) one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size, dtype=tf.float32) # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) numerator = tf.reduce_sum(label_weights * per_example_loss) denominator = tf.reduce_sum(label_weights) + 1e-5 loss = numerator / denominator return (loss, per_example_loss, log_probs)
def XlaShardRelu(i): if num_partitions: i = xla_sharding.split(i, 2, num_partitions, use_sharding_op=True) return fn(i)
def embedding_lookup(input_ids, vocab_size, embedding_size=128, initializer_range=0.02, word_embedding_name="word_embeddings", use_one_hot_embeddings=False, use_bfloat16_activation=False, num_partitions=1): """Looks up words embeddings for id tensor. Args: input_ids: int32 Tensor of shape [batch_size, seq_length] containing word ids. vocab_size: int. Size of the embedding vocabulary. embedding_size: int. Width of the word embeddings. initializer_range: float. Embedding initialization range. word_embedding_name: string. Name of the embedding table. use_one_hot_embeddings: bool. If True, use one-hot method for word embeddings. If False, use `tf.nn.embedding_lookup()`. use_bfloat16_activation: bool. If True, cast embedding to bfloat16. num_partitions: (optional) Number of SPMD partitions. Returns: float Tensor of shape [batch_size, seq_length, embedding_size]. """ # This function assumes that the input is of shape [batch_size, seq_length, # num_inputs]. # # If the input is a 2D tensor of shape [batch_size, seq_length], we # reshape to [batch_size, seq_length, 1]. if input_ids.shape.ndims == 2: input_ids = tf.expand_dims(input_ids, axis=[-1]) embedding_table = tf.get_variable( name=word_embedding_name, shape=[vocab_size, embedding_size], initializer=create_initializer(initializer_range)) if num_partitions > 1: embedding_table = xla_sharding.split(embedding_table, 0, num_partitions, use_sharding_op=True) if use_bfloat16_activation: embedding_table = tf.cast(embedding_table, tf.bfloat16) if use_one_hot_embeddings: flat_input_ids = tf.reshape(input_ids, [-1]) one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) output = tf.matmul(one_hot_input_ids, embedding_table) else: output = tf.nn.embedding_lookup(embedding_table, input_ids) input_shape = get_shape_list(input_ids) output = tf.reshape(output, input_shape[0:-1] + [input_shape[-1] * embedding_size]) return (output, embedding_table)
def mask_rcnn_loss(mask_outputs, mask_targets, select_class_targets, params): """Computes the mask loss of Mask-RCNN. This function implements the mask loss of Mask-RCNN. As the `mask_outputs` produces `num_classes` masks for each RoI, the reference model expands `mask_targets` to match the shape of `mask_outputs` and selects only the target that the RoI has a maximum overlap. (Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/roi_data/mask_rcnn.py) # pylint: disable=line-too-long Instead, this function selects the `mask_outputs` by the `class_targets` so that it doesn't expand `mask_targets`. Args: mask_outputs: a float tensor representing the class prediction for each mask with a shape of [batch_size, num_masks, mask_height, mask_width, num_classes]. mask_targets: a float tensor representing the binary mask of ground truth labels for each mask with a shape of [batch_size, num_masks, mask_height, mask_width]. select_class_targets: a tensor with a shape of [batch_size, num_masks], representing the foreground mask targets. params: the dictionary including training parameters specified in default_haprams function in this file. Returns: mask_loss: a float tensor representing total mask loss. """ with tf.name_scope('mask_loss'): # Selects the mask from `mask_outputs` based on `class_targets`, with which # the mask has the maximum overlap. num_partitions = params['num_cores_per_replica'] if params['use_spmd'] else 1 if num_partitions is not None and num_partitions > 1: mask_outputs = xla_sharding.split( mask_outputs, 1, num_partitions, use_sharding_op=True) mask_targets = xla_sharding.split( mask_targets, 1, num_partitions, use_sharding_op=True) select_class_targets = xla_sharding.split( select_class_targets, 1, num_partitions, use_sharding_op=True) (batch_size, num_masks, mask_height, mask_width) = mask_outputs.get_shape().as_list() weights = tf.tile( tf.reshape(tf.greater(select_class_targets, 0), [batch_size, num_masks, 1, 1]), [1, 1, mask_height, mask_width]) loss = tf.losses.sigmoid_cross_entropy( mask_targets, mask_outputs, weights=weights, reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS) return params['mrcnn_weight_loss_mask'] * loss
def attention_layer(from_tensor, to_tensor, layer_idx, total_layers, attention_mask=None, num_attention_heads=1, size_per_head=512, query_act=None, key_act=None, value_act=None, attention_probs_dropout_prob=0.0, initializer_range=0.02, batch_size=None, from_seq_length=None, to_seq_length=None, num_partitions=1): """Performs multi-headed attention from `from_tensor` to `to_tensor`. This is an implementation of multi-headed attention based on "Attention is all you Need". If `from_tensor` and `to_tensor` are the same, then this is self-attention. Each timestep in `from_tensor` attends to the corresponding sequence in `to_tensor`, and returns a fixed-with vector. This function first projects `from_tensor` into a "query" tensor and `to_tensor` into "key" and "value" tensors. These are (effectively) a list of tensors of length `num_attention_heads`, where each tensor is of shape [batch_size, seq_length, size_per_head]. Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor and returned. In practice, the multi-headed attention are done with tf.einsum as follows: Input_tensor: [BFD] Wq, Wk, Wv: [DNH] Q:[BFNH] = einsum('BFD,DNH->BFNH', Input_tensor, Wq) K:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wk) V:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wv) attention_scores:[BNFT] = einsum('BFNH,BTNH>BNFT', Q, K) / sqrt(H) attention_probs:[BNFT] = softmax(attention_scores) context_layer:[BFNH] = einsum('BNFT,BTNH->BFNH', attention_probs, V) Wout:[DNH] Output:[BFD] = einsum('BFNH,DNH>BFD', context_layer, Wout) Args: from_tensor: float Tensor of shape [batch_size, from_seq_length, from_width]. to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. layer_idx: the index of the current layer. total_layers: total number of layers. attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length, to_seq_length]. The values should be 1 or 0. The attention scores will effectively be set to -infinity for any positions in the mask that are 0, and will be unchanged for positions that are 1. num_attention_heads: int. Number of attention heads. size_per_head: int. Size of each attention head. query_act: (optional) Activation function for the query transform. key_act: (optional) Activation function for the key transform. value_act: (optional) Activation function for the value transform. attention_probs_dropout_prob: (optional) float. Dropout probability of the attention probabilities. initializer_range: float. Range of the weight initializer. batch_size: (Optional) int. If the input is 2D, this might be the batch size of the 3D version of the `from_tensor` and `to_tensor`. from_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `from_tensor`. to_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `to_tensor`. num_partitions: (optional) Number of SPMD partitions. Returns: float Tensor of shape [batch_size, from_seq_length, num_attention_heads, size_per_head]. Raises: ValueError: Any of the arguments or tensor shapes are invalid. """ from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) if len(from_shape) != len(to_shape): raise ValueError( "The rank of `from_tensor` must match the rank of `to_tensor`.") if len(from_shape) == 3: batch_size = from_shape[0] from_seq_length = from_shape[1] to_seq_length = to_shape[1] elif len(from_shape) == 2: if (batch_size is None or from_seq_length is None or to_seq_length is None): raise ValueError( "When passing in rank 2 tensors to attention_layer, the values " "for `batch_size`, `from_seq_length`, and `to_seq_length` " "must all be specified.") # Scalar dimensions referenced here: # B = batch size (number of sequences) # F = `from_tensor` sequence length # T = `to_tensor` sequence length # N = `num_attention_heads` # H = `size_per_head` # `query_layer` = [B, F, N, H] query_layer = dense_layer_3d(from_tensor, layer_idx, total_layers, num_attention_heads, size_per_head, create_initializer(initializer_range), query_act, name="query") # `key_layer` = [B, T, N, H] key_layer = dense_layer_3d(to_tensor, layer_idx, total_layers, num_attention_heads, size_per_head, create_initializer(initializer_range), key_act, name="key") # `value_layer` = [B, T, N, H] value_layer = dense_layer_3d(to_tensor, layer_idx, total_layers, num_attention_heads, size_per_head, create_initializer(initializer_range), value_act, name="value") if num_partitions > 1: # partition along the heads dimension query_layer = xla_sharding.split(query_layer, 2, num_partitions, use_sharding_op=True) key_layer = xla_sharding.split(key_layer, 2, num_partitions, use_sharding_op=True) value_layer = xla_sharding.split(value_layer, 2, num_partitions, use_sharding_op=True) query_layer = tf.multiply(query_layer, 1.0 / math.sqrt(float(size_per_head))) # Take the dot product between "query" and "key" to get the raw # attention scores. attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_layer, query_layer) if attention_mask is not None: # `attention_mask` = [B, 1, F, T] attention_mask = tf.expand_dims(attention_mask, axis=[1]) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. adder = (1.0 - tf.cast(attention_mask, attention_scores.dtype)) * -10000.0 # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_scores += adder # Normalize the attention scores to probabilities. # `attention_probs` = [B, N, F, T] attention_scores = tf.cast(attention_scores, tf.float32) attention_scores = attention_scores - tf.stop_gradient( tf.reduce_max(attention_scores, -1, True)) attention_scores = tf.exp(attention_scores) attention_sum = tf.reduce_sum(attention_scores, -1, True) attention_probs = tf.cast(attention_scores, key_layer.dtype) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. # Split mask and scaling ops in dropout random_u = tf.random_uniform(attention_probs.shape, dtype=tf.bfloat16) keep_mask = random_u >= attention_probs_dropout_prob keep_mask = tf.cast(keep_mask, dtype=attention_probs.dtype) attention_probs = tf.multiply(keep_mask, attention_probs) # `context_layer` = [B, F, N, H] context_layer = tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_layer) context_layer = context_layer / tf.cast( tf.transpose(attention_sum, [0, 2, 1, 3]), context_layer.dtype) if num_partitions > 1: # partition along the heads dimension context_layer = xla_sharding.split(context_layer, 2, num_partitions, use_sharding_op=True) # split mask and scaling ops in dropout # move the scaling from dropout to here to save same mul ops # TODO(yuemmawang) automate this optimization in xla keep_prob = 1 - attention_probs_dropout_prob scale = 1 / keep_prob context_layer = tf.multiply(context_layer, scale) return context_layer
def transformer_model(input_tensor, attention_mask=None, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, intermediate_act_fn=gelu, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02, do_return_all_layers=False, num_partitions=1): """Multi-headed, multi-layer Transformer from "Attention is All You Need". This is almost an exact implementation of the original Transformer encoder. See the original paper: https://arxiv.org/abs/1706.03762 Also see: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py Args: input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, seq_length], with 1 for positions that can be attended to and 0 in positions that should not be. hidden_size: int. Hidden size of the Transformer. num_hidden_layers: int. Number of layers (blocks) in the Transformer. num_attention_heads: int. Number of attention heads in the Transformer. intermediate_size: int. The size of the "intermediate" (a.k.a., feed forward) layer. intermediate_act_fn: function. The non-linear activation function to apply to the output of the intermediate/feed-forward layer. hidden_dropout_prob: float. Dropout probability for the hidden layers. attention_probs_dropout_prob: float. Dropout probability of the attention probabilities. initializer_range: float. Range of the initializer (stddev of truncated normal). do_return_all_layers: Whether to also return all layers or just the final layer. num_partitions: (optional) int32. Number of SPMD partitions. Returns: float Tensor of shape [batch_size, seq_length, hidden_size], the final hidden layer of the Transformer. Raises: ValueError: A Tensor shape or parameter is invalid. """ if hidden_size % num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (hidden_size, num_attention_heads)) attention_head_size = int(hidden_size / num_attention_heads) input_shape = get_shape_list(input_tensor, expected_rank=3) input_width = input_shape[2] # The Transformer performs sum residuals on all layers so the input needs # to be the same as the hidden size. if input_width != hidden_size: raise ValueError( "The width of the input tensor (%d) != hidden size (%d)" % (input_width, hidden_size)) prev_output = input_tensor all_layer_outputs = [] for layer_idx in range(num_hidden_layers): with tf.variable_scope("layer_common", reuse=tf.AUTO_REUSE): layer_input = prev_output with tf.variable_scope("attention"): with tf.variable_scope("self"): attention_output = attention_layer( from_tensor=layer_input, to_tensor=layer_input, layer_idx=layer_idx, total_layers=num_hidden_layers, attention_mask=attention_mask, num_attention_heads=num_attention_heads, size_per_head=attention_head_size, attention_probs_dropout_prob= attention_probs_dropout_prob, initializer_range=initializer_range, num_partitions=num_partitions) # Run a linear projection of `hidden_size` then add a residual # with `layer_input`. with tf.variable_scope("output"): attention_output = dense_layer_3d_proj( attention_output, layer_idx, num_hidden_layers, hidden_size, num_attention_heads, attention_head_size, create_initializer(initializer_range), None, "dense") attention_output = dropout(attention_output, hidden_dropout_prob) attention_output = layer_norm( attention_output + layer_input, layer_idx, num_hidden_layers) # The activation is only applied to the "intermediate" hidden layer. with tf.variable_scope("intermediate"): intermediate_output = dense_layer_2d( attention_output, layer_idx, num_hidden_layers, intermediate_size, create_initializer(initializer_range), intermediate_act_fn, "dense") if num_partitions > 1: # partition along the feature dimension intermediate_output = xla_sharding.split(intermediate_output, 2, num_partitions, use_sharding_op=True) # Down-project back to `hidden_size` then add the residual. with tf.variable_scope("output"): layer_output = dense_layer_2d( intermediate_output, layer_idx, num_hidden_layers, hidden_size, create_initializer(initializer_range), None, "dense") layer_output = dropout(layer_output, hidden_dropout_prob) layer_output = layer_norm(layer_output + attention_output, layer_idx, num_hidden_layers) prev_output = layer_output all_layer_outputs.append(layer_output) if do_return_all_layers: return all_layer_outputs else: return all_layer_outputs[-1]