예제 #1
0
파일: utils.py 프로젝트: doinker/GPTNeo
def entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, output_grads, alpha = 1.3, dim = None, n_iter = 50):
    x, = explicit_inputs
    y, = outputs
    dY, = output_grads

    gppr = mtf.where(mtf.greater(y, 0), mtf.pow(y, (2 - alpha)), mtf.zeros_like(y))
    dX = dY * gppr

    q = mtf.reduce_sum(dX, reduced_dim = dim) / mtf.reduce_sum(gppr, reduced_dim = dim)
    dX = dX - q * gppr

    return dX,
예제 #2
0
  def beam_search(self,
                  inputs,
                  decode_length,
                  variable_dtype=mtf.VariableDType(tf.float32),
                  encoder_output=None,
                  encoder_sequence_id=None,
                  alpha=0.6,
                  shared_params=None,
                  encoder_layer_outputs=None):
    """Beam search.

    Args:
      inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim,
        length_dim].
      decode_length: an int32 mtf scalar.  Maximum decode length.
      variable_dtype: a mtf.VariableDType
      encoder_output: an optional Tensor
      encoder_sequence_id: an optional Tensor
      alpha: a floating point value (length bonus)
      shared_params: an optional dictionary
      encoder_layer_outputs: optional - readonly list of tensor activations when
        decoding, one per each input layer + the embedding layer

    Returns:
      a Tensor with shape [<batch_dims>, beam_dim, length_dim]
    """
    if not self.autoregressive:
      raise ValueError("must be autoregressive")

    batch_dims = inputs.shape.dims[:-2]
    if len(batch_dims) != 1:
      raise NotImplementedError(
          "beam search supports exactly one batch dimension.")
    beam_dim = inputs.shape.dims[-2]
    length_dim = inputs.shape.dims[-1]
    initial_position = mtf.reduce_sum(
        mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim)
    sequence_id = 1 if encoder_sequence_id is not None else None

    context_first_part = Context(
        mesh=inputs.mesh,
        batch_dims=batch_dims + [beam_dim],
        length_dim=length_dim,
        model_dim=self.model_dim,
        variable_dtype=variable_dtype,
        mode="first_part",
        autoregressive=self.autoregressive,
        new_states=[],
        initial_position=initial_position,
        sequence_id=sequence_id,
        encoder_output=encoder_output,
        encoder_sequence_id=encoder_sequence_id,
        constant_states=[],
        shared_params=shared_params,
        layout=self.layout,
        mesh_shape=self.mesh_shape,
        encoder_layer_outputs=encoder_layer_outputs)

    shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False)
    with tf.variable_scope(self.name):
      logits = self._call_internal(context_first_part, shifted_inputs)
    del logits
    # There are no partial targets.
    # Replace initial states by zeros to avoid computing them.
    initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states]
    constant_states = context_first_part.constant_states

    def logits_fn(step_num, ids, states):
      """logits_fn for mtf.beam_search.beam_search()."""
      context_incremental = Context(
          mesh=inputs.mesh,
          batch_dims=batch_dims + [beam_dim],
          length_dim=length_dim,
          model_dim=self.model_dim,
          variable_dtype=variable_dtype,
          mode="incremental",
          autoregressive=self.autoregressive,
          position=step_num,
          states=states,
          new_states=[],
          sequence_id=sequence_id,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          constant_states=constant_states,
          shared_params=shared_params,
          layout=self.layout,
          mesh_shape=self.mesh_shape,
          encoder_layer_outputs=encoder_layer_outputs)
      inputs_this_step = mtf.gather(ids, step_num - 1, length_dim)
      with tf.variable_scope(self.name, reuse=True):
        logits = self._call_internal(context_incremental, inputs_this_step)
      return mtf.to_float(logits), context_incremental.new_states

    beams, unused_scores = mtf.beam_search.beam_search(
        logits_fn,
        inputs,
        alpha,
        states=initial_states,
        decode_length=decode_length,
        use_tpu=True,
        dtype=tf.float32,
        mesh_shape=self.mesh_shape,
        layout=self.layout)
    return mtf.gather(
        beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
예제 #3
0
  def sample_autoregressive(self,
                            partial_sequences,
                            stop_at_token=1,
                            max_steps=None,
                            temperature=1.0,
                            variable_dtype=mtf.VariableDType(tf.float32),
                            encoder_output=None,
                            encoder_sequence_id=None,
                            shared_params=None,
                            has_partial_sequences=True,
                            encoder_layer_outputs=None):
    """Sample randomly one token at a time.

    The partial_sequences represent partial sequences to be continued.  The
    first tokens of each sequence are nonzero representing the given partial
    sequences and the last tokens of each sequence are zeros, representing what
    needs to be filled in.

    If there are no partial sequences (you want to sample from the beginning),
    then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
    has_partial_sequences=False (so we can skip computation).

    Args:
      partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
      stop_at_token: an optional integer eos id.  Stop when we produce it.
      max_steps: an optional integer
      temperature: an optional floating point value between 0.0 and 1.0 0.0
        means argmax, 1.0 means sample according to predicted distribution.
      variable_dtype: a mtf.VariableDType
      encoder_output: an optional Tensor
      encoder_sequence_id: an optional Tensor
      shared_params: an optional dictionary
      has_partial_sequences: a boolean
      encoder_layer_outputs: optional - readonly list of tensor activations when
        decoding, one per each input layer + the embedding layer

    Returns:
      a Tensor with shape [<batch_dims>, length_dim]
    """
    del max_steps  # TODO(noam): implement
    if not self.autoregressive:
      raise ValueError("must be autoregressive")

    inputs = partial_sequences
    batch_dims = inputs.shape.dims[:-1]
    length_dim = inputs.shape.dims[-1]
    initial_position = mtf.reduce_sum(
        mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim)
    sequence_id = 1 if encoder_sequence_id is not None else None

    context_first_part = Context(
        mesh=inputs.mesh,
        batch_dims=batch_dims,
        length_dim=length_dim,
        model_dim=self.model_dim,
        variable_dtype=variable_dtype,
        mode="first_part",
        autoregressive=self.autoregressive,
        new_states=[],
        initial_position=initial_position,
        sequence_id=sequence_id,
        encoder_output=encoder_output,
        encoder_sequence_id=encoder_sequence_id,
        constant_states=[],
        shared_params=shared_params,
        layout=self.layout,
        mesh_shape=self.mesh_shape,
        encoder_layer_outputs=encoder_layer_outputs)

    shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False)
    with tf.variable_scope(self.name):
      logits = self._call_internal(context_first_part, shifted_inputs)
    del logits
    constant_states = context_first_part.constant_states
    if not has_partial_sequences:
      initial_states = [
          mtf.zeros_like(t) for t in context_first_part.new_states]
    else:
      initial_states = context_first_part.new_states

    def cond_fn(position, ids, *unused_states):
      """Should we run another loop iteration."""
      past_end = mtf.greater_equal(position, length_dim.size)
      is_done = past_end
      if stop_at_token is not None:
        has_eos = mtf.reduce_any(
            mtf.equal(ids, stop_at_token), reduced_dim=length_dim)
        is_done = mtf.logical_or(is_done, has_eos)
      all_done = mtf.reduce_all(is_done)
      return mtf.logical_not(all_done)

    def body_fn(position, ids, *states):
      """One step in the decode loop."""
      context_incremental = Context(
          mesh=inputs.mesh,
          batch_dims=batch_dims,
          length_dim=length_dim,
          model_dim=self.model_dim,
          variable_dtype=variable_dtype,
          mode="incremental",
          autoregressive=self.autoregressive,
          position=position,
          states=states,
          new_states=[],
          sequence_id=sequence_id,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          constant_states=constant_states,
          shared_params=shared_params,
          layout=self.layout,
          mesh_shape=self.mesh_shape,
          encoder_layer_outputs=encoder_layer_outputs)
      inputs_this_step = mtf.gather(ids, position - 1, length_dim)
      with tf.variable_scope(self.name, reuse=True):
        logits = self._call_internal(context_incremental, inputs_this_step)
      ids_this_step = mtf.sample_with_temperature(
          logits, self.output_vocab_dim, temperature)
      new_position = position + 1
      new_ids = ids + ids_this_step * mtf.one_hot(
          position, length_dim, dtype=tf.int32)
      return [new_position, new_ids] + context_incremental.new_states
    while_loop_inputs = [initial_position, inputs] + initial_states
    final_position, outputs = mtf.while_loop(
        cond_fn, body_fn, while_loop_inputs)[:2]
    del final_position
    return outputs
