예제 #1
0
    def compute_bias(self, context, memory_position, x):
        """Compute attention bias.

    Args:
      context: a transformer.Context
      memory_position: an int32 tensor containing memory_length dimension.
      x: a Tensor - the query antecedent - required for relative attention
    Returns:
      a Tensor or None
    """
        min_relative_position = self.min_relative_position(context)
        max_relative_position = self.max_relative_position(context)
        # we can often cache the result of this function between similar layers
        can_cache = (self.relative_attention_type is None
                     or self.relative_attention_type == "bias_shared")
        if can_cache:
            cache_key = ("self_attention_mask", min_relative_position,
                         max_relative_position, self.relative_attention_type,
                         self.num_heads)
            if cache_key in context.cache:
                return context.cache[cache_key]
        biases = []
        relative_position = memory_position - context.position
        if min_relative_position is not None:
            visible = mtf.greater_equal(relative_position,
                                        min_relative_position)
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if max_relative_position is not None:
            visible = mtf.less_equal(relative_position, max_relative_position)
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if context.read_priority is not None:
            visible = mtf.greater_equal(
                context.read_priority,
                mtf.layers.rename_length_to_memory_length(
                    context.write_priority))
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))

        sequence_id = None
        # Subsequence id should only be set if we are in the decoder and have
        # multiple targets per input. This will allow each sub-target to only attend
        # to itself.
        if isinstance(context.subsequence_id, mtf.Tensor):
            sequence_id = context.subsequence_id
        elif isinstance(context.sequence_id, mtf.Tensor):
            sequence_id = context.sequence_id
        if (sequence_id is not None
                and context.length_dim in sequence_id.shape):
            visible = mtf.equal(
                sequence_id,
                self.rename_length_to_memory_length(sequence_id, context))
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if self.relative_attention_type is not None:
            buckets_dim = mtf.Dimension("buckets",
                                        self.relative_attention_num_buckets)
            heads_dim = mtf.Dimension("heads", self.num_heads)
            bidirectional = not context.model.fully_autoregressive
            rp_bucket = _relative_position_bucket(relative_position,
                                                  bidirectional=bidirectional,
                                                  num_buckets=buckets_dim.size)
            if (self.relative_attention_type == "bias"
                    or self.relative_attention_type == "bias_shared"):
                values = mtf.get_variable(context.mesh,
                                          "relative_attention_bias",
                                          [heads_dim, buckets_dim],
                                          dtype=context.variable_dtype)
            elif self.relative_attention_type == "contextual":
                values = layers.dense(x, [buckets_dim, heads_dim],
                                      variable_dtype=context.variable_dtype,
                                      name="relative_attention_contextual")
            else:
                raise ValueError(
                    "unrecognized relative_attention_type \"%s\"" %
                    self.relative_attention_type)
            biases.append(mtf.gather(values, rp_bucket, buckets_dim))
        ret = mtf.add_n(biases) if biases else None
        if can_cache:
            context.cache[cache_key] = ret
        return ret
