Exemple #1
0
    def call(self, context, x, losses=None):
        """Call the layer."""
        if context.model.ensemble_dim:
            raise NotImplementedError("MoE not yet implemented with ensembles")

        has_length_dim = context.length_dim in x.shape.dims
        if not has_length_dim:
            x_shape = x.shape
            shape_with_length = mtf.Shape(x_shape.dims[:-1] +
                                          [mtf.Dimension("length", 1)] +
                                          x_shape.dims[-1:])
            x = mtf.reshape(x, shape_with_length)
        y, loss = transformer_moe_layer_v1(x,
                                           context.model.model_dim,
                                           self._hparams,
                                           context.train,
                                           context.variable_dtype,
                                           layout=context.model.layout,
                                           mesh_shape=context.model.mesh_shape,
                                           nonpadding=context.nonpadding)
        if context.losses is not None:
            context.losses.append(loss)
        if not has_length_dim:
            y = mtf.reshape(y, x_shape)
        return y
Exemple #2
0
def widedeep(id_hldr, wt_hldr, vocab_dim, embed_dim, outdim, float16=None):
    logger.debug("[input tensor] (name,shape):({},{})".format(id_hldr.name,id_hldr.shape))
    logger.debug("[input tensor] (name,shape):({},{})".format(wt_hldr.name,wt_hldr.shape))
    if float16:
        deep_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim, variable_dtype=float16, name="deep_embedding")
    else:
        fp32 = mtf.VariableDType(tf.float32,tf.float32,tf.float32)
        deep_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim, variable_dtype=fp32, name="deep_embedding")
    logger.debug("[output tensor] (name,shape):({},{})".format(deep_output.name,deep_output.shape))
    expend_dim = mtf.Dimension('expend',size=1)
    embed_dim_one = mtf.Dimension('embed_dim_one',size=1)
    mask = mtf.reshape(wt_hldr, new_shape=[wt_hldr.shape.dims[0],wt_hldr.shape.dims[1],expend_dim], name='mask_reshape')
    logger.debug("[output tensor] (name,shape):({},{})".format(mask.name,mask.shape))
    if float16:
        wide_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim_one, variable_dtype=float16, name="wide_embedding")
    else:
        fp32 = mtf.VariableDType(tf.float32,tf.float32,tf.float32)
        wide_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim_one, variable_dtype=fp32, name="wide_embedding")
    logger.debug("[output tensor] (name,shape):({},{})".format(wide_output.name,wide_output.shape))

    wide_output = wide(wide_output,mask=mask,float16=float16)
    deep_output = deep(deep_output,mask=mask,float16=float16)
    
    result = mtf.add(wide_output,deep_output)
    result = mtf.reshape(result, new_shape=[wide_output.shape.dims[0],outdim],name='result_reshape')
    logger.debug("[output tensor] (name,shape):({},{})".format(result.name, result.shape))
    return result
Exemple #3
0
    def call(self, context, x, losses=None):
        """Call the layer."""

        # Dim cheat sheet:
        # <B>: batch dims, e.g.
        #   [outer_batch_size, batch_size] or
        #   [beam_size, batch_size]
        # L: original length
        # M: model dim
        #
        # x
        #   <B>LM Tensor

        has_length_dim = context.length_dim in x.shape.dims
        if not has_length_dim:
            x_shape = x.shape
            shape_with_length = mtf.Shape(x_shape.dims[:-1] +
                                          [mtf.Dimension("length", 1)] +
                                          x_shape.dims[-1:])
            x = mtf.reshape(x, shape_with_length)
        y, loss = transformer_moe_layer_v1(x,
                                           context.model.model_dim,
                                           self._hparams,
                                           context.train,
                                           context.variable_dtype,
                                           layout=context.model.layout,
                                           mesh_shape=context.model.mesh_shape,
                                           nonpadding=context.nonpadding)
        if context.losses is not None:
            context.losses.append(loss)
        if not has_length_dim:
            y = mtf.reshape(y, x_shape)
        return y
def local_attention1d_spatial_decoder(x, kv_dim, heads_dim,
                                      feedforward_dim, hparams):
  """Image Transformer decoder with local1D spatial layers."""
  batch_dim, length_dim, model_dim = x.shape.dims
  blocks_w_dim = mtf.Dimension("blocksw", hparams.block_length)
  num_w_blocks_dim = mtf.Dimension("num_wblocks",
                                   length_dim.size // blocks_w_dim.size)
  x = mtf.reshape(
      x, mtf.Shape([batch_dim, num_w_blocks_dim, blocks_w_dim, model_dim]))
  # [ self attention - ffn - residual + dropout] x n
  for layer in range(hparams.num_decoder_layers):
    layer_name = "decoder_layer_%d" % layer
    with tf.variable_scope(layer_name):
      # Self attention layer
      x += layer_prepostprocess_dropout(
          mtf.layers.local_self_attention_spatial_blocks(
              mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
              kv_dim,
              heads_dim,
              memory_w_dim=blocks_w_dim,
              mask_right=True,
              name="self_att"), hparams)
      # ffn layer
      x += layer_prepostprocess_dropout(
          mtf.layers.dense_relu_dense(
              mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"),
              feedforward_dim,
              hparams.dropout,
              dropout_broadcast_dims=[length_dim]), hparams)

  output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm")
  return output
Exemple #5
0
  def call(self, context, x, losses=None):
    """Call the layer."""
    io_channels = x.shape.dims[-1]
    hidden_channels = mtf.Dimension("d_ff", self.hidden_size)

    h = dense_product_fixup(
        x,
        reduced_dims=x.shape.dims[-1:],
        new_dims=hidden_channels,
        activation_functions=self.activation,
        use_bias=self.use_bias,
        variable_dtype=context.variable_dtype,
        name="wi",
        kernel_initializer=self.upproject_initializer,
        expert_dims=context.model.ensemble_dims)
    if context.train and self.dropout_rate != 0.0:
      h = mtf.dropout(
          h, 1.0 - self.dropout_rate, noise_shape=h.shape - context.length_dim)
    shift = get_single_scalar_bias(x, "shift")
    h_res = mtf.add(h, shift)
    h = mtf.reshape(h_res, h.shape)
    return mtf.layers.dense(
        h,
        io_channels,
        use_bias=self.use_bias,
        activation=None,
        variable_dtype=context.variable_dtype,
        reduced_dims=h.shape.dims[-1:],
        name="wo",
        expert_dims=context.model.ensemble_dims,
        kernel_initializer=self.downproject_initializer)
 def _repeat(x, n, repeat_dim):
     # repeat function in MTF
     tmp_dim = mtf.Dimension("tmp", 1)
     expand_shape = mtf.Shape(x.shape.dims + [tmp_dim])
     x = mtf.reshape(x, expand_shape)
     x = _tile(x, n, tmp_dim)
     output_shape = []
     for dim in x.shape.dims:
         if dim.name == "tmp":
             continue
         if dim.name == repeat_dim.name:
             dim = mtf.Dimension(dim.name, dim.size * n)
         output_shape.append(dim)
     output_shape = mtf.Shape(output_shape)
     x = mtf.reshape(x, output_shape)
     return x
Exemple #7
0
def local_attention1d_spatial_decoder(x, kv_dim, heads_dim, feedforward_dim,
                                      hparams):
    """Image Transformer decoder with local1D spatial layers."""
    batch_dim, length_dim, model_dim = x.shape.dims
    blocks_w_dim = mtf.Dimension("blocksw", hparams.block_length)
    num_w_blocks_dim = mtf.Dimension("num_wblocks",
                                     length_dim.size // blocks_w_dim.size)
    x = mtf.reshape(
        x, mtf.Shape([batch_dim, num_w_blocks_dim, blocks_w_dim, model_dim]))
    # [ self attention - ffn - residual + dropout] x n
    for layer in range(hparams.num_decoder_layers):
        layer_name = "decoder_layer_%d" % layer
        with tf.variable_scope(layer_name):
            # Self attention layer
            x += layer_prepostprocess_dropout(
                mtf.layers.local_self_attention_spatial_blocks(
                    mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
                    kv_dim,
                    heads_dim,
                    memory_w_dim=blocks_w_dim,
                    mask_right=True,
                    name="self_att"), hparams)
            # ffn layer
            x += layer_prepostprocess_dropout(
                mtf.layers.dense_relu_dense(
                    mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"),
                    feedforward_dim,
                    hparams.dropout,
                    dropout_broadcast_dims=[length_dim]), hparams)

    output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm")
    return output
Exemple #8
0
def axial_positional_emb(embd_dim, mesh, params, variable_dtype):
    # Use axial position encoding
    axial_dim_1, axial_dim_2 = params["axial_pos_emb"]

    axial_dim = mtf.Dimension("axial_dim", axial_dim_1 * axial_dim_2)
    dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_1, axial_dim_2))]

    axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]),
                                   initializer=tf.random_normal_initializer(stddev=0.01),
                                   master_dtype=variable_dtype.master_dtype,
                                   slice_dtype=variable_dtype.slice_dtype,
                                   activation_dtype=variable_dtype.activation_dtype)

    axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]),
                                   initializer=tf.random_normal_initializer(stddev=0.01),
                                   master_dtype=variable_dtype.master_dtype,
                                   slice_dtype=variable_dtype.slice_dtype,
                                   activation_dtype=variable_dtype.activation_dtype)

    axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]),
                                   (axial_wpe_1, axial_wpe_2))
    wpe = (axial_wpe_1 + axial_wpe_2) / 2

    wpe = mtf.reshape(wpe, [axial_dim, embd_dim])

    return wpe
def downsample_hr_to_lr(field, lr_shape, hr_shape, downsampling_factor, halo_size, splittables, mesh):
    # Reshaping array into high resolution mesh
    field = mtf.reshape(field, field.shape+[mtf.Dimension('h_dim', 1)])
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)
    low = mtf.reshape(low, low.shape[:-1])

    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size//2**downsampling_factor, block_size_dim.size//2**downsampling_factor, block_size_dim.name)
    # Hack usisng  custom reshape because mesh is pretty dumb
    low = mtf.slicewise(lambda x: x[:,0,0,0],
                        [low],
                        output_dtype=field.dtype,
                        output_shape=lr_shape,
                        name='my_dumb_reshape',
                        splittable_dims=splittables)

    return low
Exemple #10
0
def reshape(x, new_shape):
    old_shape = x.shape
    assert len(old_shape) == len(new_shape)
    for o, n in zip(old_shape.dims, new_shape.dims):
        if (o.name != n.name) and (o.name.startswith('axis')
                                   and n.name.startswith('axis')):
            x = mtf.rename_dimension(x, o.name, utils.RandName())
    return mtf.reshape(x, new_shape)