예제 #4
0
def _top_2_gating(inputs,
                  outer_expert_dims,
                  experts_dim,
                  expert_capacity_dim,
                  hparams,
                  train,
                  variable_dtype,
                  importance=None,
                  name="top_2_gating"):
    """Compute gating for mixture-of-experts in TensorFlow.

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_use_second_place_loss: a boolean
    hparams.moe_second_policy_train: a string
    hparams.moe_second_policy_eval: a string
    hparams.moe_second_threshold: a float

  The returned forward assignment is a tensor used to map (via einsum) from the
  inputs to the expert_inputs.  Likewise, the returned combine_tensor is
  used to map (via einsum) from the expert outputs to the outputs.  Both the
  forward and backward assignments are mostly zeros.  The shapes of the tensors
  are as follows.

  inputs: [<batch_dims>, group_size_dim, input_dim]
  importance: [<batch_dims>, group_size_dim]
  dispatch_tensor:
    [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
  expert_inputs:
    [<batch_dims>, experts_dim, expert_capacity_dim, input_dim]

  expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim]
  combine_tensor:
    [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
  outputs: [<batch_dims>, group_size_dim, output_dim]

  "importance" is an optional tensor with one floating-point value for each
  input vector.  If the importance of an input is 1.0, then we send it to
  up to 2 experts.  If 0.0 < importance < 1.0, then we send it to at most
  one expert.  If importance == 0.0, then we send it to no experts.

  We use "importance" at the second-level gating function of a hierarchical
  mixture of experts.  Inputs to the first-choice expert-group get importance
  1.0.  Inputs to the second-choice expert group get importance 0.5.
  Inputs that represent padding get importance 0.0.

  Args:
    inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim]
    outer_expert_dims: an optional list of dimensions.  This is for the case
      where we are at an inner level of a hierarchical MoE.
    experts_dim: a Dimension (the number of experts)
    expert_capacity_dim: a Dimension (number of examples per group per expert)
    hparams: model hyperparameters.
    train: a boolean
    variable_dtype: a mtf.VariableDType
    importance: an optional tensor with shape [<batch_dims>, group_size_dim]
    name: an optional string

  Returns:
    dispatch_tensor: a Tensor with shape
      [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
    combine_tensor: a Tensor with shape
      [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on illegal hyperparameters
  """
    group_size_dim, unused_input_dim = inputs.shape.dims[-2:]

    raw_gates = mtf.layers.dense(inputs,
                                 experts_dim,
                                 use_bias=False,
                                 expert_dims=outer_expert_dims,
                                 variable_dtype=variable_dtype,
                                 name=name)
    raw_gates = mtf.softmax(raw_gates, experts_dim)

    # The internals of this function run in float32.
    #   bfloat16 seems to reduce quality.
    raw_gates = mtf.to_float(raw_gates)

    expert_capacity_f = float(expert_capacity_dim.size)

    # FIND TOP 2 EXPERTS PER POSITON
    # Find the top expert for each position. shape=[batch, group]
    index_1, gate_1 = mtf.top_1(raw_gates, experts_dim)
    # [batch, group, experts]
    mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype)
    density_1_proxy = raw_gates
    if importance is not None:
        mask_1 *= mtf.to_float(mtf.equal(importance, 1.0))
        gate_1 *= mtf.to_float(mtf.equal(importance, 1.0))
        density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0))
    gates_without_top_1 = raw_gates * (1.0 - mask_1)
    # [batch, group]
    index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim)
    # [batch, group, experts]
    mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype)
    if importance is not None:
        mask_2 *= mtf.to_float(mtf.greater(importance, 0.0))

    denom = gate_1 + gate_2 + 1e-9
    gate_1 /= denom
    gate_2 /= denom

    # BALANCING LOSSES
    # shape = [batch, experts]
    # We want to equalize the fraction of the batch assigned to each expert
    density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim)
    # Something continuous that is correlated with what we want to equalize.
    density_1_proxy = mtf.reduce_mean(density_1_proxy,
                                      reduced_dim=group_size_dim)
    loss = (mtf.reduce_mean(density_1_proxy * density_1) *
            float(experts_dim.size * experts_dim.size))

    if hparams.moe_use_second_place_loss:
        # Also add a loss to encourage all experts to be used equally also as the
        # second-place expert.  Experimentally, this seems to be a wash.
        # We want to equalize the fraction of the batch assigned to each expert:
        density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim)
        # As a proxy for density_2, we renormalize the raw gates after the top one
        # has been removed.
        normalized = gates_without_top_1 / (mtf.reduce_sum(
            gates_without_top_1, reduced_dim=experts_dim) + 1e-9)
        density_2_proxy = mtf.reduce_mean(normalized,
                                          reduced_dim=group_size_dim)
        loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) *
                  float(experts_dim.size * experts_dim.size))
        loss += loss_2 * 0.5

    # Depending on the policy in the hparams, we may drop out some of the
    # second-place experts.
    if train:
        policy = hparams.moe_second_policy_train
        threshold = hparams.moe_second_threshold_train
    else:
        policy = hparams.moe_second_policy_eval
        threshold = hparams.moe_second_threshold_eval
    if policy == "all":
        # Use second-place experts for all examples.
        pass
    elif policy == "none":
        # Never use second-place experts for all examples.
        mask_2 = mtf.zeros_like(mask_2)
    elif policy == "threshold":
        # Use second-place experts if gate_2 > threshold.
        mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold))
    elif policy == "random":
        # Use second-place experts with probablity min(1.0, gate_2 / threshold).
        mask_2 *= mtf.to_float(
            mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape),
                     gate_2 / max(threshold, 1e-9)))
    else:
        raise ValueError("Unknown policy %s" % policy)

    # COMPUTE ASSIGNMENT TO EXPERTS
    # [batch, group, experts]
    # This is the position within the expert's mini-batch for this sequence
    position_in_expert_1 = mtf.cumsum(mask_1, group_size_dim,
                                      exclusive=True) * mask_1
    # Remove the elements that don't fit. [batch, group, experts]
    mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f))
    # [batch, experts]
    # How many examples in this sequence go to this expert
    mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim)
    # [batch, group] - mostly ones, but zeros where something didn't fit
    mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim)
    # [batch, group]
    position_in_expert_1 = mtf.reduce_sum(position_in_expert_1,
                                          reduced_dim=experts_dim)
    # Weight assigned to first expert.  [batch, group]
    gate_1 *= mask_1_flat

    # [batch, group, experts]
    position_in_expert_2 = (
        mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count)
    position_in_expert_2 *= mask_2
    mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f))
    # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    gate_2 *= mask_2_flat
    position_in_expert_2 = mtf.reduce_sum(position_in_expert_2,
                                          reduced_dim=experts_dim)

    # [batch, group, experts, expert_capacity]
    combine_tensor = (
        gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) +
        gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim))

    combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
    loss = mtf.cast(loss, inputs.dtype)

    dispatch_tensor = mtf.cast(mtf.cast(combine_tensor, tf.bool),
                               combine_tensor.dtype)

    return dispatch_tensor, combine_tensor, loss
