def cond_fn(position, ids, *unused_states): """Should we run another loop iteration?""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum( mtf.to_int32(mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done)
def cond_fn(position, ids, *unused_states): """Should we run another loop iteration.""" past_end = mtf.greater_equal(position, length_dim.size) is_done = past_end if stop_at_token is not None: has_eos = mtf.reduce_any( mtf.equal(ids, stop_at_token), reduced_dim=length_dim) is_done = mtf.logical_or(is_done, has_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done)
def call_simple(self, inputs, targets, compute_loss, attributes=None, mode=tf.estimator.ModeKeys.TRAIN, variable_dtype=mtf.VariableDType(tf.float32), sequence_id=None, subsequence_id=None, position=None, encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, layer_outputs=None, encoder_layer_outputs=None, z=None): """Compute logits based on inputs (all positions in parallel). This is called during training and evaluation. Args: inputs: an int32 Tensor with shape [<batch_dims>, length_dim] For training autoregressive models this should be equal to mtf.shift(targets, offset=1, dim=length_dim, wrap=False) targets: an optional int32 Tensor with shape [<batch_dims>, length_dim] compute_loss: a boolean attributes: an (optional?) int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>]) mode: a tf.estimator.ModeKeys variable_dtype: a mtf.VariableDType sequence_id: an optional Tensor subsequence_id: an optional Tensor position: an optional Tensor encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary layer_outputs: an optional list to append Tensor layer activations to encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: logits: a Tensor with shape [<batch_dims>, output_vocab_dim] loss: an optional Scalar (if compute_loss=True) """ batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] length_range = mtf.range(inputs.mesh, length_dim, dtype=tf.int32) if not self.positional_embedding: # To make relative attention faster, we drop the information about the # position in the subsequence. The relative attention code then # assumes that the positions are given by index in the tensor, # which still leads to the correct computation of relative position. position = None if position is None: position_is_default = True position = length_range else: position_is_default = False if self.input_full_attention: # The inputs part of each sequence can fully attend within itself. full_attention_region = delimited_lm_inputs_mask(targets) # We can include one additional position to the right - the position # where the final EOS of the inputs is read and the first target token # is predicted. full_attention_region = mtf.logical_or( full_attention_region, mtf.shift(full_attention_region, offset=1, dim=length_dim, wrap=False)) # We set read_priority and write_priority to 0 in the full-attention # region and equal to the position elsewhere. read_priority = write_priority = length_range * mtf.cast( mtf.logical_not(full_attention_region), tf.int32) elif self.autoregressive: # Vanilla autoregressive model - each position can see previous positions. read_priority = write_priority = length_range else: read_priority = write_priority = None context = Context(model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode=mode, losses=[] if compute_loss else None, sequence_id=sequence_id, subsequence_id=subsequence_id, position=position, position_is_default=position_is_default, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=layer_outputs, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name): logits = self._call_internal(context, inputs, targets, attributes, z=z) if compute_loss: loss = mtf.add_n(context.losses) else: loss = None return logits, loss
def local_attention_1d(q, k, v, length_dim, key_dim, value_dim, autoregressive=True, length_dim_num_splits=1, radius=128, sequence_id=1, attention_kwargs=None): """Attention to the a neighborood around the source. If autoregressive, then query position p can only see memory positions in the range (p - radius, p]. If not autoregressive, then query position p can only see memory positions in the range (p - window_size, p + radius]. Args: q: a Tensor containing length_dim k: a Tensor containing length_dim v: an optional Tensor containing length_dim. If none then uses v=k. length_dim: a Dimension key_dim: a Dimension (the channels dimension of q and k) value_dim: a Dimension (the channels dimension of v) autoregressive: a boolean length_dim_num_splits: an optional integer indicating how many ways the length dimension is split radius: an integer sequence_id: a Tensor or an integer attention_kwargs: optional keyword arguments for attention() Returns: a Tensor with the shape x.shape - key_dim + value_dim Raises: ValueError: if channels or depth don't match. """ # Choose a suitable block size. # We choose the greatest divisor of length_per_split less than or equal # to max(window_size, 128) length_per_split = length_dim.size // length_dim_num_splits block_length = max(radius, 128) while length_per_split % block_length != 0: block_length -= 1 query_block_length = mtf.Dimension("query_block_length", block_length) memory_block_length = mtf.Dimension("memory_block_length", block_length) # The num_blocks dimension gets the same name as the length dimension, # so it will be split in the same way. num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length) def _reshape_query(x): return mtf.replace_dimensions( x, length_dim, [num_blocks, query_block_length]) def _reshape_memory(x): x = mtf.replace_dimensions( x, length_dim, [num_blocks, memory_block_length]) return (mtf.left_halo_exchange if autoregressive else mtf.halo_exchange)( x, num_blocks, memory_block_length, radius) q = _reshape_query(q) k = _reshape_memory(k) if v: v = _reshape_memory(v) else: v = k if sequence_id is None: sequence_id = 1 if (not isinstance(sequence_id, mtf.Tensor) or length_dim not in sequence_id.shape.dims): sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32) q_sequence_id = _reshape_query(sequence_id) m_sequence_id = _reshape_memory(sequence_id) pos = mtf.range(q.mesh, length_dim, dtype=tf.int32) q_pos = _reshape_query(pos) m_pos = _reshape_memory(pos) padded_memory_block_length = mtf.Dimension( "memory_block_length", (1 if autoregressive else 2) * radius + block_length) relative_position = m_pos - q_pos illegal = mtf.not_equal(q_sequence_id, m_sequence_id) illegal = mtf.logical_or(illegal, mtf.less_equal(relative_position, -radius)) illegal = mtf.logical_or(illegal, mtf.greater( relative_position, 0 if autoregressive else radius)) mask = mtf.cast(illegal, q.dtype) * -1e9 o = attention(q, k, v, padded_memory_block_length, key_dim, value_dim, mask, **attention_kwargs) return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)