示例#1
0
def transformer_layers_sharded(dp,
                               ps_devices,
                               inputs,
                               num_layers,
                               hparams,
                               self_attention_bias=None,
                               enc_output=None,
                               attention_type=AttentionType.GLOBAL,
                               name="transformer"):
  """Multi layer transformer, sharded by the data parallelism dp."""
  x = inputs
  extra_loss = tf.constant(0.0)
  moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
  expert_fn = expert_utils.ffn_expert_fn(
      hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)
  x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
  for layer in xrange(num_layers):
    with tf.variable_scope("%s_layer_%d" % (name, layer)):
      # self-attention
      if attention_type == AttentionType.LOCAL_2D:
        y = dp(local_attention_2d(common_layers.layer_preprocess(x, hparams),
                                  hparams,
                                  attention_type="masked_local_attention_2d"))
      elif attention_type == AttentionType.LOCAL_1D:
        y = dp(local_attention_1d(common_layers.layer_preprocess(x, hparams),
                                  hparams,
                                  attention_type="local_mask_right",
                                  q_padding="LEFT", kv_padding="LEFT"))
      elif attention_type == AttentionType.GLOCAL:
        y = dp(local_global_attention(
            common_layers.layer_preprocess(x, hparams), self_attention_bias,
            hparams, q_padding="LEFT", kv_padding="LEFT"))
      elif attention_type == AttentionType.GLOBAL:
        self_attention_bias = dp(get_self_attention_bias(x))
        y = dp(full_self_attention(common_layers.layer_preprocess(x, hparams),
                                   self_attention_bias, hparams,
                                   q_padding="LEFT", kv_padding="LEFT"))
      x = common_layers.layer_postprocess(x, y, hparams)
      if enc_output is not None:
        y = dp(encdec_attention_1d(common_layers.layer_preprocess(x, hparams),
                                   enc_output, None, hparams))
        x = dp(common_layers.layer_postprocess, x, y, hparams)
      with tf.variable_scope("ffn"):
        if str(layer) in hparams.moe_layers_decoder.split(","):
          y, loss = expert_utils.distributed_moe(
              dp,
              ps_devices,
              common_layers.layer_preprocess(x, hparams),
              hparams.mode == tf.estimator.ModeKeys.TRAIN,
              input_size=hparams.hidden_size,
              expert_fn=expert_fn,
              num_experts=hparams.moe_num_experts,
              k=hparams.moe_k,
              loss_coef=hparams.moe_loss_coef)
          extra_loss += loss
          x = dp(common_layers.layer_postprocess, x, y, hparams)
        else:
          y = dp(ffn_layer, common_layers.layer_preprocess(x, hparams), hparams)
          x = dp(common_layers.layer_postprocess, x, y, hparams)
  return dp(common_layers.layer_preprocess, x, hparams), extra_loss
def transformer_layers_sharded(dp,
                               ps_devices,
                               inputs,
                               num_layers,
                               hparams,
                               self_attention_bias=None,
                               enc_output=None,
                               attention_type=AttentionType.GLOBAL,
                               name="transformer"):
  """Multi layer transformer, sharded by the data parallelism dp."""
  x = inputs
  extra_loss = tf.constant(0.0)
  moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
  expert_fn = expert_utils.ffn_expert_fn(
      hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)
  x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
  for layer in range(num_layers):
    with tf.variable_scope("%s_layer_%d" % (name, layer)):
      # self-attention
      if attention_type == AttentionType.LOCAL_2D:
        y = dp(local_attention_2d(common_layers.layer_preprocess(x, hparams),
                                  hparams,
                                  attention_type="masked_local_attention_2d"))
      elif attention_type == AttentionType.LOCAL_1D:
        y = dp(local_attention_1d(common_layers.layer_preprocess(x, hparams),
                                  hparams,
                                  attention_type="local_mask_right",
                                  q_padding="LEFT", kv_padding="LEFT"))
      elif attention_type == AttentionType.GLOCAL:
        y = dp(local_global_attention(
            common_layers.layer_preprocess(x, hparams), self_attention_bias,
            hparams, q_padding="LEFT", kv_padding="LEFT"))
      elif attention_type == AttentionType.GLOBAL:
        self_attention_bias = dp(get_self_attention_bias(x))
        y = dp(full_self_attention(common_layers.layer_preprocess(x, hparams),
                                   self_attention_bias, hparams,
                                   q_padding="LEFT", kv_padding="LEFT"))
      x = common_layers.layer_postprocess(x, y, hparams)
      if enc_output is not None:
        y = dp(encdec_attention_1d(common_layers.layer_preprocess(x, hparams),
                                   enc_output, None, hparams))
        x = dp(common_layers.layer_postprocess, x, y, hparams)
      with tf.variable_scope("ffn"):
        if str(layer) in hparams.moe_layers_decoder.split(","):
          y, loss = expert_utils.distributed_moe(
              dp,
              ps_devices,
              common_layers.layer_preprocess(x, hparams),
              hparams.mode == tf.estimator.ModeKeys.TRAIN,
              input_size=hparams.hidden_size,
              expert_fn=expert_fn,
              num_experts=hparams.moe_num_experts,
              k=hparams.moe_k,
              loss_coef=hparams.moe_loss_coef)
          extra_loss += loss
          x = dp(common_layers.layer_postprocess, x, y, hparams)
        else:
          y = dp(ffn_layer, common_layers.layer_preprocess(x, hparams), hparams)
          x = dp(common_layers.layer_postprocess, x, y, hparams)
  return dp(common_layers.layer_preprocess, x, hparams), extra_loss
示例#3
0
def build_a_layer(layer_input, layer_type, is_training, config):
    """Build a single layer."""
    batch_size = tf.shape(layer_input)[0]
    net = common_layers.layer_norm(layer_input)
    if layer_type == 'att':
        net = common_attention.multihead_attention(
            query_antecedent=net,
            memory_antecedent=None,
            bias=None,
            total_key_depth=config['hidden_size'],
            total_value_depth=config['hidden_size'],
            output_depth=config['hidden_size'],
            num_heads=config['attention_heads'],
            dropout_rate=config['attention_dropout'],
            attention_type=config['attention_type'])
    elif layer_type == 'conv':
        if config['preactivate']:
            if config['activation'] == 'relu':
                net = tf.nn.relu(net)
            elif config['activation'] == 'swish':
                net = swish(net)
            elif config['activation'] == 'prelu':
                net = parametric_relu(net)
            net = separable_conv(net, config['kernel_size'], activation=None)
        else:
            if config['activation'] == 'relu':
                net = separable_conv(net,
                                     config['kernel_size'],
                                     activation=tf.nn.relu)
            elif config['activation'] == 'swish':
                net = separable_conv(net,
                                     config['kernel_size'],
                                     activation=swish)
            elif config['activation'] == 'prelu':
                net = separable_conv(net,
                                     config['kernel_size'],
                                     activation=parametric_relu)
    elif layer_type == 'ffn':
        net = tf.reshape(net, [-1, config['hidden_size']])
        net = expert_utils.ffn_expert_fn(
            input_size=config['hidden_size'],
            hidden_sizes=[config['ffn_hs_factor'] * config['hidden_size']],
            output_size=config['hidden_size'],
        )(net)
        net = tf.reshape(net, [batch_size, -1, config['hidden_size']])
    else:
        raise ValueError('Unknown layer type %s' % layer_type)

    if config['layer_output_dropout'] > 0.0 and is_training:
        net = tf.nn.dropout(net, 1.0 - config['layer_output_dropout'])
    return net
示例#4
0
def conv_experts(xs, hparams, dp, ps, padding, mask, layer_id):
  """Convolutions + Mixture-of-Experts layer."""
  del layer_id  # Unused.
  train = hparams.mode == tf.estimator.ModeKeys.TRAIN,
  conv_out = dp(conv_res_step, xs, hparams, padding, mask)
  loss = 0.0
  moe_hidden_sizes = [hparams.filter_size]
  expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes,
                                         hparams.hidden_size)
  moe_out, loss = expert_utils.distributed_moe(
      dp,
      ps,
      xs,
      train,
      input_size=hparams.hidden_size,
      expert_fn=expert_fn,
      num_experts=hparams.moe_num_experts,
      k=hparams.moe_k,
      loss_coef=1.0)
  return dp(residual_fn3, xs, moe_out, conv_out, hparams), loss
