def _concat_equal_sizes(xs, dim, new_dim_name):
    axis = xs[0].shape.dims.index(dim)
    ret = mtf.stack(xs, "tmp_concat", axis)
    new_shape = mtf.TensorShape(
        xs[0].shape.dims[:axis] +
        [mtf.Dimension(new_dim_name, dim.size * len(xs))] +
        xs[0].shape.dims[axis + 1:])
    return mtf.reshape(ret, new_shape)
Exemple #2
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a tf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    rows_dim = mtf.Dimension("rows", 28)
    cols_dim = mtf.Dimension("cols", 28)
    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(mesh, tf.reshape(image, [-1, 28, 28]),
                             mtf.Shape([batch_dim, rows_dim, cols_dim]))
    x = mtf.reshape(x, [batch_dim, rows_dim, cols_dim, one_channel_dim])

    # add some convolutional layers to demonstrate that convolution works.
    # TODO(noam): get spatially-partitioned convolution working.
    fh_dim = mtf.Dimension("fh", 3)
    fw_dim = mtf.Dimension("fw", 3)
    filters1_dim = mtf.Dimension("filters1", 32)
    filters2_dim = mtf.Dimension("filters2", 32)
    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])

    f1 = mtf.relu(mtf.conv2d(x, kernel1))
    f2 = mtf.relu(mtf.conv2d(f1, kernel2))
    x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

    h1 = mtf_layers.dense(x,
                          hidden_dim1,
                          reduced_dims=[rows_dim, cols_dim],
                          activation=mtf.relu,
                          name="hidden1")
    h2 = mtf_layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2")
    logits = mtf_layers.dense(h2, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh, labels, mtf.Shape([batch_dim]))
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
def moe_v0(inputs,
           hidden_dim,
           output_dim,
           experts_dim,
           loss_coef=1e-3,
           overhead=1.0):
    """Local mixture of experts that works well on TPU.

  See https://arxiv.org/abs/1701.06538

  There are num_experts expert networks, each containing a relu-activated
  hidden layer of size hidden_size, followed by an output projection.

  The number of parameters is thus:
    num_experts * (input_size * hidden_size + hidden_size * output_size)

  The input is 3d: [batch, length, depth], consisting of the representations
  of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    as opposed to on individual sequences.  This would allow more freedom
    for individual sequences to be unbalanced.  Unfortunately, that would
    slow down our hacked-up gather-by-matmul implementation.

    TODO(noam): There is no real reason for a single sequence to be the unit
      of equal allocation.  Reshaping the inputs would allow us to pick a
      different unit of equal allocation.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.  We also want to integrate this
  gating/dispatching logic into multi-device mixtures-of-experts.

  Args:
    inputs: a mtf.Tensor with shape [batch_dim, length_dim, input_dim]
    hidden_dim: a mtf.Dimension
    output_dim: a mtf.Dimension
    experts_dim: a mtf.Dimension
    loss_coef: a float scalar
    overhead: multiplicative factor of how much spare capacity to assign

  Returns:
    outputs: a Tensor with shape [batch_dim, length_dim, output_dim]
    loss: a mtf scalar
  """
    batch_dim, length_dim, input_dim = inputs.shape.dims

    # Each sequence sends expert_capacity positions to each expert.
    expert_capacity = min(
        length_dim.size,
        int((length_dim.size * 2 * overhead) / experts_dim.size))
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)

    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", batch_dim.size)

    # This is the learned gating function.
    # shape = [batch_dim, length_dim, experts_dim_unsplit]
    gates = mtf.softmax(dense(inputs, experts_dim_unsplit),
                        experts_dim_unsplit)

    assignment_shape = mtf.TensorShape(
        [batch_dim, length_dim, experts_dim_unsplit, expert_capacity_dim])

    backward_assignment = mtf.slicewise(functools.partial(
        _truncated_top_2_gating, expert_capacity=expert_capacity), [gates],
                                        output_shape=assignment_shape,
                                        splittable_dims=[batch_dim],
                                        name="backward_assignment")

    forward_assignment = mtf.cast(mtf.cast(backward_assignment, tf.bool),
                                  inputs.dtype)

    # put num_experts dimension first to make split easier in alltoall
    expert_inputs = mtf.einsum([inputs, forward_assignment],
                               mtf.TensorShape([
                                   experts_dim_unsplit, batch_dim,
                                   expert_capacity_dim, input_dim
                               ]))

    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.TensorShape(
            [experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim]))

    # Now feed the expert inputs through the experts.
    h = dense(expert_inputs,
              hidden_dim,
              expert_dims=[experts_dim],
              activation=mtf.relu,
              name="x0")
    expert_output = dense(h, output_dim, expert_dims=[experts_dim], name="x1")

    expert_output = mtf.reshape(
        expert_output,
        mtf.TensorShape(
            [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim]))

    output = mtf.einsum([expert_output, backward_assignment],
                        mtf.TensorShape([batch_dim, length_dim, output_dim]))

    importance = mtf.reduce_sum(backward_assignment,
                                output_shape=mtf.TensorShape(
                                    [batch_dim, experts_dim_unsplit]))

    loss = cv_squared(importance) * loss_coef
    return output, loss
