def layer_prepostprocess_dropout(x, hparams):
  batch_dim = x.shape.dims[0]
  model_dim = x.shape.dims[-1]
  return mtf.dropout(
      x,
      keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
      noise_shape=mtf.Shape([batch_dim, model_dim]))
Exemple #2
0
    def __init__(self,
                 config,
                 is_training,
                 input_ids,
                 input_mask=None,
                 token_type_ids=None,
                 scope=None,
                 mesh_shape="",
                 layout=""):
        self.config = copy.deepcopy(config)
        del config
        if not is_training:
            self.config.layer_output_dropout_prob = 0.0
            self.config.attention_probs_dropout_prob = 0.0
            self.config.feedforward_intermediate_dropout_prob = 0.0
        input_shape = input_ids.shape
        assert input_shape.ndims == 2

        self._seq_dim = input_shape.dims[1]
        self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size)
        self._extra_losses = []
        mesh = input_ids.mesh

        if token_type_ids is None:
            token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32)

        with tf.variable_scope(scope, default_name="bert"):
            with tf.variable_scope("embeddings"):
                # Perform embedding lookup on the word ids.
                self.embedding_table = mtf.get_variable(
                    mesh,
                    "word_embeddings",
                    mtf.Shape([self.vocab_dim, self.model_dim]),
                    initializer=self.embedding_initializer)
                self.word_embedding_output = mtf.gather(
                    self.embedding_table, input_ids, self.vocab_dim)

                # Add positional embeddings and token type embeddings, then layer
                # normalize and perform dropout.
                self.embedding_output = self.word_embedding_output

                token_type_table = mtf.get_variable(
                    mesh,
                    "token_type_embeddings",
                    mtf.Shape([self.token_type_vocab_dim, self.model_dim]),
                    initializer=self.embedding_initializer)
                if token_type_ids is not None:
                    self.embedding_output += mtf.gather(
                        token_type_table, token_type_ids,
                        self.token_type_vocab_dim)
                if self.config.position_signal == "embedding":
                    full_position_table = mtf.get_variable(
                        mesh,
                        "position_embeddings",
                        mtf.Shape(
                            [self.max_position_embeddings_dim,
                             self.model_dim]),
                        initializer=self.embedding_initializer)
                    short_position_table = mtf.rename_dimension(
                        mtf.slice(full_position_table, 0, self.seq_dim.size,
                                  self.max_position_embeddings_dim.name),
                        self.max_position_embeddings_dim.name,
                        self.seq_dim.name)
                    self.embedding_output += short_position_table
                self.embedding_output = self.normalize(self.embedding_output)
                self.embedding_output = mtf.dropout(
                    self.embedding_output,
                    keep_prob=1.0 - self.config.layer_output_dropout_prob)

            with tf.variable_scope("encoder"):
                attention_biases = []
                if input_mask:
                    # [batch_dim, memory_seq_dim]
                    attention_biases.append((1.0 - mtf.to_float(
                        mtf.replace_dimensions(input_mask, self.seq_dim,
                                               self.memory_seq_dim))) *
                                            -10000.0)
                if self.config.position_signal == "relative_attention_bias":
                    buckets_dim = mtf.Dimension("buckets", 32)
                    rp_bucket = _relative_position_bucket(
                        mtf.range(mesh, self.memory_seq_dim, tf.int32) -
                        mtf.range(mesh, self.seq_dim, tf.int32),
                        num_buckets=buckets_dim.size)
                    bias_var = mtf.get_variable(
                        mesh,
                        "relative_attention_bias",
                        [self.num_heads_dim, buckets_dim],
                        initializer=tf.zeros_initializer())
                    attention_biases.append(
                        mtf.gather(bias_var, rp_bucket, buckets_dim))
                attention_bias = mtf.add_n(attention_biases)
                prev_layer_output = self.embedding_output
                self.all_encoder_layers = []
                for block_num in range(self.config.num_blocks):
                    with tf.variable_scope("block_%d" % block_num):
                        for layer_idx, layer_type in enumerate(
                                self.config.block_layers):
                            layer_name = layer_type
                            count = self.config.block_layers[:layer_idx].count(
                                layer_type)
                            if count:
                                layer_name += "_%d" % count
                            with tf.variable_scope(layer_name):
                                x = prev_layer_output
                                if self.config.residual_structure == "direct":
                                    x = self.normalize(x)
                                if layer_type == "attention":
                                    x = self.self_attention(x, attention_bias)
                                elif layer_type == "feedforward":
                                    x = self.feedforward(x)
                                elif layer_type == "moe":
                                    x = self.moe(x, layout, mesh_shape,
                                                 input_mask, is_training)
                                else:
                                    raise ValueError("unknown layer type " +
                                                     layer_type)
                                x = mtf.dropout(
                                    x,
                                    keep_prob=1.0 -
                                    self.config.layer_output_dropout_prob)
                                layer_output = prev_layer_output + x
                                if self.config.residual_structure == "original":
                                    layer_output = self.normalize(layer_output)
                                prev_layer_output = layer_output
                    self.all_encoder_layers.append(layer_output)

            self.sequence_output = prev_layer_output
            if self.config.residual_structure == "direct":
                self.sequence_output = self.normalize(self.sequence_output)

            # The "pooler" converts the encoded sequence tensor of shape
            # [batch_dim, seq_dim, hidden_size] to a tensor of shape
            # [batch_dim, hidden_size]. This is necessary for segment-level
            # (or segment-pair-level) classification tasks where we need a fixed
            # dimensional representation of the segment.
            with tf.variable_scope("pooler"):
                # We "pool" the model by simply taking the hidden state corresponding
                # to the first token. We assume that this has been pre-trained
                first_token_tensor = mtf.gather(self.sequence_output, 0,
                                                self.seq_dim)
                self.pooled_output = mtf.layers.dense(
                    first_token_tensor,
                    reduced_dims=[self.model_dim],
                    new_dims=[self.model_dim],
                    activation=mtf.tanh,
                    kernel_initializer=self.dense_initializer,
                    use_bias=self.config.use_bias)