Exemple #11
0
def attention(q,
              k,
              v,
              memory_length_dim,
              key_dim,
              value_dim,
              bias=None,
              dropout_rate=0.0,
              dropout_broadcast_dims=None,
              extra_logit=None,
              context=None):
    """Dot-product attention - doesn't use positional dimensions.

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor
    context: an optional Transformer.Context

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """
    orig_q_shape = q.shape
    q, k, v, bias = maybe_reshape_attention_input_for_2d_sharding(
        context, q, k, v, bias, [key_dim, value_dim])
    logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim])
    if bias is not None:
        logits += bias
    weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit)
    if dropout_rate != 0.0:
        weights = mtf.dropout(weights,
                              1.0 - dropout_rate,
                              noise_shape=weights.shape -
                              dropout_broadcast_dims)
    outputs_shape = q.shape - key_dim + value_dim
    outputs = mtf.einsum([weights, v], outputs_shape)
    outputs = mtf.reshape(outputs, orig_q_shape - key_dim + value_dim)
    return outputs
Exemple #12
0
def sublayer_fixup_scale(x, layer_stack, context):
  """Multiply by single one-initialized scalar."""
  del layer_stack
  dim = mtf.Dimension("single_scale", 1)
  fixup_weight = mtf.get_variable(
      x.mesh, "fixup_scale_weight", shape=mtf.Shape([dim]),
      dtype=context.variable_dtype,
      initializer=tf.constant_initializer(1.))
  return mtf.reshape(x * fixup_weight, x.shape)
 def mtf_model_fn(self, features, mesh):
     logits, loss = self._mtf_model_fn(features, mesh)
     # combine batch dims
     if len(self.batch_dims) > 1:
         combined_batch_dim = mtf.Dimension(self.batch_dims[0].name,
                                            mtf.Shape(self.batch_dims).size)
         logits = mtf.reshape(logits,
                              [combined_batch_dim] + logits.shape.dims[-2:])
     return logits, loss
 def mtf_model_fn(self, features, mesh):
   with tf.variable_scope("transformer"):
     logits, loss = self._mtf_model_fn(features, mesh)
     # combine batch dims
     if len(self.batch_dims) > 1:
       combined_batch_dim = mtf.Dimension(
           self.batch_dims[0].name, mtf.Shape(self.batch_dims).size)
       logits = mtf.reshape(
           logits, [combined_batch_dim] + logits.shape.dims[-2:])
     return logits, loss
Exemple #15
0
def sublayer_fixup_shift(x, layer_stack, context):
  """Shift by single zero-initialized scalar."""
  del layer_stack
  dim = mtf.Dimension("single_bias", 1)
  fixup_bias = mtf.get_variable(
      x.mesh, "fixup_bias", shape=mtf.Shape([dim]),
      dtype=context.variable_dtype,
      initializer=tf.zeros_initializer())
  res = mtf.add(x, fixup_bias)
  res = mtf.reshape(res, x.shape)
  return res
Exemple #16
0
def split_scales(field, downsampling_factor=2., antialias=True):
    """
  Performs a multiresolution decomposition of the input field.

  The input field will be decomposed into a low resolution approximation,
  and a details component.
  """
    low = downsample(field, downsampling_factor, antialias)
    high = upsample(low, downsampling_factor)
    high = field - mtf.reshape(high, field.shape)
    return low, high
Exemple #17
0
 def call(self, context, x, losses=None):
   """Call the layer."""
   has_length_dim = context.length_dim in x.shape.dims
   if not has_length_dim:
     x_shape = x.shape
     shape_with_length = mtf.Shape(
         x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
         + x_shape.dims[-1:])
     x = mtf.reshape(x, shape_with_length)
   y, loss = transformer_moe_layer_v1(
       x,
       context.model_dim,
       self._hparams,
       context.train,
       context.variable_dtype)
   if context.losses is not None:
     context.losses.append(loss)
   if not has_length_dim:
     y = mtf.reshape(y, x_shape)
   return y
def attention(x, dim_head, dim_features_head, scope='attn', causal=False):
    with tf.variable_scope(scope):
        mesh, batch, seq, dim = x.mesh, *x.shape

        dim_heads = mtf.Dimension('dim_heads',
                                  dim_head.size * dim_features_head.size)
        dim_intermediate = mtf.Dimension('qkv_dimension', dim_heads.size * 3)
        qkv = linear(x, dim_intermediate, bias=False, scope='to_qkv')

        q, k, v = mtf.split(qkv, dim_intermediate, 3)
        q, k, v = map(
            lambda t: mtf.reshape(t, [batch, seq, dim_head, dim_features_head]
                                  ), (q, k, v))
        q, k, v = map(
            lambda t: mtf.transpose(
                t, [batch, dim_head, seq, dim_features_head]), (q, k, v))

        k, v = map(
            lambda t: mtf.rename_dimension(t, seq.name, 'memory_length'),
            (k, v))
        mem_len_dim = v.shape[-2]

        dots = mtf.layers.us_einsum([q, k],
                                    [batch, dim_head, seq, mem_len_dim])

        if causal:
            i = mtf.range(mesh, seq, tf.int32)
            j = mtf.range(mesh, mem_len_dim, tf.int32)
            i, j = map(lambda t: mtf.broadcast(t, [seq, mem_len_dim]), (i, j))
            mask = mtf.less(i + mem_len_dim.size - seq.size, j)
            mask = mtf.cast(mask, tf.float32) * -1e10
            dots += mask

        attn = mtf.softmax(dots, mem_len_dim)
        out = mtf.einsum([attn, v], [batch, dim_head, seq, dim_features_head])

        out = mtf.transpose(out, [batch, seq, dim_head, dim_features_head])
        out = mtf.reshape(out, [batch, seq, dim_heads])

        combined_out = linear(out, dim, scope='combine_output')
        return combined_out
Exemple #19
0
  def _compute_output(hidden, layer_name):
    """Compute the output of the attention layer from the hidden vector."""
    expert_output = mtf.layers.dense(
        hidden, output_dim, expert_dims=[experts_dim], use_bias=False,
        reduced_dims=hidden.shape.dims[-1:], variable_dtype=variable_dtype,
        name=layer_name)

    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim,
            experts_dim_unsplit,
            num_groups_dim,
            expert_capacity_dim,
            output_dim,
        ]))
    moe_output_dims = moe_input_dims[:-1] + [output_dim]
    output = mtf.einsum([expert_output, combine_tensor],
                        mtf.Shape(moe_output_dims))
    output = mtf.reshape(output, batch_and_length_dims + [output_dim])
    return output
Exemple #20
0
def BasicBlock(x, order, out_channels, strides):
    name = "BasicBlock"
    expansion = 1
    out_chls = out_channels // expansion
    identity = x

    x = mtf.layers.conv2d(x,
                          output_dim=mtf.Dimension(
                              name=name + '-' + str(order) + '-' + 'filters1',
                              size=out_chls),
                          filter_size=(3, 3),
                          strides=strides,
                          name="conv3x3_BB_1" + '-' + str(order),
                          variable_dtype=float16)
    print(x.name)
    print(x.dtype)
    x, _ = mtf.layers.batch_norm(x,
                                 is_training=True,
                                 momentum=0.99,
                                 epsilon=1e-5,
                                 name="batch_norm_BB_1" + '-' + str(order))
    x = mtf.relu(x, name="relu_BB_1" + '-' + str(order))

    x = mtf.layers.conv2d(x,
                          output_dim=mtf.Dimension(
                              name=name + '-' + str(order) + '-' + 'filters2',
                              size=out_channels),
                          filter_size=(3, 3),
                          strides=(1, 1),
                          name="conv3x3_BB_2" + '-' + str(order),
                          variable_dtype=float16)
    print(x.name)
    print(x.dtype)
    x, _ = mtf.layers.batch_norm(x,
                                 is_training=True,
                                 momentum=0.99,
                                 epsilon=1e-5,
                                 name="batch_norm_BB_2" + '-' + str(order))
    identity = mtf.reshape(identity,
                           new_shape=[
                               identity.shape.dims[0], identity.shape.dims[1],
                               identity.shape.dims[2], x.shape.dims[3]
                           ],
                           name="reshape_BB" + str(order))

    x = mtf.add(x,
                identity,
                output_shape=x.shape,
                name="add_BB_1" + '-' + str(order))
    x = mtf.relu(x, name="relu_BB_2" + '-' + str(order))
    print(x.name)
    print(x.dtype)
    return x
Exemple #21
0
    def _get_decoder_inputs(self, context):
        """Computes the inputs to the decoder when using transparent attention.

    We must cache on the context in order to ensure that we are not replicating
    variables when the layer's call function is called in different tf variable
    scopes.

    Args:
      context: a Context

    Returns:
      a list containing `self.num_decoder_modules` of tensors with shape
        [<batch_dims>, length_dim, output_vocab_dim]
    """
        if hasattr(context, "decoder_layers_per_module"):
            return context.decoder_layers_per_module

        encoder_layer_outputs = [
            mtf.layers.rename_length_to_memory_length(output)
            for output in context.encoder_layer_outputs
        ]

        layers_per_module = self.layers_per_encoder_module
        encoder_module_outputs_dim = mtf.Dimension(
            "encoder_module_outputs", size=self.encoder_num_modules + 1)
        decoder_module_inputs_dim = mtf.Dimension(
            "decoder_module_inputs", size=self.decoder_num_modules)
        encoder_module_outputs = mtf.stack(
            [encoder_layer_outputs[0]] +
            encoder_layer_outputs[layers_per_module::layers_per_module],
            dim_name="encoder_module_outputs")
        w = mtf.get_variable(
            context.mesh,
            "w",
            mtf.Shape([encoder_module_outputs_dim, decoder_module_inputs_dim]),
            initializer=tf.random_normal_initializer(
                stddev=(encoder_module_outputs_dim.size *
                        decoder_module_inputs_dim.size)**-0.5),
            dtype=context.variable_dtype)
        if context.train and self.dropout_rate != 0.0:
            w = mtf.dropout(w, 1.0 - self.dropout_rate)
        s = mtf.softmax(w, reduced_dim=encoder_module_outputs_dim)
        z = mtf.einsum([s, encoder_module_outputs],
                       reduced_dims=[encoder_module_outputs_dim])
        input_per_decoder = mtf.split(
            z,
            split_dim=decoder_module_inputs_dim,
            num_or_size_splits=decoder_module_inputs_dim.size)
        context.decoder_layers_per_module = [
            mtf.reshape(inpt, z.shape.dims[1:]) for inpt in input_per_decoder
        ]
        return context.decoder_layers_per_module
