def call(self, context, x, losses=None): """Call the layer.""" m = self._get_memory_antecedent(context) memory_input_dim = m.shape[-1] if memory_input_dim != context.model.model_dim: raise NotImplementedError( "TODO(noam): support different model_dim in encoder and decoder." ) q = self.compute_q(context, x) if context.mode == "incremental": k, v, memory_length = context.get_constant_state() else: k = self.compute_k(context, m) v = self.compute_v(context, m) memory_length, = [ d for d in m.shape.dims if d.name == "memory_length" ] if context.mode == "first_part": context.record_constant_state((k, v, memory_length)) if context.encoder_sequence_id and context.sequence_id: visible = mtf.equal(context.sequence_id, context.encoder_sequence_id) bias = attention.visibility_mask_to_attention_bias( visible, context.activation_dtype) else: bias = None return self.attention_internal(context, x, m, q, k, v, memory_length, bias)
def enc_dec_attention(self_attention_layer, memory_antecedent, context, x, losses): """Multi-head attention over the encoder outputs.""" memory_input_dim = memory_antecedent.shape[-1] if memory_input_dim != context.model.model_dim: raise NotImplementedError( "TODO(noam): support different model_dim in encoder and decoder.") params = self_attention_layer.make_params(context) q = params.compute_q(x) if context.mode == "incremental": k, v, memory_length = context.get_constant_state() else: m = memory_antecedent if self_attention_layer.shared_kv: kv = params.compute_kv(m) k = kv v = kv else: k = params.compute_k(m) v = params.compute_v(m) memory_length, = [d for d in m.shape.dims if d.name == "memory_length"] if context.mode == "first_part": context.record_constant_state((k, v, memory_length)) if context.encoder_sequence_id and context.sequence_id: visible = mtf.equal(context.sequence_id, context.encoder_sequence_id) bias = attention.visibility_mask_to_attention_bias(visible, context.activation_dtype) else: bias = None a = attention.attention( q, k, v, memory_length, self_attention_layer.kv_dim, self_attention_layer.kv_dim, bias, **self_attention_layer.attention_kwargs_from_context(context)) attention_output_shape = self_attention_layer.expected_attention_output_shape( x, params) attention_output = params.compute_output( a, output_shape=attention_output_shape) return self_attention_layer.layer_output_from_attention_output( context, attention_output, losses)
def call(self, context, x, losses=None): """Call the layer.""" memory_antecedent = self._get_memory_antecedent(context) memory_input_dim = memory_antecedent.shape[-1] if memory_input_dim != context.model.model_dim: raise NotImplementedError( "TODO(noam): support different model_dim in encoder and decoder." ) params = self.make_params(context) q = params.compute_q(x) if context.mode == "incremental": k, v, memory_length = context.get_constant_state() else: m = memory_antecedent if self.shared_kv: kv = params.compute_kv(m) k = kv v = kv else: k = params.compute_k(m) v = params.compute_v(m) memory_length, = [ d for d in m.shape.dims if d.name == "memory_length" ] if context.mode == "first_part": context.record_constant_state((k, v, memory_length)) if context.encoder_sequence_id and context.sequence_id: visible = mtf.equal(context.sequence_id, context.encoder_sequence_id) bias = attention.visibility_mask_to_attention_bias( visible, context.activation_dtype) else: bias = None o = attention.attention(q, k, v, memory_length, self.kv_dim, self.kv_dim, bias, **self.attention_kwargs_from_context(context)) return params.compute_output(o, output_shape=x.shape)
def call(self, context, x, losses=None): """Call the layer.""" params = self.make_params(context) q = params.compute_q(x) if self.shared_kv: kv = params.compute_kv(x) k = kv v = kv else: k = params.compute_k(x) v = params.compute_v(x) if context.mode == "incremental": if self.shared_kv: prev_kv, = context.get_states(1) else: prev_k, prev_v = context.get_states(2) current_position = mtf.equal( mtf.range(context.mesh, self.window_dim, dtype=tf.int32), mtf.mod(context.position, self.radius)) if self.shared_kv: kv = mtf.where(current_position, kv, prev_kv, output_shape=prev_kv.shape) k = kv v = kv context.record_new_states([kv]) else: k = mtf.where(current_position, params.compute_k(x), prev_k, output_shape=prev_k.shape) v = mtf.where(current_position, params.compute_v(x), prev_v, output_shape=prev_v.shape) context.record_new_states([k, v]) window_pos = mtf.range(context.mesh, self.window_dim, tf.int32) visible = mtf.greater_equal(context.position, window_pos) bias = attention.visibility_mask_to_attention_bias( visible, context.activation_dtype) o = attention.attention( q, k, v, self.window_dim, self.kv_dim, self.kv_dim, bias, **self.attention_kwargs_from_context(context)) elif context.length_dim.size <= max(256, self.radius * 4): # nothing fancy - just do full attention and mask memory_length = self.rename_length_to_memory_length( context.position, context) o = attention.attention( q, self.rename_length_to_memory_length(k, context), self.rename_length_to_memory_length(v, context), self.memory_length(context), self.kv_dim, self.kv_dim, self.compute_bias(context, memory_length, x), **self.attention_kwargs_from_context(context)) else: # fancy local attention algorithm o = attention.local_attention_1d( q=q, k=k, v=None if self.shared_kv else v, length_dim=context.length_dim, key_dim=self.kv_dim, value_dim=self.kv_dim, length_dim_num_splits=1, # TODO(noam): look at the layout autoregressive=context.model.fully_autoregressive, radius=self.radius, sequence_id=context.sequence_id, write_priority=context.write_priority, read_priority=context.read_priority, attention_kwargs=self.attention_kwargs_from_context(context)) if context.mode == "first_part": window_pos = mtf.range(context.mesh, self.window_dim, tf.int32) pos = mtf.range(context.mesh, context.length_dim, tf.int32) select_recent = mtf.cast( mtf.equal(mtf.mod(pos, self.radius), window_pos), x.dtype) select_recent *= mtf.cast(mtf.less(pos, context.initial_position), x.dtype) select_recent *= mtf.cast( mtf.greater_equal(pos, context.initial_position - self.radius), x.dtype) state_shape = (k.shape - [context.length_dim, self.kv_dim] + [self.window_dim, self.kv_dim]) k_state = mtf.einsum([k, select_recent], output_shape=state_shape, reduced_dims=[context.length_dim]) context.new_states.append(k_state) if not self.shared_kv: v_state = mtf.einsum([v, select_recent], output_shape=state_shape, reduced_dims=[context.length_dim]) context.new_states.append(v_state) return params.compute_output(o, output_shape=x.shape)
def compute_bias(self, context, memory_position, x): """Compute attention bias. Args: context: a transformer.Context memory_position: an int32 tensor containing memory_length dimension. x: a Tensor - the query antecedent - required for relative attention Returns: a Tensor or None """ min_relative_position = self.min_relative_position(context) max_relative_position = self.max_relative_position(context) # we can often cache the result of this function between similar layers can_cache = (self.relative_attention_type is None or self.relative_attention_type == "bias_shared") if can_cache: cache_key = ("self_attention_mask", min_relative_position, max_relative_position, self.relative_attention_type, self.num_heads) if cache_key in context.cache: return context.cache[cache_key] biases = [] relative_position = memory_position - context.position if min_relative_position is not None: visible = mtf.greater_equal(relative_position, min_relative_position) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if max_relative_position is not None: visible = mtf.less_equal(relative_position, max_relative_position) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if context.read_priority is not None: visible = mtf.greater_equal( context.read_priority, mtf.layers.rename_length_to_memory_length( context.write_priority)) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) sequence_id = None # Subsequence id should only be set if we are in the decoder and have # multiple targets per input. This will allow each sub-target to only attend # to itself. if isinstance(context.subsequence_id, mtf.Tensor): sequence_id = context.subsequence_id elif isinstance(context.sequence_id, mtf.Tensor): sequence_id = context.sequence_id if (sequence_id is not None and context.length_dim in sequence_id.shape): visible = mtf.equal( sequence_id, self.rename_length_to_memory_length(sequence_id, context)) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if self.relative_attention_type is not None: buckets_dim = mtf.Dimension("buckets", self.relative_attention_num_buckets) heads_dim = mtf.Dimension("heads", self.num_heads) bidirectional = not context.model.fully_autoregressive rp_bucket = _relative_position_bucket(relative_position, bidirectional=bidirectional, num_buckets=buckets_dim.size) if (self.relative_attention_type == "bias" or self.relative_attention_type == "bias_shared"): values = mtf.get_variable(context.mesh, "relative_attention_bias", [heads_dim, buckets_dim], dtype=context.variable_dtype) elif self.relative_attention_type == "contextual": values = layers.dense(x, [buckets_dim, heads_dim], variable_dtype=context.variable_dtype, name="relative_attention_contextual") else: raise ValueError( "unrecognized relative_attention_type \"%s\"" % self.relative_attention_type) biases.append(mtf.gather(values, rp_bucket, buckets_dim)) ret = mtf.add_n(biases) if biases else None if can_cache: context.cache[cache_key] = ret return ret
def local_attention_1d(q, k, v, length_dim, key_dim, value_dim, fully_autoregressive=True, length_dim_num_splits=1, radius=128, sequence_id=1, write_priority=None, read_priority=None, attention_kwargs=None, context=None): """Attention to the a neighborood around the source. If fully_autoregressive, then query position p can only see memory positions in the range (p - radius, p]. If not fully_autoregressive, then query position p can only see memory positions in the range (p - window_size, p + radius]. In addition, if write_priority and read_priority are provided, then attention is limited to position pairs where read_priority[query position] >= write_priority[memory position] Args: q: a Tensor containing length_dim k: a Tensor containing length_dim v: an optional Tensor containing length_dim. If none then uses v=k. length_dim: a Dimension key_dim: a Dimension (the channels dimension of q and k) value_dim: a Dimension (the channels dimension of v) fully_autoregressive: a boolean length_dim_num_splits: an optional integer indicating how many ways the length dimension is split radius: an integer sequence_id: a Tensor or an integer write_priority: an optional Tensor containing length_dim read_priority: an optional Tensor containing length_dim attention_kwargs: optional keyword arguments for attention() context: optional context. Returns: a Tensor with the shape x.shape - key_dim + value_dim Raises: ValueError: if channels or depth don't match. """ # Choose a suitable block size. # We choose the greatest divisor of length_per_split less than or equal # to max(window_size, 128) tf.logging.info(attention_kwargs) length_per_split = length_dim.size // length_dim_num_splits block_length = max(radius, 128) while length_per_split % block_length != 0: block_length -= 1 query_block_length = mtf.Dimension("query_block_length", block_length) memory_block_length = mtf.Dimension("memory_block_length", block_length) # The num_blocks dimension gets the same name as the length dimension, # so it will be split in the same way. num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length) def _reshape_query(x): return mtf.replace_dimensions(x, length_dim, [num_blocks, query_block_length]) def _reshape_memory(x): x = mtf.replace_dimensions(x, length_dim, [num_blocks, memory_block_length]) return (mtf.left_halo_exchange if fully_autoregressive else mtf.halo_exchange)(x, num_blocks, memory_block_length, radius) q = _reshape_query(q) k = _reshape_memory(k) if v: v = _reshape_memory(v) else: v = k if sequence_id is None: sequence_id = 1 if (not isinstance(sequence_id, mtf.Tensor) or length_dim not in sequence_id.shape.dims): sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32) q_sequence_id = _reshape_query(sequence_id) m_sequence_id = _reshape_memory(sequence_id) pos = mtf.range(q.mesh, length_dim, dtype=tf.int32) q_pos = _reshape_query(pos) m_pos = _reshape_memory(pos) padded_memory_block_length = mtf.Dimension( "memory_block_length", (1 if fully_autoregressive else 2) * radius + block_length) relative_position = m_pos - q_pos visible = mtf.equal(q_sequence_id, m_sequence_id) visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius)) visible = mtf.logical_and( visible, mtf.less_equal(relative_position, 0 if fully_autoregressive else radius)) if read_priority is not None: write_priority = _reshape_memory(write_priority) read_priority = _reshape_query(read_priority) visible = mtf.logical_and( visible, mtf.greater_equal(read_priority, write_priority)) bias = attention.visibility_mask_to_attention_bias(visible, q.dtype) o = attention.attention(q, k, v, padded_memory_block_length, key_dim, value_dim, bias, context=context, **attention_kwargs) return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)