Exemple #3
0
def transformer_moe_layer_v1(inputs,
                             output_dim,
                             hparams,
                             train,
                             variable_dtype,
                             layout=None,
                             mesh_shape=None,
                             nonpadding=None,
                             activation=mtf.relu):
    """Local mixture of experts that works well on TPU.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  The number of parameters in the gating network is:
    (input_dim.size * hparams.num_experts) +

  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Dimensions cheat sheet:
  B: batch dim(s)
  L: original sequence length
  M: input depth
  N: output depth
  G: number of groups
  S: group size
  E: number of experts
  C: expert capacity

  Args:
    inputs: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional Tensor with shape [batch_dim(s), length_dim]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).
    activation: a function.

  Returns:
    outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    # pylint: disable=line-too-long
    #
    # O outer_batch dimension can be used for expert replication, e.g.
    # outer_batch=4 for placing 128 experts on 512 cores with 4 replicas of each
    # expert.
    #
    # E.g. 16x16 basic example:
    #   moe_num_experts=512, num_groups=1024, batch=4096, length=256, d_model=1024
    # ---
    # Below ` indicates common way of splitting along mesh dimension.
    #
    # orig_inputs      OB`LM Tensor
    #                  Shape[outer_batch=1, batch=4096, length=256, d_model=1024]
    #                  v (reshaped)
    # inputs           OG`SM
    #                  Shape[outer_batch=1, batch=1024, group=1024, d_model=1024]
    #
    # combine_tensor,
    # dispatch_tensor  OG`SEC
    #                  Shape[outer_batch=1, batch=1024, group=1024, expert_unsplit=512, expert_capacity=4]
    #
    # (dispatched inputs)
    # expert_inputs    OEG`CM
    #                  Shape[outer_batch=1, expert_unsplit=512, batch=1024, expert_capacity=4, d_model=1024]
    #                  v (re-split via ReshapeOperation)
    #                  OE`GCM
    #                  Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, d_model=1024]
    #
    # (hidden representation)
    # h                OE`GCH
    #                  Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, expert_hidden=8192]
    #
    # expert_output    OE`GCM
    #                  Shape[outer_batch=1, experts=512, batch_unsplit=1024, expert_capacity=4, d_model=1024]
    #                  v (re-split via ReshapeOperation)
    #                  OEG`CM
    #                  Shape[outer_batch=1, expert_unsplit=512, batch=1024, expert_capacity=4, d_model=1024]
    #
    # (combined expert_output)
    # output           OG`SM
    #                  Shape[outer_batch=1, batch=1024, group=1024, d_model=1024
    #                  v (reshape)
    #                  OB`LM
    #                  Shape[outer_batch=1, batch=4096, length=256, d_model=1024]
    #
    # pylint: enable=line-too-long
    orig_inputs = inputs
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
    experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups is a multiple of the mesh dimension
    # over which those groups are split.
    batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
                                        orig_inputs.shape.dims[-1])
    # Hack: we assume that
    #   "outer_batch" == replication of experts
    #   mesh_dim_size can be derived from mesh_shape and orig_batch_dim
    #
    # We then reqire num_groups to be a multiple of mesh_dim_size.
    if orig_inputs.shape.dims[0].name == "outer_batch":
        outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2]
    else:
        outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1),
                                           orig_inputs.shape.dims[0])

    # Number of MoE inputs (total number of position across batch_and_length_dims
    # per replica.
    n = 1
    for d in batch_and_length_dims:
        n *= d.size

    n = n // outer_batch_dim.size

    mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
                                                    orig_batch_dim)
    num_groups, group_size = _split_into_groups(n, hparams.moe_group_size,
                                                mesh_dim_size)

    group_size_dim = mtf.Dimension("group", group_size)
    num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)

    moe_input_dims = [
        outer_batch_dim, num_groups_dim, group_size_dim, input_dim
    ]
    # OGSM Tensor
    inputs = mtf.reshape(inputs, moe_input_dims)

    # Each sequence sends expert_capacity positions to each expert.
    if train:
        capacity_factor = hparams.moe_capacity_factor_train
    else:
        capacity_factor = hparams.moe_capacity_factor_eval
    expert_capacity = min(
        group_size_dim.size,
        int((group_size_dim.size * capacity_factor) / experts_dim.size))
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)

    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
    if nonpadding is not None:
        nonpadding = mtf.zeros(inputs.mesh,
                               batch_and_length_dims,
                               dtype=inputs.dtype) + nonpadding
        nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
    if hparams.moe_gating == "top_2":
        # combine_tensor,
        # dispatch_tensor  OG`SEC Tensors
        # (G is generally split along mesh dim)
        dispatch_tensor, combine_tensor, loss = _top_2_gating(
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                               mtf.Shape([
                                   outer_batch_dim, experts_dim_unsplit,
                                   num_groups_dim, expert_capacity_dim,
                                   input_dim
                               ]))

    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.Shape([
            outer_batch_dim, experts_dim, batch_dim_unsplit,
            expert_capacity_dim, input_dim
        ]))

    # Now feed the expert inputs through the experts.
    h = mtf.layers.dense_product(expert_inputs,
                                 reduced_dims=expert_inputs.shape.dims[-1:],
                                 new_dims=[hidden_dim],
                                 expert_dims=[experts_dim],
                                 activation_functions=activation,
                                 use_bias=False,
                                 variable_dtype=variable_dtype,
                                 name="wi")

    if train and hparams.moe_dropout_rate != 0.0:
        h = mtf.dropout(h, 1.0 - hparams.moe_dropout_rate)

    expert_output = mtf.layers.dense(h,
                                     output_dim,
                                     expert_dims=[experts_dim],
                                     use_bias=False,
                                     reduced_dims=h.shape.dims[-1:],
                                     variable_dtype=variable_dtype,
                                     name="wo")

    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim,
            experts_dim_unsplit,
            num_groups_dim,
            expert_capacity_dim,
            output_dim,
        ]))

    moe_output_dims = moe_input_dims[:-1] + [output_dim]
    output = mtf.einsum([expert_output, combine_tensor],
                        mtf.Shape(moe_output_dims))
    output = mtf.reshape(output, batch_and_length_dims + [output_dim])

    return output, loss * hparams.moe_loss_coef
Exemple #4
0
def hybrid_attention(q,
                     k,
                     v,
                     context,
                     memory_length_dim,
                     key_dim,
                     value_dim,
                     bias=None,
                     dropout_rate=0.0,
                     dropout_broadcast_dims=None,
                     extra_logit=None):
    """Dot-product attention - doesn't use positional dimensions.

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    context: context of the attention layer.
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """
    logits = mtf.einsum([q, k], reduced_dims=[key_dim])
    if bias is not None:
        logits += bias

    query_length_dim = mtf.Dimension("length", memory_length_dim.size)
    doubly_coeff = mtf.get_variable(context.mesh,
                                    "doubly_coeff", [],
                                    initializer=tf.constant_initializer(0.5),
                                    dtype=context.variable_dtype)
    doubly_coeff = mtf.maximum(mtf.minimum(doubly_coeff, 1.), 0.)

    upper_weights = mtf.softmax(logits,
                                memory_length_dim,
                                extra_logit=extra_logit)

    lower_log_weights = mtf.log_softmax(logits,
                                        query_length_dim,
                                        extra_logit=extra_logit)
    doubly_weights = mtf.softmax(lower_log_weights,
                                 memory_length_dim,
                                 extra_logit=extra_logit)

    weights = doubly_coeff * doubly_weights + (1. -
                                               doubly_coeff) * upper_weights
    if dropout_rate != 0.0:
        weights = mtf.dropout(weights,
                              1.0 - dropout_rate,
                              noise_shape=weights.shape -
                              dropout_broadcast_dims)
    outputs_shape = q.shape - key_dim + value_dim
    outputs = mtf.einsum([weights, v], outputs_shape)
    return outputs
Exemple #5
0
def attention(q,
              k,
              v,
              memory_length_dim,
              key_dim,
              value_dim,
              bias=None,
              dropout_rate=0.0,
              dropout_broadcast_dims=None,
              extra_logit=None,
              context=None,
              float32_logits=True):
    """Dot-product attention - doesn't use positional dimensions.

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor
    context: an optional Transformer.Context
    float32_logits: a boolean - if True, then compute logits in float32 to avoid
      numerical issues with bfloat16

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """
    orig_q_shape = q.shape
    q, k, v, bias = maybe_reshape_attention_input_for_2d_sharding(
        context, q, k, v, bias, [key_dim, value_dim])
    if float32_logits:
        k = mtf.cast(k, tf.float32)
        q = mtf.cast(q, tf.float32)
    logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim])
    if bias is not None:
        logits += mtf.cast(bias, logits.dtype)
    weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit)
    weights = mtf.cast(weights, v.dtype)
    if dropout_rate != 0.0:
        weights = mtf.dropout(weights,
                              1.0 - dropout_rate,
                              noise_shape=weights.shape -
                              dropout_broadcast_dims)
    outputs_shape = q.shape - key_dim + value_dim
    outputs = mtf.einsum([weights, v], outputs_shape)
    outputs = mtf.reshape(outputs, orig_q_shape - key_dim + value_dim)
    return outputs
Exemple #6
0
def synthetic_attention(q,
                        k,
                        v,
                        memory_length_dim,
                        key_dim,
                        value_dim,
                        bias=None,
                        dropout_rate=0.0,
                        dropout_broadcast_dims=None,
                        extra_logit=None,
                        synthesize=True,
                        synthesize_mode="random_plus_alpha",
                        factorized_dim=16,
                        max_length=512,
                        context=None):
    """Synthetic Attention from Synthesizers (https://arxiv.org/abs/2005.00743).

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor
    synthesize: flag to use synthetic attention or not
    synthesize_mode: which variant of synthesizer to use
    factorized_dim: factorized dim for synthesizers
    max_length: max length of input sequence
    context: context since we need context mode

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """

    if synthesize:
        num_heads = v.shape.get_dim_by_name("heads")
        tf.logging.info("Using synthesizer")
        if synthesize_mode == "random":
            tf.logging.info("Using Random Synthesizers")
            r_shape = mtf.Shape([
                mtf.Dimension("length", max_length),
                mtf.Dimension("heads", num_heads.size),
                mtf.Dimension("memory_length", max_length)
            ])
            r = mtf.get_variable(context.mesh,
                                 "R",
                                 r_shape,
                                 initializer=None,
                                 dtype=context.variable_dtype)
            r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
            if context.mode == "incremental":
                r = mtf.gather(r, context.position,
                               r.shape.get_dim_by_name("length"))
            else:
                length_dim = q.shape.get_dim_by_name("length")
                r = mtf.slice(r, 0, length_dim.size, "length")
            logits = r
            r_shape = logits.shape
        elif synthesize_mode == "factorized":
            tf.logging.info("Using Factorized Random Synthesizers")
            k = factorized_dim
            r1_shape = mtf.Shape([
                mtf.Dimension("tmp", k),
                mtf.Dimension("heads", num_heads.size),
                mtf.Dimension("memory_length", 512)
            ])
            r2_shape = mtf.Shape([
                mtf.Dimension("tmp", k),
                mtf.Dimension("heads", num_heads.size),
                mtf.Dimension("memory_length", 512)
            ])
            r_shape = mtf.Shape([
                mtf.Dimension("length", 512),
                mtf.Dimension("heads", num_heads.size),
                mtf.Dimension("memory_length", 512)
            ])
            r1 = mtf.get_variable(context.mesh,
                                  "R1",
                                  r1_shape,
                                  initializer=None,
                                  dtype=context.variable_dtype)
            r2 = mtf.get_variable(context.mesh,
                                  "R2",
                                  r2_shape,
                                  initializer=None,
                                  dtype=context.variable_dtype)
            r = mtf.einsum([r1, r2], r_shape)
            r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
            if context.mode == "incremental":
                r = mtf.gather(r, context.position,
                               r.shape.get_dim_by_name("length"))
            else:
                length_dim = q.shape.get_dim_by_name("length")
                r = mtf.slice(r, 0, length_dim.size, "length")
            logits = r
        elif synthesize_mode == "dense_minus":
            # Dense Synthesizer Model
            tmp_dim = mtf.Dimension("memory_length", max_length)
            logits = mtf.layers.dense(mtf.relu(q), [tmp_dim],
                                      use_bias=False,
                                      name="pi",
                                      reduced_dims=[key_dim],
                                      variable_dtype=None)
            logits = mtf.slice(logits, 0, memory_length_dim.size,
                               memory_length_dim.name)
            if context.mode == "incremental":
                pass
            else:
                length_dim = q.shape.get_dim_by_name("length")
                logits = mtf.slice(logits, 0, length_dim.size, "length")
        elif synthesize_mode == "random_plus_alpha" or \
            synthesize_mode == "random_plus":
            # Mixture Random Synthesizer with learnable Alpha
            tf.logging.info("Using Random Plus Alpha")
            logits = mtf.einsum([q, k], reduced_dims=[key_dim])
            num_heads = logits.shape.get_dim_by_name("heads")
            r_shape = mtf.Shape([
                mtf.Dimension("length", 512),
                mtf.Dimension("heads", num_heads.size),
                mtf.Dimension("memory_length", 512)
            ])
            r = mtf.get_variable(context.mesh,
                                 "R",
                                 r_shape,
                                 initializer=None,
                                 dtype=context.variable_dtype)
            r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
            if context.mode == "incremental":
                r = mtf.gather(r, context.position,
                               r.shape.get_dim_by_name("length"))
            else:
                length_dim = q.shape.get_dim_by_name("length")
                r = mtf.slice(r, 0, length_dim.size, length_dim.name)
            if "alpha" in synthesize_mode:
                alpha = mtf.get_variable(context.mesh,
                                         "alpha",
                                         mtf.Shape([mtf.Dimension("alpha",
                                                                  1)]),
                                         initializer=tf.zeros_initializer(),
                                         dtype=context.variable_dtype)
                alpha = mtf.sigmoid(alpha)
                logits = ((1 - alpha) * logits) + (alpha * r)
            else:
                logits = logits + r
        elif synthesize_mode == "dense_plus_alpha" or \
            synthesize_mode == "dense_plus":
            # Mixture Dense Synthesizer with learnable alpha
            tf.logging.info("Using Dense Plus Alpha Scaling")
            logits = mtf.einsum([q, k], reduced_dims=[key_dim])
            tmp_dim = mtf.Dimension("memory_length", 512)
            r = mtf.layers.dense(mtf.relu(q), [tmp_dim],
                                 use_bias=False,
                                 name="pi",
                                 reduced_dims=[key_dim],
                                 variable_dtype=None)
            r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
            if context.mode == "incremental":
                pass
            else:
                length_dim = q.shape.get_dim_by_name("length")
                r = mtf.slice(r, 0, length_dim.size, "length")
            if "alpha" in synthesize_mode:
                alpha = mtf.get_variable(context.mesh,
                                         "alpha",
                                         mtf.Shape([mtf.Dimension("alpha",
                                                                  1)]),
                                         initializer=tf.zeros_initializer(),
                                         dtype=context.variable_dtype)
                alpha = mtf.sigmoid(alpha)
                logits = ((1 - alpha) * logits) + (alpha * r)
            else:
                logits = logits + r
    if bias is not None:
        logits += bias

    weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit)
    if dropout_rate != 0.0:
        weights = mtf.dropout(weights,
                              1.0 - dropout_rate,
                              noise_shape=weights.shape -
                              dropout_broadcast_dims)

    if synthesize and "plus" not in synthesize_mode:
        if synthesize_mode == "dense_minus":
            outputs_shape = mtf.Shape(q.shape.dims[:-1] + [value_dim])
        else:
            outputs_shape = mtf.Shape(q.shape.dims[:-1] +
                                      [num_heads, value_dim])
    else:
        outputs_shape = q.shape - [key_dim] + value_dim

    outputs = mtf.einsum([weights, v], outputs_shape)
    return outputs
 def layer_prepostprocess_dropout(x):
   if is_incremental:
     return x
   return mtf.dropout(
       x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
       noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
Exemple #8
0
 def layer_prepostprocess_dropout(x):
   return mtf.dropout(
       x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
       noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
Exemple #9
0
def model(mtf_features,
          other_features,
          params,
          mesh,
          variable_dtype,
          context=None):
    """A GPT style model implemented in mesh tensorflow."""

    x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(
        mtf_features, other_features)

    if is_incremental_inference(context):
        # reshape inputs if in inference mode
        x = mtf.gather(x, context.position - 1, sequence_dim)
        x = mtf.reshape(x, [batch_dim])

    use_axial_pos_emb = params["axial_pos_emb"] is not None

    if not use_axial_pos_emb:
        # Use standard position encoding
        wpe = mtf.get_variable(
            mesh,
            "wpe",
            mtf.Shape([embed_sequence_dim, embd_dim]),
            initializer=tf.random_normal_initializer(stddev=0.01),
            master_dtype=variable_dtype.master_dtype,
            slice_dtype=variable_dtype.slice_dtype,
            activation_dtype=variable_dtype.activation_dtype)
    else:
        wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype)

    # Text encoding
    wte = mtf.get_variable(
        mesh,
        "wte",
        mtf.Shape([vocab_dim, embd_dim]),
        initializer=tf.random_normal_initializer(stddev=0.02),
        master_dtype=variable_dtype.master_dtype,
        slice_dtype=variable_dtype.slice_dtype,
        activation_dtype=variable_dtype.activation_dtype)

    with tf.variable_scope("token_embd"):
        # Text embedding
        h = mtf.gather(wte, x, vocab_dim)
        if params["embed_dropout"] > 0 and params["mode"] == "train":
            h = mtf.dropout(h,
                            rate=params["embed_dropout"],
                            name="wte_dropout")

    with tf.variable_scope("pos_embd"):
        # Positional embedding
        position_indices = mtf.range(
            mesh, sequence_dim,
            tf.int64) if not is_incremental_inference(context) else (
                context.position - 1)
        pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0])
        if params["embed_dropout"] > 0 and params["mode"] == "train":
            pos_emb = mtf.dropout(pos_emb,
                                  rate=params["embed_dropout"],
                                  name="wte_dropout")
        h += pos_emb

    aux_losses = 0  # instantiate auxiliary losses (for MOE models)

    for layer in range(params["n_layer"]):
        # attn blocks
        share_parameters = exists(
            params["share_parameters"]) and params["share_parameters"] == True
        block_scope = f"h{layer}" if not share_parameters else ""

        block_fn = block(params=params,
                         scope=block_scope,
                         layer_num=layer,
                         bias=other_features["attn_bias"],
                         sequence_dim=sequence_dim,
                         memory_length_dim=other_features["memory_length_dim"],
                         variable_dtype=variable_dtype,
                         context=context)

        # If true and in train mode, enable gradient checkpointing
        recompute_grad = params["recompute_grad"] and (params["mode"]
                                                       == "train") == True
        h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(
            block_fn, [h])
        aux_losses += loss

    no_weight_tie_emb = params["no_weight_tie"] == True
    if no_weight_tie_emb:
        with tf.variable_scope("wte_final_linear"):
            logits = linear(h,
                            "linear_out",
                            vocab_dim,
                            variable_dtype=variable_dtype,
                            params=params)
    else:
        # Layer normalize & affine transform
        h = layer_norm(h, "ln_f", variable_dtype=variable_dtype)
        seq_dim = sequence_dim if not is_incremental_inference(
            context) else mtf.Dimension("sequence", 1)
        with tf.variable_scope("wte_final_einsum"):
            # Equivalent to tf.matmul
            logits = mtf.einsum([h, wte],
                                output_shape=[batch_dim, seq_dim, vocab_dim])

    if params["mode"] in ["train", "eval"]:
        labels = mtf_features["labels"]
        z_loss = params.get(
            "z_loss", 1e-4)  # an auxiliary loss used to stabilize mtf xentropy

        # Go to full precision for the logits
        logits = mtf.cast(logits, tf.float32)

        use_entmax_loss = params.get("entmax_loss", False)
        loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits

        with tf.variable_scope("xentropy_final"):
            loss_batch = loss_fn(logits=logits,
                                 targets=labels,
                                 vocab_dim=logits.shape[-1],
                                 z_loss=z_loss)

        # For non-autoregressive models (masked language modeling training)
        # Make sure labels with padding tokens are not counted in the loss
        if not params["causal"]:
            padding_id = params.get("padding_id", 0)
            loss_batch = mtf.where(mtf.not_equal(labels, padding_id),
                                   loss_batch, mtf.zeros_like(loss_batch))

        with tf.variable_scope("reduce_mean_final"):
            loss = mtf.reduce_mean(loss_batch)

        loss += aux_losses  # Add on auxiliary losses (currently only used for MoE)
        loss /= params["num_microbatches"]
        # Convert to train dtype
        loss = mtf.cast(loss, variable_dtype.slice_dtype)
    else:
        loss = None
        loss_batch = None

    # Cast back to checkpoint dtype
    logits = mtf.cast(logits, variable_dtype.master_dtype)
    return logits, loss, loss_batch
Exemple #10
0
    def fn(x):
        with tf.variable_scope(scope):
            nx = x.shape[-1]  # Grab last dimension from input

            if use_rezero:
                prenorm = identity
            elif use_scale_norm:
                prenorm = scale_norm
            else:
                prenorm = layer_norm

            pre_residual_fn = rezero if use_rezero else identity

            attention_type = params["attention_types"][layer_num]

            if macaron_attention:
                mult = 0.5
                mlp_fn = mlp_glu if use_mlp_glu else mlp
                intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2)
                # Define intermediate layer of mlp - to split
                dim_intermediate_expanded = mtf.Dimension(
                    "intermediate_expanded", intermediate_size)
                m = mlp_fn(x,
                           "mlp_macaron",
                           dim_intermediate_expanded,
                           variable_dtype=variable_dtype,
                           params=params)

                x = x + (m * mult)
            else:
                mult = 1

            if attention_type != "none":
                res_x = prenorm(x,
                                "norm_1",
                                variable_dtype=variable_dtype,
                                params=params)
                a = attn(res_x,
                         "attn",
                         nx,
                         attention_type=attention_type,
                         params=params,
                         bias=bias,
                         dim_seq=sequence_dim,
                         memory_length_dim=memory_length_dim,
                         variable_dtype=variable_dtype,
                         context=context)
            else:
                a = x

            x = x + pre_residual_fn(a, "norm_rezero_1", dtype=variable_dtype)

            res_x = prenorm(x,
                            "norm_2",
                            variable_dtype=variable_dtype,
                            params=params)

            if use_moe:
                moe_params = mtf.transformer.moe.HParams()
                mtf.transformer.moe.set_default_moe_hparams(moe_params)
                moe_params.add_hparam("moe_min_expert_capacity", 1)
                moe_params.add_hparam("moe_use_experts_attention", False)

                # Override defaults
                for k, v in params["moe_params"].items():
                    moe_params.add_hparam(k, v)

                moe_train = params["mode"] == "train"

                m, aux_loss = mtf.transformer.moe.transformer_moe_layer_v1(
                    res_x,
                    x.shape[-1],
                    moe_params,
                    train=moe_train,
                    mesh_shape=params["mesh_shape"],
                    layout=params["layout"],
                    activation=params.get("moe_activation", "relu"),
                    variable_dtype=variable_dtype,
                    num_microbatches=params["num_microbatches"])
                m = mtf.dropout(m,
                                rate=params["res_dropout"],
                                name="moe_dropout")
            else:

                mlp_fn = mlp_glu if use_mlp_glu else mlp
                intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2)

                # Define intermediate layer of mlp - to split
                dim_intermediate_expanded = mtf.Dimension(
                    "intermediate_expanded", intermediate_size)

                m = mlp_fn(res_x,
                           "mlp",
                           dim_intermediate_expanded,
                           variable_dtype=variable_dtype,
                           params=params)
                aux_loss = mtf.zeros(x.mesh,
                                     mtf.Shape([]),
                                     dtype=variable_dtype.slice_dtype)

            x = x + pre_residual_fn(
                (m * mult), "norm_rezero_2", variable_dtype)
            return x, aux_loss
Exemple #11
0
def attn(x,
         scope,
         n_state,
         *,
         attention_type,
         params,
         bias,
         dim_seq,
         memory_length_dim,
         variable_dtype,
         context=None):
    # x :: [batch, seq, n_embd]
    x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh

    # n_state is the same as config["n_embd"], which is also the same as dim_embd.
    assert n_state.size % params["n_head"] == 0

    dim_heads = mtf.Dimension("heads", params["n_head"])

    num_mem_kv = params.get("num_mem_kv", 0)
    use_num_mem_kv = num_mem_kv > 0

    with tf.variable_scope(scope):
        # Compute attention inputs
        dim_kv = mtf.Dimension("features_per_head",
                               params["n_embd"] // params["n_head"])
        mtfparams = mtf.transformer.attention.attention_params_simple(
            x.mesh,
            io_dim=dim_embd,
            kv_dim=dim_kv,
            heads_dim=dim_heads,
            variable_dtype=variable_dtype)
        q = mtfparams.compute_q(x)
        k = mtfparams.compute_k(x)
        v = mtfparams.compute_v(x)

        if is_incremental_inference(context):
            one_hot = mtf.one_hot(context.position - 1,
                                  dim_seq,
                                  dtype=variable_dtype.master_dtype)
            inv_one_hot = 1.0 - one_hot
            old_k, old_v = context.get_states(2)
            k = old_k * inv_one_hot + k * one_hot
            v = old_v * inv_one_hot + v * one_hot

        if exists(context):
            context.record_new_states([k, v])

        with tf.variable_scope("attention"):
            if attention_type == "local":
                # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.
                radius = params.get("local_attention_radius", 256)

                if is_incremental_inference(context):
                    q *= one_hot

                a = mtf_transformer.attention.local_attention_1d(
                    q,
                    k,
                    v,
                    length_dim=k.shape[1],
                    key_dim=dim_kv,
                    value_dim=dim_kv,
                    radius=radius,
                    length_dim_num_splits=1,
                    fully_autoregressive=params["causal"],
                    attention_kwargs={},
                )

                if is_incremental_inference(context):
                    a = mtf.gather(a, context.position - 1, dim_seq)

            elif attention_type == "global":

                # TODO: pass in fake context
                # Broadcast mask bias across batch and heads
                if exists(bias):
                    if not is_incremental_inference(context):
                        broadcasted_bias = mtf.broadcast(
                            bias, [
                                dim_batch, dim_heads, bias.shape[-2],
                                bias.shape[-1]
                            ])
                    else:
                        # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position
                        bias = mtf.gather(bias, context.position - 1, dim_seq)
                        broadcasted_bias = mtf.broadcast(
                            bias, [dim_batch, dim_heads, bias.shape[-1]])

                # memory key / values, from all-attention paper
                if use_num_mem_kv:
                    k, v = memory_key_values(k, v, num_mem_kv, dim_batch,
                                             dim_heads, variable_dtype, mesh)

                k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim)
                v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim)

                attn_dropout_rate = params["attn_dropout"] if params[
                    "mode"] == "train" else 0

                a = mtf_transformer.attention.attention(
                    q,
                    k,
                    v,
                    memory_length_dim=memory_length_dim,
                    key_dim=dim_kv,
                    value_dim=dim_kv,
                    bias=broadcasted_bias,
                    dropout_rate=attn_dropout_rate)

            elif attention_type == "linear":
                linear_attn_fn = causal_linear_attention if params[
                    "causal"] else linear_attention
                a = linear_attn_fn(q, k, v)

            else:
                raise NotImplementedError(
                    "Unknown attention type {}!".format(attention_type))

        with tf.variable_scope("compute_output"):
            a = mtfparams.compute_output(a, x_shape)

        with tf.variable_scope("compute_output_bias"):
            b = mtf.get_variable(
                x.mesh,
                "o_b", [dim_embd],
                initializer=tf.constant_initializer(0),
                master_dtype=variable_dtype.master_dtype,
                slice_dtype=variable_dtype.slice_dtype,
                activation_dtype=variable_dtype.activation_dtype)
            a += b

        if params["mode"] == "train" and params["res_dropout"] > 0:
            a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout")
        return a
Exemple #12
0
def _rand_1_gating(
    inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
    hparams, train, variable_dtype, importance=None, name="rand_1_gating",
    num_microbatches=None):
  """Compute a random top-1 gating."""
  # SELECT EXPERT
  if train:
    policy = hparams.moe_rand_1_policy_train
  else:
    policy = hparams.moe_rand_1_policy_eval

  # The internals of this function run in float32.
  #   bfloat16 seems to reduce quality.
  gate_inputs = mtf.to_float(inputs)

  # Input perturbations
  if train and policy == "input_dropout":
    gate_inputs = mtf.dropout(gate_inputs, 1.0 - hparams.moe_rand_1_dropout)
  elif train and policy == "input_jitter":
    gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
                                                   hparams.moe_rand_1_jitter)

  gate_logits = mtf.layers.dense(
      gate_inputs,
      experts_dim,
      use_bias=False,
      expert_dims=outer_expert_dims,
      variable_dtype=variable_dtype,
      name=name)
  raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)

  if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter":
    expert_gate, expert_index = mtf.top_1(raw_gates, reduced_dim=experts_dim)
  elif policy == "sample":
    expert_index = mtf.sample_with_temperature(
        gate_logits, experts_dim, temperature=hparams.moe_rand_1_temperature)
    expert_gate = mtf.gather(raw_gates, expert_index, dim=experts_dim)
  else:
    raise ValueError("Unknown rand_1 policy %s" % policy)

  expert_mask = mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype)

  # LOAD BALANCING LOSS
  # TODO(liamfedus): Check entropy loss.
  group_size_dim = inputs.shape[-2]
  density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
  density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
  if importance is not None:
    expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    density_1_proxy *= mtf.cast(
        mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
  loss = (
      mtf.reduce_mean(density_1_proxy * density_1) *
      float(experts_dim.size * experts_dim.size))
  if num_microbatches and num_microbatches > 1:
    tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
        num_microbatches))
    loss /= num_microbatches

  # Logging
  if train:
    entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9),
                             reduced_dim=experts_dim)
    batch_entropy = mtf.reduce_mean(entropy)
    mtf.scalar_summary(name + "/entropy", batch_entropy)

    mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
    total_routed = mtf.reduce_sum(mask_count_experts)
    expert_fraction = mtf.to_float(mask_count_experts / total_routed)
    split_fractions = mtf.split(
        expert_fraction,
        split_dim=experts_dim,
        num_or_size_splits=experts_dim.size)
    for fraction in split_fractions:
      mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
                         mtf.reduce_mean(fraction))
    mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))

  # COMPUTE ASSIGNMENT TO EXPERT
  # Experts have a limited capacity, ensure we do not exceed it. Construct
  # the batch indices, to each expert, with position_in_expert
  position_in_expert = mtf.cumsum(
      expert_mask, group_size_dim, exclusive=True) * expert_mask
  position_in_expert = mtf.cast(position_in_expert, dtype=raw_gates.dtype)
  # Keep only tokens that fit within expert_capacity.
  expert_capacity_float = float(expert_capacity_dim.size)
  expert_mask *= mtf.cast(
      mtf.less(position_in_expert, expert_capacity_float),
      dtype=raw_gates.dtype)
  expert_mask_flat = mtf.reduce_sum(expert_mask, reduced_dim=experts_dim)

  # Mask out the experts that have overflowed expert capacity. Sparsify the
  # expert_gate.
  expert_gate *= expert_mask_flat

  combine_tensor = (
      expert_gate * expert_mask_flat *
      mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) *
      mtf.one_hot(
          mtf.to_int32(position_in_expert),
          expert_capacity_dim,
          dtype=raw_gates.dtype))

  # Match the inputs dtype.
  combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
  loss = mtf.cast(loss, inputs.dtype)
  dispatch_tensor = mtf.cast(
      mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)

  return dispatch_tensor, combine_tensor, loss
Exemple #13
0
    def attention(self,
                  x,
                  n_state,
                  mask,
                  attention_type="global",
                  name="attn"):
        # x :: [batch, seq, n_embd]
        batch_dim, seq_dim, embd_dim = x_shape = x.shape
        assert n_state.size % self.n_heads == 0, "n_state must be divisible by n_heads"
        with tf.variable_scope(name):
            # Compute attention inputs
            mtfparams = mtf.transformer.attention.attention_params_simple(
                x.mesh,
                io_dim=self.dimensions["embed_dim"],
                kv_dim=self.dimensions["kv_dim"],
                heads_dim=self.dimensions["heads_dim"],
                variable_dtype=self.variable_dtype)
            q = mtfparams.compute_q(x)
            k = mtfparams.compute_k(x)
            v = mtfparams.compute_v(x)

            if self.is_incremental_inference:
                one_hot = mtf.one_hot(self.context.position - 1,
                                      seq_dim,
                                      dtype=self.variable_dtype.master_dtype)
                inv_one_hot = 1.0 - one_hot
                old_k, old_v = self.context.get_states(2)
                k = old_k * inv_one_hot + k * one_hot
                v = old_v * inv_one_hot + v * one_hot

            if exists(self.context):
                self.context.record_new_states([k, v])

            with tf.variable_scope("attention"):
                if attention_type == "local":
                    # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.
                    radius = self.params.get("local_attention_radius", 256)
                    if self.is_incremental_inference:
                        q *= one_hot
                    a = mtf_transformer.attention.local_attention_1d(
                        q,
                        k,
                        v,
                        length_dim=k.shape[1],
                        key_dim=self.dimensions["kv_dim"],
                        value_dim=self.dimensions["kv_dim"],
                        radius=radius,
                        length_dim_num_splits=1,
                        fully_autoregressive=True,
                        attention_kwargs={},
                    )
                    if self.is_incremental_inference:
                        a = mtf.gather(a, self.context.position - 1, seq_dim)

                elif attention_type == "global":
                    if exists(mask):
                        if not self.is_incremental_inference:
                            broadcasted_mask = mtf.broadcast(
                                mask, [
                                    batch_dim, self.dimensions["heads_dim"],
                                    mask.shape[-2], mask.shape[-1]
                                ])  # TODO: not sure this is correct
                        else:
                            # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position
                            mask = mtf.gather(mask, self.context.position - 1,
                                              seq_dim)
                            broadcasted_mask = mtf.broadcast(
                                mask, [
                                    batch_dim, self.dimensions["heads_dim"],
                                    mask.shape[-1]
                                ])

                    k = mtf.replace_dimensions(
                        k, k.shape[1], self.dimensions["memory_len_dim"])
                    v = mtf.replace_dimensions(
                        v, v.shape[1], self.dimensions["memory_len_dim"])

                    attn_dropout_rate = self.params.get(
                        "attention_dropout", 0) if self.mode == "train" else 0

                    a = mtf_transformer.attention.attention(
                        q,
                        k,
                        v,
                        memory_length_dim=self.dimensions["memory_len_dim"],
                        key_dim=self.dimensions["kv_dim"],
                        value_dim=self.dimensions["kv_dim"],
                        bias=broadcasted_mask,
                        dropout_rate=attn_dropout_rate)
                else:
                    raise NotImplementedError(
                        "Unknown attention type {}!".format(attention_type))

            with tf.variable_scope("compute_output"):
                a = mtfparams.compute_output(a, x_shape)

            with tf.variable_scope("compute_output_bias"):
                b = mtf.get_variable(
                    x.mesh,
                    "o_b", [embd_dim],
                    initializer=tf.constant_initializer(0),
                    master_dtype=self.variable_dtype.master_dtype,
                    slice_dtype=self.variable_dtype.slice_dtype,
                    activation_dtype=self.variable_dtype.activation_dtype)
                a += b
            residual_dropout = self.params.get("residual_dropout", 0)
            if self.mode == "train" and residual_dropout > 0:
                a = mtf.dropout(a, rate=residual_dropout, name="res_dropout")
            return a
Exemple #14
0
def transformer_moe_layer_v1(inputs,
                             output_dim,
                             hparams,
                             train,
                             variable_dtype,
                             layout=None,
                             mesh_shape=None,
                             nonpadding=None,
                             activation=mtf.relu,
                             num_microbatches=None,
                             token_embeddings=None,
                             context=None):
    """Local heterogenous mixture of experts.

  See transformer_moe_layer_v1 in moe.py for a more detailed explanation for
  a generic moe layer.

  The heterogeneous mask outputted by generate_heterogeneous_expert_masks has
  dimension [maximum hidden size, maximum # layers, # experts] and its shape
  will overwrite the parameters moe_num_layers and moe_hidden_size in hparams.
  The layer-specific mask slice is applied at each expert layer to the
  activation which is [expert width, # experts]. If the heterogeneous_mask_info
  is None, there is no mask applied and the code is equivalent to the
  homogeneous case.


  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Dimensions cheat sheet:
  B: batch dim(s)
  L: original sequence length
  M: input depth
  N: output depth
  G: number of groups
  S: group size
  E: number of experts
  C: expert capacity

  Args:
    inputs: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional Tensor with shape [batch_dim(s), length_dim]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).
    activation: a function.
    num_microbatches: number of microbatches.
    token_embeddings: a mtf.Tensor with shape
      [batch_dim(s), length_dim, input_dim]. These are the word embeddings for
      that correspond to the inputs. These can optionally be used to make
      routing decisions.
    context: a Context.

  Returns:
    outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    orig_inputs = inputs

    experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)

    if hparams.moe_heterogeneous_mask_info is not None:
        tf.logging.info("moe_heterogeneous_mask_info: {}".format(
            hparams.moe_heterogeneous_mask_info))
        heterogeneous_mask = generate_heterogeneous_expert_masks(
            hparams.moe_heterogeneous_mask_info,
            hparams.moe_num_experts,
            experts_dim,
            mesh=inputs.mesh,
            expert_width=hparams.moe_hidden_size)
        # overwrite depth and width with the mask maximum dimension
        hparams.moe_num_layers = heterogeneous_mask.shape[1].size
        hparams.moe_hidden_size = heterogeneous_mask.shape[0].size
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups is a multiple of the mesh dimension
    # over which those groups are split.
    batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
                                        orig_inputs.shape.dims[-1])
    # Hack: we assume that
    #   "outer_batch" == replication of experts
    #   mesh_dim_size can be derived from mesh_shape and orig_batch_dim
    #
    # We then reqire num_groups to be a multiple of mesh_dim_size.
    if orig_inputs.shape.dims[0].name == "outer_batch":
        outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2]
    else:
        outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1),
                                           orig_inputs.shape.dims[0])

    # Number of MoE inputs (total number of position across batch_and_length_dims
    # per replica.
    n = 1
    for d in batch_and_length_dims:
        n *= d.size

    n = n // outer_batch_dim.size

    mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
                                                    orig_batch_dim)
    num_groups, group_size = moe._split_into_groups(  # pylint: disable=protected-access
        n, hparams.moe_group_size, mesh_dim_size)
    # TODO(barretzoph): implementation without pylint calls?

    group_size_dim = mtf.Dimension("group", group_size)
    num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)

    moe_input_dims = [
        outer_batch_dim, num_groups_dim, group_size_dim, input_dim
    ]
    # OGSM Tensor
    inputs = mtf.reshape(inputs, moe_input_dims)

    # Token embeddings that can be optionally used in the router for determining
    # where to send tokens.
    if hparams.moe_word_embed_mode is not None:
        token_embeddings = mtf.cast(
            mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)

    # Each sequence sends expert_capacity positions to each expert.
    if train:
        capacity_factor = hparams.moe_capacity_factor_train
    else:
        capacity_factor = hparams.moe_capacity_factor_eval
    expert_capacity = min(
        group_size_dim.size,
        int((group_size_dim.size * capacity_factor) / experts_dim.size))
    expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
    tf.logging.info("expert_capacity: %d" % expert_capacity)
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
    if nonpadding is not None:
        nonpadding = mtf.zeros(inputs.mesh,
                               batch_and_length_dims,
                               dtype=inputs.dtype) + nonpadding
        nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
    if hparams.moe_gating == "top_2":
        # combine_tensor,
        # dispatch_tensor  OG`SEC Tensors
        # (G is generally split along mesh dim)
        dispatch_tensor, combine_tensor, loss = moe._top_2_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "top_n":
        dispatch_tensor, combine_tensor, loss = moe._top_n_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "switch":
        dispatch_tensor, combine_tensor, loss = moe._switch_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "ntlb":
        dispatch_tensor, combine_tensor, loss = moe._ntlb_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "switch_max":
        dispatch_tensor, combine_tensor, loss = moe._switch_max_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "expert_selection":
        dispatch_tensor, combine_tensor, loss = moe._expert_selection_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            group_size_dim=group_size_dim,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            name="expert_selection_gating",
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                               mtf.Shape([
                                   outer_batch_dim, experts_dim_unsplit,
                                   num_groups_dim, expert_capacity_dim,
                                   input_dim
                               ]))

    # Extra reshape reduces communication cost for model-parallel versions.
    # For model-parallel versions, this reshape causes an mtf.slice and for non-
    # model-parallel versions, this has no effect.
    d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size)
    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.Shape([
            outer_batch_dim, experts_dim, batch_dim_unsplit,
            expert_capacity_dim, d_model_split_dim
        ]))

    # Split over batch -> split over experts
    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.Shape([
            outer_batch_dim, experts_dim, batch_dim_unsplit,
            expert_capacity_dim, input_dim
        ]))

    # Pretend we have heterogeneous_mask with shape [moe_num_layers, num_experts]
    for layer_idx in range(hparams.moe_num_layers):
        with tf.variable_scope("expert_layer_{}".format(layer_idx)):
            res_h = 0.0
            if layer_idx > 0:
                res_h = expert_inputs
                expert_inputs = transformer.sublayer_rms_norm(
                    expert_inputs, None, context)

            # Now feed the expert inputs through the experts.
            h = mtf.layers.dense_product(
                expert_inputs,
                reduced_dims=expert_inputs.shape.dims[-1:],
                new_dims=[hidden_dim],
                expert_dims=[experts_dim],
                activation_functions=activation,
                use_bias=False,
                variable_dtype=variable_dtype,
                name="wi")

            # apply dropout
            if hparams.moe_dropout_rate != 0.0:
                h = mtf.dropout(h,
                                is_training=train,
                                keep_prob=1.0 - hparams.moe_dropout_rate)
            # only if heterogeneous
            if hparams.moe_heterogeneous_mask_info is not None:
                # Get mask for current layer by slicing heterogeneous mask
                heterogeneous_mask_slice = mtf.slice(heterogeneous_mask,
                                                     layer_idx, 1,
                                                     "num_expert_layers")

                # Get rid of the expert layers dimension.
                heterogeneous_mask_slice = mtf.reshape(
                    heterogeneous_mask_slice, [
                        heterogeneous_mask_slice.shape[0],
                        heterogeneous_mask_slice.shape[-1]
                    ])
                h *= mtf.cast(heterogeneous_mask_slice, h.dtype)
            expert_output = mtf.layers.dense(h,
                                             output_dim,
                                             expert_dims=[experts_dim],
                                             use_bias=False,
                                             reduced_dims=h.shape.dims[-1:],
                                             variable_dtype=variable_dtype,
                                             name="wo")

            if layer_idx < (hparams.moe_num_layers - 1):
                expert_output = transformer.sublayer_dropout(
                    expert_output, None, context)
            expert_output += res_h
            expert_inputs = expert_output

    # Extra reshape reduces communication cost for model-parallel versions.
    # For model-parallel versions, this reshape causes an mtf.slice and for non-
    # model-parallel versions, this has no effect.
    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim, experts_dim_unsplit, num_groups_dim,
            expert_capacity_dim, d_model_split_dim
        ]))

    # Split over experts -> split over batch
    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim,
            experts_dim_unsplit,
            num_groups_dim,
            expert_capacity_dim,
            output_dim,
        ]))
    moe_output_dims = moe_input_dims[:-1] + [output_dim]
    output = mtf.einsum([expert_output, combine_tensor],
                        mtf.Shape(moe_output_dims))
    output = mtf.reshape(output, batch_and_length_dims + [output_dim])
    return output, loss * hparams.moe_loss_coef
