def rename_length_to_memory_length(x, length_name="length", memory_length_name="memory_length"): return mtf.rename_dimension(x, length_name, memory_length_name)
def local_self_attention_spatial_blocks(query_antecedent, kv_channels, heads, memory_w_dim=None, mask_right=False, name=None): """Attention to the source position and a neighborhood to the left or right. The sequence is divided into blocks of length block_size. Attention for a given query position can only see memory positions less than or equal to the query position, in the corresponding block and the previous block. Args: query_antecedent: a mtf.Tensor with shape [batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels] must have the same size as query_length, but a different name. kv_channels: a mtf.Dimension (the size of the key and value vectors) heads: a mtf.Dimension (the number of heads) memory_w_dim: mtf Dimension, for the memory width block. mask_right: bool, flag specifying whether we mask out attention to the right for the decoder. name: an optional string. Returns: a Tensor of shape [batch, num_h_blocks, num_w_blocks, h_dim, w_dim, io_channels] Raises: ValueError: if channels or depth don't match. """ with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent]): w_dim, io_channels = query_antecedent.shape.dims[-2:] batch, num_w_blocks = query_antecedent.shape.dims[:2] q_var, k_var, v_var, o_var = multihead_attention_vars( query_antecedent.mesh, heads, io_channels, kv_channels, query_antecedent.dtype) # Rename dimensions for the memory height and width. memory_antecedent = mtf.rename_dimension(query_antecedent, w_dim.name, memory_w_dim.name) # Call einsum over the query and memory to get query q, keys k and values v. q = mtf.einsum([query_antecedent, q_var], mtf.Shape( [batch, heads, num_w_blocks, w_dim, kv_channels])) k = mtf.einsum([memory_antecedent, k_var], mtf.Shape( [batch, heads, num_w_blocks, w_dim, kv_channels])) v = mtf.einsum([memory_antecedent, v_var], mtf.Shape( [batch, heads, num_w_blocks, w_dim, kv_channels])) # Halo exchange for memory blocks. if memory_w_dim is not None: k, v = local_1d_halo_exchange(k, v, num_w_blocks, w_dim, memory_w_dim, mask_right) # Calculate the causal mask to avoid peeking into the future. We compute # this once and reuse it for all blocks since the block_size is known. mask = None if mask_right: mask = attention_bias_local_block(query_antecedent.mesh, w_dim, memory_w_dim) output = dot_product_attention(q, k, v, mask=mask) return mtf.einsum([output, o_var], mtf.Shape([batch, num_w_blocks, w_dim, io_channels]))
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if self.has_input: inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad( inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length( inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = ( mtf_layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.num_encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num in xrange(hparams.num_decoder_layers): with tf.variable_scope("decoder/layer_%d/encdec_attention" % layer_num): q_var, k_var, v_var, o_var = mtf_layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.activation_dtype) k = mtf.einsum( [encoder_output, k_var], mtf.Shape( [self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim])) v = mtf.einsum( [encoder_output, v_var], mtf.Shape( [self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim])) encdec_tensors.append((q_var, o_var, k, v)) partial_targets = None else: encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) if hparams.beam_size == 1: ids_shape = mtf.Shape([self.batch_dim, self.length_dim]) kv_shape = mtf.Shape([self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape([self.batch_dim, beam_dim, self.length_dim]) kv_shape = mtf.Shape([self.batch_dim, beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_kv_states = ( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * (2 * hparams.num_decoder_layers)) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" self_attention_k = states[:hparams.num_decoder_layers] self_attention_v = states[hparams.num_decoder_layers:] ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_self_attention_k, new_self_attention_v = ( self._decoder_layer_stack_incremental( x, step_num, encdec_tensors, self_attention_k, self_attention_v, encdec_attention_mask=encoder_attention_mask)) logits = mtf.matmul(x, softmax_var) return logits, new_self_attention_k + new_self_attention_v if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf_beam_search.greedy_decode( logits_fn, initial_ids, temperature=temperature, initial_states=initial_kv_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if self.has_input: input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf_beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_kv_states, decode_length=decode_length, use_tpu=hparams.use_tpu) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def _mtf_model_fn(self, features, mesh): features = copy.copy(features) hparams = self._hparams targets = tf.to_int32(features["targets"]) if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) for key in ["targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position"]: if key in features: features[key] = pad_to_max_length(features[key]) shifted_targets = common_layers.shift_right_2d(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) shifted_targets = self._import_to_batch_by_length( shifted_targets, "shifted_targets", mesh, hparams) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = ( mtf_layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) + mtf_layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype)) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) decoder_self_attention_mask = mtf_layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([self.batch_dim, self.model_dim])) extra_losses = [] (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if self.has_input: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf_layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) encoder_decoder_attention_mask = ( mtf_layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf_layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) encoder_decoder_attention_mask = encoder_self_attention_mask x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.num_encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) else: encoder_output = None encoder_decoder_attention_mask = None # DECODER x = (mtf.gather( targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather( positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) # Decoder with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.num_decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) logits = mtf.matmul(x, softmax_var) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf_layers.weights_nonzero( targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l return logits, loss
def bottleneck_block(inputs, filters, is_training, strides, projection_shortcut=None, row_blocks_dim=None, col_blocks_dim=None): """Bottleneck block variant for residual networks with BN after convolutions. Args: inputs: a `mtf.Tensor` of shape `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`. filters: `int` number of filters for the first two convolutions. Note that the third and final convolution will use 4 times as many filters. is_training: `bool` for whether the model is in training mode. strides: `int` block stride. If greater than 1, this block will ultimately downsample the input. projection_shortcut: `function` to use for projection shortcuts (typically a 1x1 convolution to match the filter dimensions). If None, no projection is used and the input is passed as unchanged through the shortcut connection. row_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis col_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis Returns: The output `Tensor` of the block. """ shortcut = inputs filter_h_dim = mtf.Dimension("filter_height", 3) filter_w_dim = mtf.Dimension("filter_width", 3) one_h_dim = mtf.Dimension("filter_height", 1) one_w_dim = mtf.Dimension("filter_width", 1) if projection_shortcut is not None: filters_dim = mtf.Dimension("filtersp", filters) kernel = mtf.get_variable( inputs.mesh, "kernel", mtf.Shape( [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim])) shortcut = projection_shortcut(inputs, kernel) # First conv block filters1_dim = mtf.Dimension("filters1", filters) kernel1 = mtf.get_variable( inputs.mesh, "kernel1", mtf.Shape([one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim])) inputs = mtf.conv2d_with_blocks(inputs, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) # TODO(nikip): Add Dropout? inputs = batch_norm_relu(inputs, is_training) # Second conv block filters2_dim = mtf.Dimension("filters2", filters) kernel2 = mtf.get_variable( inputs.mesh, "kernel2", mtf.Shape([filter_h_dim, filter_w_dim, filters1_dim, filters2_dim])) inputs = mtf.conv2d_with_blocks(inputs, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim) inputs = batch_norm_relu(inputs, is_training) # Third wide conv filter block filters3_dim = mtf.Dimension("filters3", filters) filters3_kernel = mtf.get_variable( inputs.mesh, "wide_kernel", mtf.Shape([one_h_dim, one_w_dim, filters2_dim, filters3_dim])) inputs = mtf.conv2d_with_blocks(inputs, filters3_kernel, strides, padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) inputs = batch_norm_relu(inputs, is_training, relu=False) # TODO(nikip): Maybe add residual with a projection? return mtf.relu(inputs + mtf.rename_dimension( shortcut, shortcut.shape.dims[-1].name, inputs.shape.dims[-1].name))
def _my_concat(a, b): a = mtf.rename_dimension(a, "beam", "triple_beam") b = mtf.rename_dimension(b, "double_beam", "triple_beam") return mtf.concat([a, b], "triple_beam")