示例#5
0
def conv_experts(xs, hparams, dp, ps, padding, mask, layer_id):
  """Convolutions + Mixture-of-Experts layer."""
  del layer_id  # Unused.
  train = hparams.mode == tf.estimator.ModeKeys.TRAIN,
  conv_out = dp(conv_res_step, xs, hparams, padding, mask)
  loss = 0.0
  moe_hidden_sizes = [hparams.filter_size]
  expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes,
                                         hparams.hidden_size)
  moe_out, loss = expert_utils.distributed_moe(
      dp,
      ps,
      xs,
      train,
      input_size=hparams.hidden_size,
      expert_fn=expert_fn,
      num_experts=hparams.moe_num_experts,
      k=hparams.moe_k,
      loss_coef=1.0)
  return dp(residual_fn3, xs, moe_out, conv_out, hparams), loss
  def body_sharded(self, sharded_features):
    # Remove dropout if not training
    hparams = self._hparams
    dp = self._data_parallelism
    if hparams.use_inputs:
      decoder_input = dp(tf.squeeze, sharded_features["inputs"], 2)
      decoder_self_attention_bias = None
    else:
      targets = sharded_features["targets"]
      targets = dp(tf.squeeze, targets, 2)
      (decoder_input, decoder_self_attention_bias, pad_remover) = dp(
          attention_lm_moe_prepare_decoder, targets, hparams)

    def preprocess(x):
      return dp(common_layers.layer_preprocess, x, hparams)

    def postprocess(x, y):
      return dp(common_layers.layer_postprocess, x, y, hparams)

    x = dp(tf.nn.dropout, decoder_input,
           1.0 - hparams.layer_prepostprocess_dropout)
    extra_loss = 0.0
    moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
    if hparams.diet_experts:
      hsize, = moe_hidden_sizes

      def _diet_expert(x):
        return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params())

      expert_fn = _diet_expert
    else:
      expert_fn = expert_utils.ffn_expert_fn(
          hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)

    if not hparams.use_inputs:
      # As preprocess and postprocess are called with batch of size one (all
      # batches concatenated), we just make sure that batch_norm is not use (
      # should not either way)
      assert hparams.norm_type != "batch"

      tf.logging.info("Applying Padding Remover for the attention experts")

      dp_remove_pad = functools.partial(
          dp, remove_pad, pad_remover=pad_remover, mode=hparams.mode)
      dp_restore_pad = functools.partial(
          dp, restore_pad, ref_x=x, pad_remover=pad_remover, mode=hparams.mode)
    else:
      # Using identity function: No effect
      dp_remove_pad = lambda x: x
      dp_restore_pad = lambda x: x

    if hparams.attention_exp_factor != 0:
      tf.logging.info("Expand/compress tokens before sending them to experts")
      dp_expand_bc = lambda x: dp(  # pylint: disable=g-long-lambda
          expand_batch_coordinates,
          x,
          hparams.attention_exp_factor)
      dp_expand_x = lambda x: dp(  # pylint: disable=g-long-lambda
          common_attention.deconv_elems_1d,
          x,
          hparams.attention_exp_factor,
          hparams.attention_exp_inputdim)
      dp_compress_x = lambda x, l: dp(  # pylint: disable=g-long-lambda
          common_attention.conv_elems_1d,
          x,
          hparams.attention_exp_factor,
          l)
    else:
      dp_expand_bc = lambda x: x
      dp_expand_x = lambda x: x
      dp_compress_x = lambda x, l: x

    def print_shape(x, suffix, debug=False):
      # To help debugging, print the input/output shapes at inference and eval
      # Inference for long sequences can take a long time, so that's help to
      # see the progession of the generation
      if not debug and hparams.mode == ModeKeys.TRAIN:
        return x
      return tf.Print(x, [tf.shape(x)], "shape_x_{}".format(suffix))

    with tf.name_scope("batch_coordinate_preprocess"):
      batch_coordinate = dp(get_batch_coordinate, x)
      batch_coordinate = dp_remove_pad(batch_coordinate)
      batch_coordinate = dp_expand_bc(batch_coordinate)
      batch_order = dp(get_batch_coordinate, x, axis=-1)
      batch_order = dp_remove_pad(batch_order)
      batch_order = dp_expand_bc(batch_order)

    x = dp(print_shape, x, "in")

    assert hparams.batch_size >= hparams.max_length

    num_hidden_layers = (
        len(hparams.attention_layers) or hparams.num_hidden_layers)
    for layer in xrange(num_hidden_layers):
      with tf.variable_scope("layer_%d" % layer):

        # Use the layer type defined in attention_layers
        if hparams.attention_layers:
          attention_type = LAYER_SYMBOLS[hparams.attention_layers[layer]]
        else:
          attention_type = hparams.attention_type

        with tf.variable_scope(
            "attention_{}".format(attention_type)):
          if attention_type in [
              AttentionType.MULTIHEAD, AttentionType.MULTIHEAD_FULL]:
            attention_dot_type = (
                "local_mask_right" if hparams.attention_local else
                "dot_product")
            if attention_type == AttentionType.MULTIHEAD_FULL:
              attention_dot_type = "dot_product"
            y = dp(
                common_attention.multihead_attention,
                preprocess(x),
                None,
                decoder_self_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                attention_type=attention_dot_type,
                block_length=hparams.attention_block_length,
                name="decoder_self_attention")
          elif attention_type == AttentionType.SPARSE_MULTIHEAD:
            x_in = preprocess(x)
            x_in = dp_remove_pad(x_in)
            y, loss_experts = dp(
                common_attention.multihead_attention_sparse_dot_prod,
                x_in,
                None,
                None,  # Bias is computed inside
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,

                # Additional parameters
                bi=[common_attention.BatchInfo(
                    coordinates=batch_coordinate[i],
                    order=batch_order[i],  # No future mask
                ) for i in range(dp.n)],
                use_map_fn=hparams.lsh_use_map_fn,
                experts_params=dict(
                    nb_hyperplanes=hparams.lsh_num_hyperplanes,
                ),
            )
            y = dp_restore_pad(y)

            # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
            extra_loss += tf.add_n(loss_experts) / dp.n
          elif attention_type == AttentionType.SPARSE_MULTIHEAD_TRUNCATED:
            x_in = preprocess(x)
            y, loss_experts = dp(
                common_attention.multihead_attention_sparse_truncated,
                x_in,
                None,
                None,  # Bias is computed inside
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,

                # Additional parameters
                bi=[common_attention.BatchInfo(
                    coordinates=batch_coordinate[i],
                    order=batch_order[i],  # No future mask
                ) for i in range(dp.n)],
                mask_right=True,
                experts_params=dict(
                    nb_hyperplanes=hparams.lsh_num_hyperplanes,
                ),
            )

            # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
            extra_loss += tf.add_n(loss_experts) / dp.n
          elif attention_type == AttentionType.MEMORY_EFFICIENT:
            assert hparams.layer_preprocess_sequence == "n"
            y = dp(
                common_attention.multihead_self_attention_memory_efficient,
                x,
                decoder_self_attention_bias,
                hparams.num_heads,
                name="decoder_self_attention")
          elif attention_type == AttentionType.MULTIHEAD_REDUCED:
            y = dp(
                common_attention.multihead_self_attention_reduced,
                preprocess(x),
                factor=hparams.attention_red_factor,
                reduction_type=hparams.attention_reduction_type,
                nonlinearity=hparams.attention_nonlinearity,
                multihead_params=dict(
                    total_key_depth=
                    hparams.attention_key_channels or hparams.hidden_size,
                    total_value_depth=
                    hparams.attention_value_channels or hparams.hidden_size,
                    num_heads=hparams.num_heads,
                    dropout_rate=hparams.attention_dropout,
                ))
          elif attention_type == AttentionType.LOCAL_EXPERTS:
            x_in = preprocess(x)
            x_in = dp_remove_pad(x_in)
            x_in = dp_expand_x(x_in)
            y, loss = dp(
                common_attention.local_expert_attention,
                x_in,
                k=hparams.attention_moe_k,
                loss_coef=hparams.attention_load_balance,
                attention_num_experts=hparams.attention_num_experts,
                train=hparams.mode == ModeKeys.TRAIN,
                batch_coordinate=batch_coordinate,
                mask_right=not hparams.use_inputs,
                split_batch=bool(hparams.attention_split_batch),
                attention_num_head=hparams.attention_num_head,
                attention_kq_size=hparams.attention_kq_size,
                attention_v_size=hparams.attention_v_size)
            y = dp_compress_x(y, x[0].get_shape().as_list()[-1])
            y = dp_restore_pad(y)
            # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
            extra_loss += tf.add_n(loss) / dp.n
          else:
            raise ValueError("Only {} supported for now.".format(
                AttentionType.get_choices()))
          x = postprocess(x, y)
        with tf.variable_scope("ffn"):
          if str(layer) in hparams.moe_layers.split(","):
            y, loss = expert_utils.distributed_moe(
                dp,
                self._ps_devices,
                preprocess(x),
                hparams.mode == ModeKeys.TRAIN,
                input_size=hparams.hidden_size,
                expert_fn=expert_fn,
                num_experts=hparams.moe_num_experts,
                k=hparams.moe_k,
                loss_coef=hparams.moe_loss_coef)
            extra_loss += loss
          elif hparams.memory_efficient_ffn:
            assert hparams.layer_preprocess_sequence == "n"
            y = dp(
                common_layers.conv_hidden_relu_memory_efficient,
                x,
                hparams.filter_size)
          else:
            additional_conv_params = dict()
            if hparams.use_sepconv:
              additional_conv_params = dict(
                  padding="LEFT",
                  # Parameters copied from the transformer model
                  kernel_size=(3, 1),
                  second_kernel_size=(31, 1),
              )
            y = dp(
                common_layers.conv_hidden_relu,
                preprocess(x),
                hparams.filter_size,
                hparams.hidden_size,
                dropout=hparams.relu_dropout,
                **additional_conv_params
            )
          x = postprocess(x, y)
    x = preprocess(x)

    decoder_output = dp(tf.expand_dims, x, 2)
    return decoder_output, extra_loss
示例#7
0
  def body_sharded(self, sharded_features):
    train = self._hparams.mode == tf.estimator.ModeKeys.TRAIN
    dp = self._data_parallelism
    hparams = self._hparams

    def project_to_hidden(inputs):
      return common_layers.conv_block(
          inputs,
          hparams.hidden_size, [((1, 1), (3, 3))],
          first_relu=False,
          padding="SAME",
          force2d=True)

    def flatten(inputs):
      return tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)

    # Project to hidden size if necessary
    if (sharded_features["inputs"][0].get_shape().as_list()[-1] !=
        hparams.hidden_size):
      inputs = dp(project_to_hidden, sharded_features["inputs"])

    inputs = dp(flatten, inputs)
    inputs_pad = dp(slicenet.embedding_to_padding, inputs)
    inputs_mask = dp(lambda x: 1.0 - x, inputs_pad)
    inputs_encoded = dp(common_layers.add_timing_signal, inputs)
    expert_loss = 0.0
    for i in xrange(hparams.num_hidden_layers):
      with tf.variable_scope("enc_layer_%d" % i):
        inputs_encoded, moe_loss = conv_experts(inputs_encoded, hparams, dp,
                                                self._ps_devices, "SAME",
                                                inputs_mask, i)
        expert_loss += tf.reduce_mean(moe_loss) * hparams.moe_loss_coef

    # If we're just predicing a class, there is no use for a decoder, return.
    if isinstance(hparams.problems[self._problem_idx].target_modality,
                  modalities.ClassLabelModality):
      return inputs_encoded, tf.reduce_mean(expert_loss)

    # Decoder.
    inputs3d = dp(tf.squeeze, inputs, 2)
    inputs_encoded3d = dp(tf.squeeze, inputs_encoded, 2)
    encoder_padding = dp(common_attention.embedding_to_padding, inputs3d)
    encoder_attention_bias = dp(common_attention.attention_bias_ignore_padding,
                                encoder_padding)
    targets = dp(common_layers.flatten4d3d, sharded_features["targets"])
    target_space_emb = dp(slicenet.embed_target_space,
                          sharded_features["target_space_id"],
                          hparams.hidden_size)

    (decoder_input, decoder_self_attention_bias) = dp(prepare_decoder, targets,
                                                      target_space_emb)

    moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
    expert_fn = expert_utils.ffn_expert_fn(
        hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)
    x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.dropout)
    for layer in xrange(hparams.num_hidden_layers):
      with tf.variable_scope("dec_layer_%d" % layer):
        with tf.variable_scope("attention"):
          y = dp(
              common_attention.multihead_attention,
              x,
              None,
              decoder_self_attention_bias,
              hparams.hidden_size,
              hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              name="decoder_self_attention")
          z = dp(
              common_attention.multihead_attention,
              y,
              inputs_encoded3d,
              encoder_attention_bias,
              hparams.hidden_size,
              hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              name="encdec_attention")
          x = dp(residual_fn3, x, y, z, hparams)
        with tf.variable_scope("ffn"):
          if str(layer) in hparams.moe_layers.split(","):
            y, moe_loss = expert_utils.distributed_moe(
                dp,
                self._ps_devices,
                x,
                train,
                input_size=hparams.hidden_size,
                expert_fn=expert_fn,
                num_experts=hparams.moe_num_experts,
                k=hparams.moe_k,
                loss_coef=hparams.moe_loss_coef)
            expert_loss += tf.reduce_mean(moe_loss)
          else:
            y = dp(
                common_layers.conv_hidden_relu,
                x,
                hparams.filter_size,
                hparams.hidden_size,
                dropout=hparams.dropout)
          x = dp(residual_fn2, x, y, hparams)

    x = dp(tf.expand_dims, x, 2)
    return x, tf.reduce_mean(expert_loss)