def masked_local_attention_1d(query_antecedent,
                              memory_antecedent,
                              kv_channels,
                              heads,
                              block_length=128,
                              name=None):
    """Attention to the source position and a neighborhood to the left of it.

  The sequence is divided into blocks of length block_size.
  Attention for a given query position can only see memory positions
  less than or equal to the query position, in the corresponding block
  and the previous block.

  Args:
    query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels]
    memory_antecedent: a mtf.Tensor with shape
      [batch, memory_length, io_channels] (optional). Currently, memory_length
      must have the same size as query_length, but a different name.
    kv_channels: a mtf.Dimension (the size of the key and value vectors)
    heads: a mtf.Dimension (the number of heads)
    block_length: an integer, representing receptive fields for attention.
    name: an optional string.

  Returns:
    a Tensor of shape [batch, query_length, io_channels]

  Raises:
    ValueError: if channels or depth don't match.
  """
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):

        batch, query_length, io_channels = query_antecedent.shape.dims
        q_var, k_var, v_var, o_var = multihead_attention_vars(
            query_antecedent.mesh, heads, io_channels, kv_channels,
            query_antecedent.dtype)

        if memory_antecedent is None:
            memory_antecedent = rename_length_to_memory_length(
                query_antecedent, query_length.name)
        memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims
        if memory_batch != batch:
            raise ValueError("memory batch must equal query batch")
        if memory_channels != io_channels:
            raise ValueError("memory channels must equal query channels")

        # Get query q, keys k and values v.
        q = mtf.einsum([query_antecedent, q_var],
                       mtf.TensorShape(
                           [batch, heads, query_length, kv_channels]))
        k = mtf.einsum([memory_antecedent, k_var],
                       mtf.TensorShape(
                           [batch, heads, memory_length, kv_channels]))
        v = mtf.einsum([memory_antecedent, v_var],
                       mtf.TensorShape(
                           [batch, heads, memory_length, kv_channels]))

        # Let's assume for now we don't have padding and the block length equally
        # divides the memory length.
        block_length = (query_length.size if
                        query_length.size < block_length * 2 else block_length)
        blength = mtf.Dimension("block_length", block_length)
        mlength = mtf.Dimension("mem_block_length", block_length)
        num_blocks = mtf.Dimension("num_blocks",
                                   query_length.size // block_length)

        q = mtf.reshape(
            q,
            mtf.TensorShape([batch, heads, num_blocks, blength, kv_channels]))
        k = mtf.reshape(
            k,
            mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels]))
        v = mtf.reshape(
            v,
            mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels]))

        # compute attention for the first query block.
        def first_block_attention():
            """Compute attention for the first block."""
            first_q = mtf.slice(q, 0, 1, num_blocks.name)
            first_k = mtf.slice(k, 0, 1, num_blocks.name)
            first_v = mtf.slice(v, 0, 1, num_blocks.name)
            block = first_q.shape.dims[2]

            first_logits = mtf.einsum(
                [first_q, first_k],
                mtf.TensorShape([batch, heads, block, blength, mlength]))
            weights = mtf.softmax(first_logits, mlength)
            first_output = mtf.einsum(
                [weights, first_v],
                mtf.TensorShape([batch, heads, block, blength, kv_channels]))
            return first_output

        # Attention for first block, since query_length = key_length.
        first_output = first_block_attention()

        # Concatenate two adjacent blocks to compute the overlapping memory block.
        def local(x):
            """Helper function to get memory blocks."""
            prev_block = mtf.slice(x, 0, num_blocks.size - 1, num_blocks.name)
            cur_block = mtf.slice(x, 1, num_blocks.size - 1, num_blocks.name)
            local_block = mtf.concat([prev_block, cur_block], mlength.name)
            return local_block

        local_k = local(k)
        local_v = local(v)
        mblocks = local_k.shape.dims[2]
        mlength = local_k.shape.dims[3]
        # Calculate the causal mask to avoid peeking into the future. We compute
        # this once and reuse it for all blocks since the block_size is known.
        mask = attention_bias_local_block(query_antecedent.mesh, blength,
                                          mlength)

        # Remove the first block from q since we already computed that.
        tail_q = mtf.slice(q, 1, num_blocks.size - 1, num_blocks.name)

        # Compatibility between q and k for rest of the blocks.
        # Shape [batch, heads, num_blocks - 1, block_length, local_length]
        attention = mtf.einsum([tail_q, local_k],
                               mtf.TensorShape(
                                   [batch, heads, mblocks, blength, mlength]))
        attention += mask
        attention = mtf.softmax(attention, mlength)

        # Run attention for rest of the blocks.
        # Shape [batch, heads, num_blocks-1, block_length, kv_channels]
        output = mtf.einsum([attention, local_v],
                            mtf.TensorShape(
                                [batch, heads, mblocks, blength, kv_channels]))
        # Now concatenate the first and rest of the blocks.
        final_output = mtf.concat([first_output, output], num_blocks.name)
        final_output = mtf.reshape(
            final_output,
            mtf.TensorShape([batch, heads, query_length, kv_channels]))
        return mtf.einsum([final_output, o_var],
                          mtf.TensorShape([batch, query_length, io_channels]))