Exemple #22
0
  def call(self, context, x, losses=None):
    """Call the layer."""
    if context.model.ensemble_dim:
      raise NotImplementedError("MoE not yet implemented with ensembles")

    has_length_dim = context.length_dim in x.shape.dims
    if not has_length_dim:
      x_shape = x.shape
      shape_with_length = mtf.Shape(
          x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
          + x_shape.dims[-1:])
      x = mtf.reshape(x, shape_with_length)

    # Extract the MoE output dimension
    if self._hparams.moe_output_dim is not None:
      output_dim = self._hparams.moe_output_dim
    else:
      output_dim = context.model.model_dim
    y, loss = transformer_moe_layer_v1(
        x,
        output_dim,
        self._hparams,
        context.train,
        context.variable_dtype,
        layout=context.model.layout,
        mesh_shape=context.model.mesh_shape,
        nonpadding=context.nonpadding,
        activation=self._activation,
        num_microbatches=context.num_microbatches)
    if context.losses is not None:
      context.losses.append(loss)
    if not has_length_dim:
      if self._hparams.moe_use_experts_attention:
        y_reshape = [mtf.reshape(y_out, x_shape) for y_out in y]
        y = y_reshape
      else:
        y = mtf.reshape(y, x_shape)
    return y
def local_attention2d_spatial_decoder(x, kv_dim, heads_dim, feedforward_dim,
                                      hparams):
    """Image Transformer decoder with local2D spatial layers."""
    batch_dim, length_dim, model_dim = x.shape.dims
    blocks_h_dim = mtf.Dimension("blocksh", hparams.block_height)
    blocks_w_dim = mtf.Dimension("blocksw", hparams.block_width)
    num_h_blocks_dim = mtf.Dimension("num_h_blocks",
                                     hparams.img_len // hparams.block_height)
    num_w_blocks_dim = mtf.Dimension(
        "num_w_blocks",
        hparams.img_len * hparams.num_channels // hparams.block_width)
    x = mtf.transpose(
        mtf.reshape(
            x,
            mtf.Shape([
                batch_dim, num_h_blocks_dim, blocks_h_dim, num_w_blocks_dim,
                blocks_w_dim, model_dim
            ])),
        mtf.Shape([
            batch_dim, num_h_blocks_dim, num_w_blocks_dim, blocks_h_dim,
            blocks_w_dim, model_dim
        ]))
    mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN)
    is_training = mode == tf_estimator.ModeKeys.TRAIN
    # Image Transformer Decoder
    # [ self attention - ffn - residual + dropout] x n
    for layer in range(hparams.num_decoder_layers):
        layer_name = "decoder_layer_%d" % layer
        with tf.variable_scope(layer_name):
            # Self attention layer
            x += layer_prepostprocess_dropout(
                mtf.layers.local_2d_self_attention_spatial_blocks(
                    mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
                    kv_dim,
                    heads_dim,
                    is_training,
                    memory_h_dim=num_h_blocks_dim,
                    memory_w_dim=num_w_blocks_dim,
                    name="self_att"), hparams)
            # ffn layer
            x += layer_prepostprocess_dropout(
                mtf.layers.dense_relu_dense(
                    mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"),
                    feedforward_dim,
                    hparams.dropout,
                    dropout_broadcast_dims=[length_dim]), hparams)

    output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm")
    return output
def local_attention2d_spatial_decoder(x, kv_dim, heads_dim,
                                      feedforward_dim, hparams):
  """Image Transformer decoder with local2D spatial layers."""
  batch_dim, length_dim, model_dim = x.shape.dims
  blocks_h_dim = mtf.Dimension("blocksh", hparams.block_height)
  blocks_w_dim = mtf.Dimension("blocksw", hparams.block_width)
  num_h_blocks_dim = mtf.Dimension("num_h_blocks",
                                   hparams.img_len // hparams.block_height)
  num_w_blocks_dim = mtf.Dimension(
      "num_w_blocks",
      hparams.img_len * hparams.num_channels // hparams.block_width)
  x = mtf.transpose(
      mtf.reshape(
          x,
          mtf.Shape([
              batch_dim, num_h_blocks_dim, blocks_h_dim,
              num_w_blocks_dim, blocks_w_dim, model_dim
          ])),
      mtf.Shape([
          batch_dim, num_h_blocks_dim, num_w_blocks_dim,
          blocks_h_dim, blocks_w_dim, model_dim
      ]))
  # Image Transformer Decoder
  # [ self attention - ffn - residual + dropout] x n
  for layer in range(hparams.num_decoder_layers):
    layer_name = "decoder_layer_%d" % layer
    with tf.variable_scope(layer_name):
      # Self attention layer
      x += layer_prepostprocess_dropout(
          mtf.layers.local_2d_self_attention_spatial_blocks(
              mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
              kv_dim,
              heads_dim,
              memory_h_dim=num_h_blocks_dim,
              memory_w_dim=num_w_blocks_dim,
              name="self_att"), hparams)
      # ffn layer
      x += layer_prepostprocess_dropout(
          mtf.layers.dense_relu_dense(
              mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"),
              feedforward_dim,
              hparams.dropout,
              dropout_broadcast_dims=[length_dim]), hparams)

  output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm")
  return output
Exemple #25
0
def VGG(x, classes_dim, depth, batch_norm=True):
    if depth not in vgg_dict.keys():
        print("VGG-{} are not supported!".format(depth))
        raise ValueError
    x = make_conv_layers(x, mode=vgg_dict[depth], batch_norm=batch_norm)

    x = mtf.reshape(
        x,
        new_shape=[
            x.shape.dims[0],
            mtf.Dimension(name="flatten",
                          size=x.shape.dims[1].size * x.shape.dims[2].size *
                          x.shape.dims[3].size)
        ],
        name="flatten")

    x = make_dense_layers(x, classes_dim=classes_dim)
    print(x.name)
    print(x.shape)
    return x
Exemple #26
0
def model_fn(features, labels, mode, params):
    """A model is called by TpuEstimator."""
    del labels
    del features

    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

    ctx = params['context']
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    tf.logging.info('device_list = %s' % device_list, )

    mesh_devices = [''] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                mesh_devices,
                                                ctx.device_assignment)

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "fft_mesh")

    with mtf.utils.outside_all_rewrites():
        field = nbody_model(mesh)
        batch_dim, x_dim, y_dim, z_dim = field.shape
        x_dim_nosplit = mtf.Dimension("nx_nosplit", FLAGS.cube_size)
        y_dim_nosplit = mtf.Dimension("ny_nosplit", FLAGS.cube_size)

        # Until we implement distributed outputs, we only return one example
        field_slice, _ = mtf.split(field, batch_dim, [1, FLAGS.batch_size - 1])
        field_slice = mtf.reshape(
            field_slice,
            [mtf.Dimension("bs", 1), x_dim_nosplit, y_dim_nosplit, z_dim])
        #field_slice = field

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    tf_field = tf.to_float(lowering.export_to_tf_tensor(field_slice))

    with mtf.utils.outside_all_rewrites():
        return tpu_estimator.TPUEstimatorSpec(mode,
                                              predictions={'field': tf_field})
Exemple #27
0
def transformer_moe_layer_v1(inputs,
                             output_dim,
                             hparams,
                             train,
                             variable_dtype,
                             layout=None,
                             mesh_shape=None,
                             nonpadding=None):
    """Local mixture of experts that works well on TPU.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  The number of parameters in the gating network is:
    (input_dim.size * hparams.num_experts) +

  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Dimensions cheat sheet:
  <B>: batch dims
  L: original sequence length
  M: input depth
  N: output depth
  G: number of groups
  S: group size
  E: number of experts
  C: expert capacity
  (u for unsplit dims)

  Args:
    inputs: a mtf.Tensor with shape [<batch_dims...>, length_dim, input_dim]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional Tensor with shape [<batch_dims>, length_dim]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).

  Returns:
    outputs: a Tensor with shape [<batch_dims...>, length_dim, output_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    # See "Dimensions cheat sheet"
    # <B>LM Tensor
    orig_inputs = inputs
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
    experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups is a multiple of the mesh dimension
    # over which those groups are split.
    batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
                                        orig_inputs.shape.dims[-1])
    # Hack: we assume that
    #   "outer_batch" == replication of experts
    #   mesh_dim_size can be derived from mesh_shape and orig_batch_dim
    #
    # We then reqire num_groups to be a multiple of mesh_dim_size.
    if orig_inputs.shape.dims[0].name == "outer_batch":
        outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2]
    else:
        outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1),
                                           orig_inputs.shape.dims[0])

    # Number of MoE inputs (total number of position across batch_and_length_dims
    # per replica.
    n = 1
    for d in batch_and_length_dims:
        n *= d.size

    n = n // outer_batch_dim.size

    mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
                                                    orig_batch_dim)
    num_groups, group_size = _split_into_groups(n, hparams.moe_group_size,
                                                mesh_dim_size)

    group_size_dim = mtf.Dimension("group", group_size)
    num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)

    moe_input_dims = [
        outer_batch_dim, num_groups_dim, group_size_dim, input_dim
    ]
    # OGSM Tensor
    inputs = mtf.reshape(inputs, moe_input_dims)

    # Each sequence sends expert_capacity positions to each expert.
    if train:
        capacity_factor = hparams.moe_capacity_factor_train
    else:
        capacity_factor = hparams.moe_capacity_factor_eval
    expert_capacity = min(
        group_size_dim.size,
        int((group_size_dim.size * capacity_factor) / experts_dim.size))
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)

    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
    if nonpadding is not None:
        nonpadding = mtf.zeros(inputs.mesh,
                               batch_and_length_dims,
                               dtype=inputs.dtype) + nonpadding
        nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
    if hparams.moe_gating == "top_2":
        # dispatch_tensor and combine_tensor are
        # <B>GSEC Tensors
        dispatch_tensor, combine_tensor, loss = _top_2_gating(
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                               mtf.Shape([
                                   outer_batch_dim, experts_dim_unsplit,
                                   num_groups_dim, expert_capacity_dim,
                                   input_dim
                               ]))

    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.Shape([
            outer_batch_dim, experts_dim, batch_dim_unsplit,
            expert_capacity_dim, input_dim
        ]))

    # Now feed the expert inputs through the experts.
    h = mtf.layers.dense(expert_inputs,
                         hidden_dim,
                         expert_dims=[experts_dim],
                         activation=mtf.relu,
                         use_bias=False,
                         variable_dtype=variable_dtype,
                         name="wi")

    expert_output = mtf.layers.dense(h,
                                     output_dim,
                                     expert_dims=[experts_dim],
                                     use_bias=False,
                                     variable_dtype=variable_dtype,
                                     name="wo")

    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim,
            experts_dim_unsplit,
            num_groups_dim,
            expert_capacity_dim,
            output_dim,
        ]))

    moe_output_dims = moe_input_dims[:-1] + [output_dim]
    output = mtf.einsum([expert_output, combine_tensor],
                        mtf.Shape(moe_output_dims))
    output = mtf.reshape(output, batch_and_length_dims + [output_dim])

    return output, loss * hparams.moe_loss_coef
Exemple #28
0
    def _sample(self, features, mesh):
        hparams = self._hparams
        (inputs_embedding_var, targets_embedding_var, softmax_var,
         positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
        if hparams.transformer_type == "encdec":
            inputs = features["inputs"]
            while len(inputs.shape.as_list()) > 2:
                inputs = tf.squeeze(inputs, axis=2)
            actual_batch_size = tf.shape(inputs)[0]
            actual_length = tf.shape(inputs)[1]
            inputs = tf.pad(inputs,
                            [[0, hparams.batch_size - actual_batch_size],
                             [0, hparams.max_length - actual_length]])
            inputs = self._import_to_batch_by_length(inputs, "inputs", mesh,
                                                     hparams)
            x = (mtf.gather(inputs_embedding_var, inputs,
                            self.inputs_vocab_dim) +
                 mtf.reshape(positional_embedding_var,
                             mtf.Shape([self.length_dim, self.model_dim])))
            encoder_attention_mask = (mtf.layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))
            with tf.variable_scope("encoder"):
                x = self._layer_stack(
                    x,
                    hparams.encoder_layers,
                    self_attention_mask=encoder_attention_mask)
            encoder_output = mtf.rename_dimension(x, self.length_dim.name,
                                                  self.memory_length_dim.name)
            encdec_tensors = []
            for layer_num, layer_type in enumerate(hparams.decoder_layers):
                if layer_type == "enc_att":
                    with tf.variable_scope("decoder/enc_att_%d/enc_att" %
                                           layer_num):
                        q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars(
                            mesh, self.heads_dim, self.model_dim, self.kv_dim,
                            self.master_dtype, self.slice_dtype,
                            self.activation_dtype)
                        k = mtf.einsum([encoder_output, k_var],
                                       mtf.Shape(self.batch_dims + [
                                           self.heads_dim,
                                           self.memory_length_dim, self.kv_dim
                                       ]))
                        v = mtf.einsum([encoder_output, v_var],
                                       mtf.Shape(self.batch_dims + [
                                           self.heads_dim,
                                           self.memory_length_dim, self.kv_dim
                                       ]))
                    encdec_tensors.append((q_var, o_var, k, v))
                else:
                    encdec_tensors.append(None)
            partial_targets = None
        elif hparams.transformer_type == "decoder":
            encdec_tensors = None
            encoder_output = None
            encoder_attention_mask = None
            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = features.get("inputs", None)
            if partial_targets is None:
                partial_targets = features.get("targets", None)
            if partial_targets is not None:
                partial_targets = common_layers.expand_squeeze_to_nd(
                    partial_targets, 2)
                partial_targets = tf.to_int32(partial_targets)
                partial_targets_batch = tf.shape(partial_targets)[0]
                partial_targets_length = tf.shape(partial_targets)[1]
                partial_targets = tf.pad(
                    partial_targets,
                    [[0, hparams.batch_size - partial_targets_batch],
                     [0, hparams.max_length - partial_targets_length]])
                partial_targets = self._import_to_batch_by_length(
                    partial_targets, "partial_targets", mesh, hparams)
        else:
            raise ValueError("hparams.model_type = %s not yet supported" %
                             hparams.transformer_type)

        local_attention_window = mtf.Dimension(
            "local_attention_window", hparams.local_attention_window_size)
        if hparams.beam_size == 1:
            ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
            kv_shape = mtf.Shape(
                self.batch_dims +
                [self.heads_dim, self.memory_length_dim, self.kv_dim])
            local_kv_shape = mtf.Shape(
                self.batch_dims +
                [self.heads_dim, local_attention_window, self.kv_dim])
        else:
            beam_dim = mtf.Dimension("beam", hparams.beam_size)
            ids_shape = mtf.Shape(self.batch_dims +
                                  [beam_dim, self.length_dim])
            kv_shape = mtf.Shape(self.batch_dims + [
                beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim
            ])
            local_kv_shape = mtf.Shape(self.batch_dims + [
                beam_dim, self.heads_dim, local_attention_window, self.kv_dim
            ])

        initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
        initial_states = []
        for layer in hparams.decoder_layers:
            if layer == "att":
                initial_states.extend(
                    [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] *
                    2)
            elif layer == "local_att":
                initial_states.extend([
                    mtf.zeros(
                        mesh, local_kv_shape, dtype=self.activation_dtype)
                ] * 2)

        def logits_fn(step_num, ids, states):
            """Produce logits for this step, and new states."""
            ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
            x = (mtf.gather(targets_embedding_var, ids_this_step,
                            self.targets_vocab_dim) +
                 mtf.gather(positional_embedding_var, step_num,
                            self.max_length_dim))
            with tf.variable_scope("decoder"):
                x, new_states = self._layer_stack(
                    x,
                    hparams.decoder_layers,
                    encdec_attention_mask=encoder_attention_mask,
                    step_num=step_num,
                    encdec_tensors=encdec_tensors,
                    states=states)
            logits = mtf.matmul(x, softmax_var)
            return logits, new_states

        if hparams.beam_size == 1:
            temperature = (0.0 if hparams.sampling_method == "argmax" else
                           hparams.sampling_temp)
            return mtf.beam_search.greedy_decode(logits_fn,
                                                 initial_ids,
                                                 temperature=temperature,
                                                 initial_states=initial_states,
                                                 forced_ids=partial_targets,
                                                 use_tpu=hparams.use_tpu)
        else:
            if hparams.transformer_type == "encdec":
                input_length = mtf.reduce_sum(mtf.to_float(
                    mtf.cast(inputs, tf.bool)),
                                              reduced_dim=self.length_dim)
                max_input_length = mtf.reduce_max(input_length)
                decode_length = mtf.cast(
                    max_input_length * hparams.decode_length_multiplier +
                    hparams.decode_length_constant, tf.int32)
            else:
                decode_length = None
            beams, unused_scores = mtf.beam_search.beam_search(
                logits_fn,
                initial_ids,
                hparams.alpha,
                states=initial_states,
                decode_length=decode_length,
                use_tpu=hparams.use_tpu,
                dtype=self.activation_dtype)
            return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32),
                              beam_dim)
Exemple #29
0
    def _layer_stack(self,
                     x,
                     layers,
                     encoder_output=None,
                     self_attention_mask=None,
                     encdec_attention_mask=None,
                     losses=None,
                     step_num=None,
                     encdec_tensors=None,
                     states=None):
        """Encoder or decoder stack.

    Args:
      x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
      layers: an list of strings
      encoder_output: an optional mtf.Tensor with shape
        [<batch_dims>, encoder_length_dim, model_dim]
      self_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, memory_length_dim] containing values 0 or -inf.
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.
      losses: a list to be appended-to
      step_num: an optional mtf integer Scalar (used in incrmenental mode)
      encdec_tensors: an optional list of num_layers tuples, each of the form
        (q_var, o_var, k, v), (used in incremental mode)
      states: an optional list of Tensors (used in incremental mode)
    Returns:
      a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
        hparams = self._hparams
        is_incremental = (step_num is not None)

        def layer_prepostprocess_dropout(x):
            if is_incremental:
                return x
            return mtf.dropout(
                x,
                keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
                noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

        num_layers = len(layers)
        num_layer_norms = num_layers + 1
        layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
        layer_norm_combined_var = mtf.get_variable(
            x.mesh,
            "layer_norm_scale",
            mtf.Shape([layer_norms_dim, self.model_dim]),
            initializer=tf.ones_initializer(),
            activation_dtype=x.dtype)
        layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)

        def normalize(x):
            scale = layer_norm_vars.pop(0)
            variance = mtf.reduce_mean(mtf.square(x),
                                       reduced_dim=self.model_dim)
            return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

        if is_incremental:
            states = list(states)
            new_states = []
        tf.logging.info("states = %s" % (states, ))

        for lnum, layer_type in enumerate(layers):
            with tf.variable_scope("%s_%d" % (layer_type, lnum)):
                if layer_type == "att":
                    # Self attention layer
                    if is_incremental:
                        y, new_k, new_v = mtf.layers.multihead_self_attention_incremental(
                            normalize(x),
                            prev_k=states.pop(0),
                            prev_v=states.pop(0),
                            step_num=step_num,
                            master_dtype=self.master_dtype,
                            slice_dtype=self.slice_dtype,
                            name="att")
                        new_states.append(new_k)
                        new_states.append(new_v)
                        x += y
                    else:
                        x += layer_prepostprocess_dropout(
                            mtf.layers.multihead_attention(
                                normalize(x),
                                None,
                                self_attention_mask,
                                self.kv_dim,
                                self.heads_dim,
                                dropout=hparams.attention_dropout,
                                dropout_broadcast_dims=[self.length_dim],
                                master_dtype=self.master_dtype,
                                slice_dtype=self.slice_dtype,
                                name="att"))
                elif layer_type == "enc_att":
                    # Encoder-Decoder attention layer
                    if is_incremental:
                        # Encoder-Decoder attention layer
                        q_var, o_var, k, v = encdec_tensors[lnum]
                        x += mtf.layers.multihead_encdec_attention_incremental(
                            normalize(x),
                            q_var,
                            o_var,
                            k,
                            v,
                            encdec_attention_mask,
                            name="enc_att")
                    else:
                        x += layer_prepostprocess_dropout(
                            mtf.layers.multihead_attention(
                                normalize(x),
                                encoder_output,
                                encdec_attention_mask,
                                self.kv_dim,
                                self.heads_dim,
                                dropout=hparams.attention_dropout,
                                dropout_broadcast_dims=[self.length_dim],
                                master_dtype=self.master_dtype,
                                slice_dtype=self.slice_dtype,
                                name="enc_att"))
                elif layer_type == "local_att":
                    if is_incremental:
                        y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
                            normalize(x),
                            prev_k=states.pop(0),
                            prev_v=states.pop(0),
                            step_num=step_num,
                            master_dtype=self.master_dtype,
                            slice_dtype=self.slice_dtype,
                            name="local_att")
                        new_states.append(new_k)
                        new_states.append(new_v)
                        x += y
                    else:
                        x += layer_prepostprocess_dropout(
                            mtf.layers.masked_local_attention_1d(
                                normalize(x),
                                self.kv_dim,
                                self.heads_dim,
                                window_size=hparams.
                                local_attention_window_size,
                                master_dtype=self.master_dtype,
                                slice_dtype=self.slice_dtype,
                                length_per_split=mtf.
                                tensor_dim_to_size_per_split(
                                    hparams.layout, hparams.mesh_shape,
                                    self.max_length_dim),
                                name="local_att"))
                else:
                    if is_incremental:
                        # insert length dimension.
                        x_shape = x.shape
                        shape_with_length = mtf.Shape(
                            x_shape.dims[:-1] + [mtf.Dimension("length", 1)] +
                            x_shape.dims[-1:])
                        x = mtf.reshape(x, shape_with_length)
                    # ffn layer
                    x += layer_prepostprocess_dropout(
                        self._feedforward_layer(normalize(x),
                                                layer_type,
                                                losses=losses))
                    if is_incremental:
                        # remove length dimension
                        x = mtf.reshape(x, x_shape)

        x = layer_prepostprocess_dropout(normalize(x))
        assert not layer_norm_vars
        if is_incremental:
            return x, new_states
        else:
            return x
