def entmax_forward(x, alpha=1.3, dim=None, n_iter=50): assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2' _gp = lambda x, alpha: x ** (alpha - 1) _gp_inv = lambda x, alpha: mtf.pow(x, (1 / (alpha - 1))) _p = lambda x, alpha: _gp_inv(mtf.relu(x), alpha) dim = x.shape[-1] if dim is None else dim d = dim.size x = x * (alpha - 1) max_val = mtf.reduce_max(x, reduced_dim=dim) tau_lo = max_val - _gp(1, alpha) tau_hi = max_val - _gp(1 / d, alpha) f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1 dm = tau_hi - tau_lo for _ in range(n_iter): dm = dm / 2 tau_m = tau_lo + dm p_m = _p(x - tau_m, alpha) f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1 mask = mtf.greater_equal((f_m * f_lo), 0) tau_lo = mtf.where(mask, tau_m, tau_lo) p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim) return p_m
def cond_fn(position, ids, *unused_states): """Should we run another loop iteration?""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum( mtf.to_int32(mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done)
def ids_to_embedding(self, ids): """Ids to embeddings with ids not in cluster mapped to the zero vector.""" ids -= self._start_token_id # The mtf.gather in the embedding's ids_to_embedding implementation will # cause the one hot representations of tokens greater than cluster vocab # dimension size to be the zero vector. Thus the embeddings for those tokens # will be the zero vector. ids = mtf.where(mtf.greater_equal(ids, 0), ids, self._end_token_id) return self._embedding.ids_to_embedding(ids)
def cond_fn(position, ids, *unused_states): """Should we run another loop iteration.""" past_end = mtf.greater_equal(position, length_dim.size) is_done = past_end if stop_at_token is not None: has_eos = mtf.reduce_any( mtf.equal(ids, stop_at_token), reduced_dim=length_dim) is_done = mtf.logical_or(is_done, has_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done)
def ids_to_embedding(self, ids, context): """Ids to embeddings with ids not in cluster mapped to the zero vector.""" ids -= self._start_token_id # The mtf.gather in the embedding's ids_to_embedding implementation will # cause the one hot representations of tokens greater than cluster vocab # dimension size to be the zero vector. Thus the embeddings for those tokens # will be the zero vector. ids = mtf.where(mtf.greater_equal(ids, 0), ids, self._vocab_dim.size) # Handle the case of the head cluster where we will have entries at the end # corresponding to the tail clusters. ids = mtf.where( mtf.less(ids, self._end_token_id - self._start_token_id), ids, self._vocab_dim.size, ) return self._embedding.ids_to_embedding(ids, context)
def call(self, context, x, losses=None): """Call the layer.""" params = mtf.layers.multihead_attention_params(context.mesh, self.heads_dim, context.model_dim, self.kv_dim, context.variable_dtype) if context.mode == "incremental": prev_k, prev_v = context.get_states(2) y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental( x, prev_k, prev_v, context.position, params=params) context.record_new_states([new_k, new_v]) return y else: kv = [] y = mtf.layers.masked_local_attention_1d(x, self.kv_dim, self.heads_dim, self.window_size, params=params, return_kv=kv) if context.mode == "first_part": k = kv[0] v = kv[1] window_dim = mtf.Dimension("window", self.window_size) mesh = k.mesh window_pos = mtf.range(mesh, window_dim, tf.int32) pos = mtf.range(mesh, context.length_dim, tf.int32) select_recent = mtf.cast( mtf.equal(window_pos, mtf.mod(pos, self.window_size)), k.dtype) select_recent *= mtf.cast( mtf.less(pos, context.initial_position), k.dtype) select_recent *= mtf.cast( mtf.greater_equal( pos, context.initial_position - self.window_size), k.dtype) state_shape = k.shape.dims[:-2] + [window_dim, self.kv_dim] k_state = mtf.einsum([k, select_recent], output_shape=state_shape, reduced_dims=[context.length_dim]) v_state = mtf.einsum([v, select_recent], output_shape=state_shape, reduced_dims=[context.length_dim]) context.new_states.extend([k_state, v_state]) return y
def get_cluster_mask(self, targets): """Computes mask over the targets masking out tokens not in the cluster.""" return mtf.logical_and( mtf.greater_equal(targets, self._start_token_id), mtf.less(targets, self._end_token_id))
def call(self, context, x, losses=None): """Call the layer.""" params = self.make_params(context) q = params.compute_q(x) if self.shared_kv: kv = params.compute_kv(x) k = kv v = kv else: k = params.compute_k(x) v = params.compute_v(x) if context.mode == "incremental": if self.shared_kv: prev_kv, = context.get_states(1) else: prev_k, prev_v = context.get_states(2) current_position = mtf.equal( mtf.range(context.mesh, self.window_dim, dtype=tf.int32), mtf.mod(context.position, self.radius)) if self.shared_kv: kv = mtf.where(current_position, kv, prev_kv, output_shape=prev_kv.shape) k = kv v = kv context.record_new_states([kv]) else: k = mtf.where(current_position, params.compute_k(x), prev_k, output_shape=prev_k.shape) v = mtf.where(current_position, params.compute_v(x), prev_v, output_shape=prev_v.shape) context.record_new_states([k, v]) window_pos = mtf.range(context.mesh, self.window_dim, tf.int32) visible = mtf.greater_equal(context.position, window_pos) bias = attention.visibility_mask_to_attention_bias( visible, context.activation_dtype) o = attention.attention( q, k, v, self.window_dim, self.kv_dim, self.kv_dim, bias, **self.attention_kwargs_from_context(context)) elif context.length_dim.size <= max(256, self.radius * 4): # nothing fancy - just do full attention and mask memory_length = self.rename_length_to_memory_length( context.position, context) o = attention.attention( q, self.rename_length_to_memory_length(k, context), self.rename_length_to_memory_length(v, context), self.memory_length(context), self.kv_dim, self.kv_dim, self.compute_bias(context, memory_length, x), **self.attention_kwargs_from_context(context)) else: # fancy local attention algorithm o = attention.local_attention_1d( q=q, k=k, v=None if self.shared_kv else v, length_dim=context.length_dim, key_dim=self.kv_dim, value_dim=self.kv_dim, length_dim_num_splits=1, # TODO(noam): look at the layout autoregressive=context.model.fully_autoregressive, radius=self.radius, sequence_id=context.sequence_id, write_priority=context.write_priority, read_priority=context.read_priority, attention_kwargs=self.attention_kwargs_from_context(context)) if context.mode == "first_part": window_pos = mtf.range(context.mesh, self.window_dim, tf.int32) pos = mtf.range(context.mesh, context.length_dim, tf.int32) select_recent = mtf.cast( mtf.equal(mtf.mod(pos, self.radius), window_pos), x.dtype) select_recent *= mtf.cast(mtf.less(pos, context.initial_position), x.dtype) select_recent *= mtf.cast( mtf.greater_equal(pos, context.initial_position - self.radius), x.dtype) state_shape = (k.shape - [context.length_dim, self.kv_dim] + [self.window_dim, self.kv_dim]) k_state = mtf.einsum([k, select_recent], output_shape=state_shape, reduced_dims=[context.length_dim]) context.new_states.append(k_state) if not self.shared_kv: v_state = mtf.einsum([v, select_recent], output_shape=state_shape, reduced_dims=[context.length_dim]) context.new_states.append(v_state) return params.compute_output(o, output_shape=x.shape)
def compute_bias(self, context, memory_position, x): """Compute attention bias. Args: context: a transformer.Context memory_position: an int32 tensor containing memory_length dimension. x: a Tensor - the query antecedent - required for relative attention Returns: a Tensor or None """ min_relative_position = self.min_relative_position(context) max_relative_position = self.max_relative_position(context) # we can often cache the result of this function between similar layers can_cache = (self.relative_attention_type is None or self.relative_attention_type == "bias_shared") if can_cache: cache_key = ("self_attention_mask", min_relative_position, max_relative_position, self.relative_attention_type, self.num_heads) if cache_key in context.cache: return context.cache[cache_key] biases = [] relative_position = memory_position - context.position if min_relative_position is not None: visible = mtf.greater_equal(relative_position, min_relative_position) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if max_relative_position is not None: visible = mtf.less_equal(relative_position, max_relative_position) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if context.read_priority is not None: visible = mtf.greater_equal( context.read_priority, mtf.layers.rename_length_to_memory_length( context.write_priority)) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) sequence_id = None # Subsequence id should only be set if we are in the decoder and have # multiple targets per input. This will allow each sub-target to only attend # to itself. if isinstance(context.subsequence_id, mtf.Tensor): sequence_id = context.subsequence_id elif isinstance(context.sequence_id, mtf.Tensor): sequence_id = context.sequence_id if (sequence_id is not None and context.length_dim in sequence_id.shape): visible = mtf.equal( sequence_id, self.rename_length_to_memory_length(sequence_id, context)) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if self.relative_attention_type is not None: buckets_dim = mtf.Dimension("buckets", self.relative_attention_num_buckets) heads_dim = mtf.Dimension("heads", self.num_heads) bidirectional = not context.model.fully_autoregressive rp_bucket = _relative_position_bucket(relative_position, bidirectional=bidirectional, num_buckets=buckets_dim.size) if (self.relative_attention_type == "bias" or self.relative_attention_type == "bias_shared"): values = mtf.get_variable(context.mesh, "relative_attention_bias", [heads_dim, buckets_dim], dtype=context.variable_dtype) elif self.relative_attention_type == "contextual": values = layers.dense(x, [buckets_dim, heads_dim], variable_dtype=context.variable_dtype, name="relative_attention_contextual") else: raise ValueError( "unrecognized relative_attention_type \"%s\"" % self.relative_attention_type) biases.append(mtf.gather(values, rp_bucket, buckets_dim)) ret = mtf.add_n(biases) if biases else None if can_cache: context.cache[cache_key] = ret return ret
def local_attention_1d(q, k, v, length_dim, key_dim, value_dim, fully_autoregressive=True, length_dim_num_splits=1, radius=128, sequence_id=1, write_priority=None, read_priority=None, attention_kwargs=None): """Attention to the a neighborood around the source. If fully_autoregressive, then query position p can only see memory positions in the range (p - radius, p]. If not fully_autoregressive, then query position p can only see memory positions in the range (p - window_size, p + radius]. In addition, if write_priority and read_priority are provided, then attention is limited to position pairs where read_priority[query position] >= write_priority[memory position] Args: q: a Tensor containing length_dim k: a Tensor containing length_dim v: an optional Tensor containing length_dim. If none then uses v=k. length_dim: a Dimension key_dim: a Dimension (the channels dimension of q and k) value_dim: a Dimension (the channels dimension of v) fully_autoregressive: a boolean length_dim_num_splits: an optional integer indicating how many ways the length dimension is split radius: an integer sequence_id: a Tensor or an integer write_priority: an optional Tensor containing length_dim read_priority: an optional Tensor containing length_dim attention_kwargs: optional keyword arguments for attention() Returns: a Tensor with the shape x.shape - key_dim + value_dim Raises: ValueError: if channels or depth don't match. """ # Choose a suitable block size. # We choose the greatest divisor of length_per_split less than or equal # to max(window_size, 128) length_per_split = length_dim.size // length_dim_num_splits block_length = max(radius, 128) while length_per_split % block_length != 0: block_length -= 1 query_block_length = mtf.Dimension("query_block_length", block_length) memory_block_length = mtf.Dimension("memory_block_length", block_length) # The num_blocks dimension gets the same name as the length dimension, # so it will be split in the same way. num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length) def _reshape_query(x): return mtf.replace_dimensions( x, length_dim, [num_blocks, query_block_length]) def _reshape_memory(x): x = mtf.replace_dimensions( x, length_dim, [num_blocks, memory_block_length]) return (mtf.left_halo_exchange if fully_autoregressive else mtf.halo_exchange)( x, num_blocks, memory_block_length, radius) q = _reshape_query(q) k = _reshape_memory(k) if v: v = _reshape_memory(v) else: v = k if sequence_id is None: sequence_id = 1 if (not isinstance(sequence_id, mtf.Tensor) or length_dim not in sequence_id.shape.dims): sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32) q_sequence_id = _reshape_query(sequence_id) m_sequence_id = _reshape_memory(sequence_id) pos = mtf.range(q.mesh, length_dim, dtype=tf.int32) q_pos = _reshape_query(pos) m_pos = _reshape_memory(pos) padded_memory_block_length = mtf.Dimension( "memory_block_length", (1 if fully_autoregressive else 2) * radius + block_length) relative_position = m_pos - q_pos visible = mtf.equal(q_sequence_id, m_sequence_id) visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius)) visible = mtf.logical_and(visible, mtf.less_equal( relative_position, 0 if fully_autoregressive else radius)) if read_priority is not None: write_priority = _reshape_memory(write_priority) read_priority = _reshape_query(read_priority) visible = mtf.logical_and( visible, mtf.greater_equal(read_priority, write_priority)) bias = visibility_mask_to_attention_bias(visible, q.dtype) o = attention(q, k, v, padded_memory_block_length, key_dim, value_dim, bias, **attention_kwargs) return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)