예제 #2
0
def local_attention_1d(q,
                       k,
                       v,
                       length_dim,
                       key_dim,
                       value_dim,
                       fully_autoregressive=True,
                       length_dim_num_splits=1,
                       radius=128,
                       sequence_id=1,
                       write_priority=None,
                       read_priority=None,
                       attention_kwargs=None):
  """Attention to the a neighborood around the source.

  If fully_autoregressive, then query position p can only see memory positions
  in the range (p - radius, p].

  If not fully_autoregressive, then query position p can only see memory
  positions in the range (p - window_size, p + radius].

  In addition, if write_priority and read_priority are provided, then attention
  is limited to position pairs where
  read_priority[query position] >= write_priority[memory position]

  Args:
    q: a Tensor containing length_dim
    k: a Tensor containing length_dim
    v: an optional Tensor containing length_dim.  If none then uses v=k.
    length_dim: a Dimension
    key_dim: a Dimension (the channels dimension of q and k)
    value_dim: a Dimension (the channels dimension of v)
    fully_autoregressive: a boolean
    length_dim_num_splits: an optional integer indicating how many ways the
      length dimension is split
    radius: an integer
    sequence_id: a Tensor or an integer
    write_priority: an optional Tensor containing length_dim
    read_priority: an optional Tensor containing length_dim
    attention_kwargs: optional keyword arguments for attention()

  Returns:
    a Tensor with the shape x.shape - key_dim + value_dim

  Raises:
    ValueError: if channels or depth don't match.
  """
  # Choose a suitable block size.
  # We choose the greatest divisor of length_per_split less than or equal
  # to max(window_size, 128)
  length_per_split = length_dim.size // length_dim_num_splits
  block_length = max(radius, 128)
  while length_per_split % block_length != 0:
    block_length -= 1
  query_block_length = mtf.Dimension("query_block_length", block_length)
  memory_block_length = mtf.Dimension("memory_block_length", block_length)
  # The num_blocks dimension gets the same name as the length dimension,
  # so it will be split in the same way.
  num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length)
  def _reshape_query(x):
    return mtf.replace_dimensions(
        x, length_dim, [num_blocks, query_block_length])
  def _reshape_memory(x):
    x = mtf.replace_dimensions(
        x, length_dim, [num_blocks, memory_block_length])
    return (mtf.left_halo_exchange if fully_autoregressive
            else mtf.halo_exchange)(
                x, num_blocks, memory_block_length, radius)
  q = _reshape_query(q)
  k = _reshape_memory(k)
  if v:
    v = _reshape_memory(v)
  else:
    v = k
  if sequence_id is None:
    sequence_id = 1
  if (not isinstance(sequence_id, mtf.Tensor) or
      length_dim not in sequence_id.shape.dims):
    sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32)
  q_sequence_id = _reshape_query(sequence_id)
  m_sequence_id = _reshape_memory(sequence_id)
  pos = mtf.range(q.mesh, length_dim, dtype=tf.int32)
  q_pos = _reshape_query(pos)
  m_pos = _reshape_memory(pos)

  padded_memory_block_length = mtf.Dimension(
      "memory_block_length",
      (1 if fully_autoregressive else 2) * radius + block_length)

  relative_position = m_pos - q_pos
  visible = mtf.equal(q_sequence_id, m_sequence_id)
  visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius))
  visible = mtf.logical_and(visible, mtf.less_equal(
      relative_position, 0 if fully_autoregressive else radius))
  if read_priority is not None:
    write_priority = _reshape_memory(write_priority)
    read_priority = _reshape_query(read_priority)
    visible = mtf.logical_and(
        visible, mtf.greater_equal(read_priority, write_priority))

  bias = visibility_mask_to_attention_bias(visible, q.dtype)
  o = attention(q, k, v, padded_memory_block_length,
                key_dim, value_dim, bias, **attention_kwargs)
  return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)
        def ut_function(state, step, halting_probability, remainders,
                        n_updates, previous_state):
            """implements act (position-wise halting).

      Args:
        state: 3-D Tensor: [batch_size, length, channel]
        step: indicates number of steps taken so far
        halting_probability: halting probability
        remainders: act remainders
        n_updates: act n_updates
        previous_state: previous state

      Returns:
        transformed_state: transformed state
        step: step+1
        halting_probability: halting probability
        remainders: act remainders
        n_updates: act n_updates
        new_state: new state
      """
            state = self.step_preprocess(context, state, step)

            if self.act_type == "random":
                # random as halting probability
                p = mtf.random_uniform(context.mesh,
                                       shape=halting_probability.shape.dims,
                                       dtype=context.variable_dtype)
            else:
                last_dim_name = state.shape.dimension_names[-1]
                new_dims = [mtf.Dimension(last_dim_name, 1)]
                with tf.variable_scope("sigmoid_activation_for_pondering",
                                       reuse=tf.AUTO_REUSE):
                    p = mtf.layers.dense(state,
                                         variable_dtype=context.variable_dtype,
                                         reduced_dims=[state.shape.dims[-1]],
                                         new_dims=new_dims,
                                         activation=mtf.sigmoid,
                                         use_bias=True)
                    if self.act_type == "global":
                        # average over all positions (as a global halting prob)
                        p = mtf.reduce_mean(p, reduced_dim=p.shape.dims[1])
                        p = mtf.squeeze(p)
                    else:
                        # maintain position-wise probabilities
                        new_shape = p.shape.dims[:-1]
                        p = mtf.reshape(p, new_shape)
            # Mask for inputs which have not halted yet
            still_running = mtf.cast(mtf.less(halting_probability, 1.0),
                                     context.activation_dtype)

            # Mask of inputs which halted at this step
            new_halted = mtf.cast(
                mtf.greater(halting_probability + p * still_running,
                            threshold),
                context.activation_dtype) * still_running
            # Mask of inputs which haven't halted, and didn't halt this step
            still_running = mtf.cast(
                mtf.less_equal(halting_probability + p * still_running,
                               threshold),
                context.activation_dtype) * still_running

            # Add the halting probability for this step to the halting
            # probabilities for those input which haven't halted yet
            halting_probability += p * still_running

            # Compute remainders for the inputs which halted at this step
            remainders += new_halted * (1 - halting_probability)

            # Add the remainders to those inputs which halted at this step
            halting_probability += new_halted * remainders

            # Increment n_updates for all inputs which are still running
            n_updates += still_running + new_halted

            # Compute the weight to be applied to the new state and output
            # 0 when the input has already halted
            # p when the input hasn't halted yet
            # the remainders when it halted this step
            input_tensor = p * still_running + new_halted * remainders
            update_weights = input_tensor

            # apply transformation on the state
            transformed_state = state

            for _ in range(self.num_inrecurrence_layers):
                transformed_state = self.vanilla_transformer_layer(
                    context, transformed_state, mask)

            # update running part in the weighted state and keep the rest
            new_state = ((transformed_state * update_weights) +
                         (previous_state * (1 - update_weights)))

            if self.act_type == "accumulated":
                # Add in the weighted state
                new_state = (transformed_state *
                             update_weights) + previous_state

            step += 1

            return (transformed_state, step, halting_probability, remainders,
                    n_updates, new_state)