示例#8
0
    def model_fn_body_sharded(self, sharded_features):
        # Remove dropout if not training
        hparams = self._hparams
        dp = self._data_parallelism
        x = dp(tf.squeeze, sharded_features["inputs"], 2)

        def preprocess(x):
            return dp(common_layers.layer_preprocess, x, hparams)

        def postprocess(x, y):
            return dp(common_layers.layer_postprocess, x, y, hparams)

        x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
        extra_loss = 0.0
        ffn_hidden_sizes = [
            int(s) for s in hparams.ffn_hidden_sizes.split(",")
        ]
        moe_hidden_sizes = [
            int(s) for s in hparams.moe_hidden_sizes.split(",")
        ]
        if hparams.mask_right:

            def _bias(x):
                return common_attention.attention_bias_lower_triangle(
                    tf.shape(x)[1])

            bias = dp(_bias, x)
        else:
            bias = tf.zeros([1, 1, 1, 1])
        if hparams.diet_experts:
            hsize, = moe_hidden_sizes

            def _diet_expert(x):
                return diet.diet_expert(x, hsize,
                                        diet.diet_adam_optimizer_params())

            expert_fn = _diet_expert
        else:
            expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size,
                                                   moe_hidden_sizes,
                                                   hparams.hidden_size)

        batch_coordinate = dp(get_batch_coordinate, x)

        layers = hparams.layers.strip(",").split(",")
        for layer_num, layer_type in enumerate(layers):
            with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
                if _should_preprocess(layer_type):
                    x = preprocess(x)
                if layer_type == "timing":
                    y = dp(common_attention.add_timing_signal_nd, x)
                elif layer_type == "pos_emb":
                    y = dp(common_attention.add_positional_embedding_nd,
                           x,
                           hparams.max_length,
                           name="pos_emb")
                elif layer_type == "att":
                    y = dp(
                        common_attention.multihead_attention,
                        x,
                        None,
                        bias,  # bias
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout)
                elif layer_type == "att_grouped":
                    multiplicative_overhead = (
                        hparams.multiplicative_overhead
                        if hparams.mode == ModeKeys.TRAIN else
                        hparams.multiplicative_overhead_eval)
                    y, loss = dp(
                        common_attention.grouped_attention_multihead,
                        x,
                        x,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        num_groups=hparams.attention_num_groups,
                        memory_target_density=hparams.memory_target_density,
                        multiplicative_overhead=multiplicative_overhead,
                        make_image_summary=hparams.attention_image_summary,
                        mask_right=hparams.mask_right,
                    )
                    extra_loss += tf.add_n(loss) / dp.n
                elif layer_type == "att_memory_efficient":
                    assert hparams.layer_preprocess_sequence == "n"
                    y = dp(
                        common_attention.
                        multihead_self_attention_memory_efficient, x, bias,
                        hparams.num_heads)
                elif layer_type == "att_local":
                    y = dp(
                        common_attention.multihead_attention,
                        x,
                        None,
                        None,  # bias
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=("local_mask_right"
                                        if hparams.mask_right else
                                        "local_unmasked"),
                        block_length=hparams.local_attention_window,
                        block_width=hparams.local_attention_window)
                elif layer_type == "att_pseudolocal":
                    # This is an inefficient implementation of local attention, for the
                    # purpose of testing model quality.
                    def _pseudolocal_bias(x):
                        return common_attention.attention_bias_local(
                            tf.shape(x)[1], hparams.local_attention_window,
                            0 if hparams.mask_right else
                            hparams.local_attention_window)

                    pseudolocal_bias = dp(_pseudolocal_bias, x)
                    y = dp(
                        common_attention.multihead_attention, x, None,
                        pseudolocal_bias, hparams.attention_key_channels
                        or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size, hparams.hidden_size,
                        hparams.num_heads, hparams.attention_dropout)
                elif layer_type == "att_local_expert":
                    y, loss = dp(
                        common_attention.local_expert_attention,
                        x,
                        k=hparams.attention_moe_k,
                        loss_coef=hparams.attention_load_balance,
                        attention_num_experts=hparams.attention_num_experts,
                        train=hparams.mode == ModeKeys.TRAIN,
                        batch_coordinate=batch_coordinate,
                        mask_right=hparams.mask_right,
                        split_batch=bool(hparams.attention_split_batch),
                        attention_kq_size=hparams.attention_kq_size,
                        attention_v_size=hparams.attention_v_size)
                    # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                    extra_loss += tf.add_n(loss) / dp.n
                elif layer_type == "att_lsh":
                    if hparams.lsh_truncated:
                        attention_fn = common_attention.multihead_attention_sparse_truncated
                    else:
                        attention_fn = common_attention.multihead_attention_sparse_dot_prod
                    y, loss = dp(
                        attention_fn,
                        x,
                        None,
                        None,  # Bias is computed inside
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,

                        # Additional parameters
                        bi=[
                            common_attention.BatchInfo(
                                coordinates=batch_coordinate[i],
                                order=None,  # No future mask
                            ) for i in range(dp.n)
                        ],
                        use_map_fn=False,
                        experts_params=dict(nb_hyperplanes=4, ))
                    extra_loss += tf.add_n(loss) / dp.n
                elif layer_type == "moe":
                    y, loss = expert_utils.distributed_moe(
                        dp,
                        self._ps_devices,
                        x,
                        hparams.mode == ModeKeys.TRAIN,
                        input_size=hparams.hidden_size,
                        expert_fn=expert_fn,
                        num_experts=hparams.moe_num_experts,
                        k=hparams.moe_k,
                        loss_coef=hparams.moe_loss_coef)
                    extra_loss += loss
                elif layer_type == "ffn":
                    y = dp(
                        expert_utils.ffn_expert_fn(hparams.hidden_size,
                                                   ffn_hidden_sizes,
                                                   hparams.hidden_size),
                        dp(expert_utils.flatten_all_but_last, x))
                    y = dp(expert_utils.reshape_like, y, x)
                elif layer_type == "conv":
                    y = dp(
                        common_layers.conv1d,
                        x,
                        hparams.hidden_size,
                        hparams.kernel_height,
                        activation=tf.nn.relu,
                        padding="SAME",
                    )
                else:
                    assert False, "unknown sublayer %s" % layer_type
                if _should_postprocess(layer_type):
                    x = postprocess(x, y)
                else:
                    x = y
        x = preprocess(x)

        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, extra_loss
示例#9
0
  def body_sharded(self, sharded_features):
    # Remove dropout if not training
    hparams = self._hparams
    dp = self._data_parallelism
    x = dp(tf.squeeze, sharded_features["inputs"], 2)

    def preprocess(x):
      return dp(common_layers.layer_preprocess, x, hparams)

    def postprocess(x, y):
      return dp(common_layers.layer_postprocess, x, y, hparams)

    x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
    extra_loss = 0.0
    ffn_hidden_sizes = [int(s) for s in hparams.ffn_hidden_sizes.split(",")]
    moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
    if hparams.mask_right:

      def _bias(x):
        return common_attention.attention_bias_lower_triangle(
            common_layers.shape_list(x)[1])

      bias = dp(_bias, x)
    else:
      bias = tf.zeros([1, 1, 1, 1])
    if hparams.diet_experts:
      hsize, = moe_hidden_sizes

      def _diet_expert(x):
        return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params())

      expert_fn = _diet_expert
    else:
      expert_fn = expert_utils.ffn_expert_fn(
          hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)

    batch_coordinate = dp(get_batch_coordinate, x)

    layers = hparams.layers.strip(",").split(",")
    for layer_num, layer_type in enumerate(layers):
      with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
        if _should_preprocess(layer_type):
          x = preprocess(x)
        if layer_type == "timing":
          y = dp(common_attention.add_timing_signal_nd, x)
        elif layer_type == "pos_emb":
          y = dp(
              common_attention.add_positional_embedding_nd,
              x,
              hparams.max_length,
              name="pos_emb")
        elif layer_type == "att":
          y = dp(
              common_attention.multihead_attention,
              x,
              None,
              bias,  # bias
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout)
        elif layer_type == "att_grouped":
          multiplicative_overhead = (
              hparams.multiplicative_overhead if hparams.mode == ModeKeys.TRAIN
              else hparams.multiplicative_overhead_eval)
          y, loss = dp(
              common_attention.grouped_attention_multihead,
              x,
              x,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              num_groups=hparams.attention_num_groups,
              memory_target_density=hparams.memory_target_density,
              multiplicative_overhead=multiplicative_overhead,
              make_image_summary=hparams.attention_image_summary,
              mask_right=hparams.mask_right,
          )
          extra_loss += tf.add_n(loss) / dp.n
        elif layer_type == "att_memory_efficient":
          assert hparams.layer_preprocess_sequence == "n"
          y = dp(common_attention.multihead_self_attention_memory_efficient, x,
                 bias, hparams.num_heads)
        elif layer_type == "att_local":
          y = dp(
              common_attention.multihead_attention,
              x,
              None,
              None,  # bias
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              attention_type=("local_mask_right"
                              if hparams.mask_right else "local_unmasked"),
              block_length=hparams.local_attention_window,
              block_width=hparams.local_attention_window)
        elif layer_type == "att_pseudolocal":
          # This is an inefficient implementation of local attention, for the
          # purpose of testing model quality.
          def _pseudolocal_bias(x):
            return common_attention.attention_bias_local(
                common_layers.shape_list(x)[1], hparams.local_attention_window,
                0 if hparams.mask_right else hparams.local_attention_window)

          pseudolocal_bias = dp(_pseudolocal_bias, x)
          y = dp(common_attention.multihead_attention, x, None,
                 pseudolocal_bias, hparams.attention_key_channels or
                 hparams.hidden_size, hparams.attention_value_channels or
                 hparams.hidden_size, hparams.hidden_size, hparams.num_heads,
                 hparams.attention_dropout)
        elif layer_type == "att_local_expert":
          y, loss = dp(
              common_attention.local_expert_attention,
              x,
              k=hparams.attention_moe_k,
              loss_coef=hparams.attention_load_balance,
              attention_num_experts=hparams.attention_num_experts,
              train=hparams.mode == ModeKeys.TRAIN,
              batch_coordinate=batch_coordinate,
              mask_right=hparams.mask_right,
              split_batch=bool(hparams.attention_split_batch),
              attention_kq_size=hparams.attention_kq_size,
              attention_v_size=hparams.attention_v_size)
          # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
          extra_loss += tf.add_n(loss) / dp.n
        elif layer_type == "att_lsh":
          if hparams.lsh_truncated:
            attention_fn = common_attention.multihead_attention_sparse_truncated
          else:
            attention_fn = common_attention.multihead_attention_sparse_dot_prod
          y, loss = dp(
              attention_fn,
              x,
              None,
              None,  # Bias is computed inside
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,

              # Additional parameters
              bi=[
                  common_attention.BatchInfo(
                      coordinates=batch_coordinate[i],
                      order=None,  # No future mask
                  ) for i in range(dp.n)
              ],
              use_map_fn=False,
              experts_params=dict(nb_hyperplanes=4,))
          extra_loss += tf.add_n(loss) / dp.n
        elif layer_type == "moe":
          y, loss = expert_utils.distributed_moe(
              dp,
              self._ps_devices,
              x,
              hparams.mode == ModeKeys.TRAIN,
              input_size=hparams.hidden_size,
              expert_fn=expert_fn,
              num_experts=hparams.moe_num_experts,
              k=hparams.moe_k,
              loss_coef=hparams.moe_loss_coef)
          extra_loss += loss
        elif layer_type == "ffn":
          y = dp(
              expert_utils.ffn_expert_fn(hparams.hidden_size, ffn_hidden_sizes,
                                         hparams.hidden_size),
              dp(expert_utils.flatten_all_but_last, x))
          y = dp(expert_utils.reshape_like, y, x)
        elif layer_type == "conv":
          y = dp(
              common_layers.conv1d,
              x,
              hparams.hidden_size,
              hparams.kernel_height,
              activation=tf.nn.relu,
              padding="SAME",
          )
        else:
          assert False, "unknown sublayer %s" % layer_type
        if _should_postprocess(layer_type):
          x = postprocess(x, y)
        else:
          x = y
    x = preprocess(x)

    decoder_output = dp(tf.expand_dims, x, 2)
    return decoder_output, extra_loss