예제 #5
0
파일: gpt2.py 프로젝트: doinker/GPTNeo
def model(mtf_features, other_features, params, mesh, variable_dtype, context=None):
    """A GPT style model implemented in mesh tensorflow."""

    x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(mtf_features, other_features)

    if is_incremental_inference(context):
        # reshape inputs if in inference mode
        x = mtf.gather(x, context.position - 1, sequence_dim)
        x = mtf.reshape(x, [batch_dim])

    use_axial_pos_emb = params["axial_pos_emb"] is not None

    if not use_axial_pos_emb:
        # Use standard position encoding
        wpe = mtf.get_variable(mesh, "wpe", mtf.Shape([embed_sequence_dim, embd_dim]),
                               initializer=tf.random_normal_initializer(stddev=0.01),
                               master_dtype=variable_dtype.master_dtype,
                               slice_dtype=variable_dtype.slice_dtype,
                               activation_dtype=variable_dtype.activation_dtype)
    else:
        wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype)

    # Text encoding
    wte = mtf.get_variable(mesh, "wte", mtf.Shape([vocab_dim, embd_dim]),
                           initializer=tf.random_normal_initializer(stddev=0.02),
                           master_dtype=variable_dtype.master_dtype,
                           slice_dtype=variable_dtype.slice_dtype,
                           activation_dtype=variable_dtype.activation_dtype)

    with tf.variable_scope("token_embd"):
        # Text embedding
        h = mtf.gather(wte, x, vocab_dim)
        if params["embed_dropout"] > 0 and params["mode"] == "train":
            h = mtf.dropout(h, rate=params["embed_dropout"], name="wte_dropout")

    with tf.variable_scope("pos_embd"):
        # Positional embedding
        position_indices = mtf.range(mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else (
                context.position - 1)
        pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0])
        if params["embed_dropout"] > 0 and params["mode"] == "train":
            pos_emb = mtf.dropout(pos_emb, rate=params["embed_dropout"], name="wte_dropout")
        h += pos_emb

    aux_losses = 0  # instantiate auxiliary losses (for MOE models)

    for layer in range(params["n_layer"]):
        # attn blocks
        share_parameters = exists(params["share_parameters"]) and params["share_parameters"] == True
        block_scope = f"h{layer}" if not share_parameters else ""

        block_fn = block(params=params, scope=block_scope, layer_num=layer,
                         bias=other_features["attn_bias"],
                         sequence_dim=sequence_dim,
                         memory_length_dim=other_features["memory_length_dim"],
                         variable_dtype=variable_dtype,
                         context=context)

        # If true and in train mode, enable gradient checkpointing
        recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True
        h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(block_fn, [h])
        aux_losses += loss

    no_weight_tie_emb = params["no_weight_tie"] == True
    if no_weight_tie_emb:
        with tf.variable_scope("wte_final_linear"):
            logits = linear(h, "linear_out", vocab_dim, variable_dtype=variable_dtype, params=params)
    else:
        # Layer normalize & affine transform
        h = layer_norm(h, "ln_f", variable_dtype=variable_dtype)
        seq_dim = sequence_dim if not is_incremental_inference(context) else mtf.Dimension("sequence", 1)
        with tf.variable_scope("wte_final_einsum"):
            # Equivalent to tf.matmul
            logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim])

    if params["mode"] in ["train", "eval"]:
        labels = mtf_features["labels"]
        z_loss = params.get("z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy

        # Go to full precision for the logits 
        logits = mtf.cast(logits, tf.float32)

        use_entmax_loss = params.get("entmax_loss", False)
        loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits

        with tf.variable_scope("xentropy_final"):
            loss_batch = loss_fn(logits=logits, targets=labels,
                                 vocab_dim=logits.shape[-1], z_loss=z_loss)

        # For non-autoregressive models (masked language modeling training)
        # Make sure labels with padding tokens are not counted in the loss
        if not params["causal"]:
            padding_id = params.get("padding_id", 0)
            loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch))

        with tf.variable_scope("reduce_mean_final"):
            loss = mtf.reduce_mean(loss_batch)

        loss += aux_losses  # Add on auxiliary losses (currently only used for MoE)
        loss /= params["num_microbatches"]
        # Convert to train dtype
        loss = mtf.cast(loss, variable_dtype.slice_dtype)
    else:
        loss = None
        loss_batch = None

    # Cast back to checkpoint dtype
    logits = mtf.cast(logits, variable_dtype.master_dtype)
    return logits, loss, loss_batch
    def act_layer(self, context, x, mask):
        """Build a Universal Transformer ACT layer."""
        state = x
        act_max_steps = self.act_max_steps
        threshold = 1.0 - self.act_epsilon
        state_shape_static = state.shape.dims

        state_slice = slice(0, 3)
        if self.act_type == "global":
            state_slice = slice(0, 2)

        # Dynamic shape for update tensors below
        update_shape = state_shape_static[state_slice]

        # Halting probabilities (p_t^n in the paper)
        halting_probability = mtf.zeros(context.mesh,
                                        update_shape,
                                        dtype=context.activation_dtype)

        # Remainders (R(t) in the paper)
        remainders = mtf.zeros(context.mesh,
                               update_shape,
                               dtype=context.activation_dtype)

        # Number of updates performed (N(t) in the paper)
        n_updates = mtf.zeros(context.mesh,
                              update_shape,
                              dtype=context.activation_dtype)

        # Previous cell states (s_t in the paper)
        previous_state = mtf.zeros_like(state)
        step = mtf.constant(context.mesh, 0, dtype=tf.int32)

        def ut_function(state, step, halting_probability, remainders,
                        n_updates, previous_state):
            """implements act (position-wise halting).

      Args:
        state: 3-D Tensor: [batch_size, length, channel]
        step: indicates number of steps taken so far
        halting_probability: halting probability
        remainders: act remainders
        n_updates: act n_updates
        previous_state: previous state

      Returns:
        transformed_state: transformed state
        step: step+1
        halting_probability: halting probability
        remainders: act remainders
        n_updates: act n_updates
        new_state: new state
      """
            state = self.step_preprocess(context, state, step)

            if self.act_type == "random":
                # random as halting probability
                p = mtf.random_uniform(context.mesh,
                                       shape=halting_probability.shape.dims,
                                       dtype=context.variable_dtype)
            else:
                last_dim_name = state.shape.dimension_names[-1]
                new_dims = [mtf.Dimension(last_dim_name, 1)]
                with tf.variable_scope("sigmoid_activation_for_pondering",
                                       reuse=tf.AUTO_REUSE):
                    p = mtf.layers.dense(state,
                                         variable_dtype=context.variable_dtype,
                                         reduced_dims=[state.shape.dims[-1]],
                                         new_dims=new_dims,
                                         activation=mtf.sigmoid,
                                         use_bias=True)
                    if self.act_type == "global":
                        # average over all positions (as a global halting prob)
                        p = mtf.reduce_mean(p, reduced_dim=p.shape.dims[1])
                        p = mtf.squeeze(p)
                    else:
                        # maintain position-wise probabilities
                        new_shape = p.shape.dims[:-1]
                        p = mtf.reshape(p, new_shape)
            # Mask for inputs which have not halted yet
            still_running = mtf.cast(mtf.less(halting_probability, 1.0),
                                     context.activation_dtype)

            # Mask of inputs which halted at this step
            new_halted = mtf.cast(
                mtf.greater(halting_probability + p * still_running,
                            threshold),
                context.activation_dtype) * still_running
            # Mask of inputs which haven't halted, and didn't halt this step
            still_running = mtf.cast(
                mtf.less_equal(halting_probability + p * still_running,
                               threshold),
                context.activation_dtype) * still_running

            # Add the halting probability for this step to the halting
            # probabilities for those input which haven't halted yet
            halting_probability += p * still_running

            # Compute remainders for the inputs which halted at this step
            remainders += new_halted * (1 - halting_probability)

            # Add the remainders to those inputs which halted at this step
            halting_probability += new_halted * remainders

            # Increment n_updates for all inputs which are still running
            n_updates += still_running + new_halted

            # Compute the weight to be applied to the new state and output
            # 0 when the input has already halted
            # p when the input hasn't halted yet
            # the remainders when it halted this step
            input_tensor = p * still_running + new_halted * remainders
            update_weights = input_tensor

            # apply transformation on the state
            transformed_state = state

            for _ in range(self.num_inrecurrence_layers):
                transformed_state = self.vanilla_transformer_layer(
                    context, transformed_state, mask)

            # update running part in the weighted state and keep the rest
            new_state = ((transformed_state * update_weights) +
                         (previous_state * (1 - update_weights)))

            if self.act_type == "accumulated":
                # Add in the weighted state
                new_state = (transformed_state *
                             update_weights) + previous_state

            step += 1

            return (transformed_state, step, halting_probability, remainders,
                    n_updates, new_state)

        for _ in range(act_max_steps + 1):
            (state, step, halting_probability, remainders, n_updates,
             previous_state) = ut_function(state, step, halting_probability,
                                           remainders, n_updates,
                                           previous_state)
        ponder_times = n_updates

        mtf.scalar_summary("ponder_times", mtf.reduce_mean(ponder_times))
        return previous_state