Exemple #5
0
  def _sample(self, features, mesh):
    hparams = self._hparams
    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if self.has_input:
      inputs = features["inputs"]
      while len(inputs.shape.as_list()) > 2:
        inputs = tf.squeeze(inputs, axis=2)
      actual_batch_size = tf.shape(inputs)[0]
      actual_length = tf.shape(inputs)[1]
      inputs = tf.pad(
          inputs, [[0, hparams.batch_size - actual_batch_size],
                   [0, hparams.max_length - actual_length]])
      inputs = self._import_to_batch_by_length(
          inputs, "inputs", mesh, hparams)
      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.reshape(positional_embedding_var,
                       mtf.Shape([self.length_dim, self.model_dim])))
      encoder_attention_mask = (
          mtf_layers.attention_mask_ignore_padding(
              inputs, dtype=self.activation_dtype))
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.num_encoder_layers,
                              self_attention_mask=encoder_attention_mask)
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)
      encdec_tensors = []
      for layer_num in xrange(hparams.num_decoder_layers):
        with tf.variable_scope("decoder/layer_%d/encdec_attention" % layer_num):
          q_var, k_var, v_var, o_var = mtf_layers.multihead_attention_vars(
              mesh, self.heads_dim, self.model_dim,
              self.kv_dim, self.activation_dtype)
          k = mtf.einsum(
              [encoder_output, k_var],
              mtf.Shape(
                  [self.batch_dim, self.heads_dim,
                   self.memory_length_dim, self.kv_dim]))
          v = mtf.einsum(
              [encoder_output, v_var],
              mtf.Shape(
                  [self.batch_dim, self.heads_dim,
                   self.memory_length_dim, self.kv_dim]))
        encdec_tensors.append((q_var, o_var, k, v))
      partial_targets = None
    else:
      encdec_tensors = None
      encoder_output = None
      encoder_attention_mask = None
      # Prepare partial targets.
      # In either features["inputs"] or features["targets"].
      # We force the outputs to begin with these sequences.
      partial_targets = features.get("inputs", None)
      if partial_targets is None:
        partial_targets = features.get("targets", None)
      if partial_targets is not None:
        partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
        partial_targets = tf.to_int32(partial_targets)
        partial_targets_batch = tf.shape(partial_targets)[0]
        partial_targets_length = tf.shape(partial_targets)[1]
        partial_targets = tf.pad(
            partial_targets, [[0, hparams.batch_size - partial_targets_batch],
                              [0, hparams.max_length - partial_targets_length]])
        partial_targets = self._import_to_batch_by_length(
            partial_targets, "partial_targets", mesh, hparams)

    if hparams.beam_size == 1:
      ids_shape = mtf.Shape([self.batch_dim, self.length_dim])
      kv_shape = mtf.Shape([self.batch_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
    else:
      beam_dim = mtf.Dimension("beam", hparams.beam_size)
      ids_shape = mtf.Shape([self.batch_dim, beam_dim, self.length_dim])
      kv_shape = mtf.Shape([self.batch_dim, beam_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])

    initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
    initial_kv_states = (
        [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)]
        * (2 * hparams.num_decoder_layers))
    def logits_fn(step_num, ids, states):
      """Produce logits for this step, and new states."""
      self_attention_k = states[:hparams.num_decoder_layers]
      self_attention_v = states[hparams.num_decoder_layers:]
      ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
      x = (mtf.gather(targets_embedding_var, ids_this_step,
                      self.targets_vocab_dim) +
           mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
      with tf.variable_scope("decoder"):
        x, new_self_attention_k, new_self_attention_v = (
            self._decoder_layer_stack_incremental(
                x,
                step_num,
                encdec_tensors,
                self_attention_k,
                self_attention_v,
                encdec_attention_mask=encoder_attention_mask))
      logits = mtf.matmul(x, softmax_var)
      return logits, new_self_attention_k + new_self_attention_v

    if hparams.beam_size == 1:
      temperature = (0.0 if hparams.sampling_method == "argmax"
                     else hparams.sampling_temp)
      return mtf_beam_search.greedy_decode(
          logits_fn,
          initial_ids,
          temperature=temperature,
          initial_states=initial_kv_states,
          forced_ids=partial_targets,
          use_tpu=hparams.use_tpu)
    else:
      if self.has_input:
        input_length = mtf.reduce_sum(
            mtf.to_float(mtf.cast(inputs, tf.bool)),
            reduced_dim=self.length_dim)
        max_input_length = mtf.reduce_max(input_length)
        decode_length = mtf.cast(
            max_input_length * hparams.decode_length_multiplier
            + hparams.decode_length_constant, tf.int32)
      else:
        decode_length = None
      beams, unused_scores = mtf_beam_search.beam_search(
          logits_fn,
          initial_ids,
          hparams.alpha,
          states=initial_kv_states,
          decode_length=decode_length,
          use_tpu=hparams.use_tpu)
      return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.set_activation_type()
        is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

        # Declare all the dimensions
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
        filter_h_dim = mtf.Dimension("filter_height", 7)
        filter_w_dim = mtf.Dimension("filter_width", 7)
        filters = mtf.Dimension("filters", hparams.filter_sizes[0])
        rows_dim = mtf.Dimension("rows_size", 32)
        cols_dim = mtf.Dimension("cols_size", 96)
        row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
        col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
        classes_dim = mtf.Dimension("classes", 10)
        one_channel_dim = mtf.Dimension("one_channel", 1)

        inputs = features["inputs"]
        x = mtf.import_tf_tensor(
            mesh,
            tf.reshape(inputs, [
                hparams.batch_size, hparams.row_blocks,
                hparams.rows_size // hparams.row_blocks, hparams.col_blocks,
                hparams.num_channels * hparams.cols_size // hparams.col_blocks,
                1
            ]),
            mtf.Shape([
                batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
                one_channel_dim
            ]))
        x = mtf.transpose(x, [
            batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
            one_channel_dim
        ])

        x = mtf.to_float(x)
        initial_filters = mtf.get_variable(
            mesh, "init_filters",
            mtf.Shape([filter_h_dim, filter_w_dim, one_channel_dim, filters]))
        x = mtf.conv2d_with_blocks(x,
                                   initial_filters,
                                   strides=[1, 1, 1, 1],
                                   padding="SAME",
                                   h_blocks_dim=None,
                                   w_blocks_dim=col_blocks_dim)

        x = batch_norm_relu(x, is_training)

        # Conv blocks
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_layers):
            layer_name = "block_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Residual block layer
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[0],
                                blocks=hparams.layer_sizes[0],
                                strides=[1, 1, 1, 1],
                                is_training=is_training,
                                name="block_layer1",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[1],
                                blocks=hparams.layer_sizes[1],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer2",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[2],
                                blocks=hparams.layer_sizes[2],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer3",
                                row_blocks_dim=None,
                                col_blocks_dim=None)

        # Calculate the logits and loss.
        out = x
        outputs = mtf_layers.dense(out,
                                   hidden_dim,
                                   reduced_dims=out.shape.dims[-5:],
                                   activation=mtf.relu,
                                   name="dense")

        # We assume fixed vocab size for targets
        labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [hparams.batch_size]),
                                      mtf.Shape([batch_dim]))

        logits = mtf_layers.dense(outputs, classes_dim, name="logits")
        soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, classes_dim)

        # Reshape logits so it doesn't break inside t2t.
        logits = mtf.reshape(
            logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
        loss = mtf.reduce_mean(loss)
        return logits, loss
Exemple #7
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.set_activation_type()

        # We assume fixed vocab size for targets
        targets_vocab_size = self._problem_hparams.target_modality._vocab_size  # pylint: disable=protected-access
        targets = tf.to_int32(features["targets"])

        # Image preprocessing, reshape into a 1D sequence and shift right.
        length = hparams.img_len * hparams.img_len * hparams.num_channels
        targets = tf.reshape(targets, [hparams.batch_size, length])
        shifted_targets = common_layers.shift_right_2d(targets)

        # Declare all the dimensions
        model_dim = mtf.Dimension("d_model", hparams.hidden_size)
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        length_dim = mtf.Dimension("length", length)
        max_length_dim = mtf.Dimension("max_length", hparams.max_length)
        filter_dim = mtf.Dimension("d_ff", hparams.d_ff)
        kv_channels = mtf.Dimension("kv_channels", hparams.d_kv)
        heads = mtf.Dimension("heads", hparams.num_heads)

        def import_to_batch_by_length(x, name):
            return mtf.import_tf_tensor(mesh,
                                        x,
                                        mtf.Shape([batch_dim, length_dim]),
                                        name=name)

        def layer_prepostprocess_dropout(x):
            return mtf.dropout(x,
                               keep_prob=1.0 -
                               hparams.layer_prepostprocess_dropout,
                               noise_shape=mtf.Shape([batch_dim, model_dim]))

        targets = import_to_batch_by_length(targets, "targets")
        shifted_targets = import_to_batch_by_length(shifted_targets,
                                                    "shifted_targets")

        extra_losses = []

        # Create targets content and position embeddings.
        targets_vocab_size = 256 * hparams.num_channels
        targets_vocab_dim = mtf.Dimension("vocab", targets_vocab_size)
        outputs_vocab_dim = mtf.Dimension("output_vocab", 256)

        # Create embedding var for targets and positions and do a gather.
        targets_embedding_var = mtf.get_variable(
            mesh,
            "targets_embedding",
            mtf.Shape([targets_vocab_dim, model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)

        x = mtf.gather(targets_embedding_var, shifted_targets,
                       targets_vocab_dim)
        # Add positional embeddings
        x += mtf.reshape(
            self.create_positional_emb_2d(targets, max_length_dim, model_dim),
            [length_dim, model_dim])

        # If conditional and input is given, add the input embedding to the target.
        # TODO(nikip): Verify conditional.
        if self.has_input and not hparams.unconditional:
            vocab_size = hparams.num_classes
            inputs_vocab_dim = mtf.Dimension("vocab", vocab_size)
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = import_to_batch_by_length(inputs, "inputs")

            # Input embeddings
            inputs_embedding_var = mtf_layers.embedding(
                mesh,
                "input_embedding",
                mtf.Shape([inputs_vocab_dim, model_dim]),
                activation_dtype=activation_dtype)
            inputs_emb = mtf.gather(inputs_embedding_var, inputs,
                                    inputs_vocab_dim)
            x += inputs_emb

        # Image Transformer Decoder
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_decoder_layers):
            layer_name = "decoder_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Self attention layer
                x += layer_prepostprocess_dropout(
                    mtf_layers.masked_local_attention_1d(
                        mtf_layers.layer_norm(x,
                                              model_dim,
                                              name="layer_norm_self_att"),
                        None,
                        kv_channels,
                        heads,
                        block_length=hparams.block_length,
                        name="self_att"))
                # ffn layer
                x += layer_prepostprocess_dropout(
                    mtf_layers.dense_relu_dense(
                        mtf_layers.layer_norm(x,
                                              model_dim,
                                              name="layer_norm_ffn"),
                        filter_dim,
                        hparams.dropout,
                        dropout_broadcast_dims=[length_dim]))

        x = mtf_layers.layer_norm(x,
                                  model_dim,
                                  name="decoder_final_layer_norm")

        # Calculate the logits and loss.
        logits = mtf_layers.dense(x, outputs_vocab_dim, name="logits")
        soft_targets = mtf.one_hot(targets,
                                   outputs_vocab_dim,
                                   dtype=activation_dtype)
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, outputs_vocab_dim)

        loss = mtf.reduce_mean(loss)
        for l in extra_losses:
            loss += l
        return logits, loss
    def grow_topk(i, alive_seq, alive_log_probs, states=None):
        r"""Inner beam search loop.

    This function takes the current alive sequences, and grows them to topk
    sequences where k = 2*beam. We use 2*beam because, we could have beam_size
    number of sequences that might hit <EOS> and there will be no alive
    sequences to continue. With 2*beam_size, this will not happen. This relies
    on the assumption the vocab size is > beam size. If this is true, we'll
    have at least beam_size non <EOS> extensions if we extract the next top
    2*beam words.
    Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to
    https://arxiv.org/abs/1609.08144.

    Args:
      i: loop index
      alive_seq: Topk sequences decoded so far [batch, beam, length]
      alive_log_probs: probabilities of these sequences. [batch, beam]
      states: optional list of mtf.Tensor
    Returns:
      Tuple of
        (Topk sequences extended by the next word,
         The log probs of these sequences,
         The scores with length penalty of these sequences,
         Flags indicating which of these sequences have finished decoding,
         list of transformed decoding states)
    """
        logits, new_states = logits_fn(i, alive_seq, states)
        batch_dim, beam_dim, vocab_dim = logits.shape.dims

        # Convert logits to normalized log probs
        candidate_log_probs = mtf.log_softmax(logits, vocab_dim)

        # Multiply the probabilities by the current probabilities of the beam.
        # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1)
        log_probs = candidate_log_probs + alive_log_probs

        length_penalty = mtf.pow(((5. + mtf.to_float(i + 1)) / 6.), alpha)

        curr_scores = log_probs / length_penalty

        # scores have shape [batch, beam, vocab]
        beam_and_vocab_dim = mtf.Dimension("beam_and_vocab",
                                           beam_dim.size * vocab_dim.size)
        flat_shape = mtf.Shape([batch_dim, beam_and_vocab_dim])
        double_beam = mtf.Dimension("double_beam", beam_dim.size * 2)
        # Flatten out (beam_size, vocab_size) probs in to a list of possibilities
        flat_curr_scores = mtf.reshape(curr_scores, flat_shape)

        top_ids, top_scores = mtf.top_k(flat_curr_scores,
                                        reduced_dim=beam_and_vocab_dim,
                                        new_dim=double_beam)

        # Recovering the log probs because we will need to send them back
        top_log_probs = top_scores * length_penalty

        # Work out what beam the top probs are in.
        top_beam_index = top_ids // vocab_dim.size
        top_ids %= vocab_dim.size  # Unflatten the ids

        def my_gather(tensor):
            return mtf.gather(tensor,
                              top_beam_index,
                              beam_dim,
                              output_shape=mtf.Shape([
                                  double_beam if d == beam_dim else d
                                  for d in tensor.shape.dims
                              ]))

        # Gather up the most probable 2*beams both for the ids and finished_in_alive
        # bools
        top_seq = my_gather(alive_seq)

        if states:
            states = [my_gather(state) for state in new_states]

        # Append the most probable alive
        top_seq += top_ids * mtf.one_hot(i, length_dim, dtype=tf.int32)
        top_finished = mtf.equal(top_ids, eos_id)

        return top_seq, top_log_probs, top_scores, top_finished, states
Exemple #9
0
def transformer_moe_layer_v1(inputs, output_dim, hparams, train):
    """Local mixture of experts that works well on TPU.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  The number of parameters in the gating network is:
    (input_dim.size * hparams.num_experts) +

  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Args:
    inputs: a mtf.Tensor with shape [<batch_dims...>, length_dim, input_dim]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean

  Returns:
    outputs: a Tensor with shape [<batch_dims...>, length_dim, output_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    orig_inputs = inputs
    input_dim = inputs.shape.dims[-1]
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
    experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)
    group_size_dim = mtf.Dimension("group", hparams.moe_group_size)
    batch_dim = mtf.Dimension(
        orig_inputs.shape[0].name,
        orig_inputs.shape.size // (group_size_dim.size * input_dim.size))
    inputs = mtf.reshape(inputs, [batch_dim, group_size_dim, input_dim])

    # Each sequence sends expert_capacity positions to each expert.
    capacity_factor = (hparams.moe_capacity_factor_train
                       if train else hparams.moe_capacity_factor_eval)
    expert_capacity = min(
        group_size_dim.size,
        int((group_size_dim.size * capacity_factor) / experts_dim.size))
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)

    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", batch_dim.size)

    if hparams.moe_gating == "top_2":
        dispatch_tensor, combine_tensor, loss = _top_2_gating(
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # put num_experts dimension first to make split easier in alltoall
    expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                               mtf.Shape([
                                   experts_dim_unsplit, batch_dim,
                                   expert_capacity_dim, input_dim
                               ]))

    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.Shape(
            [experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim]))

    # Now feed the expert inputs through the experts.
    h = mtf_layers.dense(expert_inputs,
                         hidden_dim,
                         expert_dims=[experts_dim],
                         activation=mtf.relu,
                         use_bias=False,
                         name="x0")
    expert_output = mtf_layers.dense(h,
                                     output_dim,
                                     expert_dims=[experts_dim],
                                     use_bias=False,
                                     name="x1")

    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape(
            [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim]))

    output = mtf.einsum([expert_output, combine_tensor],
                        mtf.Shape([batch_dim, group_size_dim, output_dim]))

    output = mtf.reshape(output, orig_inputs.shape.dims[:-1] + [output_dim])

    return output, loss * hparams.moe_loss_coef
Exemple #10
0
def transformer_moe_layer_v2(inputs, output_dim, hparams, train):
    """2-level mixture of experts.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_capacity_factor_second_level: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  One set of params for experts in first level and different of hparams
  per expert in the second level.
  The number of parameters in the gating network is:
    (input_dim.size * (hparams.num_experts) +
      (moe_hidden_size * hparams.num_experts) * hparams.num_experts


  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-3 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Dimensions cheat sheet:
  a, b: batch size
  l: original sequence length
  m: input depth
  n: output depth
  g, h: number of groups
  s, t: group size
  x, y: number of experts
  c, d: expert capacity

  input: [a0, b1, l, m]
  input: [a0, g1, s, m]
  dispatch_tensor_x: [a0, g1, s, x, c]
  expert_input: [a0, g1, x, c, m]
  alltoall: [a0, g, x1, c, m]
  alltoall: [a0, g, x1, c, m]
  transpose: [x1, a0, g, c, m]
  reshape: [x1, h0, s, m]
  assignment2: [x1, h0, t, y, d]
  expert_input2: [x1, h0, y, d, m]
  alltoall: [x1, h, y0, d, m]
  ...
  reverse of that

  gating params 0: [m, x]
  gating params 1: [x1, m, y]

  expert params:
     [x1, y0, m, hidden]
     [x1, y0, hidden, n]

  Args:
    inputs: a mtf.Tensor with shape [a, b, l, m]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean

  Returns:
    outputs: a Tensor with shape [a, b, l, n]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    insert_outer_batch_dim = (len(inputs.shape.dims) == 3)
    if insert_outer_batch_dim:
        inputs = mtf.reshape(inputs, [mtf.Dimension("outer_batch", 1)] +
                             inputs.shape.dims)

    assert len(hparams.moe_num_experts) == 2
    a0, b1, l, m = inputs.shape.dims
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
    x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0])
    y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1])
    x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0])
    y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1])
    n = output_dim

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups (g.size) is a multiple of the mesh dimension
    # over which those groups are split.
    num_groups, group_size = _split_into_groups(
        b1.size * l.size, hparams.moe_group_size,
        _tensor_dim_to_mesh_dim_size(hparams, b1))
    g1 = mtf.Dimension(b1.name, num_groups)
    g = mtf.Dimension(b1.name + "_unsplit", g1.size)
    s = mtf.Dimension("group_size_x", group_size)

    # Each sequence sends (at most?) expert_capacity positions to each expert.
    # Static expert_capacity dimension is needed for expert batch sizes
    capacity_factor = (hparams.moe_capacity_factor_train
                       if train else hparams.moe_capacity_factor_eval)
    expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size))
    c = mtf.Dimension("expert_capacity_x", expert_capacity)

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups (h.size) is a multiple of the mesh dimension
    # over which those groups are split.
    num_groups, group_size = _split_into_groups(
        a0.size * g.size * c.size, hparams.moe_group_size,
        _tensor_dim_to_mesh_dim_size(hparams, a0))
    t = mtf.Dimension("group_size_y", group_size)
    h0 = mtf.Dimension(a0.name, num_groups)
    h = mtf.Dimension(a0.name + "_unsplit", h0.size)

    expert_capacity = min(
        t.size,
        int((t.size * hparams.moe_capacity_factor_second_level) / y.size))
    d = mtf.Dimension("expert_capacity_y", expert_capacity)

    # First level of expert routing
    # Reshape the inner batch size to a multiple of group_dim g1 and
    # group_size_dim s.
    inputs = mtf.reshape(inputs, [a0, g1, s, m])

    # Get the assignments for the first level.
    # dispatch_tensor_x has shape [a0, g1, s, x, c]
    if hparams.moe_gating == "top_2":
        dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating(
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=x,
            expert_capacity_dim=c,
            hparams=hparams,
            train=train)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # Now create expert_inputs based on the assignments.
    # put num_experts dimension first to make split easier in alltoall
    expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x],
                                 [x, a0, g1, c, m])

    # we construct an "importance" Tensor for the inputs to the second-level
    # gating.  The importance of an input is 1.0 if it represents the
    # first-choice expert-group and 0.5 if it represents the second-choice expert
    # group.  This is used by the second-level gating.
    importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c])
    importance = 0.5 * (mtf.to_float(mtf.greater(importance, 0.5)) +
                        mtf.to_float(mtf.greater(importance, 0.0)))

    # First level, all to all. Here we change the split dimension from g1 to x1.
    expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape([x1, a0, g, c,
                                                              m]))
    importance = mtf.reshape(importance, [x1, a0, g, c])

    # Second level of expert routing
    # Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0
    # and group_size_dim t.
    inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m])
    importance = mtf.reshape(importance, [x1, h0, t])

    # Get the assignments for the second level.
    # dispatch_tensor_y has shape [x1, h0, t, y, d]
    if hparams.moe_gating == "top_2":
        dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating(
            inputs=inputs_y,
            outer_expert_dims=[x1],
            experts_dim=y,
            expert_capacity_dim=d,
            hparams=hparams,
            train=train,
            importance=importance)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # Now create expert_inputs based on the assignments.
    # put num_experts dimension first to make split easier in alltoall
    expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y],
                                 [y, x1, h0, d, m])

    # Second level, all to all. Here we change the split dimension from h0 to y0.
    expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape([y0, x1, h, d,
                                                              m]))

    # Now feed the expert inputs through the experts.
    hidden_output = mtf_layers.dense(expert_inputs_y,
                                     hidden_dim,
                                     expert_dims=[y0, x1],
                                     activation=mtf.relu,
                                     use_bias=False,
                                     name="expert0")
    expert_output = mtf_layers.dense(hidden_output,
                                     output_dim,
                                     expert_dims=[y0, x1],
                                     use_bias=False,
                                     name="expert1")

    # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
    # expert_output has shape [y0, x1, h, d, n]

    # alltoall
    expert_output = mtf.reshape(expert_output, mtf.Shape([y, x1, h0, d, n]))

    # combine results from inner level
    output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n])

    # Reshape the combined tensor from inner level to now contain outer_batch_dim
    # a0 and group_dim g
    output = mtf.reshape(output_y, [x1, a0, g, c, n])

    # alltoall from expert_dim x to group_dim g1
    expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n]))

    # combine results from outer level
    output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n])

    # Reshape the combined tensor to now contain inner_batch_dim
    # b1 and the original sequence length
    output = mtf.reshape(output_x, [a0, b1, l, n])
    if insert_outer_batch_dim:
        output = mtf.reshape(output, [b1, l, n])
    return output, (loss_outer + loss_inner) * hparams.moe_loss_coef