示例#10
0
def _super_stack(inputs, attention_bias, hparams, mp, padding="LEFT"):
    """A stack of super_lm layers.

  Args:
    inputs: a list of Tensors
    attention_bias: list of bias Tensor for self-attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    mp: a Parallelism object
    padding: a string

  Returns:
    y: a list of Tensors
    extra_loss: an optional scalar
  """
    layers = hparams.layers.strip(",").split(",")
    moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
    if hparams.diet_experts:
        hsize, = moe_hidden_sizes

        def _diet_expert(x):
            return diet.diet_expert(x, hsize,
                                    diet.diet_adam_optimizer_params())

        expert_fn = _diet_expert
    else:
        expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size,
                                               moe_hidden_sizes,
                                               hparams.hidden_size)
    # scaled_dot_product_attention_with_projections uses a 3d attention bias
    # (no heads), where multihead_attention uses 4d attention bias.
    attention_bias_3d = mp(tf.squeeze, attention_bias, 1)
    mix_size = int(hparams.mix_fraction * hparams.hidden_size)
    accumulator = inputs
    x = inputs
    extra_losses = []
    for layer_num, layer_type in enumerate(layers):
        with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
            tf.logging.info("%s_%d" % (layer_type, layer_num))
            if layer_type == "a":
                # accumulate
                accumulator = mp(tf.add, x, accumulator)
                x = accumulator
            elif layer_type == "n":
                # normalize
                x = mp(common_layers.apply_norm, x, hparams.norm_type,
                       hparams.hidden_size, hparams.norm_epsilon)
            elif layer_type == "d":
                # dropout
                x = mp(tf.nn.dropout, x,
                       1.0 - hparams.layer_prepostprocess_dropout)
            elif layer_type == "m":
                # mix across shards
                def _split(t):
                    return tuple(
                        tf.split(t, [mix_size, hparams.hidden_size - mix_size],
                                 2))

                to_mix, to_keep = mp(_split, x)
                mixed = expert_utils.all_reduce_ring(to_mix, mp)
                mixed = mp(tf.multiply, mixed, mp.n**-0.5)
                x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep)
            elif layer_type == "att":
                # single-head attention
                q = mp(tf.layers.dense,
                       x,
                       hparams.hidden_size,
                       use_bias=False,
                       name="q_transform")
                x = mp(common_attention.scaled_dot_product_attention_simple, q,
                       x, x, attention_bias_3d)
                x = mp(tf.layers.dense,
                       x,
                       hparams.hidden_size,
                       use_bias=False,
                       name="o_transform")
            elif layer_type == "multihead-att":
                # multi-head attention
                x = mp(
                    common_attention.multihead_attention,
                    x,
                    None,
                    attention_bias,  # bias
                    hparams.multihead_attention_key_channels
                    or hparams.hidden_size,
                    hparams.multihead_attention_value_channels
                    or hparams.hidden_size,
                    hparams.hidden_size,
                    hparams.multihead_attention_num_heads,
                    hparams.attention_dropout)
            elif layer_type == "ffn":
                x = mp(common_layers.dense_relu_dense, x, hparams.filter_size,
                       hparams.hidden_size)
            elif layer_type == "conv":
                # convolution
                x = mp(
                    common_layers.conv1d,
                    x,
                    hparams.hidden_size,
                    hparams.kernel_height,
                    activation=tf.nn.relu,
                    padding=padding,
                )
            elif layer_type == "moe":
                # mixture of experts - each model shard has its own local MoE.
                x, loss = mp(expert_utils.local_moe,
                             x,
                             train=hparams.mode == tf_estimator.ModeKeys.TRAIN,
                             expert_fn=expert_fn,
                             num_experts=hparams.moe_num_experts,
                             k=hparams.moe_k,
                             loss_coef=hparams.moe_loss_coef)
                extra_losses.extend(loss)
            else:
                assert False, "unknown sublayer %s" % layer_type
    if extra_losses:
        extra_loss = tf.add_n(extra_losses)
    else:
        extra_loss = None
    return x, extra_loss
示例#11
0
    def model_fn_body_sharded(self, sharded_features):
        hparams = self._hparams
        dp = self._data_parallelism
        targets = sharded_features["targets"]
        inputs = sharded_features["inputs"]
        target_space = sharded_features["target_space_id"]

        inputs = dp(common_layers.flatten4d3d, inputs)
        targets = dp(common_layers.flatten4d3d, targets)

        def preprocess(x):
            return dp(common_layers.layer_preprocess, x, hparams)

        def postprocess(x, y):
            return dp(common_layers.layer_postprocess, x, y, hparams)

        (encoder_input, encoder_self_attention_bias,
         encoder_decoder_attention_bias) = dp(
            transformer.transformer_prepare_encoder,
            inputs, target_space, hparams)
        (decoder_input, decoder_self_attention_bias) = dp(
            transformer.transformer_prepare_decoder, targets, hparams)
        encoder_input = dp(tf.nn.dropout, encoder_input,
                           1.0 - hparams.layer_prepostprocess_dropout)
        decoder_input = dp(tf.nn.dropout, decoder_input,
                           1.0 - hparams.layer_prepostprocess_dropout)
        extra_loss = 0
        moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
        expert_fn = expert_utils.ffn_expert_fn(
            hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)
        x = encoder_input
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("encoder_layer_%d" % layer):
                with tf.variable_scope("encoder_self_attention"):
                    y = dp(
                        common_attention.multihead_attention,
                        preprocess(x),
                        None,
                        encoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout)
                    x = postprocess(x, y)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers_encoder.split(","):
                        y, loss = expert_utils.distributed_moe(
                            dp,
                            self._ps_devices,
                            preprocess(x),
                            hparams.mode == tf.contrib.learn.ModeKeys.TRAIN,
                            input_size=hparams.hidden_size,
                            expert_fn=expert_fn,
                            num_experts=hparams.moe_num_experts,
                            k=hparams.moe_k,
                            loss_coef=hparams.moe_loss_coef)
                        extra_loss += loss
                    else:
                        y = dp(
                            common_layers.conv_hidden_relu,
                            preprocess(x),
                            hparams.filter_size,
                            hparams.hidden_size,
                            dropout=hparams.relu_dropout)
                    x = postprocess(x, y)
        encoder_output = preprocess(x)
        x = decoder_input
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("decoder_layer_%d" % layer):
                with tf.variable_scope("decoder_self_attention"):
                    y = dp(
                        common_attention.multihead_attention,
                        preprocess(x),
                        None,
                        decoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout)
                    x = postprocess(x, y)
                with tf.variable_scope("encoder_decoder_attention"):
                    y = dp(
                        common_attention.multihead_attention,
                        preprocess(x),
                        encoder_output,
                        encoder_decoder_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout)
                    x = postprocess(x, y)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers_decoder.split(","):
                        y, loss = expert_utils.distributed_moe(
                            dp,
                            self._ps_devices,
                            preprocess(x),
                            hparams.mode == tf.contrib.learn.ModeKeys.TRAIN,
                            input_size=hparams.hidden_size,
                            expert_fn=expert_fn,
                            num_experts=hparams.moe_num_experts,
                            k=hparams.moe_k,
                            loss_coef=hparams.moe_loss_coef)
                        extra_loss += loss
                    else:
                        y = dp(
                            common_layers.conv_hidden_relu,
                            preprocess(x),
                            hparams.filter_size,
                            hparams.hidden_size,
                            dropout=hparams.relu_dropout)
                    x = postprocess(x, y)
        x = preprocess(x)
        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, extra_loss
    def model_fn_body_sharded(self, sharded_features):
        # Remove dropout if not training
        hparams = self._hparams
        dp = self._data_parallelism
        if hparams.use_inputs:
            decoder_input = dp(tf.squeeze, sharded_features["inputs"], 2)
            decoder_self_attention_bias = None
        else:
            targets = sharded_features["targets"]
            targets = dp(tf.squeeze, targets, 2)
            (decoder_input, decoder_self_attention_bias,
             pad_remover) = dp(attention_lm_moe_prepare_decoder, targets,
                               hparams)

        def preprocess(x):
            return dp(common_layers.layer_preprocess, x, hparams)

        def postprocess(x, y):
            return dp(common_layers.layer_postprocess, x, y, hparams)

        x = dp(tf.nn.dropout, decoder_input,
               1.0 - hparams.layer_prepostprocess_dropout)
        extra_loss = 0.0
        moe_hidden_sizes = [
            int(s) for s in hparams.moe_hidden_sizes.split(",")
        ]
        if hparams.diet_experts:
            hsize, = moe_hidden_sizes

            def _diet_expert(x):
                return diet.diet_expert(x, hsize,
                                        diet.diet_adam_optimizer_params())

            expert_fn = _diet_expert
        else:
            expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size,
                                                   moe_hidden_sizes,
                                                   hparams.hidden_size)

        if not hparams.use_inputs:
            # As preprocess and postprocess are called with batch of size one (all
            # batches concatenated), we just make sure that batch_norm is not use (
            # should not either way)
            assert hparams.norm_type != "batch"

            tf.logging.info(
                "Applying Padding Remover for the attention experts")

            dp_remove_pad = functools.partial(dp,
                                              remove_pad,
                                              pad_remover=pad_remover,
                                              mode=hparams.mode)
            dp_restore_pad = functools.partial(dp,
                                               restore_pad,
                                               ref_x=x,
                                               pad_remover=pad_remover,
                                               mode=hparams.mode)
        else:
            # Using identity function: No effect
            dp_remove_pad = lambda x: x
            dp_restore_pad = lambda x: x

        if hparams.attention_exp_factor != 0:
            tf.logging.info(
                "Expand/compress tokens before sending them to experts")
            dp_expand_bc = lambda x: dp(  # pylint: disable=g-long-lambda
                expand_batch_coordinates, x, hparams.attention_exp_factor)
            dp_expand_x = lambda x: dp(  # pylint: disable=g-long-lambda
                common_attention.deconv_elems_1d, x, hparams.
                attention_exp_factor, hparams.attention_exp_inputdim)
            dp_compress_x = lambda x, l: dp(  # pylint: disable=g-long-lambda
                common_attention.conv_elems_1d, x, hparams.
                attention_exp_factor, l)
        else:
            dp_expand_bc = lambda x: x
            dp_expand_x = lambda x: x
            dp_compress_x = lambda x, l: x

        def print_shape(x, suffix, debug=False):
            # To help debugging, print the input/output shapes at inference and eval
            # Inference for long sequences can take a long time, so that's help to
            # see the progession of the generation
            if not debug and hparams.mode == ModeKeys.TRAIN:
                return x
            return tf.Print(x, [tf.shape(x)], "shape_x_{}".format(suffix))

        with tf.name_scope("batch_coordinate_preprocess"):
            batch_coordinate = dp(get_batch_coordinate, x)
            batch_coordinate = dp_remove_pad(batch_coordinate)
            batch_coordinate = dp_expand_bc(batch_coordinate)
            batch_order = dp(get_batch_coordinate, x, axis=-1)
            batch_order = dp_remove_pad(batch_order)
            batch_order = dp_expand_bc(batch_order)

        x = dp(print_shape, x, "in")

        assert hparams.batch_size >= hparams.max_length

        num_hidden_layers = (len(hparams.attention_layers)
                             or hparams.num_hidden_layers)
        for layer in xrange(num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):

                # Use the layer type defined in attention_layers
                if hparams.attention_layers:
                    attention_type = LAYER_SYMBOLS[
                        hparams.attention_layers[layer]]
                else:
                    attention_type = hparams.attention_type

                with tf.variable_scope("attention_{}".format(attention_type)):
                    if attention_type in [
                            AttentionType.MULTIHEAD,
                            AttentionType.MULTIHEAD_FULL
                    ]:
                        attention_dot_type = ("local_mask_right"
                                              if hparams.attention_local else
                                              "dot_product")
                        if attention_type == AttentionType.MULTIHEAD_FULL:
                            attention_dot_type = "dot_product"
                        y = dp(common_attention.multihead_attention,
                               preprocess(x),
                               None,
                               decoder_self_attention_bias,
                               hparams.attention_key_channels
                               or hparams.hidden_size,
                               hparams.attention_value_channels
                               or hparams.hidden_size,
                               hparams.hidden_size,
                               hparams.num_heads,
                               hparams.attention_dropout,
                               attention_type=attention_dot_type,
                               block_length=hparams.attention_block_length,
                               name="decoder_self_attention")
                    elif attention_type == AttentionType.SPARSE_MULTIHEAD:
                        x_in = preprocess(x)
                        x_in = dp_remove_pad(x_in)
                        y, loss_experts = dp(
                            common_attention.
                            multihead_attention_sparse_dot_prod,
                            x_in,
                            None,
                            None,  # Bias is computed inside
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,

                            # Additional parameters
                            bi=[
                                common_attention.BatchInfo(
                                    coordinates=batch_coordinate[i],
                                    order=batch_order[i],  # No future mask
                                ) for i in range(dp.n)
                            ],
                            use_map_fn=hparams.lsh_use_map_fn,
                            experts_params=dict(
                                nb_hyperplanes=hparams.lsh_num_hyperplanes, ),
                        )
                        y = dp_restore_pad(y)

                        # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                        extra_loss += tf.add_n(loss_experts) / dp.n
                    elif attention_type == AttentionType.SPARSE_MULTIHEAD_TRUNCATED:
                        x_in = preprocess(x)
                        y, loss_experts = dp(
                            common_attention.
                            multihead_attention_sparse_truncated,
                            x_in,
                            None,
                            None,  # Bias is computed inside
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,

                            # Additional parameters
                            bi=[
                                common_attention.BatchInfo(
                                    coordinates=batch_coordinate[i],
                                    order=batch_order[i],  # No future mask
                                ) for i in range(dp.n)
                            ],
                            mask_right=True,
                            experts_params=dict(
                                nb_hyperplanes=hparams.lsh_num_hyperplanes, ),
                        )

                        # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                        extra_loss += tf.add_n(loss_experts) / dp.n
                    elif attention_type == AttentionType.MEMORY_EFFICIENT:
                        assert hparams.layer_preprocess_sequence == "n"
                        y = dp(common_attention.
                               multihead_self_attention_memory_efficient,
                               x,
                               decoder_self_attention_bias,
                               hparams.num_heads,
                               name="decoder_self_attention")
                    elif attention_type == AttentionType.MULTIHEAD_REDUCED:
                        y = dp(
                            common_attention.multihead_self_attention_reduced,
                            preprocess(x),
                            factor=hparams.attention_red_factor,
                            reduction_type=hparams.attention_reduction_type,
                            nonlinearity=hparams.attention_nonlinearity,
                            multihead_params=dict(
                                total_key_depth=hparams.attention_key_channels
                                or hparams.hidden_size,
                                total_value_depth=hparams.
                                attention_value_channels
                                or hparams.hidden_size,
                                num_heads=hparams.num_heads,
                                dropout_rate=hparams.attention_dropout,
                            ))
                    elif attention_type == AttentionType.LOCAL_EXPERTS:
                        x_in = preprocess(x)
                        x_in = dp_remove_pad(x_in)
                        x_in = dp_expand_x(x_in)
                        y, loss = dp(
                            common_attention.local_expert_attention,
                            x_in,
                            k=hparams.attention_moe_k,
                            loss_coef=hparams.attention_load_balance,
                            attention_num_experts=hparams.
                            attention_num_experts,
                            train=hparams.mode == ModeKeys.TRAIN,
                            batch_coordinate=batch_coordinate,
                            mask_right=not hparams.use_inputs,
                            split_batch=bool(hparams.attention_split_batch),
                            attention_num_head=hparams.attention_num_head,
                            attention_kq_size=hparams.attention_kq_size,
                            attention_v_size=hparams.attention_v_size)
                        y = dp_compress_x(y, x[0].get_shape().as_list()[-1])
                        y = dp_restore_pad(y)
                        # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                        extra_loss += tf.add_n(loss) / dp.n
                    else:
                        raise ValueError("Only {} supported for now.".format(
                            AttentionType.get_choices()))
                    x = postprocess(x, y)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers.split(","):
                        y, loss = expert_utils.distributed_moe(
                            dp,
                            self._ps_devices,
                            preprocess(x),
                            hparams.mode == ModeKeys.TRAIN,
                            input_size=hparams.hidden_size,
                            expert_fn=expert_fn,
                            num_experts=hparams.moe_num_experts,
                            k=hparams.moe_k,
                            loss_coef=hparams.moe_loss_coef)
                        extra_loss += loss
                    elif hparams.memory_efficient_ffn:
                        assert hparams.layer_preprocess_sequence == "n"
                        y = dp(common_layers.conv_hidden_relu_memory_efficient,
                               x, hparams.filter_size)
                    else:
                        additional_conv_params = dict()
                        if hparams.use_sepconv:
                            additional_conv_params = dict(
                                padding="LEFT",
                                # Parameters copied from the transformer model
                                kernel_size=(3, 1),
                                second_kernel_size=(31, 1),
                            )
                        y = dp(common_layers.conv_hidden_relu,
                               preprocess(x),
                               hparams.filter_size,
                               hparams.hidden_size,
                               dropout=hparams.relu_dropout,
                               **additional_conv_params)
                    x = postprocess(x, y)
        x = preprocess(x)

        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, extra_loss
    def model_fn_body_sharded(self, sharded_features):
        train = self._hparams.mode == tf.estimator.ModeKeys.TRAIN
        dp = self._data_parallelism
        hparams = self._hparams

        def flatten(inputs):
            return tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)

        inputs = dp(flatten, sharded_features["inputs"])
        inputs_pad = dp(slicenet.embedding_to_padding, inputs)
        inputs_mask = dp(lambda x: 1.0 - x, inputs_pad)
        inputs_encoded = dp(common_layers.add_timing_signal, inputs)
        expert_loss = 0.0
        for i in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("enc_layer_%d" % i):
                inputs_encoded, moe_loss = conv_experts(
                    inputs_encoded, hparams, dp, self._ps_devices, "SAME",
                    inputs_mask, i)
                expert_loss += tf.reduce_mean(moe_loss) * hparams.moe_loss_coef

        # If we're just predicing a class, there is no use for a decoder, return.
        if isinstance(hparams.problems[self._problem_idx].target_modality,
                      modalities.ClassLabelModality):
            return inputs_encoded, tf.reduce_mean(expert_loss)

        # Decoder.
        inputs3d = dp(tf.squeeze, inputs, 2)
        inputs_encoded3d = dp(tf.squeeze, inputs_encoded, 2)
        encoder_padding = dp(common_attention.embedding_to_padding, inputs3d)
        encoder_attention_bias = dp(
            common_attention.attention_bias_ignore_padding, encoder_padding)
        targets = dp(common_layers.flatten4d3d, sharded_features["targets"])
        target_space_emb = dp(slicenet.embed_target_space,
                              sharded_features["target_space_id"],
                              hparams.hidden_size)

        (decoder_input,
         decoder_self_attention_bias) = dp(prepare_decoder, targets,
                                           target_space_emb)

        moe_hidden_sizes = [
            int(s) for s in hparams.moe_hidden_sizes.split(",")
        ]
        expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size,
                                               moe_hidden_sizes,
                                               hparams.hidden_size)
        x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.dropout)
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("dec_layer_%d" % layer):
                with tf.variable_scope("attention"):
                    y = dp(common_attention.multihead_attention,
                           x,
                           None,
                           decoder_self_attention_bias,
                           hparams.hidden_size,
                           hparams.hidden_size,
                           hparams.hidden_size,
                           hparams.num_heads,
                           hparams.attention_dropout,
                           name="decoder_self_attention")
                    z = dp(common_attention.multihead_attention,
                           y,
                           inputs_encoded3d,
                           encoder_attention_bias,
                           hparams.hidden_size,
                           hparams.hidden_size,
                           hparams.hidden_size,
                           hparams.num_heads,
                           hparams.attention_dropout,
                           name="encdec_attention")
                    x = dp(residual_fn3, x, y, z, hparams)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers.split(","):
                        y, moe_loss = expert_utils.distributed_moe(
                            dp,
                            self._ps_devices,
                            x,
                            train,
                            input_size=hparams.hidden_size,
                            expert_fn=expert_fn,
                            num_experts=hparams.moe_num_experts,
                            k=hparams.moe_k,
                            loss_coef=hparams.moe_loss_coef)
                        expert_loss += tf.reduce_mean(moe_loss)
                    else:
                        y = dp(common_layers.conv_hidden_relu,
                               x,
                               hparams.filter_size,
                               hparams.hidden_size,
                               dropout=hparams.dropout)
                    x = dp(residual_fn2, x, y, hparams)

        x = dp(tf.expand_dims, x, 2)
        return x, tf.reduce_mean(expert_loss)
