def attention_bias_local_block(mesh, block_length, memory_length, dtype=tf.int32): """Bias for attention for local blocks where attention to right is disallowed. Create the bias matrix by using two separate masks, one for the memory part which doesn't overlap with the query and second which interacts with the query and should be disallowed to look to the right of the current query position. Args: mesh: a MeshTensorflow object block_length: a mtf.Dimension memory_length: a mtf.Dimension dtype: a tf.dtype Returns: a mtf.Tensor with shape [block_length, memory_length] """ memory_length = mtf.Dimension(memory_length.name, block_length.size) memory_mask = mtf.zeros(mesh, [block_length, memory_length], dtype=dtype) mask = mtf.cast(mtf.less(mtf.range(mesh, block_length, dtype=dtype), mtf.range(mesh, memory_length, dtype=dtype)), dtype=dtype) mask = mtf.cast(mtf.concat([memory_mask, mask], memory_length.name), dtype=tf.float32) * -1e9 return mask
def create_positional_emb_2d(self, targets, max_length_dim, model_dim): """Learned 2d positional embedding for images.""" mesh = targets.mesh hparams = self._hparams activation_dtype = self.set_activation_type() rows_dim = mtf.Dimension("rows", hparams.img_len) cols_dim = mtf.Dimension("cols", hparams.img_len * hparams.num_channels) positional_emb_rows_var = mtf.get_variable( mesh, "positional_emb_rows", mtf.Shape([max_length_dim, model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) positional_emb_cols_var = mtf.get_variable( mesh, "positional_emb_cols", mtf.Shape([max_length_dim, model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) targets_position_x = mtf.range(mesh, rows_dim, dtype=tf.int32) targets_position_y = mtf.range(mesh, cols_dim, dtype=tf.int32) position_x = mtf.broadcast( mtf.gather(positional_emb_rows_var, targets_position_x, max_length_dim), mtf.Shape([rows_dim, cols_dim, model_dim])) position_y = mtf.broadcast( mtf.gather(positional_emb_cols_var, targets_position_y, max_length_dim), mtf.Shape([rows_dim, cols_dim, model_dim])) return position_x + position_y
def create_positional_emb_2d(self, targets): """Learned 2d positional embedding for images.""" mesh = targets.mesh positional_emb_rows_var = mtf.get_variable( mesh, "positional_emb_rows", mtf.Shape([self.max_length_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_type) positional_emb_cols_var = mtf.get_variable( mesh, "positional_emb_cols", mtf.Shape([self.max_length_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_type) targets_position_x = mtf.range(mesh, self.rows_dim, dtype=tf.int32) targets_position_y = mtf.range(mesh, self.cols_dim, dtype=tf.int32) position_x = mtf.broadcast( mtf.gather(positional_emb_rows_var, targets_position_x, self.max_length_dim), mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim])) position_y = mtf.broadcast( mtf.gather(positional_emb_cols_var, targets_position_y, self.max_length_dim), mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim])) return position_x + position_y
def attention_bias_local_block(mesh, block_length, memory_length, dtype=tf.int32): """Bias for attention for local blocks where attention to right is disallowed. Args: mesh: a MeshTensorflow object block_length: a mtf.Dimension memory_length: a mtf.Dimension dtype: a tf.dtype Returns: a mtf.Tensor with shape [rows, cols] """ mask = mtf.cast(mtf.less(mtf.range(mesh, block_length, dtype=dtype), mtf.range(mesh, memory_length, dtype=dtype)), dtype=dtype) mask = mtf.cast(mask, dtype=tf.float32) * -1e9 return mask
def multihead_self_attention_incremental(query_antecedent, prev_k, prev_v, step_num, name="multihead_attention"): """Incremental self-attention (one decode step). In order to use only one variable containing the four weight matrices packed together, we insist that the query and memory antecedents have the same dimensionality (io_channels) and that the keys and values have the same dimensionality (kv_channels). Args: query_antecedent: a mtf.Tensor with shape [batch..., io_channels] prev_k: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels] prev_v: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels] step_num: mtf Scalar with dtype tf.int32 name: an optional string. Returns: y: A mtf.Tensor with shape [batch..., io_channels] new_k: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels] new_v: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels] Raises: ValueError: if the dimensions do not match. """ batch_dims = query_antecedent.shape.dims[:-1] io_channels = query_antecedent.shape.dims[-1] heads, memory_length, kv_channels = prev_k.shape.dims[-3:] with tf.variable_scope(name, default_name="multihead_attention"): q_var, k_var, v_var, o_var = multihead_attention_vars( query_antecedent.mesh, heads, io_channels, kv_channels, query_antecedent.dtype) memory_antecedent = query_antecedent q = mtf.einsum( [query_antecedent, q_var], mtf.Shape(batch_dims + [heads, kv_channels])) k = mtf.einsum( [memory_antecedent, k_var], mtf.Shape(batch_dims + [heads, kv_channels])) v = mtf.einsum( [memory_antecedent, v_var], mtf.Shape(batch_dims + [heads, kv_channels])) k = prev_k + mtf.multiply( k, mtf.one_hot(step_num, memory_length), output_shape=prev_k.shape) v = prev_v + mtf.multiply( v, mtf.one_hot(step_num, memory_length), output_shape=prev_v.shape) mask = mtf.to_float(mtf.greater(mtf.range( query_antecedent.mesh, memory_length, dtype=tf.int32), step_num) ) * -1e9 o = dot_product_attention(q, k, v, mask) y = mtf.einsum([o, o_var], query_antecedent.shape) return y, k, v
def _mtf_model_fn(self, features, mesh): features = copy.copy(features) hparams = self._hparams targets = tf.to_int32(features["targets"]) if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) for key in ["targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position"]: if key in features: features[key] = pad_to_max_length(features[key]) shifted_targets = common_layers.shift_right_2d(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) shifted_targets = self._import_to_batch_by_length( shifted_targets, "shifted_targets", mesh, hparams) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = ( mtf_layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) + mtf_layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype)) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) decoder_self_attention_mask = mtf_layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([self.batch_dim, self.model_dim])) extra_losses = [] (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if self.has_input: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf_layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) encoder_decoder_attention_mask = ( mtf_layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf_layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) encoder_decoder_attention_mask = encoder_self_attention_mask x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.num_encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) else: encoder_output = None encoder_decoder_attention_mask = None # DECODER x = (mtf.gather( targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather( positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) # Decoder with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.num_decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) logits = mtf.matmul(x, softmax_var) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf_layers.weights_nonzero( targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l return logits, loss
def _truncated_top_2_gating_mtf( gates, group_dim, experts_dim, expert_capacity_dim): """Compute gating for mixture-of-experts in TensorFlow. gates is usually the output of a softmax function. The return value is a dense representation of the mapping between the input positions in the positions in the batches sent to the experts. TODO(noam): this function contains code factored out of expert_utils.local_moe_tpu. Move this function to that file and call it from both places. Args: gates: a Tensor group_dim: one dimension of gates experts_dim: one dimension of gates expert_capacity_dim: a Dimension not in gates Returns: a Tensor with shape gates.shape + expert_capacity_dim Raises: ValueError: if group_dim has size >256 """ gates = mtf.to_float(gates) expert_capacity_f = float(expert_capacity_dim.size) # Find the top expert for each position. shape=[batch, group] index_1, gate_1 = mtf.top_1(gates, experts_dim) # [batch, group, experts] mask_1 = mtf.one_hot(index_1, experts_dim, dtype=gates.dtype) if expert_capacity_dim.size > 256: # using mtf.cumsum (implemented on TPU as bfloat16 matmul) to compute # position in the mini-batch sent to the expert. This will cause # very bad things to happen if expert_capacity_dim > 256. raise ValueError( "expert_capacity_dim.size must be <=256 to avoid roundoff errors in" " indices - got %s" % (expert_capacity_dim,)) # [batch, group, experts] # This is the position within the expert's mini-batch for this sequence position_in_expert_1 = mtf.cumsum(mask_1, group_dim, exclusive=True) * mask_1 # Remove the elements that don't fit. [batch, group, experts] mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f)) # [batch, experts] # How many examples in this sequence go to this expert mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_dim) # [batch, group] - mostly ones, but zeros where something didn't fit mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim) # [batch, group] position_in_expert_1 = mtf.reduce_sum( position_in_expert_1, reduced_dim=experts_dim) # Weight assigned to first expert. [batch, group] gate_1 *= mask_1_flat # Pick a second-place expert for each position. # We first mask out the experts that we expect to be over-capacity # [batch, experts] space_remaining = expert_capacity_f - mask_1_count use_rate = (mask_1_count + 1.0) / float(group_dim.size) # At what point in the sequence do we expect the expert to be full. # [batch, experts] expected_exhaustion_pos = space_remaining / use_rate # A Tensor with shape [batch, group, experts] representing a boolean # - whether we expect that the expert will already be full. expected_exhausted = mtf.to_float(mtf.greater( mtf.range(gates.mesh, group_dim, tf.float32), expected_exhaustion_pos)) masked_gates = gates - mask_1 - expected_exhausted # This section is similar to the section above. # [batch, group] index_2, gate_2 = mtf.top_1(masked_gates, experts_dim) # [batch, group, experts] mask_2 = mtf.one_hot(index_2, experts_dim, dtype=gates.dtype) # [batch, group, experts] position_in_expert_2 = ( mtf.cumsum(mask_2, group_dim, exclusive=True) + mask_1_count) position_in_expert_2 *= mask_2 mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f)) # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) position_in_expert_2 = mtf.reduce_sum( position_in_expert_2, reduced_dim=experts_dim) gate_2 *= mask_2_flat # renormalize the two gate values to add up to 1 denom = gate_1 + gate_2 + 1e-9 gate_1 /= denom gate_2 /= denom # [batch, group, experts, expert_capacity] assignment = ( gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) + gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim)) return assignment
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() # We assume fixed vocab size for targets targets_vocab_size = self._problem_hparams.target_modality._vocab_size # pylint: disable=protected-access targets = tf.to_int32(features["targets"]) # Image preprocessing, reshape into a 1D sequence and shift right. length = hparams.img_len * hparams.img_len * hparams.num_channels targets = tf.reshape(targets, [hparams.batch_size, length]) shifted_targets = common_layers.shift_right_2d(targets) # Declare all the dimensions model_dim = mtf.Dimension("d_model", hparams.hidden_size) batch_dim = mtf.Dimension("batch", hparams.batch_size) length_dim = mtf.Dimension("length", length) max_length_dim = mtf.Dimension("max_length", hparams.max_length) filter_dim = mtf.Dimension("d_ff", hparams.d_ff) kv_channels = mtf.Dimension("kv_channels", hparams.d_kv) heads = mtf.Dimension("heads", hparams.num_heads) def import_to_batch_by_length(x, name): return mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, length_dim]), name=name) def layer_prepostprocess_dropout(x): return mtf.dropout(x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([batch_dim, model_dim])) targets = import_to_batch_by_length(targets, "targets") shifted_targets = import_to_batch_by_length(shifted_targets, "shifted_targets") extra_losses = [] # TODO(nikip): Verify conditional. if self.has_input and not hparams.unconditional: vocab_size = hparams.num_classes inputs_vocab_dim = mtf.Dimension("vocab", vocab_size) inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = import_to_batch_by_length(inputs, "inputs") # Input embeddings inputs, _ = mtf_layers.embedding(inputs, inputs_vocab_dim, model_dim, activation_dtype=activation_dtype, name="inputs_embedding") # Create targets content and position embeddings. targets_position = mtf.range(mesh, length_dim, dtype=tf.int32) targets_vocab_size = 256 * hparams.num_channels targets_vocab_dim = mtf.Dimension("vocab", targets_vocab_size) outputs_vocab_dim = mtf.Dimension("output_vocab", 256) # Create embedding var for targets and positions and do a gather. targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([targets_vocab_dim, model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) positional_embedding_var = mtf.get_variable( mesh, "positional_embedding", mtf.Shape([max_length_dim, model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) x = (mtf.gather(targets_embedding_var, shifted_targets, targets_vocab_dim) + mtf.gather(positional_embedding_var, targets_position, max_length_dim)) # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_decoder_layers): layer_name = "decoder_layer_%d" % layer with tf.variable_scope(layer_name): # Self attention layer x += layer_prepostprocess_dropout( mtf_layers.masked_local_attention_1d( mtf_layers.layer_norm(x, model_dim, name="layer_norm_self_att"), None, kv_channels, heads, block_length=hparams.block_length, name="self_att")) # ffn layer x += layer_prepostprocess_dropout( mtf_layers.dense_relu_dense( mtf_layers.layer_norm(x, model_dim, name="layer_norm_ffn"), filter_dim, hparams.dropout, dropout_broadcast_dims=[length_dim])) x = mtf_layers.layer_norm(x, model_dim, name="decoder_final_layer_norm") # Calculate the logits and loss. logits = mtf_layers.dense(x, outputs_vocab_dim, name="logits") soft_targets = mtf.one_hot(targets, outputs_vocab_dim, dtype=activation_dtype) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, soft_targets, outputs_vocab_dim) loss = mtf.reduce_mean(loss) for l in extra_losses: loss += l return logits, loss