def attention_bias_local_block(mesh, block_length, memory_length, dtype=tf.int32): """Bias for attention for local blocks where attention to right is disallowed. Create the bias matrix by using two separate masks, one for the memory part which doesn't overlap with the query and second which interacts with the query and should be disallowed to look to the right of the current query position. Args: mesh: a MeshTensorflow object block_length: a mtf.Dimension memory_length: a mtf.Dimension dtype: a tf.dtype Returns: a mtf.Tensor with shape [block_length, memory_length] """ memory_length = mtf.Dimension(memory_length.name, block_length.size) memory_mask = mtf.zeros(mesh, [block_length, memory_length], dtype=dtype) mask = mtf.cast(mtf.less(mtf.range(mesh, block_length, dtype=dtype), mtf.range(mesh, memory_length, dtype=dtype)), dtype=dtype) mask = mtf.cast(mtf.concat([memory_mask, mask], memory_length.name), dtype=tf.float32) * -1e9 return mask
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if self.has_input: inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad( inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length( inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = ( mtf_layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.num_encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num in xrange(hparams.num_decoder_layers): with tf.variable_scope("decoder/layer_%d/encdec_attention" % layer_num): q_var, k_var, v_var, o_var = mtf_layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.activation_dtype) k = mtf.einsum( [encoder_output, k_var], mtf.Shape( [self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim])) v = mtf.einsum( [encoder_output, v_var], mtf.Shape( [self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim])) encdec_tensors.append((q_var, o_var, k, v)) partial_targets = None else: encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) if hparams.beam_size == 1: ids_shape = mtf.Shape([self.batch_dim, self.length_dim]) kv_shape = mtf.Shape([self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape([self.batch_dim, beam_dim, self.length_dim]) kv_shape = mtf.Shape([self.batch_dim, beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_kv_states = ( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * (2 * hparams.num_decoder_layers)) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" self_attention_k = states[:hparams.num_decoder_layers] self_attention_v = states[hparams.num_decoder_layers:] ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_self_attention_k, new_self_attention_v = ( self._decoder_layer_stack_incremental( x, step_num, encdec_tensors, self_attention_k, self_attention_v, encdec_attention_mask=encoder_attention_mask)) logits = mtf.matmul(x, softmax_var) return logits, new_self_attention_k + new_self_attention_v if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf_beam_search.greedy_decode( logits_fn, initial_ids, temperature=temperature, initial_states=initial_kv_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if self.has_input: input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf_beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_kv_states, decode_length=decode_length, use_tpu=hparams.use_tpu) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)