示例#14
0
def transformer_ffn_layer(x,
                          hparams,
                          pad_remover=None,
                          conv_padding="LEFT",
                          nonpadding_mask=None,
                          losses=None,
                          cache=None,
                          decode_loop_step=None,
                          readout_filter_size=0):
  """Feed-forward layer in the transformer.

  Args:
    x: a Tensor of shape [batch_size, length, hparams.hidden_size]
    hparams: hyperparameters for model
    pad_remover: an expert_utils.PadRemover object tracking the padding
      positions. If provided, when using convolutional settings, the padding
      is removed before applying the convolution, and restored afterward. This
      can give a significant speedup.
    conv_padding: a string - either "LEFT" or "SAME".
    nonpadding_mask: an optional Tensor with shape [batch_size, length].
      needed for convolutional layers with "SAME" padding.
      Contains 1.0 in positions corresponding to nonpadding.
    losses: optional list onto which to append extra training losses
    cache: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
    decode_loop_step: An integer, step number of the decoding loop.
        Only used for inference on TPU.
    readout_filter_size: if it's greater than 0, then it will be used instead of
      filter_size


  Returns:
    a Tensor of shape [batch_size, length, hparams.hidden_size]

  Raises:
    ValueError: If losses arg is None, but layer generates extra losses.
  """
  ffn_layer = hparams.ffn_layer
  relu_dropout_broadcast_dims = (
      common_layers.comma_separated_string_to_integer_list(
          getattr(hparams, "relu_dropout_broadcast_dims", "")))
  if ffn_layer == "conv_hidden_relu":
    # Backwards compatibility
    ffn_layer = "dense_relu_dense"
  if ffn_layer == "dense_relu_dense":
    # In simple convolution mode, use `pad_remover` to speed up processing.
    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_FFN_FILTER_DENSE,
        value={
            "filter_size": hparams.filter_size,
            "use_bias": "True",
            "activation": mlperf_log.RELU
        })
    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_FFN_OUTPUT_DENSE,
        value={
            "hidden_size": hparams.hidden_size,
            "use_bias": "True",
        })
    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_RELU_DROPOUT, value=hparams.relu_dropout)
    if pad_remover:
      original_shape = common_layers.shape_list(x)
      # Collapse `x` across examples, and remove padding positions.
      x = tf.reshape(x, tf.concat([[-1], original_shape[2:]], axis=0))
      x = tf.expand_dims(pad_remover.remove(x), axis=0)
    conv_output = common_layers.dense_relu_dense(
        x,
        hparams.filter_size,
        hparams.hidden_size,
        dropout=hparams.relu_dropout,
        dropout_broadcast_dims=relu_dropout_broadcast_dims)
    if pad_remover:
      # Restore `conv_output` to the original shape of `x`, including padding.
      conv_output = tf.reshape(
          pad_remover.restore(tf.squeeze(conv_output, axis=0)), original_shape)
    return conv_output
  elif ffn_layer == "conv_relu_conv":
    return common_layers.conv_relu_conv(
        x,
        readout_filter_size or hparams.filter_size,
        hparams.hidden_size,
        first_kernel_size=hparams.conv_first_kernel,
        second_kernel_size=1,
        padding=conv_padding,
        nonpadding_mask=nonpadding_mask,
        dropout=hparams.relu_dropout,
        cache=cache,
        decode_loop_step=decode_loop_step)
  elif ffn_layer == "parameter_attention":
    return common_attention.parameter_attention(
        x, hparams.parameter_attention_key_channels or hparams.hidden_size,
        hparams.parameter_attention_value_channels or hparams.hidden_size,
        hparams.hidden_size, readout_filter_size or hparams.filter_size,
        hparams.num_heads,
        hparams.attention_dropout)
  elif ffn_layer == "conv_hidden_relu_with_sepconv":
    return common_layers.conv_hidden_relu(
        x,
        readout_filter_size or hparams.filter_size,
        hparams.hidden_size,
        kernel_size=(3, 1),
        second_kernel_size=(31, 1),
        padding="LEFT",
        dropout=hparams.relu_dropout)
  elif ffn_layer == "sru":
    return common_layers.sru(x)
  elif ffn_layer == "local_moe_tpu":
    overhead = (
        hparams.moe_overhead_train
        if hparams.mode == tf.estimator.ModeKeys.TRAIN else
        hparams.moe_overhead_eval)
    ret, loss = expert_utils.local_moe_tpu(
        x,
        hparams.filter_size // 2,
        hparams.hidden_size,
        hparams.moe_num_experts,
        overhead=overhead,
        loss_coef=hparams.moe_loss_coef)
  elif ffn_layer == "local_moe":
    overhead = (
        hparams.moe_overhead_train
        if hparams.mode == tf.estimator.ModeKeys.TRAIN else
        hparams.moe_overhead_eval)
    ret, loss = expert_utils.local_moe(
        x,
        True,
        expert_utils.ffn_expert_fn(hparams.hidden_size, [hparams.filter_size],
                                   hparams.hidden_size),
        hparams.moe_num_experts,
        k=hparams.moe_k,
        hparams=hparams)
    losses.append(loss)
    return ret
  else:
    assert ffn_layer == "none"
    return x