Exemple #30
0
    def _mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        hparams = self._hparams
        targets = tf.to_int32(features["targets"])
        if len(targets.get_shape()) > 2:
            tf.logging.info("targets = %s" % targets)
            targets = tf.squeeze(targets, [2, 3])
        # pad targets to max_length
        def pad_to_max_length(x):
            extra_length = hparams.max_length - tf.shape(x)[1]
            x = tf.pad(x, [[0, 0], [0, extra_length]])
            x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
            return x

        targets = pad_to_max_length(targets)
        for key in [
                "targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"
        ]:
            if key in features:
                features[key] = pad_to_max_length(features[key])
        shifted_targets = common_layers.shift_right_2d(targets)

        targets = self._import_to_batch_by_length(targets, "targets", mesh,
                                                  hparams)
        shifted_targets = self._import_to_batch_by_length(
            shifted_targets, "shifted_targets", mesh, hparams)

        if "targets_segmentation" in features:
            # "Packed" dataset - keep the examples from seeing each other.
            targets_segmentation = self._import_to_batch_by_length(
                features["targets_segmentation"], "targets_segmentation", mesh,
                hparams)
            targets_position = self._import_to_batch_by_length(
                features["targets_position"], "targets_position", mesh,
                hparams)
            decoder_self_attention_mask = (
                mtf.layers.attention_mask_autoregressive(
                    targets_position, dtype=self.activation_dtype) +
                mtf.layers.attention_mask_same_segment(
                    targets_segmentation, dtype=self.activation_dtype))
        else:
            targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
            decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
                targets_position, dtype=self.activation_dtype)

        def layer_prepostprocess_dropout(x):
            return mtf.dropout(
                x,
                keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
                noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

        extra_losses = []
        (inputs_embedding_var, targets_embedding_var, softmax_var,
         positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
        if hparams.transformer_type == "decoder":
            encoder_output = None
            encoder_decoder_attention_mask = None
        else:
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = pad_to_max_length(inputs)
            inputs = self._import_to_batch_by_length(inputs, "inputs", mesh,
                                                     hparams)
            if "inputs_segmentation" in features:
                # "Packed" dataset - keep the examples from seeing each other.
                inputs_segmentation = self._import_to_batch_by_length(
                    features["inputs_segmentation"], "inputs_segmentation",
                    mesh, hparams)
                inputs_position = self._import_to_batch_by_length(
                    features["inputs_position"], "inputs_position", mesh,
                    hparams)
                encoder_self_attention_mask = (
                    mtf.layers.attention_mask_same_segment(
                        inputs_segmentation, dtype=self.activation_dtype))
            else:
                inputs_position = mtf.range(mesh,
                                            self.length_dim,
                                            dtype=tf.int32)
                encoder_self_attention_mask = (
                    mtf.layers.attention_mask_ignore_padding(
                        inputs, dtype=self.activation_dtype))

            x = (mtf.gather(inputs_embedding_var, inputs,
                            self.inputs_vocab_dim) +
                 mtf.gather(positional_embedding_var, inputs_position,
                            self.max_length_dim))
            x = layer_prepostprocess_dropout(x)
            with tf.variable_scope("encoder"):
                x = self._layer_stack(
                    x,
                    hparams.encoder_layers,
                    self_attention_mask=encoder_self_attention_mask,
                    losses=extra_losses)

        if hparams.transformer_type == "encdec":
            if "inputs_segmentation" in features:
                encoder_decoder_attention_mask = (
                    mtf.layers.attention_mask_same_segment(
                        targets_segmentation,
                        inputs_segmentation,
                        dtype=self.activation_dtype))
            else:
                encoder_decoder_attention_mask = encoder_self_attention_mask
            encoder_output = mtf.rename_dimension(x, self.length_dim.name,
                                                  self.memory_length_dim.name)

        if hparams.transformer_type != "encoder":
            # DECODER
            x = (mtf.gather(targets_embedding_var, shifted_targets,
                            self.targets_vocab_dim) +
                 mtf.gather(positional_embedding_var, targets_position,
                            self.max_length_dim))
            x = layer_prepostprocess_dropout(x)
            with tf.variable_scope("decoder"):
                x = self._layer_stack(
                    x,
                    hparams.decoder_layers,
                    encoder_output=encoder_output,
                    self_attention_mask=decoder_self_attention_mask,
                    encdec_attention_mask=encoder_decoder_attention_mask,
                    losses=extra_losses)
        logits = mtf.matmul(x, softmax_var)
        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
        off_value = hparams.label_smoothing / self._targets_vocab_size
        on_value = 1.0 - hparams.label_smoothing + off_value
        soft_targets = mtf.one_hot(targets,
                                   self.targets_vocab_dim,
                                   on_value=on_value,
                                   off_value=off_value,
                                   dtype=self.activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, self.targets_vocab_dim)
        weights = mtf.layers.weights_nonzero(targets,
                                             dtype=self.activation_dtype)
        loss = mtf.reduce_mean(loss * weights)
        for l in extra_losses:
            loss += l
        logits = mtf.to_float(logits)
        # combine batch dims
        if len(self.batch_dims) > 1:
            combined_batch_dim = mtf.Dimension(self.batch_dims[0].name,
                                               mtf.Shape(self.batch_dims).size)
            logits = mtf.reshape(logits,
                                 [combined_batch_dim] + logits.shape.dims[-2:])
        return logits, loss
Exemple #31
0
def lpt_init(lr_field,
             hr_field,
             a0,
             kvec_lr,
             kvec_hr,
             halo_size,
             hr_shape,
             lr_shape,
             part_shape,
             antialias=True,
             downsampling_factor=2,
             order=1,
             post_filtering=True,
             cosmology=Planck15):
    a = a0
    batch_dim = hr_field.shape[0]
    lnc = lr_shape[-1].size
    k_dims_lr = [d.shape[0] for d in kvec_lr]
    k_dims_hr = [d.shape[0] for d in kvec_hr]
    k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]]
    k_dims_hr = [k_dims_hr[2], k_dims_hr[0], k_dims_hr[1]]

    # Create particles on the high resolution grid
    mstate = mesh_ops.mtf_indices(hr_field.mesh,
                                  shape=part_shape,
                                  dtype=tf.float32)
    X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate],
                   output_shape=[batch_dim] + mstate.shape[:])

    lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr)
    hr_kfield = mesh_utils.r2c3d(hr_field, k_dims_hr)

    grad_kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(
        lr_kfield, kvec_lr)
    grad_kfield_hr = mesh_kernels.apply_gradient_laplace_kernel(
        hr_kfield, kvec_hr)

    # Reorder the low res FFTs which where transposed# y,z,x
    grad_kfield_lr = [grad_kfield_lr[2], grad_kfield_lr[0], grad_kfield_lr[1]]
    grad_kfield_hr = [grad_kfield_hr[2], grad_kfield_hr[0], grad_kfield_hr[1]]

    displacement = []
    for f, g in zip(grad_kfield_lr, grad_kfield_hr):
        f = mesh_utils.c2r3d(f, lr_shape[-3:])
        f = mtf.slicewise(
            lambda x: tf.expand_dims(
                tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1),
            [f],
            output_dtype=tf.float32,
            output_shape=mtf.Shape(hr_shape[0:4] + [
                mtf.Dimension('sx_block', lnc // hr_shape[1].size),
                mtf.Dimension('sy_block', lnc // hr_shape[2].size),
                mtf.Dimension('sz_block', lnc // hr_shape[3].size)
            ]),
            name='my_reshape',
            splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])
        for block_size_dim in hr_shape[-3:]:
            f = mtf.pad(f, [
                halo_size // 2**downsampling_factor,
                halo_size // 2**downsampling_factor
            ], block_size_dim.name)
        for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]):
            f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim,
                                     halo_size // 2**downsampling_factor)
        f = mtf.reshape(f, f.shape + [mtf.Dimension('h_dim', 1)])
        f = mesh_utils.upsample(f, downsampling_factor)
        f = mtf.reshape(f, f.shape[:-1])

        g = mesh_utils.c2r3d(g, f.shape[-3:])
        high_shape = g.shape
        # And now we remove the large scales
        g = mtf.reshape(g, g.shape + [mtf.Dimension('h_dim', 1)])
        _low = mesh_utils.downsample(g,
                                     downsampling_factor,
                                     antialias=antialias)
        g = g - mtf.reshape(mesh_utils.upsample(_low, downsampling_factor),
                            g.shape)
        g = mtf.reshape(g, high_shape)

        d = mesh_utils.cic_readout(f + g, X, halo_size)
        displacement.append(d)

    # Readout to particle positions
    displacement = mtf.stack([d for d in displacement], "ndim", axis=4)

    pt = PerturbationGrowth(cosmology, a=[a], a_normalize=1.0)
    DX = pt.D1(a) * displacement
    P = (a**2 * pt.f1(a) * pt.E(a)) * DX
    F = (a**2 * pt.E(a) * pt.gf(a) / pt.D1(a)) * DX
    # TODO: Implement 2nd order LPT

    # Moves the particles according to displacement
    X = X + DX

    return X, P, F
