Exemple #1
0
def entmax_forward(x, alpha=1.3, dim=None, n_iter=50):
    assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2'

    _gp = lambda x, alpha: x ** (alpha - 1)
    _gp_inv = lambda x, alpha: mtf.pow(x, (1 / (alpha - 1)))
    _p = lambda x, alpha: _gp_inv(mtf.relu(x), alpha)

    dim = x.shape[-1] if dim is None else dim
    d = dim.size

    x = x * (alpha - 1)

    max_val = mtf.reduce_max(x, reduced_dim=dim)

    tau_lo = max_val - _gp(1, alpha)
    tau_hi = max_val - _gp(1 / d, alpha)

    f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1

    dm = tau_hi - tau_lo

    for _ in range(n_iter):
        dm = dm / 2
        tau_m = tau_lo + dm
        p_m = _p(x - tau_m, alpha)
        f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1

        mask = mtf.greater_equal((f_m * f_lo), 0)
        tau_lo = mtf.where(mask, tau_m, tau_lo)

    p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim)
    return p_m
Exemple #2
0
    def cond_fn(position, ids, *unused_states):
        """Should we run another loop iteration?"""
        past_end = mtf.greater_equal(position, length_dim.size)
        if max_steps:
            past_end = mtf.logical_or(
                past_end, mtf.greater_equal(position - initial_position, max_steps))

        is_done = past_end
        if stop_at_token is not None:
            eos_count = mtf.reduce_sum(
                mtf.to_int32(mtf.equal(ids, stop_at_token)),
                reduced_dim=length_dim)
            has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count)
            is_done = mtf.logical_or(is_done, has_additional_eos)
        all_done = mtf.reduce_all(is_done)
        return mtf.logical_not(all_done)
Exemple #3
0
 def ids_to_embedding(self, ids):
   """Ids to embeddings with ids not in cluster mapped to the zero vector."""
   ids -= self._start_token_id
   # The mtf.gather in the embedding's ids_to_embedding implementation will
   # cause the one hot representations of tokens greater than cluster vocab
   # dimension size to be the zero vector. Thus the embeddings for those tokens
   # will be the zero vector.
   ids = mtf.where(mtf.greater_equal(ids, 0), ids, self._end_token_id)
   return self._embedding.ids_to_embedding(ids)
Exemple #4
0
 def cond_fn(position, ids, *unused_states):
   """Should we run another loop iteration."""
   past_end = mtf.greater_equal(position, length_dim.size)
   is_done = past_end
   if stop_at_token is not None:
     has_eos = mtf.reduce_any(
         mtf.equal(ids, stop_at_token), reduced_dim=length_dim)
     is_done = mtf.logical_or(is_done, has_eos)
   all_done = mtf.reduce_all(is_done)
   return mtf.logical_not(all_done)
 def ids_to_embedding(self, ids, context):
     """Ids to embeddings with ids not in cluster mapped to the zero vector."""
     ids -= self._start_token_id
     # The mtf.gather in the embedding's ids_to_embedding implementation will
     # cause the one hot representations of tokens greater than cluster vocab
     # dimension size to be the zero vector. Thus the embeddings for those tokens
     # will be the zero vector.
     ids = mtf.where(mtf.greater_equal(ids, 0), ids, self._vocab_dim.size)
     # Handle the case of the head cluster where we will have entries at the end
     # corresponding to the tail clusters.
     ids = mtf.where(
         mtf.less(ids, self._end_token_id - self._start_token_id),
         ids,
         self._vocab_dim.size,
     )
     return self._embedding.ids_to_embedding(ids, context)