示例#15
0
def transformer_ffn_layer(x,
                          hparams,
                          pad_remover=None,
                          conv_padding="LEFT",
                          nonpadding_mask=None,
                          losses=None,
                          cache=None,
                          decode_loop_step=None,
                          readout_filter_size=0):
  """Feed-forward layer in the transformer.

  Args:
    x: a Tensor of shape [batch_size, length, hparams.hidden_size]
    hparams: hyperparameters for model
    pad_remover: an expert_utils.PadRemover object tracking the padding
      positions. If provided, when using convolutional settings, the padding
      is removed before applying the convolution, and restored afterward. This
      can give a significant speedup.
    conv_padding: a string - either "LEFT" or "SAME".
    nonpadding_mask: an optional Tensor with shape [batch_size, length].
      needed for convolutional layers with "SAME" padding.
      Contains 1.0 in positions corresponding to nonpadding.
    losses: optional list onto which to append extra training losses
    cache: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
    decode_loop_step: An integer, step number of the decoding loop.
        Only used for inference on TPU.
    readout_filter_size: if it's greater than 0, then it will be used instead of
      filter_size


  Returns:
    a Tensor of shape [batch_size, length, hparams.hidden_size]

  Raises:
    ValueError: If losses arg is None, but layer generates extra losses.
  """
  ffn_layer = hparams.ffn_layer
  relu_dropout_broadcast_dims = (
      common_layers.comma_separated_string_to_integer_list(
          getattr(hparams, "relu_dropout_broadcast_dims", "")))
  if ffn_layer == "conv_hidden_relu":
    # Backwards compatibility
    ffn_layer = "dense_relu_dense"
  if ffn_layer == "dense_relu_dense":
    # In simple convolution mode, use `pad_remover` to speed up processing.
    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_FFN_FILTER_DENSE,
        value={
            "filter_size": hparams.filter_size,
            "use_bias": "True",
            "activation": mlperf_log.RELU
        })
    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_FFN_OUTPUT_DENSE,
        value={
            "hidden_size": hparams.hidden_size,
            "use_bias": "True",
        })
    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_RELU_DROPOUT, value=hparams.relu_dropout)
    if pad_remover:
      original_shape = common_layers.shape_list(x)
      # Collapse `x` across examples, and remove padding positions.
      x = tf.reshape(x, tf.concat([[-1], original_shape[2:]], axis=0))
      x = tf.expand_dims(pad_remover.remove(x), axis=0)
    conv_output = common_layers.dense_relu_dense(
        x,
        hparams.filter_size,
        hparams.hidden_size,
        dropout=hparams.relu_dropout,
        dropout_broadcast_dims=relu_dropout_broadcast_dims)
    if pad_remover:
      # Restore `conv_output` to the original shape of `x`, including padding.
      conv_output = tf.reshape(
          pad_remover.restore(tf.squeeze(conv_output, axis=0)), original_shape)
    return conv_output
  elif ffn_layer == "conv_relu_conv":
    return common_layers.conv_relu_conv(
        x,
        readout_filter_size or hparams.filter_size,
        hparams.hidden_size,
        first_kernel_size=hparams.conv_first_kernel,
        second_kernel_size=1,
        padding=conv_padding,
        nonpadding_mask=nonpadding_mask,
        dropout=hparams.relu_dropout,
        cache=cache,
        decode_loop_step=decode_loop_step)
  elif ffn_layer == "parameter_attention":
    return common_attention.parameter_attention(
        x, hparams.parameter_attention_key_channels or hparams.hidden_size,
        hparams.parameter_attention_value_channels or hparams.hidden_size,
        hparams.hidden_size, readout_filter_size or hparams.filter_size,
        hparams.num_heads,
        hparams.attention_dropout)
  elif ffn_layer == "conv_hidden_relu_with_sepconv":
    return common_layers.conv_hidden_relu(
        x,
        readout_filter_size or hparams.filter_size,
        hparams.hidden_size,
        kernel_size=(3, 1),
        second_kernel_size=(31, 1),
        padding="LEFT",
        dropout=hparams.relu_dropout)
  elif ffn_layer == "sru":
    return common_layers.sru(x)
  elif ffn_layer == "local_moe_tpu":
    overhead = (
        hparams.moe_overhead_train
        if hparams.mode == tf.estimator.ModeKeys.TRAIN else
        hparams.moe_overhead_eval)
    ret, loss = expert_utils.local_moe_tpu(
        x,
        hparams.filter_size // 2,
        hparams.hidden_size,
        hparams.moe_num_experts,
        overhead=overhead,
        loss_coef=hparams.moe_loss_coef)
  elif ffn_layer == "local_moe":
    overhead = (
        hparams.moe_overhead_train
        if hparams.mode == tf.estimator.ModeKeys.TRAIN else
        hparams.moe_overhead_eval)
    ret, loss = expert_utils.local_moe(
        x,
        True,
        expert_utils.ffn_expert_fn(hparams.hidden_size, [hparams.filter_size],
                                   hparams.hidden_size),
        hparams.moe_num_experts,
        k=hparams.moe_k,
        hparams=hparams)
    losses.append(loss)
    return ret
  else:
    assert ffn_layer == "none"
    return x
示例#16
0
    def model_fn_body_sharded(self, sharded_features):
        # Remove dropout if not training
        hparams = self._hparams
        dp = self._data_parallelism
        targets = sharded_features["targets"]
        targets = dp(tf.squeeze, targets, 2)

        def preprocess(x):
            return dp(common_layers.layer_preprocess, x, hparams)

        def postprocess(x, y):
            return dp(common_layers.layer_postprocess, x, y, hparams)

        (decoder_input,
         decoder_self_attention_bias) = dp(attention_lm_moe_prepare_decoder,
                                           targets, hparams)

        x = dp(tf.nn.dropout, decoder_input,
               1.0 - hparams.layer_prepostprocess_dropout)
        extra_loss = 0.0
        moe_hidden_sizes = [
            int(s) for s in hparams.moe_hidden_sizes.split(",")
        ]
        if hparams.diet_experts:
            hsize, = moe_hidden_sizes

            def _diet_expert(x):
                return diet.diet_expert(x, hsize,
                                        diet.diet_adam_optimizer_params())

            expert_fn = _diet_expert
        else:
            expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size,
                                                   moe_hidden_sizes,
                                                   hparams.hidden_size)
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                with tf.variable_scope("attention_{}".format(
                        hparams.attention_moe_type)):
                    x = preprocess(x)
                    if hparams.attention_moe_type == AttentionMoeType.NONE:
                        y = dp(common_attention.multihead_attention,
                               x,
                               None,
                               decoder_self_attention_bias,
                               hparams.attention_key_channels
                               or hparams.hidden_size,
                               hparams.attention_value_channels
                               or hparams.hidden_size,
                               hparams.hidden_size,
                               hparams.num_heads,
                               hparams.attention_dropout,
                               name="decoder_self_attention")
                    elif hparams.attention_moe_type == AttentionMoeType.LOCAL:
                        y, loss = dp(
                            common_attention.local_expert_attention,
                            x,
                            k=2,
                            loss_coef=1e-2,
                            attention_num_experts=hparams.
                            attention_num_experts,
                            train=hparams.mode ==
                            tf.contrib.learn.ModeKeys.TRAIN,
                            mask_right=True,
                            attention_kq_size=hparams.attention_kq_size,
                            attention_v_size=hparams.attention_v_size)
                        # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                        extra_loss += tf.add_n(loss) / dp.n
                    else:
                        raise ValueError("Only {} supported for now.".format(
                            AttentionMoeType.get_choices()))
                    x = postprocess(x, y)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers.split(","):
                        y, loss = expert_utils.distributed_moe(
                            dp,
                            self._ps_devices,
                            preprocess(x),
                            hparams.mode == tf.contrib.learn.ModeKeys.TRAIN,
                            input_size=hparams.hidden_size,
                            expert_fn=expert_fn,
                            num_experts=hparams.moe_num_experts,
                            k=hparams.moe_k,
                            loss_coef=hparams.moe_loss_coef)
                        extra_loss += loss
                    else:
                        y = dp(common_layers.conv_hidden_relu,
                               preprocess(x),
                               hparams.filter_size,
                               hparams.hidden_size,
                               dropout=hparams.relu_dropout)
                    x = postprocess(x, y)
        x = preprocess(x)
        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, extra_loss