Exemple #15
0
    def self_attention(self, x, attention_bias):
        """Performs multi-headed self-attention with output projection.

    Args:
      x: output of previous layer
      attention_bias: optional float32 Tensor broadcastable to shape
        x.shape - self.model_dim + self.memory_seq_dim
        to be added to attention logits.
        This may used to mask out padding regions of the memory.

    Returns:
      float Tensor with the same shape as x
    """

        queries = mtf.layers.dense(
            x,
            reduced_dims=[self.model_dim],
            new_dims=[self.num_heads_dim, self.size_per_head_dim],
            kernel_initializer=self.dense_initializer,
            name="query",
            use_bias=self.config.use_bias)
        keys = mtf.layers.dense(
            mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim),
            reduced_dims=[self.model_dim],
            new_dims=[self.num_heads_dim, self.size_per_head_dim],
            kernel_initializer=self.dense_initializer,
            name="key",
            use_bias=self.config.use_bias)
        values = mtf.layers.dense(
            mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim),
            reduced_dims=[self.model_dim],
            new_dims=[self.num_heads_dim, self.size_per_head_dim],
            kernel_initializer=self.dense_initializer,
            name="value",
            use_bias=self.config.use_bias)

        # Take the dot product between "query" and "key" to get the raw
        # attention scores.
        attention_scores = mtf.einsum([queries, keys],
                                      reduced_dims=[self.size_per_head_dim])
        attention_scores *= self.size_per_head_dim.size**-0.5

        if attention_bias is not None:
            attention_scores += attention_bias

        # Normalize the attention scores to probabilities.
        attention_probs = mtf.softmax(attention_scores, self.memory_seq_dim)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = mtf.dropout(attention_probs,
                                      keep_prob=1.0 -
                                      self.config.attention_probs_dropout_prob)

        output = mtf.einsum([attention_probs, values],
                            reduced_dims=[self.memory_seq_dim])

        # linear transformation back to shape of query_antecedent
        output = mtf.layers.dense(
            output,
            reduced_dims=[self.num_heads_dim, self.size_per_head_dim],
            new_dims=[self.model_dim],
            kernel_initializer=self.dense_initializer,
            name="output",
            use_bias=self.config.use_bias)
        output = mtf.transpose(output, x.shape)
        return output