Exemple #32
0
def force(state,
          lr_shape,
          hr_shape,
          kvec_lr,
          kvec_hr,
          halo_size,
          cosmology=Planck15,
          downsampling_factor=2,
          pm_nc_factor=1,
          antialias=True,
          **kwargs):
    """
  Estimate force on the particles given a state.

  Parameters:
  -----------
  state: tensor
    Input state tensor of shape (3, batch_size, npart, 3)

  boxsize: float
    Size of the simulation volume (Mpc/h) TODO: check units

  cosmology: astropy.cosmology
    Cosmology object

  pm_nc_factor: int
    TODO: @modichirag please add doc
  """
    X, P, F = state
    #TODO: support different factor
    assert pm_nc_factor == 1
    lnc = lr_shape[-1].size
    part_shape = X.shape
    k_dims_lr = [d.shape[0] for d in kvec_lr]
    k_dims_hr = [d.shape[0] for d in kvec_hr]
    # Reorder the FFTs which where transposed# y,z,x
    k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]]
    k_dims_hr = [k_dims_hr[2], k_dims_hr[0], k_dims_hr[1]]

    # Paint the particles on the high resolution mesh
    field = mtf.zeros(X.mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)
    field = mesh_utils.cic_paint(field, X, halo_size)
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim,
                                     halo_size)

    # Split the field into low and high resolution
    field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
    high = field
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)
    low = mtf.reshape(low, low.shape[:-1])
    hr_field = mtf.reshape(high, high.shape[:-1])
    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size // 2**downsampling_factor,
                        block_size_dim.size // 2**downsampling_factor,
                        block_size_dim.name)

    # Hack usisng  custom reshape because mesh is pretty dumb
    lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low],
                             output_dtype=tf.float32,
                             output_shape=lr_shape,
                             name='my_dumb_reshape',
                             splittable_dims=lr_shape[:-1] + hr_shape[:4])

    lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr)
    hr_kfield = mesh_utils.r2c3d(hr_field, k_dims_hr)

    kfield_lr = mesh_kernels.apply_longrange_kernel(lr_kfield,
                                                    kvec_lr,
                                                    r_split=0)
    kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr)
    kfield_hr = mesh_kernels.apply_longrange_kernel(hr_kfield,
                                                    kvec_hr,
                                                    r_split=0)
    kfield_hr = mesh_kernels.apply_gradient_laplace_kernel(kfield_hr, kvec_hr)

    # Reorder the low res FFTs which where transposed# y,z,x
    kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]]
    kfield_hr = [kfield_hr[2], kfield_hr[0], kfield_hr[1]]

    displacement = []
    for f, g in zip(kfield_lr, kfield_hr):
        f = mesh_utils.c2r3d(f, lr_shape[-3:])
        f = mtf.slicewise(
            lambda x: tf.expand_dims(
                tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1),
            [f],
            output_dtype=tf.float32,
            output_shape=mtf.Shape(hr_shape[0:4] + [
                mtf.Dimension('sx_block', lnc // hr_shape[1].size),
                mtf.Dimension('sy_block', lnc // hr_shape[2].size),
                mtf.Dimension('sz_block', lnc // hr_shape[3].size)
            ]),
            name='my_reshape',
            splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])
        for block_size_dim in hr_shape[-3:]:
            f = mtf.pad(f, [
                halo_size // 2**downsampling_factor,
                halo_size // 2**downsampling_factor
            ], block_size_dim.name)
        for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]):
            f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim,
                                     halo_size // 2**downsampling_factor)
        f = mtf.reshape(f, f.shape + [mtf.Dimension('h_dim', 1)])
        f = mesh_utils.upsample(f, downsampling_factor)
        f = mtf.reshape(f, f.shape[:-1])

        g = mesh_utils.c2r3d(g, f.shape[-3:])
        high_shape = g.shape
        # And now we remove the large scales
        g = mtf.reshape(g, g.shape + [mtf.Dimension('h_dim', 1)])
        _low = mesh_utils.downsample(g,
                                     downsampling_factor,
                                     antialias=antialias)
        g = g - mtf.reshape(mesh_utils.upsample(_low, downsampling_factor),
                            g.shape)
        g = mtf.reshape(g, high_shape)

        d = mesh_utils.cic_readout(f + g, X, halo_size)
        displacement.append(d)

    # Readout the force to particle positions
    F = mtf.stack([d for d in displacement], "ndim", axis=4)

    F = F * 1.5 * cosmology.Om0
    return X, P, F
  def mtf_model_fn(self, features, mesh):
    features = copy.copy(features)
    tf.logging.info("features = %s" % features)
    hparams = self._hparams
    activation_dtype = self.activation_type

    # We assume fixed vocab size for targets
    targets = tf.to_int32(features["targets"])

    # Image preprocessing, reshape into a 1D sequence and shift right.
    length = hparams.img_len*hparams.img_len*hparams.num_channels
    targets = tf.reshape(targets, [hparams.batch_size, length])
    shifted_targets = common_layers.shift_right_2d(targets)

    # Declare all the dimensions
    batch_dim = mtf.Dimension("batch", hparams.batch_size)

    def import_to_batch_by_length(x, name):
      return mtf.import_tf_tensor(
          mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name)

    targets = import_to_batch_by_length(targets, "targets")
    shifted_targets = import_to_batch_by_length(
        shifted_targets, "shifted_targets")

    extra_losses = []

    # Create targets content and position embeddings.
    # Create embedding var for targets and positions and do a gather.
    targets_embedding_var = mtf.get_variable(
        mesh, "targets_embedding",
        mtf.Shape([self.targets_vocab_dim, self.model_dim]),
        initializer=tf.random_normal_initializer(),
        activation_dtype=activation_dtype)

    x = mtf.gather(targets_embedding_var,
                   shifted_targets, self.targets_vocab_dim)

    # Add positional embeddings
    x += mtf.reshape(self.create_positional_emb_2d(targets),
                     [self.length_dim, self.model_dim])

    # If conditional and input is given, add the input embedding to the target.
    # TODO(nikip): Verify conditional.
    if self.has_input and not hparams.unconditional:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = import_to_batch_by_length(inputs, "inputs")

      # Input embeddings
      inputs_embedding_var = mtf.layers.embedding(
          mesh, "input_embedding",
          mtf.Shape([self.inputs_vocab_dim, self.model_dim]),
          activation_dtype=activation_dtype)
      inputs_emb = mtf.gather(
          inputs_embedding_var, inputs, self.inputs_vocab_dim)
      x += inputs_emb

    # Image Transformer Decoder
    # [ self attention - ffn - residual + dropout] x n
    if hparams.attention_type == "local1d_spatial":
      decoder_output = local_attention1d_spatial_decoder(
          x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
    elif hparams.attention_type == "local2d_spatial":
      decoder_output = local_attention2d_spatial_decoder(
          x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
    elif hparams.attention_type == "local1d":
      decoder_output = local_attention1d_masked_decoder(
          x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
    else:
      raise ValueError("Invalid attention type.")

    # Calculate the logits and loss.
    logits = mtf.layers.dense(
        decoder_output, self.outputs_vocab_dim, name="logits")
    # Need a reshape for logits
    logits = mtf.reshape(
        logits, mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim]))
    soft_targets = mtf.one_hot(
        targets, self.outputs_vocab_dim, dtype=activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.outputs_vocab_dim)
    loss = mtf.reduce_mean(loss)
    for l in extra_losses:
      loss += l

    # Reshape logits to original target shape.
    logits = mtf.reshape(
        logits,
        mtf.Shape([batch_dim, self.rows_dim, self.orig_cols_dim,
                   self.channels_dim, self.outputs_vocab_dim]))

    return logits, loss
  def _sample(self, features, mesh):
    hparams = self._hparams
    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if hparams.transformer_type == "encdec":
      inputs = features["inputs"]
      while len(inputs.shape.as_list()) > 2:
        inputs = tf.squeeze(inputs, axis=2)
      actual_batch_size = tf.shape(inputs)[0]
      actual_length = tf.shape(inputs)[1]
      inputs = tf.pad(
          inputs, [[0, hparams.batch_size - actual_batch_size],
                   [0, hparams.max_length - actual_length]])
      inputs = self._import_to_batch_by_length(
          inputs, "inputs", mesh, hparams)
      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.reshape(positional_embedding_var,
                       mtf.Shape([self.length_dim, self.model_dim])))
      encoder_attention_mask = (
          mtf.layers.attention_mask_ignore_padding(
              inputs, dtype=self.activation_dtype))
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.encoder_layers,
                              self_attention_mask=encoder_attention_mask)
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)
      encdec_tensors = []
      for layer_num, layer_type in enumerate(hparams.decoder_layers):
        if layer_type == "enc_att":
          with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num):
            q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars(
                mesh, self.heads_dim, self.model_dim,
                self.kv_dim, self.master_dtype, self.slice_dtype,
                self.activation_dtype)
            k = mtf.einsum(
                [encoder_output, k_var],
                mtf.Shape(
                    self.batch_dims + [self.heads_dim,
                                       self.memory_length_dim, self.kv_dim]))
            v = mtf.einsum(
                [encoder_output, v_var],
                mtf.Shape(
                    self.batch_dims + [self.heads_dim,
                                       self.memory_length_dim, self.kv_dim]))
          encdec_tensors.append((q_var, o_var, k, v))
        else:
          encdec_tensors.append(None)
      partial_targets = None
    elif hparams.transformer_type == "decoder":
      encdec_tensors = None
      encoder_output = None
      encoder_attention_mask = None
      # Prepare partial targets.
      # In either features["inputs"] or features["targets"].
      # We force the outputs to begin with these sequences.
      partial_targets = features.get("inputs", None)
      if partial_targets is None:
        partial_targets = features.get("targets", None)
      if partial_targets is not None:
        partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
        partial_targets = tf.to_int32(partial_targets)
        partial_targets_batch = tf.shape(partial_targets)[0]
        partial_targets_length = tf.shape(partial_targets)[1]
        partial_targets = tf.pad(
            partial_targets, [[0, hparams.batch_size - partial_targets_batch],
                              [0, hparams.max_length - partial_targets_length]])
        partial_targets = self._import_to_batch_by_length(
            partial_targets, "partial_targets", mesh, hparams)
    else:
      raise ValueError(
          "hparams.model_type = %s not yet supported"
          % hparams.transformer_type)

    local_attention_window = mtf.Dimension(
        "local_attention_window", hparams.local_attention_window_size)
    if hparams.beam_size == 1:
      ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
      kv_shape = mtf.Shape(self.batch_dims +
                           [self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
      local_kv_shape = mtf.Shape(self.batch_dims +
                                 [self.heads_dim,
                                  local_attention_window, self.kv_dim])
    else:
      beam_dim = mtf.Dimension("beam", hparams.beam_size)
      ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim])
      kv_shape = mtf.Shape(self.batch_dims +
                           [beam_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
      local_kv_shape = mtf.Shape(self.batch_dims +
                                 [beam_dim, self.heads_dim,
                                  local_attention_window, self.kv_dim])

    initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
    initial_states = []
    for layer in hparams.decoder_layers:
      if layer == "att":
        initial_states.extend(
            [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2)
      elif layer == "local_att":
        initial_states.extend(
            [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2)

    def logits_fn(step_num, ids, states):
      """Produce logits for this step, and new states."""
      ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
      x = (mtf.gather(targets_embedding_var, ids_this_step,
                      self.targets_vocab_dim) +
           mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
      with tf.variable_scope("decoder"):
        x, new_states = self._layer_stack(
            x,
            hparams.decoder_layers,
            encdec_attention_mask=encoder_attention_mask,
            step_num=step_num,
            encdec_tensors=encdec_tensors,
            states=states)
      logits = mtf.matmul(x, softmax_var)
      return logits, new_states

    if hparams.beam_size == 1:
      temperature = (0.0 if hparams.sampling_method == "argmax"
                     else hparams.sampling_temp)
      return mtf.beam_search.greedy_decode(
          logits_fn,
          initial_ids,
          temperature=temperature,
          initial_states=initial_states,
          forced_ids=partial_targets,
          use_tpu=hparams.use_tpu)
    else:
      if hparams.transformer_type == "encdec":
        input_length = mtf.reduce_sum(
            mtf.to_float(mtf.cast(inputs, tf.bool)),
            reduced_dim=self.length_dim)
        max_input_length = mtf.reduce_max(input_length)
        decode_length = mtf.cast(
            max_input_length * hparams.decode_length_multiplier
            + hparams.decode_length_constant, tf.int32)
      else:
        decode_length = None
      beams, unused_scores = mtf.beam_search.beam_search(
          logits_fn,
          initial_ids,
          hparams.alpha,
          states=initial_states,
          decode_length=decode_length,
          use_tpu=hparams.use_tpu,
          dtype=self.activation_dtype)
      return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
  def _layer_stack(self,
                   x,
                   layers,
                   encoder_output=None,
                   self_attention_mask=None,
                   encdec_attention_mask=None,
                   losses=None,
                   step_num=None,
                   encdec_tensors=None,
                   states=None):
    """Encoder or decoder stack.

    Args:
      x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
      layers: an list of strings
      encoder_output: an optional mtf.Tensor with shape
        [<batch_dims>, encoder_length_dim, model_dim]
      self_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, memory_length_dim] containing values 0 or -inf.
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.
      losses: a list to be appended-to
      step_num: an optional mtf integer Scalar (used in incrmenental mode)
      encdec_tensors: an optional list of num_layers tuples, each of the form
        (q_var, o_var, k, v), (used in incremental mode)
      states: an optional list of Tensors (used in incremental mode)
    Returns:
      a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
    hparams = self._hparams
    is_incremental = (step_num is not None)
    def layer_prepostprocess_dropout(x):
      if is_incremental:
        return x
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
    num_layers = len(layers)
    num_layer_norms = num_layers + 1
    layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
    layer_norm_combined_var = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([layer_norms_dim, self.model_dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
    def normalize(x):
      scale = layer_norm_vars.pop(0)
      variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
      return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

    if is_incremental:
      states = list(states)
      new_states = []
    tf.logging.info("states = %s" % (states,))

    for lnum, layer_type in enumerate(layers):
      with tf.variable_scope("%s_%d" % (layer_type, lnum)):
        if layer_type == "att":
          # Self attention layer
          if is_incremental:
            y, new_k, new_v = mtf.layers.multihead_self_attention_incremental(
                normalize(x),
                prev_k=states.pop(0),
                prev_v=states.pop(0),
                step_num=step_num,
                master_dtype=self.master_dtype,
                slice_dtype=self.slice_dtype,
                name="att")
            new_states.append(new_k)
            new_states.append(new_v)
            x += y
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_attention(
                    normalize(x), None,
                    self_attention_mask, self.kv_dim, self.heads_dim,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="att"))
        elif layer_type == "enc_att":
          # Encoder-Decoder attention layer
          if is_incremental:
            # Encoder-Decoder attention layer
            q_var, o_var, k, v = encdec_tensors[lnum]
            x += mtf.layers.multihead_encdec_attention_incremental(
                normalize(x),
                q_var, o_var, k, v,
                encdec_attention_mask,
                name="enc_att")
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_attention(
                    normalize(x), encoder_output,
                    encdec_attention_mask, self.kv_dim, self.heads_dim,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="enc_att"))
        elif layer_type == "local_att":
          if is_incremental:
            y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
                normalize(x),
                prev_k=states.pop(0),
                prev_v=states.pop(0),
                step_num=step_num,
                master_dtype=self.master_dtype,
                slice_dtype=self.slice_dtype,
                name="local_att")
            new_states.append(new_k)
            new_states.append(new_v)
            x += y
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.masked_local_attention_1d(
                    normalize(x),
                    self.kv_dim, self.heads_dim,
                    window_size=hparams.local_attention_window_size,
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    length_per_split=mtf.tensor_dim_to_size_per_split(
                        hparams.layout, hparams.mesh_shape,
                        self.max_length_dim),
                    name="local_att"))
        elif layer_type == "compressed_att":
          if is_incremental:
            raise ValueError("compressed_att incremental not implemented")
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_self_attention_memory_compressed(
                    normalize(x),
                    mask_right=True,
                    compression_factor=hparams.compression_factor,
                    kv_channels=self.kv_dim,
                    heads=self.heads_dim,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="compressed_att"))
        else:
          if is_incremental:
            # insert length dimension.
            x_shape = x.shape
            shape_with_length = mtf.Shape(
                x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
                + x_shape.dims[-1:])
            x = mtf.reshape(x, shape_with_length)
          # ffn layer
          x += layer_prepostprocess_dropout(
              self._feedforward_layer(normalize(x), layer_type, losses=losses))
          if is_incremental:
            # remove length dimension
            x = mtf.reshape(x, x_shape)

    x = layer_prepostprocess_dropout(normalize(x))
    assert not layer_norm_vars
    if is_incremental:
      return x, new_states
    else:
      return x
Exemple #36
0
def transformer_moe_layer_v2(inputs,
                             output_dim,
                             hparams,
                             train,
                             variable_dtype,
                             layout=None,
                             mesh_shape=None,
                             nonpadding=None):
    """2-level mixture of experts.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_capacity_factor_second_level: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  One set of params for experts in first level and different of hparams
  per expert in the second level.
  The number of parameters in the gating network is:
    (input_dim.size * (hparams.num_experts) +
      (moe_hidden_size * hparams.num_experts) * hparams.num_experts


  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-3 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Dimensions cheat sheet:
  a, b: batch size
  l: original sequence length
  m: input depth
  n: output depth
  g, h: number of groups
  s, t: group size
  x, y: number of experts
  c, d: expert capacity

  input: [a0, b1, l, m]
  input: [a0, g1, s, m]
  dispatch_tensor_x: [a0, g1, s, x, c]
  expert_input: [a0, g1, x, c, m]
  alltoall: [a0, g, x1, c, m]
  alltoall: [a0, g, x1, c, m]
  transpose: [x1, a0, g, c, m]
  reshape: [x1, h0, s, m]
  assignment2: [x1, h0, t, y, d]
  expert_input2: [x1, h0, y, d, m]
  alltoall: [x1, h, y0, d, m]
  ...
  reverse of that

  gating params 0: [m, x]
  gating params 1: [x1, m, y]

  expert params:
     [x1, y0, m, hidden]
     [x1, y0, hidden, n]

  Args:
    inputs: a mtf.Tensor with shape [a, b, l, m]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional mtf.Tensor with shape [a, b, l]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).

  Returns:
    outputs: a Tensor with shape [a, b, l, n]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    if nonpadding is not None:
        nonpadding = mtf.zeros(inputs.mesh,
                               inputs.shape.dims[:-1],
                               dtype=inputs.dtype) + nonpadding
    insert_outer_batch_dim = (len(inputs.shape.dims) == 3)
    if insert_outer_batch_dim:
        inputs = mtf.reshape(inputs, [mtf.Dimension("outer_batch", 1)] +
                             inputs.shape.dims)

    assert len(hparams.moe_num_experts) == 2
    a0, b1, l, m = inputs.shape.dims
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
    x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0])
    y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1])
    x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0])
    y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1])
    n = output_dim

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups (g.size) is a multiple of the mesh dimension
    # over which those groups are split.
    num_groups, group_size = _split_into_groups(
        b1.size * l.size, hparams.moe_group_size,
        mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, b1))
    g1 = mtf.Dimension(b1.name, num_groups)
    g = mtf.Dimension(b1.name + "_unsplit", g1.size)
    s = mtf.Dimension("group_size_x", group_size)

    # Each sequence sends (at most?) expert_capacity positions to each expert.
    # Static expert_capacity dimension is needed for expert batch sizes
    if train:
        capacity_factor = hparams.moe_capacity_factor_train
    else:
        capacity_factor = hparams.moe_capacity_factor_eval
    expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size))
    expert_capacity = max(expert_capacity, 4)
    c = mtf.Dimension("expert_capacity_x", expert_capacity)

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups (h.size) is a multiple of the mesh dimension
    # over which those groups are split.
    num_groups, group_size = _split_into_groups(
        a0.size * g.size * c.size, hparams.moe_group_size,
        mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, a0))
    t = mtf.Dimension("group_size_y", group_size)
    h0 = mtf.Dimension(a0.name, num_groups)
    h = mtf.Dimension(a0.name + "_unsplit", h0.size)

    expert_capacity = min(
        t.size,
        int((t.size * hparams.moe_capacity_factor_second_level) / y.size))
    expert_capacity = max(expert_capacity, 4)
    d = mtf.Dimension("expert_capacity_y", expert_capacity)

    # First level of expert routing
    # Reshape the inner batch size to a multiple of group_dim g1 and
    # group_size_dim s.
    inputs = mtf.reshape(inputs, [a0, g1, s, m])
    if nonpadding is not None:
        nonpadding = mtf.reshape(nonpadding, [a0, g1, s])

    # Get the assignments for the first level.
    # dispatch_tensor_x has shape [a0, g1, s, x, c]
    if hparams.moe_gating == "top_2":
        dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating(
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=x,
            expert_capacity_dim=c,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            name="outer_gating",
            importance=nonpadding)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # Now create expert_inputs based on the assignments.
    # put num_experts dimension first to make split easier in alltoall
    expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x],
                                 [x, a0, g1, c, m])

    # we construct an "importance" Tensor for the inputs to the second-level
    # gating.  The importance of an input is 1.0 if it represents the
    # first-choice expert-group and 0.5 if it represents the second-choice expert
    # group.  This is used by the second-level gating.
    importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c])
    importance = 0.5 * (mtf.to_float(mtf.greater(importance, 0.5)) +
                        mtf.to_float(mtf.greater(importance, 0.0)))

    # First level, all to all. Here we change the split dimension from g1 to x1.
    expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape([x1, a0, g, c,
                                                              m]))
    importance = mtf.reshape(importance, [x1, a0, g, c])

    # Second level of expert routing
    # Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0
    # and group_size_dim t.
    inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m])
    importance = mtf.reshape(importance, [x1, h0, t])

    # Get the assignments for the second level.
    # dispatch_tensor_y has shape [x1, h0, t, y, d]
    if hparams.moe_gating == "top_2":
        dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating(
            inputs=inputs_y,
            outer_expert_dims=[x1],
            experts_dim=y,
            expert_capacity_dim=d,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=importance,
            name="inner_gating")
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # Now create expert_inputs based on the assignments.
    # put num_experts dimension first to make split easier in alltoall
    expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y],
                                 [y, x1, h0, d, m])

    # Second level, all to all. Here we change the split dimension from h0 to y0.
    expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape([y0, x1, h, d,
                                                              m]))

    hidden_output = mtf.layers.dense(expert_inputs_y,
                                     hidden_dim,
                                     expert_dims=[y0, x1],
                                     activation=mtf.relu,
                                     use_bias=False,
                                     variable_dtype=variable_dtype,
                                     name="wi")
    expert_output = mtf.layers.dense(hidden_output,
                                     output_dim,
                                     expert_dims=[y0, x1],
                                     use_bias=False,
                                     variable_dtype=variable_dtype,
                                     name="wo")

    # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
    # expert_output has shape [y0, x1, h, d, n]

    # alltoall
    expert_output = mtf.reshape(expert_output, mtf.Shape([y, x1, h0, d, n]))

    # combine results from inner level
    output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n])

    # Reshape the combined tensor from inner level to now contain outer_batch_dim
    # a0 and group_dim g
    output = mtf.reshape(output_y, [x1, a0, g, c, n])

    # alltoall from expert_dim x to group_dim g1
    expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n]))

    # combine results from outer level
    output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n])

    # Reshape the combined tensor to now contain inner_batch_dim
    # b1 and the original sequence length
    output = mtf.reshape(output_x, [a0, b1, l, n])
    if insert_outer_batch_dim:
        output = mtf.reshape(output, [b1, l, n])
    return output, (loss_outer + loss_inner) * hparams.moe_loss_coef
  def _mtf_model_fn(self, features, mesh):
    self._original_features = features
    features = copy.copy(features)
    hparams = self._hparams
    extra_losses = []
    targets = tf.to_int32(features["targets"])
    if len(targets.get_shape()) > 2:
      tf.logging.info("targets = %s" % targets)
      targets = tf.squeeze(targets, [2, 3])
    # pad targets to max_length
    def pad_to_max_length(x):
      extra_length = hparams.max_length - tf.shape(x)[1]
      x = tf.pad(x, [[0, 0], [0, extra_length]])
      x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
      return x
    targets = pad_to_max_length(targets)
    targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams)
    for key in ["targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"]:
      if key in features:
        features[key] = pad_to_max_length(features[key])
    if hparams.decoder_type == "autoregressive":
      shifted_targets = mtf.shift(
          targets, offset=1, dim=self.length_dim, wrap=False)
    elif hparams.decoder_type == "denoising":
      shifted_targets = self._noisy_targets(targets, extra_losses)
    else:
      raise ValueError(
          "unknown hparams.decoder_type = %s" % hparams.decoder_type)

    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = self._import_to_batch_by_length(
          features["targets_segmentation"], "targets_segmentation",
          mesh, hparams)
      targets_position = self._import_to_batch_by_length(
          features["targets_position"], "targets_position",
          mesh, hparams)
      decoder_self_attention_mask = mtf.layers.attention_mask_same_segment(
          targets_segmentation, dtype=self.activation_dtype)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
    else:
      targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
      else:
        decoder_self_attention_mask = None

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if hparams.transformer_type == "decoder":
      encoder_output = None
      encoder_decoder_attention_mask = None
    else:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = pad_to_max_length(inputs)
      inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams)
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = self._import_to_batch_by_length(
            features["inputs_segmentation"], "inputs_segmentation",
            mesh, hparams)
        inputs_position = self._import_to_batch_by_length(
            features["inputs_position"], "inputs_position",
            mesh, hparams)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                inputs_segmentation, dtype=self.activation_dtype))
      else:
        inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))

      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.gather(positional_embedding_var, inputs_position,
                      self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.encoder_layers,
                              self_attention_mask=encoder_self_attention_mask,
                              losses=extra_losses)

    if hparams.transformer_type == "encdec":
      if "inputs_segmentation" in features:
        encoder_decoder_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                targets_segmentation, inputs_segmentation,
                dtype=self.activation_dtype))
      else:
        encoder_decoder_attention_mask = encoder_self_attention_mask
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)

    if hparams.transformer_type != "encoder":
      # DECODER
      x = (mtf.gather(
          targets_embedding_var, shifted_targets, self.targets_vocab_dim) +
           mtf.gather(
               positional_embedding_var, targets_position, self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("decoder"):
        x = self._layer_stack(
            x,
            hparams.decoder_layers,
            encoder_output=encoder_output,
            self_attention_mask=decoder_self_attention_mask,
            encdec_attention_mask=encoder_decoder_attention_mask,
            losses=extra_losses)
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      # For some reason, the logits computation is extremely slow on TPU
      # in some cases where the batch size per core is 1.  Reshape the logits
      # and the targets to double the batch size and halve the length.
      # TODO(noam): file a bug.
      old_dims = self.batch_dims + [self.length_dim]
      new_dims = self.batch_dims[:-1] + [
          mtf.Dimension(self.batch_dims[-1].name,
                        self.batch_dims[-1].size * 2),
          mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)]
      x = mtf.reshape(x, new_dims + [self.model_dim])
      targets = mtf.reshape(targets, new_dims)

    logits = mtf.matmul(x, softmax_var)
    if hparams.mode == tf.estimator.ModeKeys.TRAIN:
      logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
    off_value = hparams.label_smoothing / self._targets_vocab_size
    on_value = 1.0 - hparams.label_smoothing + off_value
    soft_targets = mtf.one_hot(
        targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value,
        dtype=self.activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.targets_vocab_dim)
    weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype)
    loss = mtf.reduce_mean(loss * weights)
    for l in extra_losses:
      loss += l
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim])
    logits = mtf.to_float(logits)
    return logits, loss