示例#17
0
def _super_stack(inputs, attention_bias, hparams, mp, padding="LEFT"):
    """A stack of super_lm layers.

  Args:
    inputs: a list of Tensors
    attention_bias: list of bias Tensor for self-attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    mp: a Parallelism object
    padding: a string

  Returns:
    y: a Tensors
  """
    layers = hparams.layers.strip(",").split(",")
    ffn_hidden_sizes = [int(s) for s in hparams.ffn_hidden_sizes.split(",")]
    # scaled_dot_product_attention_with_projections uses a 3d attention bias
    # (no heads), where multihead_attention uses 4d attention bias.
    mix_size = int(hparams.mix_fraction * hparams.hidden_size)
    attention_bias_3d = mp(tf.squeeze, attention_bias, 1)
    accumulator = inputs
    x = inputs
    for layer_num, layer_type in enumerate(layers):
        with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
            tf.logging.info("%s_%d" % (layer_type, layer_num))
            if layer_type == "a":
                # accumulate
                accumulator = mp(tf.add, x, accumulator)
                x = accumulator
            elif layer_type == "n":
                # normalize
                x = mp(common_layers.apply_norm, x, hparams.norm_type,
                       hparams.hidden_size, hparams.norm_epsilon)
            elif layer_type == "d":
                # dropout
                x = mp(tf.nn.dropout, x,
                       1.0 - hparams.layer_prepostprocess_dropout)
            elif layer_type == "m":
                # mix across shards
                def _split(t):
                    return tuple(
                        tf.split(t, [mix_size, hparams.hidden_size - mix_size],
                                 2))

                to_mix, to_keep = mp(_split, x)
                mixed = common_layers.all_reduce_ring(to_mix, mp)
                mixed = mp(tf.multiply, mixed, mp.n**-0.5)
                x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep)
            elif layer_type == "att":
                # single-head attention
                q = mp(tf.layers.dense,
                       x,
                       hparams.hidden_size,
                       use_bias=False,
                       name="q_transform")
                x = mp(common_attention.scaled_dot_product_attention_simple, q,
                       x, x, attention_bias_3d)
                x = mp(tf.layers.dense,
                       x,
                       hparams.hidden_size,
                       use_bias=False,
                       name="o_transform")
            elif layer_type == "multihead-att":
                # multi-head attention
                x = mp(
                    common_attention.multihead_attention,
                    x,
                    None,
                    attention_bias,  # bias
                    hparams.attention_key_channels or hparams.hidden_size,
                    hparams.attention_value_channels or hparams.hidden_size,
                    hparams.hidden_size,
                    hparams.num_heads,
                    hparams.attention_dropout)
            elif layer_type == "ffn":
                y = mp(
                    expert_utils.ffn_expert_fn(hparams.hidden_size,
                                               ffn_hidden_sizes,
                                               hparams.hidden_size),
                    mp(expert_utils.flatten_all_but_last, x))
                x = mp(expert_utils.reshape_like, y, x)
            elif layer_type == "conv":
                # convolution
                x = mp(
                    common_layers.conv1d,
                    x,
                    hparams.hidden_size,
                    hparams.kernel_height,
                    activation=tf.nn.relu,
                    padding=padding,
                )
            else:
                assert False, "unknown sublayer %s" % layer_type
    return x
示例#18
0
def _super_stack(inputs,
                 attention_bias,
                 hparams,
                 mp,
                 padding="LEFT"):
  """A stack of super_lm layers.

  Args:
    inputs: a list of Tensors
    attention_bias: list of bias Tensor for self-attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    mp: a Parallelism object
    padding: a string

  Returns:
    y: a list of Tensors
    extra_loss: an optional scalar
  """
  layers = hparams.layers.strip(",").split(",")
  moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
  if hparams.diet_experts:
    hsize, = moe_hidden_sizes
    def _diet_expert(x):
      return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params())
    expert_fn = _diet_expert
  else:
    expert_fn = expert_utils.ffn_expert_fn(
        hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)
  # scaled_dot_product_attention_with_projections uses a 3d attention bias
  # (no heads), where multihead_attention uses 4d attention bias.
  attention_bias_3d = mp(tf.squeeze, attention_bias, 1)
  mix_size = int(hparams.mix_fraction * hparams.hidden_size)
  accumulator = inputs
  x = inputs
  extra_losses = []
  for layer_num, layer_type in enumerate(layers):
    with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
      tf.logging.info("%s_%d" % (layer_type, layer_num))
      if layer_type == "a":
        # accumulate
        accumulator = mp(tf.add, x, accumulator)
        x = accumulator
      elif layer_type == "n":
        # normalize
        x = mp(common_layers.apply_norm,
               x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon)
      elif layer_type == "d":
        # dropout
        x = mp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
      elif layer_type == "m":
        # mix across shards
        def _split(t):
          return tuple(tf.split(
              t, [mix_size, hparams.hidden_size - mix_size], 2))
        to_mix, to_keep = mp(_split, x)
        mixed = common_layers.all_reduce_ring(to_mix, mp)
        mixed = mp(tf.multiply, mixed, mp.n ** -0.5)
        x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep)
      elif layer_type == "att":
        # single-head attention
        q = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False,
               name="q_transform")
        x = mp(
            common_attention.scaled_dot_product_attention_simple,
            q, x, x, attention_bias_3d)
        x = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False,
               name="o_transform")
      elif layer_type == "multihead-att":
        # multi-head attention
        x = mp(
            common_attention.multihead_attention,
            x,
            None,
            attention_bias,  # bias
            hparams.multihead_attention_key_channels or hparams.hidden_size,
            hparams.multihead_attention_value_channels or hparams.hidden_size,
            hparams.hidden_size,
            hparams.multihead_attention_num_heads,
            hparams.attention_dropout)
      elif layer_type == "ffn":
        x = mp(
            common_layers.dense_relu_dense, x,
            hparams.filter_size, hparams.hidden_size)
      elif layer_type == "conv":
        # convolution
        x = mp(
            common_layers.conv1d,
            x,
            hparams.hidden_size,
            hparams.kernel_height,
            activation=tf.nn.relu,
            padding=padding,
        )
      elif layer_type == "moe":
        # mixture of experts - each model shard has its own local MoE.
        x, loss = mp(
            expert_utils.local_moe,
            x,
            train=hparams.mode == tf.estimator.ModeKeys.TRAIN,
            expert_fn=expert_fn,
            num_experts=hparams.moe_num_experts,
            k=hparams.moe_k,
            loss_coef=hparams.moe_loss_coef)
        extra_losses.extend(loss)
      else:
        assert False, "unknown sublayer %s" % layer_type
  if extra_losses:
    extra_loss = tf.add_n(extra_losses)
  else:
    extra_loss = None
  return x, extra_loss
示例#19
0
    def model_fn_body_sharded(self, sharded_features):

        # ========= Prepare the input and target =========

        hparams = self._hparams
        dp = self._data_parallelism
        targets = sharded_features["targets"]
        inputs = sharded_features["inputs"]
        target_space = sharded_features["target_space_id"]

        inputs = dp(common_layers.flatten4d3d, inputs)
        targets = dp(common_layers.flatten4d3d, targets)

        def dp_preprocess(x):
            return dp(common_layers.layer_preprocess, x, hparams)

        def dp_postprocess(x, y):
            return dp(common_layers.layer_postprocess, x, y, hparams)

        (encoder_input, encoder_self_attention_bias,
         encoder_decoder_attention_bias) = dp(
             transformer.transformer_prepare_encoder, inputs, target_space,
             hparams)
        (decoder_input, decoder_self_attention_bias) = dp(
            transformer.transformer_prepare_decoder, targets, hparams)
        encoder_input = dp(tf.nn.dropout, encoder_input,
                           1.0 - hparams.layer_prepostprocess_dropout)
        decoder_input = dp(tf.nn.dropout, decoder_input,
                           1.0 - hparams.layer_prepostprocess_dropout)
        cache = dict(extra_loss=0)
        moe_hidden_sizes = [
            int(s) for s in hparams.moe_hidden_sizes.split(",")
        ]
        expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size,
                                               moe_hidden_sizes,
                                               hparams.hidden_size)

        # ========= Define some utils decorators =========

        def prepostprocess(fct):
            """Add pre and post processing."""
            # WARNING: Should be applied after dp (pre/post-process use dp and
            # can be applied to function which doesn't use dp)
            @functools.wraps(fct)
            def decorated(x, *args, **kwargs):
                x = dp_preprocess(x)
                y = fct(x, *args, **kwargs)
                return dp_postprocess(x, y)

            return decorated

        def dp_wrapper(fct):
            """Encapsulate the function in a data parallelism object."""
            @functools.wraps(fct)
            def decorated(*args, **kwargs):
                return dp(fct, *args, **kwargs)

            return decorated

        def add_kwargs(
                fct,
                enco_kwargs=None,
                deco_kwargs=None,
                endeco_kwargs=None,  # Enco-deco attention: overwrite deco_kwargs
        ):
            """Allow to have different arguments for the encoder and decoder."""
            # WARNING: If this decorator is applied before dp_wrapper, the kwargs
            # may not be correctly dipatched across the devices.
            @functools.wraps(fct)
            def decorated(*args, **kwargs):
                current_scope = tf.contrib.framework.get_name_scope()
                if "/encoder/" in current_scope:
                    kwargs.update(enco_kwargs or {})
                elif "/decoder/" in current_scope:
                    kwargs.update(deco_kwargs or {})
                    if "/att_ende_" in current_scope:
                        kwargs.update(endeco_kwargs or {})
                return fct(*args, **kwargs)

            return decorated

        def capture_extra_loss(fct, loss_coef=1.0):
            """Capture the additional loss."""
            @functools.wraps(fct)
            def decorated(*args, **kwargs):
                y, loss = fct(*args, **kwargs)
                cache["extra_loss"] += loss * loss_coef
                return y

            return decorated

        def remove_kwargs(fct, extra_params):
            """Remove some unused parameters."""
            @functools.wraps(fct)
            def decorated(*args, **kwargs):
                for k in extra_params:  # Remove the extra params
                    kwargs.pop(k, None)
                return fct(*args, **kwargs)

            return decorated

        # def pad_remover(fct):
        #   """Remove/restore the padding on the input."""
        #   @functools.wraps(fct)
        #   def decorated(x, *args, **kwargs):
        #     x = pad_remover.remove(x)
        #     x = fct(x, *args, **kwargs)
        #     x = pad_remover.restore(x)
        #     return x
        #   return decorated

        # ========= Define the available layers =========
        total_key_depth = hparams.attention_key_channels or hparams.hidden_size
        total_value_depth = hparams.attention_value_channels or hparams.hidden_size

        # Multi-head full attention layer
        multihead_attention = partial(
            common_attention.multihead_attention,
            total_key_depth=total_key_depth,
            total_value_depth=total_value_depth,
            output_depth=hparams.hidden_size,
            num_heads=hparams.num_heads,
            dropout_rate=hparams.attention_dropout,
        )
        multihead_attention = dp_wrapper(multihead_attention)
        multihead_attention = add_kwargs(  # After dp to correctly dispatch kwargs
            multihead_attention,
            enco_kwargs={"bias": encoder_self_attention_bias},
            deco_kwargs={"bias": decoder_self_attention_bias},
            endeco_kwargs={"bias": encoder_decoder_attention_bias},
        )
        multihead_attention = prepostprocess(multihead_attention)

        # Local attention layer
        # Reuse same parameters as multihead_attention (dp and pre/post-processing
        # already applied)
        # Only works for self attention. Always mask the future.
        local_attention = partial(
            multihead_attention,
            block_length=hparams.attention_loc_block_length,
            attention_type="local_mask_right",
        )

        # Memory-compressed multihead self attention layer
        # Only works for self attention. Always mask the future.
        compressed_attention = partial(
            common_attention.multihead_self_attention_reduced,
            factor=hparams.attention_red_factor,
            nonlinearity=hparams.attention_red_nonlinearity,
            reduction_type=hparams.attention_red_type,
            multihead_params=dict(
                total_key_depth=total_key_depth,
                total_value_depth=total_value_depth,
                num_heads=hparams.num_heads,
                dropout_rate=hparams.attention_dropout,
            ))
        compressed_attention = remove_kwargs(compressed_attention,
                                             ["memory_antecedent"])
        compressed_attention = dp_wrapper(compressed_attention)
        compressed_attention = prepostprocess(compressed_attention)

        # Mixture of expert layer
        distributed_moe = partial(
            expert_utils.distributed_moe,
            dp,
            self._ps_devices,
            train=hparams.mode == tf.estimator.ModeKeys.TRAIN,
            input_size=hparams.hidden_size,
            expert_fn=expert_fn,
            num_experts=hparams.moe_num_experts,
            k=hparams.moe_k,
            loss_coef=hparams.moe_loss_coef)
        distributed_moe = capture_extra_loss(distributed_moe)
        distributed_moe = prepostprocess(distributed_moe)

        # FC layer
        conv_hidden_relu = partial(
            common_layers.conv_hidden_relu,
            hidden_size=hparams.filter_size,
            output_size=hparams.hidden_size,
            dropout=hparams.relu_dropout,
        )
        conv_hidden_relu = dp_wrapper(conv_hidden_relu)
        conv_hidden_relu = prepostprocess(conv_hidden_relu)

        # Separable convolution layer
        # Reuse conv_hidden_relu (dp and pre/post-processing already applied)
        # Mask the future for the decoder only
        sep_conv_relu = partial(
            conv_hidden_relu,
            # Parameters copied from the transformer model, could add hparams
            kernel_size=(3, 1),
            second_kernel_size=(31, 1),
        )
        sep_conv_relu = add_kwargs(
            sep_conv_relu,
            enco_kwargs={"padding": "SAME"},
            deco_kwargs={"padding": "LEFT"},  # Mask future for decoder
        )

        # This dictionary contains the list of all available layers
        available_layers = dict(
            # Attention layers
            a=multihead_attention,  # Standard multihead full attention
            loc=local_attention,  # Local attention
            red=compressed_attention,  # Memory-compressed attention
            mem=None,  # Memory efficient
            # Feed-forward layers
            moe=distributed_moe,  # Mixture of expert layer
            sep=sep_conv_relu,  # Separable convolution
            fc=conv_hidden_relu,  # Fully connected
        )

        def extract_layer_types(layer_types):
            """Parse the layer string.

      Args:
        layer_types (str): String containing the network architecture. See
          top file comment for examples of format.

      Returns:
        list[tuple[str, str]]: Encoder layers: list of (attention, feed-forward)
        list[tuple[str, str, str]]: Decoder layers: list of (self-attention,
          enc-dec attention, feed-forward)
      """
            # If the architecture has not explicitly been set, we just construct a
            # standard transformer with the fallback values
            if not layer_types:
                layer_types = SEP_LAYER.join([hparams.default_att] *
                                             hparams.num_hidden_layers)

            # If encoder not explicitly defined, the encoder will have the same
            # structure as the decoder
            layer_types = layer_types.split(SEP_ENCODEC)
            if len(layer_types) == 1:
                layer_types *= 2

            # Some models don't need the encoder (ex: language modeling)
            # TODO(epot): What are the other conditions (has_input ?)
            if hparams.prepend_mode != "none":
                layer_types[0] = ""

            # Extend the blocks and fill them with the default values if not specified
            final_layers = ([], [])
            for i, blocks_str in enumerate(layer_types):
                for blocks_str in blocks_str.split(SEP_LAYER):
                    if not blocks_str:
                        continue
                    blocks_list = blocks_str.split(SEP_FF)
                    # Eventually use the fallback values for the layer_types. If the
                    # encoder is empty, do not use the enco-deco attention.
                    self_att = blocks_list[0] or hparams.default_att
                    ende_att = hparams.default_att if layer_types[0] else "_"
                    ff = hparams.default_ff
                    if len(blocks_list) > 1:
                        ff = blocks_list[-1]
                    if len(blocks_list) == 3:
                        ende_att = blocks_list[1]
                    if i == 0:  # Encoder
                        blocks_tuple = (self_att, ff)
                    elif i == 1:  # Decoder
                        blocks_tuple = (self_att, ende_att, ff)
                    final_layers[i].append(blocks_tuple)

            return final_layers

        # ========= Construct the transformer encoder and decoder =========

        encoder_layers, decoder_layers = extract_layer_types(
            hparams.layer_types)

        # Display the encoder-decoder architecture
        def print_layer(name, layers):
            tf.logging.info("{} architecture:".format(name))
            for i, l in enumerate(layers):
                tf.logging.info(" * Layer {}: {}".format(i, " - ".join(l)))

        print_layer("Encoder", encoder_layers)
        print_layer("Decoder", decoder_layers)

        encoder_outputs = []

        x = encoder_input
        with tf.variable_scope("encoder"):
            for layer_num, block_types in enumerate(encoder_layers):
                # Each encoder layers is composed of two blocks:
                # * self-attention block
                # * feed-forward block
                att_type, ff_type = block_types
                with tf.variable_scope("layer_{}".format(layer_num)):
                    with tf.variable_scope("att_{}".format(att_type)):
                        x = available_layers[att_type](
                            x,
                            memory_antecedent=None,
                        )
                    with tf.variable_scope("ff_{}".format(ff_type)):
                        x = available_layers[ff_type](x)
                encoder_outputs.append(x)
            if encoder_outputs:
                encoder_outputs[-1] = dp_preprocess(x)

        x = decoder_input
        with tf.variable_scope("decoder"):
            for layer_num, block_types in enumerate(decoder_layers):
                # Each decoder layers is composed of three blocks:
                # * self-attention block
                # * enco-deco attention block (optional)
                # * feed-forward block
                self_att_type, att_ende_type, ff_type = block_types
                with tf.variable_scope("layer_{}".format(layer_num)):
                    with tf.variable_scope(
                            "self_att_{}".format(self_att_type)):
                        x = available_layers[self_att_type](
                            x,
                            memory_antecedent=None,
                        )
                    with tf.variable_scope(
                            "att_ende_{}".format(att_ende_type)):
                        # Only add the enco-deco attention layer if there is an encoder
                        if encoder_outputs:
                            x = available_layers[att_ende_type](
                                x,
                                memory_antecedent=encoder_outputs[-1],
                            )
                    with tf.variable_scope("ff_{}".format(ff_type)):
                        x = available_layers[ff_type](x)
            # If normalization is done in layer_preprocess, then it should also be
            # done on the output, since the output can grow very large, being the sum
            # of a whole stack of unnormalized layer outputs.
            x = dp_preprocess(x)
        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, cache["extra_loss"]