Exemple #16
0
 def add_dropout(self, x, dropout_prob=0.0):
     return mtf.dropout(x, keep_prob=1.0 - dropout_prob)
Exemple #17
0
def Alexnet(img, labels, num_nodes, num_gpus, args):
    num_classes = 1000
    keep_prob = 0.5
    learning_rate = 0.01
    graph, meshes, mesh_to_impl, mtf_img, mtf_labels = CreateMeshes(
        img, labels, num_nodes, num_gpus, args)
    RenameFC = lambda x: mt.rename_dimension(x, x.shape[-1].name,
                                             utils.RandName())

    strategy = args.strategy
    if strategy == 0:
        fc6_units = mtf.Dimension(utils.RandName(), 4096)
        fc7_units = mtf.Dimension(utils.RandName(), 4096)
        fc8_units = mtf.Dimension(utils.RandName(), num_classes)

    elif strategy == 1:
        fc6_units = mtf.Dimension('axis1', 4096)
        fc7_units = mtf.Dimension('axis0', 4096)
        fc8_units = mtf.Dimension('axis1', num_classes)

    elif strategy == 2:
        num_classes = utils.RoundUp(num_classes, num_gpus)
        fc6_units = mtf.Dimension('axis0', 4096)
        fc7_units = mtf.Dimension('axis0', 4096)
        fc8_units = mtf.Dimension('axis0', num_classes)

    elif strategy == 3:
        num_classes = utils.RoundUp(num_classes, num_gpus // 2)
        fc6_units = mtf.Dimension('axis1', 4096)
        fc7_units = mtf.Dimension('axis1', 4096)
        fc8_units = mtf.Dimension('axis1', num_classes)

    with tf.variable_scope('alexnet'):
        # Conv1 + ReLU + maxpool1
        conv1 = mt.Conv2d(mtf_img,
                          GetFilterShape(mtf_img, (11, 11, 3, 96)), (4, 4),
                          'VALID',
                          activation=mtf.relu,
                          name='conv1')
        pool1 = mt.MaxPool(conv1, (3, 3), (2, 2), 'VALID', name='pool1')

        # Conv2 + ReLU + maxpool2
        conv2 = mt.Conv2d(pool1,
                          GetFilterShape(pool1, (5, 5, 96, 256)), (1, 1),
                          'SAME',
                          activation=mtf.relu,
                          name='conv2')
        pool2 = mt.MaxPool(conv2, (3, 3), (2, 2), name='pool2')

        # Conv3 + ReLU
        conv3 = mt.Conv2d(pool2,
                          GetFilterShape(pool2, (3, 3, 256, 384)),
                          padding='SAME',
                          activation=mtf.relu,
                          name='conv3')

        # Conv4 + ReLU
        conv4 = mt.Conv2d(conv3,
                          GetFilterShape(conv3, (3, 3, 384, 384)),
                          padding='SAME',
                          activation=mtf.relu,
                          name='conv4')

        # Conv5 + ReLU + maxpool5
        conv5 = mt.Conv2d(conv4,
                          GetFilterShape(conv4, (3, 3, 384, 256)),
                          padding='SAME',
                          activation=mtf.relu,
                          name='conv5')
        pool5 = mt.MaxPool(conv5, (3, 3), (2, 2), name='pool5')

        # Rename dims
        if strategy == 1:
            k_dim = mtf.Dimension(utils.RandName(),
                                  utils.Prod(pool5.shape.to_integer_list[1:]))
            pool5 = mtf.reshape(pool5, mtf.Shape([pool5.shape[0], k_dim]))
            pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1],
                                                   (utils.RandName(), 'axis0'))

        elif strategy == 2:
            pool5 = mt.rename_dimension(pool5, pool5.shape[0].name,
                                        utils.RandName())

        elif strategy == 3:
            assert pool5.shape[0].name == 'axis0'
            #dim_names = pool5.shape.rename_dimension('axis0', utils.RandName())
            #pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1], dim_names)
            pool5 = ReplaceMeshWithConcatSplit(pool5, meshes[1])

        # FC + ReLU + dropout
        fc_activation = lambda x: mtf.dropout(mtf.relu(x), keep_prob)
        fc6 = mtf.layers.dense(pool5,
                               fc6_units,
                               activation=fc_activation,
                               reduced_dims=pool5.shape[1:],
                               name='fc6')
        if strategy == 2:
            fc6 = RenameFC(fc6)
        elif strategy == 3:
            fc6 = RenameFC(fc6)

        fc7 = mtf.layers.dense(fc6,
                               fc7_units,
                               activation=fc_activation,
                               reduced_dims=fc6.shape.dims[-1:],
                               name='fc7')
        if strategy == 2:
            fc7 = RenameFC(fc7)
        elif strategy == 3:
            fc7 = RenameFC(fc7)

        fc8 = mtf.layers.dense(fc7,
                               fc8_units,
                               reduced_dims=fc7.shape.dims[-1:],
                               name='fc8')
        fc8 = mtf.dropout(fc8, keep_prob)

        if strategy == 1:
            assert fc8.shape[-1].name == 'axis1'
            fc8 = ReplaceMeshWithDuplicates(fc8, meshes[2])

    with tf.variable_scope('loss'):
        if fc8.shape[0] != mtf_labels.shape[0]:
            fc8 = mt.rename_dimension(fc8, fc8.shape[0].name,
                                      mtf_labels.shape[0].name)
        one_hot_labels = mtf.one_hot(mtf_labels, fc8.shape[-1])
        mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits(
            fc8, one_hot_labels, fc8.shape[-1])
        mtf_loss = mtf.reduce_mean(mtf_cross_ent)

    return graph, mesh_to_impl, mtf_loss