예제 #7
0
def sample_autoregressive(
    partial_sequences,
    other_features,
    params,
    stop_at_token=50256,
    max_steps=None,
    temperature=0.9,
    variable_dtype=mtf.VariableDType(tf.float32),
    encoder_output=None,
    encoder_sequence_id=None,
    encoder_inputs=None,
    shared_params=None,
    has_partial_sequences=True,
    encoder_layer_outputs=None,
    never_end=False,
    remove_partial_sequences=False,
    sampling_keep_top_k=-1,
    bos_id=50256,
):
    """Sample randomly one token at a time.

    The partial_sequences represent partial sequences to be continued.  The
    first tokens of each sequence are nonzero representing the given partial
    sequences and the last tokens of each sequence are zeros, representing what
    needs to be filled in.

    If there are no partial sequences (you want to sample from the beginning),
    then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
    has_partial_sequences=False (so we can skip computation).

    Args:
        partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
        stop_at_token: an optional integer eos id.  Stop when we produce it.
        max_steps: an optional integer, the max number of steps to decode.
        temperature: an optional floating point value between 0.0 and 1.0 0.0
        means argmax, 1.0 means sample according to predicted distribution.
        variable_dtype: a mtf.VariableDType
        encoder_output: an optional Tensor
        encoder_sequence_id: an optional Tensor
        encoder_inputs: an optional Tensor
        shared_params: an optional dictionary
        has_partial_sequences: a boolean
        encoder_layer_outputs: optional - readonly list of tensor activations when
        decoding, one per each input layer + the embedding layer
        never_end: a boolean - if set, then avoid generating stop_at_token
        remove_partial_sequences: a boolean - whether to remove the partial
        sequences from the output
        sampling_keep_top_k: an integer - if not -1, only sample from the top k
        logits.
        bos_id: beginning of sequence id

    Returns:
        a Tensor with shape [<batch_dims>, length_dim]
    """

    inputs = partial_sequences  # Partial sequences to fill in
    batch_dims = inputs.shape.dims[:-1]
    length_dim = inputs.shape.dims[-1]
    padding_id = params.get("padding_id", 0)
    slow_sampling = params.get("slow_sampling", False)

    initial_position = mtf.reduce_sum(
        mtf.to_int32(mtf.not_equal(inputs, padding_id)),
        reduced_dim=length_dim)  # Gets position where zero padding starts

    length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
    input_full_attention = True  # for now hardcode this to true bc lazy
    if input_full_attention:
        # Vanilla autoregressive model - each position can see previous positions.
        # Think this feeds in to the loop fn and tells each position where it can attend to?
        read_priority = write_priority = length_range * mtf.to_int32(
            mtf.greater(length_range, initial_position))
    else:
        read_priority = write_priority = length_range

    # Builds context to pass around internally
    # The 'first part' context records initial states of k / v / x

    if not slow_sampling:
        context_first_part = mtf_transformer.transformer.Context(
            model=None,
            mesh=inputs.mesh,
            batch_dims=batch_dims,
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="first_part",
            position=length_range,
            position_is_default=True,
            new_states=[],
            initial_position=initial_position,
            sequence_id=None,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            constant_states=[],
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=inputs,
            encoder_inputs=encoder_inputs)

        with tf.variable_scope("gpt2"):
            logits, _, _ = gpt2.model({"inputs": inputs},
                                      other_features,
                                      params,
                                      inputs.mesh,
                                      variable_dtype=variable_dtype,
                                      context=context_first_part)

        if not has_partial_sequences:
            initial_states = [
                mtf.zeros_like(t) for t in context_first_part.new_states
            ]
        else:
            initial_states = context_first_part.new_states
    else:
        initial_states = []

    if not has_partial_sequences:
        partial_sequences_eos_count = 0

    if stop_at_token is not None:
        partial_sequences_eos_count = mtf.reduce_sum(mtf.to_int32(
            mtf.equal(partial_sequences, stop_at_token)),
                                                     reduced_dim=length_dim)

    def cond_fn(position, ids, *unused_states):
        """Should we run another loop iteration?"""
        past_end = mtf.greater_equal(position, length_dim.size)
        if max_steps:
            past_end = mtf.logical_or(
                past_end,
                mtf.greater_equal(position - initial_position, max_steps))

        is_done = past_end
        if stop_at_token is not None:
            eos_count = mtf.reduce_sum(mtf.to_int32(
                mtf.equal(ids, stop_at_token)),
                                       reduced_dim=length_dim)
            has_additional_eos = mtf.greater(eos_count,
                                             partial_sequences_eos_count)
            is_done = mtf.logical_or(is_done, has_additional_eos)
        all_done = mtf.reduce_all(is_done)
        return mtf.logical_not(all_done)

    def body_fn(position, ids, *states):
        """One step in the decode loop."""
        nonlocal sampling_keep_top_k

        context = mtf_transformer.transformer.Context(
            model=None,
            mesh=inputs.mesh,
            batch_dims=batch_dims,
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="incremental",
            position=position,
            position_is_default=True,
            states=states,
            new_states=[],
            initial_position=position,
            sequence_id=None,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=ids,
            encoder_inputs=encoder_inputs) if not slow_sampling else None

        with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE):
            logits, _, _ = gpt2.model({"inputs": ids},
                                      other_features,
                                      params,
                                      inputs.mesh,
                                      variable_dtype=variable_dtype,
                                      context=context)

        # By default, do top_k sampling of 0.9
        if sampling_keep_top_k == -2:
            sampling_keep_top_k = int(logits.shape[-1].size * 0.1)

        if sampling_keep_top_k != -1:
            if sampling_keep_top_k <= 0:
                raise ValueError(
                    "sampling_keep_top_k must either be -1 or positive.")
            k_largest = mtf.nth_largest_element(
                logits,
                n=sampling_keep_top_k,
                reduced_dim=other_features["vocab_dim"])
            logits = mtf.where(mtf.less_equal(logits, k_largest),
                               mtf.ones_like(logits) * -1e6, logits)

        ids_this_step = mtf.sample_with_temperature(
            logits, other_features["vocab_dim"], temperature)

        if slow_sampling:
            ids_this_step = mtf.shift(ids_this_step,
                                      offset=1,
                                      dim=length_dim,
                                      wrap=False)
        else:
            ids_this_step = mtf.reshape(ids_this_step, (batch_dims))

        one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32)
        one_new_id = ids_this_step * one_hot
        new_ids = (1 - one_hot) * ids + one_new_id
        new_position = position + 1

        ret = [new_position, new_ids]
        if context is not None:
            ret += context.new_states
        return ret

    while_loop_inputs = [initial_position, inputs] + initial_states
    final_position, outputs = mtf.while_loop(cond_fn, body_fn,
                                             while_loop_inputs)[:2]
    del final_position
    if has_partial_sequences and remove_partial_sequences:
        # Remove partial sequences from outputs
        partial_length = mtf.reduce_sum(mtf.to_int32(
            mtf.not_equal(partial_sequences, padding_id)),
                                        reduced_dim=length_dim)
        outputs = mtf.dynamic_shift(outputs,
                                    -partial_length,
                                    length_dim,
                                    wrap=False)
    return outputs
    def beam_search(self,
                    inputs,
                    decode_length,
                    dst_attributes=None,
                    variable_dtype=mtf.VariableDType(tf.float32),
                    encoder_output=None,
                    encoder_sequence_id=None,
                    encoder_inputs=None,
                    alpha=0.6,
                    shared_params=None,
                    encoder_layer_outputs=None,
                    z=None):
        """Beam search.
        Args:
          inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim,
            length_dim].#
          decode_length: an int32 mtf scalar.  Maximum decode length.
          attributes: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim]
                                          ([<batch_dims>]
                                           [<batch_dims>, beam_dim]).
          variable_dtype: a mtf.VariableDType
          encoder_output: an optional Tensor
          encoder_sequence_id: an optional Tensor
          encoder_inputs: an optional Tensor
          alpha: a floating point value (length bonus)
          shared_params: an optional dictionary
          encoder_layer_outputs: optional - readonly list of tensor activations when
            decoding, one per each input layer + the embedding layer
        Returns:
          a Tensor with shape [<batch_dims>, beam_dim, length_dim]
        """
        attributes = dst_attributes
        if not self.autoregressive:
            raise ValueError("must be autoregressive")

        batch_dims = inputs.shape.dims[:-2]
        if len(batch_dims) != 1:
            raise NotImplementedError(
                "beam search supports exactly one batch dimension.")
        beam_dim = inputs.shape.dims[-2]
        length_dim = inputs.shape.dims[-1]
        length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
        initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal(
            inputs, 0)),
                                          reduced_dim=length_dim)
        sequence_id = 1 if encoder_sequence_id is not None else None

        if self.input_full_attention:
            # This only makes sense in the case of beam search with given partial
            # sequences, which is not yet implemented.
            # TODO(noam): implement
            raise NotImplementedError(
                "Beam search for language models not yet implemented")
        else:
            read_priority = write_priority = length_range

        context_first_part = Context(
            model=self,
            mesh=inputs.mesh,
            batch_dims=batch_dims + [beam_dim],
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="first_part",
            position=length_range,
            position_is_default=True,
            new_states=[],
            initial_position=initial_position,
            sequence_id=sequence_id,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            constant_states=[],
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=inputs,
            encoder_inputs=encoder_inputs)

        shifted_inputs = mtf.shift(inputs,
                                   offset=1,
                                   dim=length_dim,
                                   wrap=False)
        with tf.variable_scope(self.name):
            logits = self._call_internal(context_first_part,
                                         shifted_inputs,
                                         attributes=attributes,
                                         z=z)
        del logits
        # There are no partial targets.
        # Replace initial states by zeros to avoid computing them.
        initial_states = [
            mtf.zeros_like(t) for t in context_first_part.new_states
        ]
        constant_states = context_first_part.constant_states

        def logits_fn(step_num, ids, states):
            """logits_fn for mtf.beam_search.beam_search()."""
            inputs_this_step = mtf.gather(ids, step_num - 1, length_dim)

            if self.attribute_embedding:
                attributes_this_step = mtf.gather(attributes, step_num - 1,
                                                  length_dim)
            else:
                attributes_this_step = None

            context_incremental = Context(
                model=self,
                mesh=inputs.mesh,
                batch_dims=batch_dims + [beam_dim],
                length_dim=length_dim,
                variable_dtype=variable_dtype,
                mode="incremental",
                position=step_num,
                states=states,
                new_states=[],
                sequence_id=sequence_id,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                constant_states=constant_states,
                shared_params=shared_params,
                encoder_layer_outputs=encoder_layer_outputs,
                write_priority=write_priority,
                read_priority=step_num,
                inputs=inputs_this_step,
                encoder_inputs=encoder_inputs)
            with tf.variable_scope(self.name, reuse=True):
                logits = self._call_internal(context_incremental,
                                             inputs_this_step,
                                             attributes=attributes_this_step,
                                             z=z)
            return mtf.to_float(logits), context_incremental.new_states

        beams, unused_scores = mtf.beam_search.beam_search(
            logits_fn,
            inputs,
            alpha,
            states=initial_states,
            decode_length=decode_length,
            use_tpu=True,
            dtype=tf.float32,
            mesh_shape=self.mesh_shape,
            layout=self.layout)
        return mtf.gather(beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32),
                          beam_dim)
    def sample_autoregressive(self,
                              partial_sequences,
                              dst_attributes=None,
                              stop_at_token=1,
                              max_steps=None,
                              temperature=0.0,
                              variable_dtype=mtf.VariableDType(tf.float32),
                              encoder_output=None,
                              encoder_sequence_id=None,
                              encoder_inputs=None,
                              shared_params=None,
                              has_partial_sequences=True,
                              encoder_layer_outputs=None,
                              never_end=False,
                              remove_partial_sequences=False,
                              sampling_keep_top_k=-1,
                              z=None):
        """Sample randomly one token at a time.
        The partial_sequences represent partial sequences to be continued.  The
        first tokens of each sequence are nonzero representing the given partial
        sequences and the last tokens of each sequence are zeros, representing what
        needs to be filled in.
        If there are no partial sequences (you want to sample from the beginning),
        then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
        has_partial_sequences=False (so we can skip computation).
        The dst_attributes represents the destination attributes in which we want to generate sequences.
        Args:
          partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
          dst_attribute: an int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>])
          stop_at_token: an optional integer eos id.  Stop when we produce it.
          max_steps: an optional integer, the max number of steps to decode.
          temperature: an optional floating point value between 0.0 and 1.0 0.0
            means argmax, 1.0 means sample according to predicted distribution.
          variable_dtype: a mtf.VariableDType
          encoder_output: an optional Tensor
          encoder_sequence_id: an optional Tensor
          encoder_inputs: an optional Tensor
          shared_params: an optional dictionary
          has_partial_sequences: a boolean
          encoder_layer_outputs: optional - readonly list of tensor activations when
            decoding, one per each input layer + the embedding layer
          never_end: a boolean - if set, then avoid generating stop_at_token
          remove_partial_sequences: a boolean - whether to remove the partial
            sequences from the output
          sampling_keep_top_k: an integer - if not -1, only sample from the top k
            logits.
        Returns:
          a Tensor with shape [<batch_dims>, length_dim]
        """
        if not self.autoregressive:
            raise ValueError("must be autoregressive")

        inputs = partial_sequences
        attributes = dst_attributes
        batch_dims = inputs.shape.dims[:-1]
        length_dim = inputs.shape.dims[-1]
        initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal(
            inputs, 0)),
                                          reduced_dim=length_dim)
        sequence_id = 1 if encoder_sequence_id is not None else None

        length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
        if self.input_full_attention:
            read_priority = write_priority = length_range * mtf.to_int32(
                mtf.greater(length_range, initial_position))
        else:
            read_priority = write_priority = length_range

        context_first_part = Context(
            model=self,
            mesh=inputs.mesh,
            batch_dims=batch_dims,
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="first_part",
            position=length_range,
            position_is_default=True,
            new_states=[],
            initial_position=initial_position,
            sequence_id=sequence_id,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            constant_states=[],
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=inputs,
            encoder_inputs=encoder_inputs)

        shifted_inputs = mtf.shift(inputs,
                                   offset=1,
                                   dim=length_dim,
                                   wrap=False)
        with tf.variable_scope(self.name):
            logits = self._call_internal(context_first_part,
                                         shifted_inputs,
                                         attributes=attributes,
                                         z=z)
        del logits
        constant_states = context_first_part.constant_states
        if not has_partial_sequences:
            initial_states = [
                mtf.zeros_like(t) for t in context_first_part.new_states
            ]
            partial_sequences_eos_count = 0
        else:
            initial_states = context_first_part.new_states
            partial_sequences_eos_count = mtf.reduce_sum(
                mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)),
                reduced_dim=length_dim)

        def cond_fn(position, ids, *unused_states):
            """Should we run another loop iteration."""
            past_end = mtf.greater_equal(position, length_dim.size)
            if max_steps:
                past_end = mtf.logical_or(
                    past_end,
                    mtf.greater_equal(position - initial_position, max_steps))

            is_done = past_end
            if stop_at_token is not None:
                eos_count = mtf.reduce_sum(mtf.to_int32(
                    mtf.equal(ids, stop_at_token)),
                                           reduced_dim=length_dim)
                has_additional_eos = mtf.greater(eos_count,
                                                 partial_sequences_eos_count)
                is_done = mtf.logical_or(is_done, has_additional_eos)
            all_done = mtf.reduce_all(is_done)
            return mtf.logical_not(all_done)

        def body_fn(position, ids, *states):
            """One step in the decode loop."""
            inputs_this_step = mtf.gather(ids, position - 1, length_dim)
            if self.attribute_embedding:
                attributes_this_step = mtf.gather(attributes, position - 1,
                                                  length_dim)
            else:
                attributes_this_step = None
            # raise ValueError("inputs_this_step shape=%s , ids shape=%s, position - 1 shape=%s, length_dim=%s" % (inputs_this_step.shape, ids.shape, (position - 1).shape, length_dim))
            context_incremental = Context(
                model=self,
                mesh=inputs.mesh,
                batch_dims=batch_dims,
                length_dim=length_dim,
                variable_dtype=variable_dtype,
                mode="incremental",
                position=position,
                states=states,
                new_states=[],
                sequence_id=sequence_id,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                constant_states=constant_states,
                shared_params=shared_params,
                encoder_layer_outputs=encoder_layer_outputs,
                write_priority=write_priority,
                read_priority=position,
                inputs=inputs_this_step,
                encoder_inputs=encoder_inputs)

            with tf.variable_scope(self.name, reuse=True):
                logits = self._call_internal(context_incremental,
                                             inputs_this_step,
                                             attributes=attributes_this_step,
                                             z=z)
                if never_end:
                    logits += mtf.one_hot(mtf.constant(logits.mesh,
                                                       stop_at_token,
                                                       dtype=tf.int32),
                                          self.output_vocab_dim,
                                          on_value=-1e9,
                                          off_value=0.0,
                                          dtype=logits.dtype)

            # TBD whether this should be before or after never_end:
            # Note for adding top_p sampling in the future, in other code bases, the
            # option to apply temperature is done before the top-k truncation. This
            # implementation does this in the opposite order. For top-k this doesn't
            # matter, but for top_p it will.
            if sampling_keep_top_k != -1:
                if sampling_keep_top_k <= 0:
                    raise ValueError(
                        "sampling_keep_top_k must either be -1 or positive.")
                k_largest = mtf.nth_largest_element(
                    logits,
                    n=sampling_keep_top_k,
                    reduced_dim=self.output_vocab_dim)
                logits = mtf.where(mtf.less_equal(logits, k_largest),
                                   mtf.ones_like(logits) * -1e6, logits)

            ids_this_step = mtf.sample_with_temperature(
                logits, self.output_vocab_dim, temperature)
            new_position = position + 1
            new_ids = ids + ids_this_step * mtf.one_hot(
                position, length_dim, dtype=tf.int32)
            return [new_position, new_ids] + context_incremental.new_states

        while_loop_inputs = [initial_position, inputs] + initial_states
        final_position, outputs = mtf.while_loop(cond_fn, body_fn,
                                                 while_loop_inputs)[:2]
        del final_position
        if has_partial_sequences and remove_partial_sequences:
            # remove partial sequences from outputs
            partial_length = mtf.reduce_sum(mtf.to_int32(
                mtf.not_equal(partial_sequences, 0)),
                                            reduced_dim=length_dim)
            outputs = mtf.dynamic_shift(outputs,
                                        -partial_length,
                                        length_dim,
                                        wrap=False)
        return outputs