示例#20
0
    def model_fn_body_sharded(self, sharded_features):
        # Remove dropout if not training
        hparams = self._hparams
        dp = self._data_parallelism
        if hparams.use_inputs:
            decoder_input = dp(tf.squeeze, sharded_features["inputs"], 2)
            decoder_self_attention_bias = None
        else:
            targets = sharded_features["targets"]
            targets = dp(tf.squeeze, targets, 2)
            (decoder_input, decoder_self_attention_bias,
             pad_remover) = dp(attention_lm_moe_prepare_decoder, targets,
                               hparams)

        def preprocess(x):
            return dp(common_layers.layer_preprocess, x, hparams)

        def postprocess(x, y):
            return dp(common_layers.layer_postprocess, x, y, hparams)

        x = dp(tf.nn.dropout, decoder_input,
               1.0 - hparams.layer_prepostprocess_dropout)
        extra_loss = 0.0
        moe_hidden_sizes = [
            int(s) for s in hparams.moe_hidden_sizes.split(",")
        ]
        if hparams.diet_experts:
            hsize, = moe_hidden_sizes

            def _diet_expert(x):
                return diet.diet_expert(x, hsize,
                                        diet.diet_adam_optimizer_params())

            expert_fn = _diet_expert
        else:
            expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size,
                                                   moe_hidden_sizes,
                                                   hparams.hidden_size)

        if (hparams.attention_type == AttentionType.LOCAL_EXPERTS
                and not hparams.use_inputs):
            # As preprocess and postprocess are called with batch of size one (all
            # batches concatenated), we just make sure that batch_norm is not use (
            # should not either way)
            assert hparams.norm_type != "batch"

            tf.logging.info(
                "Applying Padding Remover for the attention experts")

            dp_remove_pad = functools.partial(dp,
                                              remove_pad,
                                              pad_remover=pad_remover,
                                              mode=hparams.mode)
            dp_restore_pad = functools.partial(dp,
                                               restore_pad,
                                               ref_x=x,
                                               pad_remover=pad_remover,
                                               mode=hparams.mode)
        else:
            # Using identity function: No effect
            dp_remove_pad = lambda x: x
            dp_restore_pad = lambda x: x

        def print_shape(x, suffix, debug=False):
            # To help debugging, print the input/output shapes at inference and eval
            # Inference for long sequences can take a long time, so that's help to
            # see the progession of the generation
            if not debug and hparams.mode == ModeKeys.TRAIN:
                return x
            return tf.Print(x, [tf.shape(x)], "shape_x_{}".format(suffix))

        batch_coordinate = dp(get_batch_coordinate, x)
        batch_coordinate = dp_remove_pad(batch_coordinate)

        x = dp(print_shape, x, "in")
        x = dp_remove_pad(x)
        x = dp(print_shape, x, "in_flat")

        assert hparams.batch_size >= hparams.max_length

        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                with tf.variable_scope("attention_{}".format(
                        hparams.attention_type)):
                    if hparams.attention_type == AttentionType.MULTIHEAD:
                        y = dp(common_attention.multihead_attention,
                               preprocess(x),
                               None,
                               decoder_self_attention_bias,
                               hparams.attention_key_channels
                               or hparams.hidden_size,
                               hparams.attention_value_channels
                               or hparams.hidden_size,
                               hparams.hidden_size,
                               hparams.num_heads,
                               hparams.attention_dropout,
                               attention_type=("local_mask_right"
                                               if hparams.attention_local else
                                               "dot_product"),
                               name="decoder_self_attention")
                    elif hparams.attention_type == AttentionType.MEMORY_EFFICIENT:
                        assert hparams.layer_preprocess_sequence == "n"
                        y = dp(common_attention.
                               multihead_self_attention_memory_efficient,
                               x,
                               decoder_self_attention_bias,
                               hparams.num_heads,
                               name="decoder_self_attention")
                    elif hparams.attention_type == AttentionType.LOCAL_EXPERTS:
                        y, loss = dp(
                            common_attention.local_expert_attention,
                            preprocess(x),
                            k=hparams.attention_moe_k,
                            loss_coef=hparams.attention_load_balance,
                            attention_num_experts=hparams.
                            attention_num_experts,
                            train=hparams.mode == ModeKeys.TRAIN,
                            batch_coordinate=batch_coordinate,
                            mask_right=not hparams.use_inputs,
                            split_batch=bool(hparams.attention_split_batch),
                            attention_kq_size=hparams.attention_kq_size,
                            attention_v_size=hparams.attention_v_size)
                        # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                        extra_loss += tf.add_n(loss) / dp.n
                    else:
                        raise ValueError("Only {} supported for now.".format(
                            AttentionType.get_choices()))
                    x = postprocess(x, y)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers.split(","):
                        y, loss = expert_utils.distributed_moe(
                            dp,
                            self._ps_devices,
                            preprocess(x),
                            hparams.mode == ModeKeys.TRAIN,
                            input_size=hparams.hidden_size,
                            expert_fn=expert_fn,
                            num_experts=hparams.moe_num_experts,
                            k=hparams.moe_k,
                            loss_coef=hparams.moe_loss_coef)
                        extra_loss += loss
                    elif hparams.memory_efficient_ffn:
                        assert hparams.layer_preprocess_sequence == "n"
                        y = dp(common_layers.conv_hidden_relu_memory_efficient,
                               x, hparams.filter_size)
                    else:
                        x_in = preprocess(x)
                        additional_conv_params = dict()
                        if hparams.use_sepconv:
                            # Restore padding so sequences don't attend to each others
                            # restore_pad will apply a reshape like x_ref, to restore the
                            # original shape. Here this works because the last dimension is
                            # constant between the output of attention and the original input
                            # but it shouldn't necessarily be the case.
                            x_in = dp_restore_pad(x_in)
                            additional_conv_params = dict(
                                padding="LEFT",
                                # Parameters copied from the transformer model
                                kernel_size=(3, 1),
                                second_kernel_size=(31, 1),
                            )
                        y = dp(common_layers.conv_hidden_relu,
                               x_in,
                               hparams.filter_size,
                               hparams.hidden_size,
                               dropout=hparams.relu_dropout,
                               **additional_conv_params)
                        if hparams.use_sepconv:
                            y = dp_remove_pad(y)
                    x = postprocess(x, y)
        x = preprocess(x)

        x = dp_restore_pad(x)

        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, extra_loss