Exemple #6
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     params = mtf.layers.multihead_attention_params(context.mesh,
                                                    self.heads_dim,
                                                    context.model_dim,
                                                    self.kv_dim,
                                                    context.variable_dtype)
     if context.mode == "incremental":
         prev_k, prev_v = context.get_states(2)
         y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
             x, prev_k, prev_v, context.position, params=params)
         context.record_new_states([new_k, new_v])
         return y
     else:
         kv = []
         y = mtf.layers.masked_local_attention_1d(x,
                                                  self.kv_dim,
                                                  self.heads_dim,
                                                  self.window_size,
                                                  params=params,
                                                  return_kv=kv)
         if context.mode == "first_part":
             k = kv[0]
             v = kv[1]
             window_dim = mtf.Dimension("window", self.window_size)
             mesh = k.mesh
             window_pos = mtf.range(mesh, window_dim, tf.int32)
             pos = mtf.range(mesh, context.length_dim, tf.int32)
             select_recent = mtf.cast(
                 mtf.equal(window_pos, mtf.mod(pos, self.window_size)),
                 k.dtype)
             select_recent *= mtf.cast(
                 mtf.less(pos, context.initial_position), k.dtype)
             select_recent *= mtf.cast(
                 mtf.greater_equal(
                     pos, context.initial_position - self.window_size),
                 k.dtype)
             state_shape = k.shape.dims[:-2] + [window_dim, self.kv_dim]
             k_state = mtf.einsum([k, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             v_state = mtf.einsum([v, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             context.new_states.extend([k_state, v_state])
         return y
 def get_cluster_mask(self, targets):
     """Computes mask over the targets masking out tokens not in the cluster."""
     return mtf.logical_and(
         mtf.greater_equal(targets, self._start_token_id),
         mtf.less(targets, self._end_token_id))
Exemple #8
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     params = self.make_params(context)
     q = params.compute_q(x)
     if self.shared_kv:
         kv = params.compute_kv(x)
         k = kv
         v = kv
     else:
         k = params.compute_k(x)
         v = params.compute_v(x)
     if context.mode == "incremental":
         if self.shared_kv:
             prev_kv, = context.get_states(1)
         else:
             prev_k, prev_v = context.get_states(2)
         current_position = mtf.equal(
             mtf.range(context.mesh, self.window_dim, dtype=tf.int32),
             mtf.mod(context.position, self.radius))
         if self.shared_kv:
             kv = mtf.where(current_position,
                            kv,
                            prev_kv,
                            output_shape=prev_kv.shape)
             k = kv
             v = kv
             context.record_new_states([kv])
         else:
             k = mtf.where(current_position,
                           params.compute_k(x),
                           prev_k,
                           output_shape=prev_k.shape)
             v = mtf.where(current_position,
                           params.compute_v(x),
                           prev_v,
                           output_shape=prev_v.shape)
             context.record_new_states([k, v])
         window_pos = mtf.range(context.mesh, self.window_dim, tf.int32)
         visible = mtf.greater_equal(context.position, window_pos)
         bias = attention.visibility_mask_to_attention_bias(
             visible, context.activation_dtype)
         o = attention.attention(
             q, k, v, self.window_dim, self.kv_dim, self.kv_dim, bias,
             **self.attention_kwargs_from_context(context))
     elif context.length_dim.size <= max(256, self.radius * 4):
         # nothing fancy - just do full attention and mask
         memory_length = self.rename_length_to_memory_length(
             context.position, context)
         o = attention.attention(
             q, self.rename_length_to_memory_length(k, context),
             self.rename_length_to_memory_length(v, context),
             self.memory_length(context), self.kv_dim, self.kv_dim,
             self.compute_bias(context, memory_length, x),
             **self.attention_kwargs_from_context(context))
     else:
         # fancy local attention algorithm
         o = attention.local_attention_1d(
             q=q,
             k=k,
             v=None if self.shared_kv else v,
             length_dim=context.length_dim,
             key_dim=self.kv_dim,
             value_dim=self.kv_dim,
             length_dim_num_splits=1,  # TODO(noam): look at the layout
             autoregressive=context.model.fully_autoregressive,
             radius=self.radius,
             sequence_id=context.sequence_id,
             write_priority=context.write_priority,
             read_priority=context.read_priority,
             attention_kwargs=self.attention_kwargs_from_context(context))
     if context.mode == "first_part":
         window_pos = mtf.range(context.mesh, self.window_dim, tf.int32)
         pos = mtf.range(context.mesh, context.length_dim, tf.int32)
         select_recent = mtf.cast(
             mtf.equal(mtf.mod(pos, self.radius), window_pos), x.dtype)
         select_recent *= mtf.cast(mtf.less(pos, context.initial_position),
                                   x.dtype)
         select_recent *= mtf.cast(
             mtf.greater_equal(pos, context.initial_position - self.radius),
             x.dtype)
         state_shape = (k.shape - [context.length_dim, self.kv_dim] +
                        [self.window_dim, self.kv_dim])
         k_state = mtf.einsum([k, select_recent],
                              output_shape=state_shape,
                              reduced_dims=[context.length_dim])
         context.new_states.append(k_state)
         if not self.shared_kv:
             v_state = mtf.einsum([v, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             context.new_states.append(v_state)
     return params.compute_output(o, output_shape=x.shape)
Exemple #9
0
    def compute_bias(self, context, memory_position, x):
        """Compute attention bias.

    Args:
      context: a transformer.Context
      memory_position: an int32 tensor containing memory_length dimension.
      x: a Tensor - the query antecedent - required for relative attention
    Returns:
      a Tensor or None
    """
        min_relative_position = self.min_relative_position(context)
        max_relative_position = self.max_relative_position(context)
        # we can often cache the result of this function between similar layers
        can_cache = (self.relative_attention_type is None
                     or self.relative_attention_type == "bias_shared")
        if can_cache:
            cache_key = ("self_attention_mask", min_relative_position,
                         max_relative_position, self.relative_attention_type,
                         self.num_heads)
            if cache_key in context.cache:
                return context.cache[cache_key]
        biases = []
        relative_position = memory_position - context.position
        if min_relative_position is not None:
            visible = mtf.greater_equal(relative_position,
                                        min_relative_position)
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if max_relative_position is not None:
            visible = mtf.less_equal(relative_position, max_relative_position)
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if context.read_priority is not None:
            visible = mtf.greater_equal(
                context.read_priority,
                mtf.layers.rename_length_to_memory_length(
                    context.write_priority))
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))

        sequence_id = None
        # Subsequence id should only be set if we are in the decoder and have
        # multiple targets per input. This will allow each sub-target to only attend
        # to itself.
        if isinstance(context.subsequence_id, mtf.Tensor):
            sequence_id = context.subsequence_id
        elif isinstance(context.sequence_id, mtf.Tensor):
            sequence_id = context.sequence_id
        if (sequence_id is not None
                and context.length_dim in sequence_id.shape):
            visible = mtf.equal(
                sequence_id,
                self.rename_length_to_memory_length(sequence_id, context))
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if self.relative_attention_type is not None:
            buckets_dim = mtf.Dimension("buckets",
                                        self.relative_attention_num_buckets)
            heads_dim = mtf.Dimension("heads", self.num_heads)
            bidirectional = not context.model.fully_autoregressive
            rp_bucket = _relative_position_bucket(relative_position,
                                                  bidirectional=bidirectional,
                                                  num_buckets=buckets_dim.size)
            if (self.relative_attention_type == "bias"
                    or self.relative_attention_type == "bias_shared"):
                values = mtf.get_variable(context.mesh,
                                          "relative_attention_bias",
                                          [heads_dim, buckets_dim],
                                          dtype=context.variable_dtype)
            elif self.relative_attention_type == "contextual":
                values = layers.dense(x, [buckets_dim, heads_dim],
                                      variable_dtype=context.variable_dtype,
                                      name="relative_attention_contextual")
            else:
                raise ValueError(
                    "unrecognized relative_attention_type \"%s\"" %
                    self.relative_attention_type)
            biases.append(mtf.gather(values, rp_bucket, buckets_dim))
        ret = mtf.add_n(biases) if biases else None
        if can_cache:
            context.cache[cache_key] = ret
        return ret
def local_attention_1d(q,
                       k,
                       v,
                       length_dim,
                       key_dim,
                       value_dim,
                       fully_autoregressive=True,
                       length_dim_num_splits=1,
                       radius=128,
                       sequence_id=1,
                       write_priority=None,
                       read_priority=None,
                       attention_kwargs=None):
  """Attention to the a neighborood around the source.

  If fully_autoregressive, then query position p can only see memory positions
  in the range (p - radius, p].

  If not fully_autoregressive, then query position p can only see memory
  positions in the range (p - window_size, p + radius].

  In addition, if write_priority and read_priority are provided, then attention
  is limited to position pairs where
  read_priority[query position] >= write_priority[memory position]

  Args:
    q: a Tensor containing length_dim
    k: a Tensor containing length_dim
    v: an optional Tensor containing length_dim.  If none then uses v=k.
    length_dim: a Dimension
    key_dim: a Dimension (the channels dimension of q and k)
    value_dim: a Dimension (the channels dimension of v)
    fully_autoregressive: a boolean
    length_dim_num_splits: an optional integer indicating how many ways the
      length dimension is split
    radius: an integer
    sequence_id: a Tensor or an integer
    write_priority: an optional Tensor containing length_dim
    read_priority: an optional Tensor containing length_dim
    attention_kwargs: optional keyword arguments for attention()

  Returns:
    a Tensor with the shape x.shape - key_dim + value_dim

  Raises:
    ValueError: if channels or depth don't match.
  """
  # Choose a suitable block size.
  # We choose the greatest divisor of length_per_split less than or equal
  # to max(window_size, 128)
  length_per_split = length_dim.size // length_dim_num_splits
  block_length = max(radius, 128)
  while length_per_split % block_length != 0:
    block_length -= 1
  query_block_length = mtf.Dimension("query_block_length", block_length)
  memory_block_length = mtf.Dimension("memory_block_length", block_length)
  # The num_blocks dimension gets the same name as the length dimension,
  # so it will be split in the same way.
  num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length)
  def _reshape_query(x):
    return mtf.replace_dimensions(
        x, length_dim, [num_blocks, query_block_length])
  def _reshape_memory(x):
    x = mtf.replace_dimensions(
        x, length_dim, [num_blocks, memory_block_length])
    return (mtf.left_halo_exchange if fully_autoregressive
            else mtf.halo_exchange)(
                x, num_blocks, memory_block_length, radius)
  q = _reshape_query(q)
  k = _reshape_memory(k)
  if v:
    v = _reshape_memory(v)
  else:
    v = k
  if sequence_id is None:
    sequence_id = 1
  if (not isinstance(sequence_id, mtf.Tensor) or
      length_dim not in sequence_id.shape.dims):
    sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32)
  q_sequence_id = _reshape_query(sequence_id)
  m_sequence_id = _reshape_memory(sequence_id)
  pos = mtf.range(q.mesh, length_dim, dtype=tf.int32)
  q_pos = _reshape_query(pos)
  m_pos = _reshape_memory(pos)

  padded_memory_block_length = mtf.Dimension(
      "memory_block_length",
      (1 if fully_autoregressive else 2) * radius + block_length)

  relative_position = m_pos - q_pos
  visible = mtf.equal(q_sequence_id, m_sequence_id)
  visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius))
  visible = mtf.logical_and(visible, mtf.less_equal(
      relative_position, 0 if fully_autoregressive else radius))
  if read_priority is not None:
    write_priority = _reshape_memory(write_priority)
    read_priority = _reshape_query(read_priority)
    visible = mtf.logical_and(
        visible, mtf.greater_equal(read_priority, write_priority))

  bias = visibility_mask_to_attention_bias(visible, q.dtype)
  o = attention(q, k, v, padded_memory_block_length,
                key_dim, value_dim, bias, **attention_kwargs)
  return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)