예제 #4
0
    def body_fn(position, ids, *states):
        """One step in the decode loop."""
        nonlocal sampling_keep_top_k

        context = mtf_transformer.transformer.Context(
            model=None,
            mesh=inputs.mesh,
            batch_dims=batch_dims,
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="incremental",
            position=position,
            position_is_default=True,
            states=states,
            new_states=[],
            initial_position=position,
            sequence_id=None,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=ids,
            encoder_inputs=encoder_inputs) if not slow_sampling else None

        with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE):
            logits, _, _ = gpt2.model({"inputs": ids},
                                      other_features,
                                      params,
                                      inputs.mesh,
                                      variable_dtype=variable_dtype,
                                      context=context)

        # By default, do top_k sampling of 0.9
        if sampling_keep_top_k == -2:
            sampling_keep_top_k = int(logits.shape[-1].size * 0.1)

        if sampling_keep_top_k != -1:
            if sampling_keep_top_k <= 0:
                raise ValueError(
                    "sampling_keep_top_k must either be -1 or positive.")
            k_largest = mtf.nth_largest_element(
                logits,
                n=sampling_keep_top_k,
                reduced_dim=other_features["vocab_dim"])
            logits = mtf.where(mtf.less_equal(logits, k_largest),
                               mtf.ones_like(logits) * -1e6, logits)

        ids_this_step = mtf.sample_with_temperature(
            logits, other_features["vocab_dim"], temperature)

        if slow_sampling:
            ids_this_step = mtf.shift(ids_this_step,
                                      offset=1,
                                      dim=length_dim,
                                      wrap=False)
        else:
            ids_this_step = mtf.reshape(ids_this_step, (batch_dims))

        one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32)
        one_new_id = ids_this_step * one_hot
        new_ids = (1 - one_hot) * ids + one_new_id
        new_position = position + 1

        ret = [new_position, new_ids]
        if context is not None:
            ret += context.new_states
        return ret