def layer_prepostprocess_dropout(x, hparams):
    batch_dim = x.shape.dims[0]
    model_dim = x.shape.dims[-1]
    return mtf.dropout(x,
                       keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
                       noise_shape=mtf.Shape([batch_dim, model_dim]))
Exemple #19
0
        def model(self, mesh, x, y, params):
            # x :: [batch, io, vocab]

            if params["precision"] == "bfloat16":
                dtype = tf.bfloat16
                # master has type float32, slice and activation have type bfloat16
                variable_dtype = mtf.VariableDType(tf.float32, tf.bfloat16,
                                                   tf.bfloat16)
            else:
                dtype = tf.float32
                # master, slice and activate have all float16
                variable_dtype = mtf.VariableDType(tf.float32, tf.float32,
                                                   tf.float32)

            # Build the actual model
            batch_dim = mtf.Dimension("batch", params["batch_size"])
            vocab_dim = mtf.Dimension("vocab", params["vocab_size"])
            io_dim = mtf.Dimension("sequence", params["io"])
            io_chan_dim = mtf.Dimension("io", params["io_channels"])

            # from input to mtf
            x = mtf.import_tf_tensor(mesh, x,
                                     mtf.Shape([batch_dim, io_dim, vocab_dim]))

            # Embeddings
            with tf.variable_scope(scope="toy", default_name="seq2seq"):
                with tf.variable_scope("embeddings"):
                    # Perform embedding lookup on the word ids.
                    embedding_table = mtf.get_variable(
                        mesh,
                        "word_embeddings",
                        mtf.Shape([vocab_dim, io_chan_dim]),
                        initializer=self.embedding_initializer,
                    )

                    word_embedding_output = mtf.gather(
                        embedding_table,
                        x,
                        dim=vocab_dim,
                        output_shape=io_chan_dim)

                    # Add positional embeddings and token type embeddings, then layer
                    # normalize and perform dropout.
                    embedding_output = word_embedding_output

                    pos_embedding = mtf.get_variable(
                        mesh,
                        "pos_embeddings",
                        mtf.Shape([io_dim, io_chan_dim]),
                        initializer=self.embedding_initializer,
                    )
                    embedding_output = self.normalize(embedding_output)
                    embedding_output = mtf.dropout(
                        embedding_output,
                        keep_prob=1.0 - self.config.layer_output_dropout_prob,
                    )

                # shift token by pos embeddings
                x = word_embedding_output + pos_embedding
                x = mtf.cast(x, variable_dtype.activation_dtype)

                h = x
                for lnum in range(1, self.num_hidden_layers + 2):
                    if lnum + 1 == self.num_hidden_layers + 2:
                        # output layer
                        dim = io_dim
                    elif lnum % 2 == 0:
                        dim = mtf.Dimension("hidden_even", io_chan_dim)
                    else:
                        dim = mtf.Dimension("hidden_odd", io_chan_dim)
                        h = mtf.layers.dense(
                            h,
                            dim,
                            use_bias=False,
                            master_dtype=variable_dtype.master_dtype,
                            slice_dtype=variable_dtype.slice_dtype,
                            name="layer_%d" % lnum,
                        )

                prediction = h
                # project back to token dimensions

                # compute the mean quare loss between the input and the output
                loss = mtf.reduce_mean(mtf.square(y - prediction))
                return prediction, loss