def shift_targets_no_offset(targets, bos_id=0, eos_id=1): """Transforms decoder labels to decoder inputs. Args: targets: decoder labels bos_id: begin of sequence id, defaults to 0 eos_id: end of sequence id, defaults to 1 Returns: Decoder inputs. """ length_dim = targets.shape.dims[-1] shifted_targets = targets # We should have a 0 at the beginning of each sequence rather than the # shifted EOS (e.g. 1) from the previous sequence. shifted_targets *= mtf.to_int32(mtf.not_equal(shifted_targets, eos_id)) if bos_id: shifted_targets += mtf.to_int32( mtf.logical_and(mtf.equal(shifted_targets, 0), mtf.not_equal(targets, 0))) * bos_id return shifted_targets
def get_cluster_mask(self, targets): """Computes mask over the targets masking out tokens not in the cluster.""" return mtf.logical_and( mtf.greater_equal(targets, self._start_token_id), mtf.less(targets, self._end_token_id))
def local_attention_1d(q, k, v, length_dim, key_dim, value_dim, fully_autoregressive=True, length_dim_num_splits=1, radius=128, sequence_id=1, write_priority=None, read_priority=None, attention_kwargs=None): """Attention to the a neighborood around the source. If fully_autoregressive, then query position p can only see memory positions in the range (p - radius, p]. If not fully_autoregressive, then query position p can only see memory positions in the range (p - window_size, p + radius]. In addition, if write_priority and read_priority are provided, then attention is limited to position pairs where read_priority[query position] >= write_priority[memory position] Args: q: a Tensor containing length_dim k: a Tensor containing length_dim v: an optional Tensor containing length_dim. If none then uses v=k. length_dim: a Dimension key_dim: a Dimension (the channels dimension of q and k) value_dim: a Dimension (the channels dimension of v) fully_autoregressive: a boolean length_dim_num_splits: an optional integer indicating how many ways the length dimension is split radius: an integer sequence_id: a Tensor or an integer write_priority: an optional Tensor containing length_dim read_priority: an optional Tensor containing length_dim attention_kwargs: optional keyword arguments for attention() Returns: a Tensor with the shape x.shape - key_dim + value_dim Raises: ValueError: if channels or depth don't match. """ # Choose a suitable block size. # We choose the greatest divisor of length_per_split less than or equal # to max(window_size, 128) length_per_split = length_dim.size // length_dim_num_splits block_length = max(radius, 128) while length_per_split % block_length != 0: block_length -= 1 query_block_length = mtf.Dimension("query_block_length", block_length) memory_block_length = mtf.Dimension("memory_block_length", block_length) # The num_blocks dimension gets the same name as the length dimension, # so it will be split in the same way. num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length) def _reshape_query(x): return mtf.replace_dimensions( x, length_dim, [num_blocks, query_block_length]) def _reshape_memory(x): x = mtf.replace_dimensions( x, length_dim, [num_blocks, memory_block_length]) return (mtf.left_halo_exchange if fully_autoregressive else mtf.halo_exchange)( x, num_blocks, memory_block_length, radius) q = _reshape_query(q) k = _reshape_memory(k) if v: v = _reshape_memory(v) else: v = k if sequence_id is None: sequence_id = 1 if (not isinstance(sequence_id, mtf.Tensor) or length_dim not in sequence_id.shape.dims): sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32) q_sequence_id = _reshape_query(sequence_id) m_sequence_id = _reshape_memory(sequence_id) pos = mtf.range(q.mesh, length_dim, dtype=tf.int32) q_pos = _reshape_query(pos) m_pos = _reshape_memory(pos) padded_memory_block_length = mtf.Dimension( "memory_block_length", (1 if fully_autoregressive else 2) * radius + block_length) relative_position = m_pos - q_pos visible = mtf.equal(q_sequence_id, m_sequence_id) visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius)) visible = mtf.logical_and(visible, mtf.less_equal( relative_position, 0 if fully_autoregressive else radius)) if read_priority is not None: write_priority = _reshape_memory(write_priority) read_priority = _reshape_query(read_priority) visible = mtf.logical_and( visible, mtf.greater_equal(read_priority, write_priority)) bias = visibility_mask_to_attention_bias(visible, q.dtype) o = attention(q, k, v, padded_memory_block_length, key_dim, value_dim, bias, **attention_kwargs) return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)