예제 #5
0
def _switch_gating(inputs,
                   outer_expert_dims,
                   experts_dim,
                   expert_capacity_dim,
                   hparams,
                   train,
                   variable_dtype,
                   importance=None,
                   name="switch_gating",
                   num_microbatches=None):
  """Compute a switch top-1 gating with no-token-left behind behavior."""
  # SELECT EXPERT
  if train:
    policy = hparams.moe_rand_1_policy_train
  else:
    policy = hparams.moe_rand_1_policy_eval

  # Input perturbations
  if train and policy == "input_jitter":
    inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_rand_1_jitter)

  gate_logits = mtf.layers.dense(
      inputs,
      experts_dim,
      use_bias=False,
      expert_dims=outer_expert_dims,
      variable_dtype=variable_dtype,
      name=name)
  raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)

  # The internals of this function run in float32.
  #   bfloat16 seems to reduce quality.
  raw_gates = mtf.to_float(raw_gates)

  # Top-k operation
  k_dim = mtf.Dimension("k", hparams.moe_switch_top_k)
  expert_gate, expert_index = mtf.top_k(
      raw_gates, reduced_dim=experts_dim, k_dim=k_dim)
  expert_mask = mtf.one_hot(expert_index, experts_dim)

  # LOAD BALANCING LOSS
  outer_batch_dim = inputs.shape[0]
  batch_dim = inputs.shape[1]
  group_size_dim = inputs.shape[-2]
  density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
  density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
  if importance is not None:
    expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    density_1_proxy *= mtf.cast(
        mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
  loss = (
      mtf.reduce_mean(density_1_proxy * density_1) *
      float(experts_dim.size * experts_dim.size))
  if num_microbatches and num_microbatches > 1:
    tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
        num_microbatches))
    loss /= num_microbatches

  # Logging
  if train:
    entropy = mtf.reduce_sum(
        -raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim)
    batch_entropy = mtf.reduce_mean(entropy)
    mtf.scalar_summary(name + "/entropy", batch_entropy)

    mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
    total_routed = mtf.reduce_sum(mask_count_experts)
    expert_fraction = mtf.to_float(mask_count_experts / total_routed)
    split_fractions = mtf.split(
        expert_fraction,
        split_dim=experts_dim,
        num_or_size_splits=experts_dim.size)
    for fraction in split_fractions:
      mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
                         mtf.reduce_mean(fraction))
    mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))

  # COMPUTE ASSIGNMENT TO EXPERT
  # Iteratively route tokens (no-token-left-behind). The idea is to route as
  # many tokens as possible to top-i before then trying top-(i+1).
  top_k_masks = mtf.split(
      expert_mask, split_dim=k_dim, num_or_size_splits=k_dim.size)
  top_k_gates = mtf.split(
      expert_gate, split_dim=k_dim, num_or_size_splits=k_dim.size)
  top_k_indices = mtf.split(
      expert_index, split_dim=k_dim, num_or_size_splits=k_dim.size)

  # Tensors cumulative values over the iterative process.
  combine_tensor = mtf.constant(
      inputs.mesh,
      value=0,
      shape=[outer_batch_dim, batch_dim, experts_dim, expert_capacity_dim])
  cum_tokens = mtf.constant(
      inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim])
  tokens_left_to_route = mtf.constant(
      inputs.mesh, value=1., shape=[outer_batch_dim, batch_dim, group_size_dim])

  expert_capacity_float = float(expert_capacity_dim.size)
  for (top_i_mask, top_i_gate, top_i_index) in zip(top_k_masks, top_k_gates,
                                                   top_k_indices):
    top_i_mask = mtf.reshape(
        top_i_mask,
        new_shape=[outer_batch_dim, batch_dim, group_size_dim, experts_dim])
    # Operate only on the unrouted tokens.
    top_i_mask *= tokens_left_to_route

    # Record cumulative number of tokens to each expert across iterations.
    cumulative_tokens_in_expert = cum_tokens + mtf.cumsum(
        top_i_mask, group_size_dim)

    expert_overflow = mtf.to_float(
        mtf.less_equal(cumulative_tokens_in_expert, expert_capacity_float))
    output_i_tokens = top_i_mask * expert_overflow

    # Update the cumulative tokens routed to each expert.
    cum_tokens += mtf.reduce_sum(output_i_tokens, reduced_dim=group_size_dim)
    tokens_left_to_route -= (
        mtf.reduce_sum(output_i_tokens, reduced_dim=experts_dim))

    # Combine-tensor for this iteration
    output_i_tokens_flat = mtf.reduce_sum(
        output_i_tokens, reduced_dim=experts_dim)
    position_in_expert = cumulative_tokens_in_expert - 1
    top_i_combine_tensor = (
        top_i_gate * output_i_tokens_flat *
        mtf.one_hot(top_i_index, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert), expert_capacity_dim))
    combine_tensor += top_i_combine_tensor

  # Match the inputs dtype.
  combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
  loss = mtf.cast(loss, inputs.dtype)
  dispatch_tensor = mtf.cast(
      mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)

  return dispatch_tensor, combine_tensor, loss
        def body_fn(position, ids, *states):
            """One step in the decode loop."""
            inputs_this_step = mtf.gather(ids, position - 1, length_dim)
            if self.attribute_embedding:
                attributes_this_step = mtf.gather(attributes, position - 1,
                                                  length_dim)
            else:
                attributes_this_step = None
            # raise ValueError("inputs_this_step shape=%s , ids shape=%s, position - 1 shape=%s, length_dim=%s" % (inputs_this_step.shape, ids.shape, (position - 1).shape, length_dim))
            context_incremental = Context(
                model=self,
                mesh=inputs.mesh,
                batch_dims=batch_dims,
                length_dim=length_dim,
                variable_dtype=variable_dtype,
                mode="incremental",
                position=position,
                states=states,
                new_states=[],
                sequence_id=sequence_id,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                constant_states=constant_states,
                shared_params=shared_params,
                encoder_layer_outputs=encoder_layer_outputs,
                write_priority=write_priority,
                read_priority=position,
                inputs=inputs_this_step,
                encoder_inputs=encoder_inputs)

            with tf.variable_scope(self.name, reuse=True):
                logits = self._call_internal(context_incremental,
                                             inputs_this_step,
                                             attributes=attributes_this_step,
                                             z=z)
                if never_end:
                    logits += mtf.one_hot(mtf.constant(logits.mesh,
                                                       stop_at_token,
                                                       dtype=tf.int32),
                                          self.output_vocab_dim,
                                          on_value=-1e9,
                                          off_value=0.0,
                                          dtype=logits.dtype)

            # TBD whether this should be before or after never_end:
            # Note for adding top_p sampling in the future, in other code bases, the
            # option to apply temperature is done before the top-k truncation. This
            # implementation does this in the opposite order. For top-k this doesn't
            # matter, but for top_p it will.
            if sampling_keep_top_k != -1:
                if sampling_keep_top_k <= 0:
                    raise ValueError(
                        "sampling_keep_top_k must either be -1 or positive.")
                k_largest = mtf.nth_largest_element(
                    logits,
                    n=sampling_keep_top_k,
                    reduced_dim=self.output_vocab_dim)
                logits = mtf.where(mtf.less_equal(logits, k_largest),
                                   mtf.ones_like(logits) * -1e6, logits)

            ids_this_step = mtf.sample_with_temperature(
                logits, self.output_vocab_dim, temperature)
            new_position = position + 1
            new_ids = ids + ids_this_step * mtf.one_hot(
                position, length_dim, dtype=tf.int32)
            return [new_position, new_ids] + context_incremental.new_states