Exemple #38
0
  def mtf_model_fn(self, features, mesh):
    features = copy.copy(features)
    tf.logging.info("features = %s" % features)
    hparams = self._hparams
    activation_dtype = self.set_activation_type()
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

    # Declare all the dimensions
    batch_dim = mtf.Dimension("batch", hparams.batch_size)
    hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
    filter_h_dim = mtf.Dimension("filter_height", 7)
    filter_w_dim = mtf.Dimension("filter_width", 7)
    filters = mtf.Dimension("filters", hparams.filter_sizes[0])
    rows_dim = mtf.Dimension("rows_size", hparams.rows_size)
    cols_dim = mtf.Dimension("cols_size", hparams.cols_size)
    row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
    col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
    classes_dim = mtf.Dimension("classes", 10)
    channels_dim = mtf.Dimension("channels", 3)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    inputs = features["inputs"]
    x = mtf.import_tf_tensor(
        mesh, tf.reshape(inputs, [
            hparams.batch_size,
            hparams.row_blocks,
            hparams.rows_size // hparams.row_blocks,
            hparams.col_blocks,
            hparams.num_channels*hparams.cols_size // hparams.col_blocks,
            hparams.num_channels]),
        mtf.Shape(
            [batch_dim, row_blocks_dim, rows_dim,
             col_blocks_dim, cols_dim, channels_dim]))
    x = mtf.transpose(x, [batch_dim, row_blocks_dim, col_blocks_dim,
                          rows_dim, cols_dim, channels_dim])

    x = mtf.to_float(x)
    initial_filters = mtf.get_variable(
        mesh, "init_filters",
        mtf.Shape([filter_h_dim, filter_w_dim, channels_dim, filters]))
    x = mtf.conv2d_with_blocks(
        x,
        initial_filters,
        strides=[1, 1, 1, 1],
        padding="SAME",
        h_blocks_dim=None, w_blocks_dim=col_blocks_dim)

    x = batch_norm_relu(x, is_training)

    # Conv blocks
    # [block - strided block layer - strided block layer] x n
    for layer in range(hparams.num_layers):
      layer_name = "block_layer_%d" % layer
      with tf.variable_scope(layer_name):
        # Residual block layer
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[0],
            blocks=hparams.layer_sizes[0],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer1",
            row_blocks_dim=None,
            col_blocks_dim=None)
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[1],
            blocks=hparams.layer_sizes[1],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer2",
            row_blocks_dim=None,
            col_blocks_dim=None)
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[2],
            blocks=hparams.layer_sizes[2],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer3",
            row_blocks_dim=None,
            col_blocks_dim=None)

    # Calculate the logits and loss.
    out = x
    outputs = mtf.layers.dense(
        out, hidden_dim,
        reduced_dims=out.shape.dims[-5:],
        activation=mtf.relu, name="dense")

    # We assume fixed vocab size for targets
    labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
    labels = mtf.import_tf_tensor(
        mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim]))

    logits = mtf.layers.dense(outputs, classes_dim, name="logits")
    soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, classes_dim)

    # Reshape logits so it doesn't break inside t2t.
    logits = mtf.reshape(
        logits,
        mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
    loss = mtf.reduce_mean(loss)
    return logits, loss