def transformer_layers_sharded(dp,
                               ps_devices,
                               inputs,
                               num_layers,
                               hparams,
                               self_attention_bias=None,
                               enc_output=None,
                               attention_type=AttentionType.GLOBAL,
                               name="transformer"):
  """Multi layer transformer, sharded by the data parallelism dp."""
  x = inputs
  extra_loss = tf.constant(0.0)
  moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
  expert_fn = expert_utils.ffn_expert_fn(
      hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)
  x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
  for layer in range(num_layers):
    with tf.variable_scope("%s_layer_%d" % (name, layer)):
      # self-attention
      if attention_type == AttentionType.LOCAL_2D:
        y = dp(local_attention_2d(common_layers.layer_preprocess(x, hparams),
                                  hparams,
                                  attention_type="masked_local_attention_2d"))
      elif attention_type == AttentionType.LOCAL_1D:
        y = dp(local_attention_1d(common_layers.layer_preprocess(x, hparams),
                                  hparams,
                                  attention_type="local_mask_right",
                                  q_padding="LEFT", kv_padding="LEFT"))
      elif attention_type == AttentionType.GLOCAL:
        y = dp(local_global_attention(
            common_layers.layer_preprocess(x, hparams), self_attention_bias,
            hparams, q_padding="LEFT", kv_padding="LEFT"))
      elif attention_type == AttentionType.GLOBAL:
        self_attention_bias = dp(get_self_attention_bias(x))
        y = dp(full_self_attention(common_layers.layer_preprocess(x, hparams),
                                   self_attention_bias, hparams,
                                   q_padding="LEFT", kv_padding="LEFT"))
      x = common_layers.layer_postprocess(x, y, hparams)
      if enc_output is not None:
        y = dp(encdec_attention_1d(common_layers.layer_preprocess(x, hparams),
                                   enc_output, None, hparams))
        x = dp(common_layers.layer_postprocess, x, y, hparams)
      with tf.variable_scope("ffn"):
        if str(layer) in hparams.moe_layers_decoder.split(","):
          y, loss = expert_utils.distributed_moe(
              dp,
              ps_devices,
              common_layers.layer_preprocess(x, hparams),
              hparams.mode == tf.estimator.ModeKeys.TRAIN,
              input_size=hparams.hidden_size,
              expert_fn=expert_fn,
              num_experts=hparams.moe_num_experts,
              k=hparams.moe_k,
              loss_coef=hparams.moe_loss_coef)
          extra_loss += loss
          x = dp(common_layers.layer_postprocess, x, y, hparams)
        else:
          y = dp(ffn_layer, common_layers.layer_preprocess(x, hparams), hparams)
          x = dp(common_layers.layer_postprocess, x, y, hparams)
  return dp(common_layers.layer_preprocess, x, hparams), extra_loss
Esempio n. 2
0
def transformer_layers_sharded(dp,
                               ps_devices,
                               inputs,
                               num_layers,
                               hparams,
                               self_attention_bias=None,
                               enc_output=None,
                               attention_type=AttentionType.GLOBAL,
                               name="transformer"):
  """Multi layer transformer, sharded by the data parallelism dp."""
  x = inputs
  extra_loss = tf.constant(0.0)
  moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
  expert_fn = expert_utils.ffn_expert_fn(
      hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)
  x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
  for layer in xrange(num_layers):
    with tf.variable_scope("%s_layer_%d" % (name, layer)):
      # self-attention
      if attention_type == AttentionType.LOCAL_2D:
        y = dp(local_attention_2d(common_layers.layer_preprocess(x, hparams),
                                  hparams,
                                  attention_type="masked_local_attention_2d"))
      elif attention_type == AttentionType.LOCAL_1D:
        y = dp(local_attention_1d(common_layers.layer_preprocess(x, hparams),
                                  hparams,
                                  attention_type="local_mask_right",
                                  q_padding="LEFT", kv_padding="LEFT"))
      elif attention_type == AttentionType.GLOCAL:
        y = dp(local_global_attention(
            common_layers.layer_preprocess(x, hparams), self_attention_bias,
            hparams, q_padding="LEFT", kv_padding="LEFT"))
      elif attention_type == AttentionType.GLOBAL:
        self_attention_bias = dp(get_self_attention_bias(x))
        y = dp(full_self_attention(common_layers.layer_preprocess(x, hparams),
                                   self_attention_bias, hparams,
                                   q_padding="LEFT", kv_padding="LEFT"))
      x = common_layers.layer_postprocess(x, y, hparams)
      if enc_output is not None:
        y = dp(encdec_attention_1d(common_layers.layer_preprocess(x, hparams),
                                   enc_output, None, hparams))
        x = dp(common_layers.layer_postprocess, x, y, hparams)
      with tf.variable_scope("ffn"):
        if str(layer) in hparams.moe_layers_decoder.split(","):
          y, loss = expert_utils.distributed_moe(
              dp,
              ps_devices,
              common_layers.layer_preprocess(x, hparams),
              hparams.mode == tf.estimator.ModeKeys.TRAIN,
              input_size=hparams.hidden_size,
              expert_fn=expert_fn,
              num_experts=hparams.moe_num_experts,
              k=hparams.moe_k,
              loss_coef=hparams.moe_loss_coef)
          extra_loss += loss
          x = dp(common_layers.layer_postprocess, x, y, hparams)
        else:
          y = dp(ffn_layer, common_layers.layer_preprocess(x, hparams), hparams)
          x = dp(common_layers.layer_postprocess, x, y, hparams)
  return dp(common_layers.layer_preprocess, x, hparams), extra_loss
Esempio n. 3
0
def conv_experts(xs, hparams, dp, ps, padding, mask, layer_id):
  """Convolutions + Mixture-of-Experts layer."""
  del layer_id  # Unused.
  train = hparams.mode == tf.estimator.ModeKeys.TRAIN,
  conv_out = dp(conv_res_step, xs, hparams, padding, mask)
  loss = 0.0
  moe_hidden_sizes = [hparams.filter_size]
  expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes,
                                         hparams.hidden_size)
  moe_out, loss = expert_utils.distributed_moe(
      dp,
      ps,
      xs,
      train,
      input_size=hparams.hidden_size,
      expert_fn=expert_fn,
      num_experts=hparams.moe_num_experts,
      k=hparams.moe_k,
      loss_coef=1.0)
  return dp(residual_fn3, xs, moe_out, conv_out, hparams), loss
Esempio n. 4
0
def conv_experts(xs, hparams, dp, ps, padding, mask, layer_id):
  """Convolutions + Mixture-of-Experts layer."""
  del layer_id  # Unused.
  train = hparams.mode == tf.estimator.ModeKeys.TRAIN,
  conv_out = dp(conv_res_step, xs, hparams, padding, mask)
  loss = 0.0
  moe_hidden_sizes = [hparams.filter_size]
  expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes,
                                         hparams.hidden_size)
  moe_out, loss = expert_utils.distributed_moe(
      dp,
      ps,
      xs,
      train,
      input_size=hparams.hidden_size,
      expert_fn=expert_fn,
      num_experts=hparams.moe_num_experts,
      k=hparams.moe_k,
      loss_coef=1.0)
  return dp(residual_fn3, xs, moe_out, conv_out, hparams), loss
Esempio n. 5
0
  def body_sharded(self, sharded_features):
    train = self._hparams.mode == tf.estimator.ModeKeys.TRAIN
    dp = self._data_parallelism
    hparams = self._hparams

    def project_to_hidden(inputs):
      return common_layers.conv_block(
          inputs,
          hparams.hidden_size, [((1, 1), (3, 3))],
          first_relu=False,
          padding="SAME",
          force2d=True)

    def flatten(inputs):
      return tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)

    # Project to hidden size if necessary
    if (sharded_features["inputs"][0].get_shape().as_list()[-1] !=
        hparams.hidden_size):
      inputs = dp(project_to_hidden, sharded_features["inputs"])

    inputs = dp(flatten, inputs)
    inputs_pad = dp(slicenet.embedding_to_padding, inputs)
    inputs_mask = dp(lambda x: 1.0 - x, inputs_pad)
    inputs_encoded = dp(common_layers.add_timing_signal, inputs)
    expert_loss = 0.0
    for i in xrange(hparams.num_hidden_layers):
      with tf.variable_scope("enc_layer_%d" % i):
        inputs_encoded, moe_loss = conv_experts(inputs_encoded, hparams, dp,
                                                self._ps_devices, "SAME",
                                                inputs_mask, i)
        expert_loss += tf.reduce_mean(moe_loss) * hparams.moe_loss_coef

    # If we're just predicing a class, there is no use for a decoder, return.
    if isinstance(hparams.problems[self._problem_idx].target_modality,
                  modalities.ClassLabelModality):
      return inputs_encoded, tf.reduce_mean(expert_loss)

    # Decoder.
    inputs3d = dp(tf.squeeze, inputs, 2)
    inputs_encoded3d = dp(tf.squeeze, inputs_encoded, 2)
    encoder_padding = dp(common_attention.embedding_to_padding, inputs3d)
    encoder_attention_bias = dp(common_attention.attention_bias_ignore_padding,
                                encoder_padding)
    targets = dp(common_layers.flatten4d3d, sharded_features["targets"])
    target_space_emb = dp(slicenet.embed_target_space,
                          sharded_features["target_space_id"],
                          hparams.hidden_size)

    (decoder_input, decoder_self_attention_bias) = dp(prepare_decoder, targets,
                                                      target_space_emb)

    moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
    expert_fn = expert_utils.ffn_expert_fn(
        hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)
    x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.dropout)
    for layer in xrange(hparams.num_hidden_layers):
      with tf.variable_scope("dec_layer_%d" % layer):
        with tf.variable_scope("attention"):
          y = dp(
              common_attention.multihead_attention,
              x,
              None,
              decoder_self_attention_bias,
              hparams.hidden_size,
              hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              name="decoder_self_attention")
          z = dp(
              common_attention.multihead_attention,
              y,
              inputs_encoded3d,
              encoder_attention_bias,
              hparams.hidden_size,
              hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              name="encdec_attention")
          x = dp(residual_fn3, x, y, z, hparams)
        with tf.variable_scope("ffn"):
          if str(layer) in hparams.moe_layers.split(","):
            y, moe_loss = expert_utils.distributed_moe(
                dp,
                self._ps_devices,
                x,
                train,
                input_size=hparams.hidden_size,
                expert_fn=expert_fn,
                num_experts=hparams.moe_num_experts,
                k=hparams.moe_k,
                loss_coef=hparams.moe_loss_coef)
            expert_loss += tf.reduce_mean(moe_loss)
          else:
            y = dp(
                common_layers.conv_hidden_relu,
                x,
                hparams.filter_size,
                hparams.hidden_size,
                dropout=hparams.dropout)
          x = dp(residual_fn2, x, y, hparams)

    x = dp(tf.expand_dims, x, 2)
    return x, tf.reduce_mean(expert_loss)
Esempio n. 6
0
    def model_fn_body_sharded(self, sharded_features):
        # Remove dropout if not training
        hparams = self._hparams
        dp = self._data_parallelism
        x = dp(tf.squeeze, sharded_features["inputs"], 2)

        def preprocess(x):
            return dp(common_layers.layer_preprocess, x, hparams)

        def postprocess(x, y):
            return dp(common_layers.layer_postprocess, x, y, hparams)

        x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
        extra_loss = 0.0
        ffn_hidden_sizes = [
            int(s) for s in hparams.ffn_hidden_sizes.split(",")
        ]
        moe_hidden_sizes = [
            int(s) for s in hparams.moe_hidden_sizes.split(",")
        ]
        if hparams.mask_right:

            def _bias(x):
                return common_attention.attention_bias_lower_triangle(
                    tf.shape(x)[1])

            bias = dp(_bias, x)
        else:
            bias = tf.zeros([1, 1, 1, 1])
        if hparams.diet_experts:
            hsize, = moe_hidden_sizes

            def _diet_expert(x):
                return diet.diet_expert(x, hsize,
                                        diet.diet_adam_optimizer_params())

            expert_fn = _diet_expert
        else:
            expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size,
                                                   moe_hidden_sizes,
                                                   hparams.hidden_size)

        batch_coordinate = dp(get_batch_coordinate, x)

        layers = hparams.layers.strip(",").split(",")
        for layer_num, layer_type in enumerate(layers):
            with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
                if _should_preprocess(layer_type):
                    x = preprocess(x)
                if layer_type == "timing":
                    y = dp(common_attention.add_timing_signal_nd, x)
                elif layer_type == "pos_emb":
                    y = dp(common_attention.add_positional_embedding_nd,
                           x,
                           hparams.max_length,
                           name="pos_emb")
                elif layer_type == "att":
                    y = dp(
                        common_attention.multihead_attention,
                        x,
                        None,
                        bias,  # bias
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout)
                elif layer_type == "att_grouped":
                    multiplicative_overhead = (
                        hparams.multiplicative_overhead
                        if hparams.mode == ModeKeys.TRAIN else
                        hparams.multiplicative_overhead_eval)
                    y, loss = dp(
                        common_attention.grouped_attention_multihead,
                        x,
                        x,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        num_groups=hparams.attention_num_groups,
                        memory_target_density=hparams.memory_target_density,
                        multiplicative_overhead=multiplicative_overhead,
                        make_image_summary=hparams.attention_image_summary,
                        mask_right=hparams.mask_right,
                    )
                    extra_loss += tf.add_n(loss) / dp.n
                elif layer_type == "att_memory_efficient":
                    assert hparams.layer_preprocess_sequence == "n"
                    y = dp(
                        common_attention.
                        multihead_self_attention_memory_efficient, x, bias,
                        hparams.num_heads)
                elif layer_type == "att_local":
                    y = dp(
                        common_attention.multihead_attention,
                        x,
                        None,
                        None,  # bias
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=("local_mask_right"
                                        if hparams.mask_right else
                                        "local_unmasked"),
                        block_length=hparams.local_attention_window,
                        block_width=hparams.local_attention_window)
                elif layer_type == "att_pseudolocal":
                    # This is an inefficient implementation of local attention, for the
                    # purpose of testing model quality.
                    def _pseudolocal_bias(x):
                        return common_attention.attention_bias_local(
                            tf.shape(x)[1], hparams.local_attention_window,
                            0 if hparams.mask_right else
                            hparams.local_attention_window)

                    pseudolocal_bias = dp(_pseudolocal_bias, x)
                    y = dp(
                        common_attention.multihead_attention, x, None,
                        pseudolocal_bias, hparams.attention_key_channels
                        or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size, hparams.hidden_size,
                        hparams.num_heads, hparams.attention_dropout)
                elif layer_type == "att_local_expert":
                    y, loss = dp(
                        common_attention.local_expert_attention,
                        x,
                        k=hparams.attention_moe_k,
                        loss_coef=hparams.attention_load_balance,
                        attention_num_experts=hparams.attention_num_experts,
                        train=hparams.mode == ModeKeys.TRAIN,
                        batch_coordinate=batch_coordinate,
                        mask_right=hparams.mask_right,
                        split_batch=bool(hparams.attention_split_batch),
                        attention_kq_size=hparams.attention_kq_size,
                        attention_v_size=hparams.attention_v_size)
                    # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                    extra_loss += tf.add_n(loss) / dp.n
                elif layer_type == "att_lsh":
                    if hparams.lsh_truncated:
                        attention_fn = common_attention.multihead_attention_sparse_truncated
                    else:
                        attention_fn = common_attention.multihead_attention_sparse_dot_prod
                    y, loss = dp(
                        attention_fn,
                        x,
                        None,
                        None,  # Bias is computed inside
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,

                        # Additional parameters
                        bi=[
                            common_attention.BatchInfo(
                                coordinates=batch_coordinate[i],
                                order=None,  # No future mask
                            ) for i in range(dp.n)
                        ],
                        use_map_fn=False,
                        experts_params=dict(nb_hyperplanes=4, ))
                    extra_loss += tf.add_n(loss) / dp.n
                elif layer_type == "moe":
                    y, loss = expert_utils.distributed_moe(
                        dp,
                        self._ps_devices,
                        x,
                        hparams.mode == ModeKeys.TRAIN,
                        input_size=hparams.hidden_size,
                        expert_fn=expert_fn,
                        num_experts=hparams.moe_num_experts,
                        k=hparams.moe_k,
                        loss_coef=hparams.moe_loss_coef)
                    extra_loss += loss
                elif layer_type == "ffn":
                    y = dp(
                        expert_utils.ffn_expert_fn(hparams.hidden_size,
                                                   ffn_hidden_sizes,
                                                   hparams.hidden_size),
                        dp(expert_utils.flatten_all_but_last, x))
                    y = dp(expert_utils.reshape_like, y, x)
                elif layer_type == "conv":
                    y = dp(
                        common_layers.conv1d,
                        x,
                        hparams.hidden_size,
                        hparams.kernel_height,
                        activation=tf.nn.relu,
                        padding="SAME",
                    )
                else:
                    assert False, "unknown sublayer %s" % layer_type
                if _should_postprocess(layer_type):
                    x = postprocess(x, y)
                else:
                    x = y
        x = preprocess(x)

        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, extra_loss
Esempio n. 7
0
    def model_fn_body_sharded(self, sharded_features):
        hparams = self._hparams
        dp = self._data_parallelism
        targets = sharded_features["targets"]
        inputs = sharded_features["inputs"]
        target_space = sharded_features["target_space_id"]

        inputs = dp(common_layers.flatten4d3d, inputs)
        targets = dp(common_layers.flatten4d3d, targets)

        def preprocess(x):
            return dp(common_layers.layer_preprocess, x, hparams)

        def postprocess(x, y):
            return dp(common_layers.layer_postprocess, x, y, hparams)

        (encoder_input, encoder_self_attention_bias,
         encoder_decoder_attention_bias) = dp(
            transformer.transformer_prepare_encoder,
            inputs, target_space, hparams)
        (decoder_input, decoder_self_attention_bias) = dp(
            transformer.transformer_prepare_decoder, targets, hparams)
        encoder_input = dp(tf.nn.dropout, encoder_input,
                           1.0 - hparams.layer_prepostprocess_dropout)
        decoder_input = dp(tf.nn.dropout, decoder_input,
                           1.0 - hparams.layer_prepostprocess_dropout)
        extra_loss = 0
        moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
        expert_fn = expert_utils.ffn_expert_fn(
            hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)
        x = encoder_input
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("encoder_layer_%d" % layer):
                with tf.variable_scope("encoder_self_attention"):
                    y = dp(
                        common_attention.multihead_attention,
                        preprocess(x),
                        None,
                        encoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout)
                    x = postprocess(x, y)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers_encoder.split(","):
                        y, loss = expert_utils.distributed_moe(
                            dp,
                            self._ps_devices,
                            preprocess(x),
                            hparams.mode == tf.contrib.learn.ModeKeys.TRAIN,
                            input_size=hparams.hidden_size,
                            expert_fn=expert_fn,
                            num_experts=hparams.moe_num_experts,
                            k=hparams.moe_k,
                            loss_coef=hparams.moe_loss_coef)
                        extra_loss += loss
                    else:
                        y = dp(
                            common_layers.conv_hidden_relu,
                            preprocess(x),
                            hparams.filter_size,
                            hparams.hidden_size,
                            dropout=hparams.relu_dropout)
                    x = postprocess(x, y)
        encoder_output = preprocess(x)
        x = decoder_input
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("decoder_layer_%d" % layer):
                with tf.variable_scope("decoder_self_attention"):
                    y = dp(
                        common_attention.multihead_attention,
                        preprocess(x),
                        None,
                        decoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout)
                    x = postprocess(x, y)
                with tf.variable_scope("encoder_decoder_attention"):
                    y = dp(
                        common_attention.multihead_attention,
                        preprocess(x),
                        encoder_output,
                        encoder_decoder_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout)
                    x = postprocess(x, y)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers_decoder.split(","):
                        y, loss = expert_utils.distributed_moe(
                            dp,
                            self._ps_devices,
                            preprocess(x),
                            hparams.mode == tf.contrib.learn.ModeKeys.TRAIN,
                            input_size=hparams.hidden_size,
                            expert_fn=expert_fn,
                            num_experts=hparams.moe_num_experts,
                            k=hparams.moe_k,
                            loss_coef=hparams.moe_loss_coef)
                        extra_loss += loss
                    else:
                        y = dp(
                            common_layers.conv_hidden_relu,
                            preprocess(x),
                            hparams.filter_size,
                            hparams.hidden_size,
                            dropout=hparams.relu_dropout)
                    x = postprocess(x, y)
        x = preprocess(x)
        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, extra_loss
    def model_fn_body_sharded(self, sharded_features):
        # Remove dropout if not training
        hparams = self._hparams
        dp = self._data_parallelism
        if hparams.use_inputs:
            decoder_input = dp(tf.squeeze, sharded_features["inputs"], 2)
            decoder_self_attention_bias = None
        else:
            targets = sharded_features["targets"]
            targets = dp(tf.squeeze, targets, 2)
            (decoder_input, decoder_self_attention_bias,
             pad_remover) = dp(attention_lm_moe_prepare_decoder, targets,
                               hparams)

        def preprocess(x):
            return dp(common_layers.layer_preprocess, x, hparams)

        def postprocess(x, y):
            return dp(common_layers.layer_postprocess, x, y, hparams)

        x = dp(tf.nn.dropout, decoder_input,
               1.0 - hparams.layer_prepostprocess_dropout)
        extra_loss = 0.0
        moe_hidden_sizes = [
            int(s) for s in hparams.moe_hidden_sizes.split(",")
        ]
        if hparams.diet_experts:
            hsize, = moe_hidden_sizes

            def _diet_expert(x):
                return diet.diet_expert(x, hsize,
                                        diet.diet_adam_optimizer_params())

            expert_fn = _diet_expert
        else:
            expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size,
                                                   moe_hidden_sizes,
                                                   hparams.hidden_size)

        if not hparams.use_inputs:
            # As preprocess and postprocess are called with batch of size one (all
            # batches concatenated), we just make sure that batch_norm is not use (
            # should not either way)
            assert hparams.norm_type != "batch"

            tf.logging.info(
                "Applying Padding Remover for the attention experts")

            dp_remove_pad = functools.partial(dp,
                                              remove_pad,
                                              pad_remover=pad_remover,
                                              mode=hparams.mode)
            dp_restore_pad = functools.partial(dp,
                                               restore_pad,
                                               ref_x=x,
                                               pad_remover=pad_remover,
                                               mode=hparams.mode)
        else:
            # Using identity function: No effect
            dp_remove_pad = lambda x: x
            dp_restore_pad = lambda x: x

        if hparams.attention_exp_factor != 0:
            tf.logging.info(
                "Expand/compress tokens before sending them to experts")
            dp_expand_bc = lambda x: dp(  # pylint: disable=g-long-lambda
                expand_batch_coordinates, x, hparams.attention_exp_factor)
            dp_expand_x = lambda x: dp(  # pylint: disable=g-long-lambda
                common_attention.deconv_elems_1d, x, hparams.
                attention_exp_factor, hparams.attention_exp_inputdim)
            dp_compress_x = lambda x, l: dp(  # pylint: disable=g-long-lambda
                common_attention.conv_elems_1d, x, hparams.
                attention_exp_factor, l)
        else:
            dp_expand_bc = lambda x: x
            dp_expand_x = lambda x: x
            dp_compress_x = lambda x, l: x

        def print_shape(x, suffix, debug=False):
            # To help debugging, print the input/output shapes at inference and eval
            # Inference for long sequences can take a long time, so that's help to
            # see the progession of the generation
            if not debug and hparams.mode == ModeKeys.TRAIN:
                return x
            return tf.Print(x, [tf.shape(x)], "shape_x_{}".format(suffix))

        with tf.name_scope("batch_coordinate_preprocess"):
            batch_coordinate = dp(get_batch_coordinate, x)
            batch_coordinate = dp_remove_pad(batch_coordinate)
            batch_coordinate = dp_expand_bc(batch_coordinate)
            batch_order = dp(get_batch_coordinate, x, axis=-1)
            batch_order = dp_remove_pad(batch_order)
            batch_order = dp_expand_bc(batch_order)

        x = dp(print_shape, x, "in")

        assert hparams.batch_size >= hparams.max_length

        num_hidden_layers = (len(hparams.attention_layers)
                             or hparams.num_hidden_layers)
        for layer in xrange(num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):

                # Use the layer type defined in attention_layers
                if hparams.attention_layers:
                    attention_type = LAYER_SYMBOLS[
                        hparams.attention_layers[layer]]
                else:
                    attention_type = hparams.attention_type

                with tf.variable_scope("attention_{}".format(attention_type)):
                    if attention_type in [
                            AttentionType.MULTIHEAD,
                            AttentionType.MULTIHEAD_FULL
                    ]:
                        attention_dot_type = ("local_mask_right"
                                              if hparams.attention_local else
                                              "dot_product")
                        if attention_type == AttentionType.MULTIHEAD_FULL:
                            attention_dot_type = "dot_product"
                        y = dp(common_attention.multihead_attention,
                               preprocess(x),
                               None,
                               decoder_self_attention_bias,
                               hparams.attention_key_channels
                               or hparams.hidden_size,
                               hparams.attention_value_channels
                               or hparams.hidden_size,
                               hparams.hidden_size,
                               hparams.num_heads,
                               hparams.attention_dropout,
                               attention_type=attention_dot_type,
                               block_length=hparams.attention_block_length,
                               name="decoder_self_attention")
                    elif attention_type == AttentionType.SPARSE_MULTIHEAD:
                        x_in = preprocess(x)
                        x_in = dp_remove_pad(x_in)
                        y, loss_experts = dp(
                            common_attention.
                            multihead_attention_sparse_dot_prod,
                            x_in,
                            None,
                            None,  # Bias is computed inside
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,

                            # Additional parameters
                            bi=[
                                common_attention.BatchInfo(
                                    coordinates=batch_coordinate[i],
                                    order=batch_order[i],  # No future mask
                                ) for i in range(dp.n)
                            ],
                            use_map_fn=hparams.lsh_use_map_fn,
                            experts_params=dict(
                                nb_hyperplanes=hparams.lsh_num_hyperplanes, ),
                        )
                        y = dp_restore_pad(y)

                        # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                        extra_loss += tf.add_n(loss_experts) / dp.n
                    elif attention_type == AttentionType.SPARSE_MULTIHEAD_TRUNCATED:
                        x_in = preprocess(x)
                        y, loss_experts = dp(
                            common_attention.
                            multihead_attention_sparse_truncated,
                            x_in,
                            None,
                            None,  # Bias is computed inside
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,

                            # Additional parameters
                            bi=[
                                common_attention.BatchInfo(
                                    coordinates=batch_coordinate[i],
                                    order=batch_order[i],  # No future mask
                                ) for i in range(dp.n)
                            ],
                            mask_right=True,
                            experts_params=dict(
                                nb_hyperplanes=hparams.lsh_num_hyperplanes, ),
                        )

                        # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                        extra_loss += tf.add_n(loss_experts) / dp.n
                    elif attention_type == AttentionType.MEMORY_EFFICIENT:
                        assert hparams.layer_preprocess_sequence == "n"
                        y = dp(common_attention.
                               multihead_self_attention_memory_efficient,
                               x,
                               decoder_self_attention_bias,
                               hparams.num_heads,
                               name="decoder_self_attention")
                    elif attention_type == AttentionType.MULTIHEAD_REDUCED:
                        y = dp(
                            common_attention.multihead_self_attention_reduced,
                            preprocess(x),
                            factor=hparams.attention_red_factor,
                            reduction_type=hparams.attention_reduction_type,
                            nonlinearity=hparams.attention_nonlinearity,
                            multihead_params=dict(
                                total_key_depth=hparams.attention_key_channels
                                or hparams.hidden_size,
                                total_value_depth=hparams.
                                attention_value_channels
                                or hparams.hidden_size,
                                num_heads=hparams.num_heads,
                                dropout_rate=hparams.attention_dropout,
                            ))
                    elif attention_type == AttentionType.LOCAL_EXPERTS:
                        x_in = preprocess(x)
                        x_in = dp_remove_pad(x_in)
                        x_in = dp_expand_x(x_in)
                        y, loss = dp(
                            common_attention.local_expert_attention,
                            x_in,
                            k=hparams.attention_moe_k,
                            loss_coef=hparams.attention_load_balance,
                            attention_num_experts=hparams.
                            attention_num_experts,
                            train=hparams.mode == ModeKeys.TRAIN,
                            batch_coordinate=batch_coordinate,
                            mask_right=not hparams.use_inputs,
                            split_batch=bool(hparams.attention_split_batch),
                            attention_num_head=hparams.attention_num_head,
                            attention_kq_size=hparams.attention_kq_size,
                            attention_v_size=hparams.attention_v_size)
                        y = dp_compress_x(y, x[0].get_shape().as_list()[-1])
                        y = dp_restore_pad(y)
                        # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                        extra_loss += tf.add_n(loss) / dp.n
                    else:
                        raise ValueError("Only {} supported for now.".format(
                            AttentionType.get_choices()))
                    x = postprocess(x, y)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers.split(","):
                        y, loss = expert_utils.distributed_moe(
                            dp,
                            self._ps_devices,
                            preprocess(x),
                            hparams.mode == ModeKeys.TRAIN,
                            input_size=hparams.hidden_size,
                            expert_fn=expert_fn,
                            num_experts=hparams.moe_num_experts,
                            k=hparams.moe_k,
                            loss_coef=hparams.moe_loss_coef)
                        extra_loss += loss
                    elif hparams.memory_efficient_ffn:
                        assert hparams.layer_preprocess_sequence == "n"
                        y = dp(common_layers.conv_hidden_relu_memory_efficient,
                               x, hparams.filter_size)
                    else:
                        additional_conv_params = dict()
                        if hparams.use_sepconv:
                            additional_conv_params = dict(
                                padding="LEFT",
                                # Parameters copied from the transformer model
                                kernel_size=(3, 1),
                                second_kernel_size=(31, 1),
                            )
                        y = dp(common_layers.conv_hidden_relu,
                               preprocess(x),
                               hparams.filter_size,
                               hparams.hidden_size,
                               dropout=hparams.relu_dropout,
                               **additional_conv_params)
                    x = postprocess(x, y)
        x = preprocess(x)

        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, extra_loss
    def model_fn_body_sharded(self, sharded_features):
        train = self._hparams.mode == tf.estimator.ModeKeys.TRAIN
        dp = self._data_parallelism
        hparams = self._hparams

        def flatten(inputs):
            return tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)

        inputs = dp(flatten, sharded_features["inputs"])
        inputs_pad = dp(slicenet.embedding_to_padding, inputs)
        inputs_mask = dp(lambda x: 1.0 - x, inputs_pad)
        inputs_encoded = dp(common_layers.add_timing_signal, inputs)
        expert_loss = 0.0
        for i in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("enc_layer_%d" % i):
                inputs_encoded, moe_loss = conv_experts(
                    inputs_encoded, hparams, dp, self._ps_devices, "SAME",
                    inputs_mask, i)
                expert_loss += tf.reduce_mean(moe_loss) * hparams.moe_loss_coef

        # If we're just predicing a class, there is no use for a decoder, return.
        if isinstance(hparams.problems[self._problem_idx].target_modality,
                      modalities.ClassLabelModality):
            return inputs_encoded, tf.reduce_mean(expert_loss)

        # Decoder.
        inputs3d = dp(tf.squeeze, inputs, 2)
        inputs_encoded3d = dp(tf.squeeze, inputs_encoded, 2)
        encoder_padding = dp(common_attention.embedding_to_padding, inputs3d)
        encoder_attention_bias = dp(
            common_attention.attention_bias_ignore_padding, encoder_padding)
        targets = dp(common_layers.flatten4d3d, sharded_features["targets"])
        target_space_emb = dp(slicenet.embed_target_space,
                              sharded_features["target_space_id"],
                              hparams.hidden_size)

        (decoder_input,
         decoder_self_attention_bias) = dp(prepare_decoder, targets,
                                           target_space_emb)

        moe_hidden_sizes = [
            int(s) for s in hparams.moe_hidden_sizes.split(",")
        ]
        expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size,
                                               moe_hidden_sizes,
                                               hparams.hidden_size)
        x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.dropout)
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("dec_layer_%d" % layer):
                with tf.variable_scope("attention"):
                    y = dp(common_attention.multihead_attention,
                           x,
                           None,
                           decoder_self_attention_bias,
                           hparams.hidden_size,
                           hparams.hidden_size,
                           hparams.hidden_size,
                           hparams.num_heads,
                           hparams.attention_dropout,
                           name="decoder_self_attention")
                    z = dp(common_attention.multihead_attention,
                           y,
                           inputs_encoded3d,
                           encoder_attention_bias,
                           hparams.hidden_size,
                           hparams.hidden_size,
                           hparams.hidden_size,
                           hparams.num_heads,
                           hparams.attention_dropout,
                           name="encdec_attention")
                    x = dp(residual_fn3, x, y, z, hparams)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers.split(","):
                        y, moe_loss = expert_utils.distributed_moe(
                            dp,
                            self._ps_devices,
                            x,
                            train,
                            input_size=hparams.hidden_size,
                            expert_fn=expert_fn,
                            num_experts=hparams.moe_num_experts,
                            k=hparams.moe_k,
                            loss_coef=hparams.moe_loss_coef)
                        expert_loss += tf.reduce_mean(moe_loss)
                    else:
                        y = dp(common_layers.conv_hidden_relu,
                               x,
                               hparams.filter_size,
                               hparams.hidden_size,
                               dropout=hparams.dropout)
                    x = dp(residual_fn2, x, y, hparams)

        x = dp(tf.expand_dims, x, 2)
        return x, tf.reduce_mean(expert_loss)
Esempio n. 10
0
  def body_sharded(self, sharded_features):
    # Remove dropout if not training
    hparams = self._hparams
    dp = self._data_parallelism
    if hparams.use_inputs:
      decoder_input = dp(tf.squeeze, sharded_features["inputs"], 2)
      decoder_self_attention_bias = None
    else:
      targets = sharded_features["targets"]
      targets = dp(tf.squeeze, targets, 2)
      (decoder_input, decoder_self_attention_bias, pad_remover) = dp(
          attention_lm_moe_prepare_decoder, targets, hparams)

    def preprocess(x):
      return dp(common_layers.layer_preprocess, x, hparams)

    def postprocess(x, y):
      return dp(common_layers.layer_postprocess, x, y, hparams)

    x = dp(tf.nn.dropout, decoder_input,
           1.0 - hparams.layer_prepostprocess_dropout)
    extra_loss = 0.0
    moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
    if hparams.diet_experts:
      hsize, = moe_hidden_sizes

      def _diet_expert(x):
        return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params())

      expert_fn = _diet_expert
    else:
      expert_fn = expert_utils.ffn_expert_fn(
          hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)

    if not hparams.use_inputs:
      # As preprocess and postprocess are called with batch of size one (all
      # batches concatenated), we just make sure that batch_norm is not use (
      # should not either way)
      assert hparams.norm_type != "batch"

      tf.logging.info("Applying Padding Remover for the attention experts")

      dp_remove_pad = functools.partial(
          dp, remove_pad, pad_remover=pad_remover, mode=hparams.mode)
      dp_restore_pad = functools.partial(
          dp, restore_pad, ref_x=x, pad_remover=pad_remover, mode=hparams.mode)
    else:
      # Using identity function: No effect
      dp_remove_pad = lambda x: x
      dp_restore_pad = lambda x: x

    if hparams.attention_exp_factor != 0:
      tf.logging.info("Expand/compress tokens before sending them to experts")
      dp_expand_bc = lambda x: dp(  # pylint: disable=g-long-lambda
          expand_batch_coordinates,
          x,
          hparams.attention_exp_factor)
      dp_expand_x = lambda x: dp(  # pylint: disable=g-long-lambda
          common_attention.deconv_elems_1d,
          x,
          hparams.attention_exp_factor,
          hparams.attention_exp_inputdim)
      dp_compress_x = lambda x, l: dp(  # pylint: disable=g-long-lambda
          common_attention.conv_elems_1d,
          x,
          hparams.attention_exp_factor,
          l)
    else:
      dp_expand_bc = lambda x: x
      dp_expand_x = lambda x: x
      dp_compress_x = lambda x, l: x

    def print_shape(x, suffix, debug=False):
      # To help debugging, print the input/output shapes at inference and eval
      # Inference for long sequences can take a long time, so that's help to
      # see the progession of the generation
      if not debug and hparams.mode == ModeKeys.TRAIN:
        return x
      return tf.Print(x, [tf.shape(x)], "shape_x_{}".format(suffix))

    with tf.name_scope("batch_coordinate_preprocess"):
      batch_coordinate = dp(get_batch_coordinate, x)
      batch_coordinate = dp_remove_pad(batch_coordinate)
      batch_coordinate = dp_expand_bc(batch_coordinate)
      batch_order = dp(get_batch_coordinate, x, axis=-1)
      batch_order = dp_remove_pad(batch_order)
      batch_order = dp_expand_bc(batch_order)

    x = dp(print_shape, x, "in")

    assert hparams.batch_size >= hparams.max_length

    num_hidden_layers = (
        len(hparams.attention_layers) or hparams.num_hidden_layers)
    for layer in xrange(num_hidden_layers):
      with tf.variable_scope("layer_%d" % layer):

        # Use the layer type defined in attention_layers
        if hparams.attention_layers:
          attention_type = LAYER_SYMBOLS[hparams.attention_layers[layer]]
        else:
          attention_type = hparams.attention_type

        with tf.variable_scope(
            "attention_{}".format(attention_type)):
          if attention_type in [
              AttentionType.MULTIHEAD, AttentionType.MULTIHEAD_FULL]:
            attention_dot_type = (
                "local_mask_right" if hparams.attention_local else
                "dot_product")
            if attention_type == AttentionType.MULTIHEAD_FULL:
              attention_dot_type = "dot_product"
            y = dp(
                common_attention.multihead_attention,
                preprocess(x),
                None,
                decoder_self_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                attention_type=attention_dot_type,
                block_length=hparams.attention_block_length,
                name="decoder_self_attention")
          elif attention_type == AttentionType.SPARSE_MULTIHEAD:
            x_in = preprocess(x)
            x_in = dp_remove_pad(x_in)
            y, loss_experts = dp(
                common_attention.multihead_attention_sparse_dot_prod,
                x_in,
                None,
                None,  # Bias is computed inside
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,

                # Additional parameters
                bi=[common_attention.BatchInfo(
                    coordinates=batch_coordinate[i],
                    order=batch_order[i],  # No future mask
                ) for i in range(dp.n)],
                use_map_fn=hparams.lsh_use_map_fn,
                experts_params=dict(
                    nb_hyperplanes=hparams.lsh_num_hyperplanes,
                ),
            )
            y = dp_restore_pad(y)

            # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
            extra_loss += tf.add_n(loss_experts) / dp.n
          elif attention_type == AttentionType.SPARSE_MULTIHEAD_TRUNCATED:
            x_in = preprocess(x)
            y, loss_experts = dp(
                common_attention.multihead_attention_sparse_truncated,
                x_in,
                None,
                None,  # Bias is computed inside
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,

                # Additional parameters
                bi=[common_attention.BatchInfo(
                    coordinates=batch_coordinate[i],
                    order=batch_order[i],  # No future mask
                ) for i in range(dp.n)],
                mask_right=True,
                experts_params=dict(
                    nb_hyperplanes=hparams.lsh_num_hyperplanes,
                ),
            )

            # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
            extra_loss += tf.add_n(loss_experts) / dp.n
          elif attention_type == AttentionType.MEMORY_EFFICIENT:
            assert hparams.layer_preprocess_sequence == "n"
            y = dp(
                common_attention.multihead_self_attention_memory_efficient,
                x,
                decoder_self_attention_bias,
                hparams.num_heads,
                name="decoder_self_attention")
          elif attention_type == AttentionType.MULTIHEAD_REDUCED:
            y = dp(
                common_attention.multihead_self_attention_reduced,
                preprocess(x),
                factor=hparams.attention_red_factor,
                reduction_type=hparams.attention_reduction_type,
                nonlinearity=hparams.attention_nonlinearity,
                multihead_params=dict(
                    total_key_depth=
                    hparams.attention_key_channels or hparams.hidden_size,
                    total_value_depth=
                    hparams.attention_value_channels or hparams.hidden_size,
                    num_heads=hparams.num_heads,
                    dropout_rate=hparams.attention_dropout,
                ))
          elif attention_type == AttentionType.LOCAL_EXPERTS:
            x_in = preprocess(x)
            x_in = dp_remove_pad(x_in)
            x_in = dp_expand_x(x_in)
            y, loss = dp(
                common_attention.local_expert_attention,
                x_in,
                k=hparams.attention_moe_k,
                loss_coef=hparams.attention_load_balance,
                attention_num_experts=hparams.attention_num_experts,
                train=hparams.mode == ModeKeys.TRAIN,
                batch_coordinate=batch_coordinate,
                mask_right=not hparams.use_inputs,
                split_batch=bool(hparams.attention_split_batch),
                attention_num_head=hparams.attention_num_head,
                attention_kq_size=hparams.attention_kq_size,
                attention_v_size=hparams.attention_v_size)
            y = dp_compress_x(y, x[0].get_shape().as_list()[-1])
            y = dp_restore_pad(y)
            # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
            extra_loss += tf.add_n(loss) / dp.n
          else:
            raise ValueError("Only {} supported for now.".format(
                AttentionType.get_choices()))
          x = postprocess(x, y)
        with tf.variable_scope("ffn"):
          if str(layer) in hparams.moe_layers.split(","):
            y, loss = expert_utils.distributed_moe(
                dp,
                self._ps_devices,
                preprocess(x),
                hparams.mode == ModeKeys.TRAIN,
                input_size=hparams.hidden_size,
                expert_fn=expert_fn,
                num_experts=hparams.moe_num_experts,
                k=hparams.moe_k,
                loss_coef=hparams.moe_loss_coef)
            extra_loss += loss
          elif hparams.memory_efficient_ffn:
            assert hparams.layer_preprocess_sequence == "n"
            y = dp(
                common_layers.conv_hidden_relu_memory_efficient,
                x,
                hparams.filter_size)
          else:
            additional_conv_params = dict()
            if hparams.use_sepconv:
              additional_conv_params = dict(
                  padding="LEFT",
                  # Parameters copied from the transformer model
                  kernel_size=(3, 1),
                  second_kernel_size=(31, 1),
              )
            y = dp(
                common_layers.conv_hidden_relu,
                preprocess(x),
                hparams.filter_size,
                hparams.hidden_size,
                dropout=hparams.relu_dropout,
                **additional_conv_params
            )
          x = postprocess(x, y)
    x = preprocess(x)

    decoder_output = dp(tf.expand_dims, x, 2)
    return decoder_output, extra_loss
Esempio n. 11
0
    def model_fn_body_sharded(self, sharded_features):
        # Remove dropout if not training
        hparams = self._hparams
        dp = self._data_parallelism
        targets = sharded_features["targets"]
        targets = dp(tf.squeeze, targets, 2)

        def preprocess(x):
            return dp(common_layers.layer_preprocess, x, hparams)

        def postprocess(x, y):
            return dp(common_layers.layer_postprocess, x, y, hparams)

        (decoder_input,
         decoder_self_attention_bias) = dp(attention_lm_moe_prepare_decoder,
                                           targets, hparams)

        x = dp(tf.nn.dropout, decoder_input,
               1.0 - hparams.layer_prepostprocess_dropout)
        extra_loss = 0.0
        moe_hidden_sizes = [
            int(s) for s in hparams.moe_hidden_sizes.split(",")
        ]
        if hparams.diet_experts:
            hsize, = moe_hidden_sizes

            def _diet_expert(x):
                return diet.diet_expert(x, hsize,
                                        diet.diet_adam_optimizer_params())

            expert_fn = _diet_expert
        else:
            expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size,
                                                   moe_hidden_sizes,
                                                   hparams.hidden_size)
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                with tf.variable_scope("attention_{}".format(
                        hparams.attention_moe_type)):
                    x = preprocess(x)
                    if hparams.attention_moe_type == AttentionMoeType.NONE:
                        y = dp(common_attention.multihead_attention,
                               x,
                               None,
                               decoder_self_attention_bias,
                               hparams.attention_key_channels
                               or hparams.hidden_size,
                               hparams.attention_value_channels
                               or hparams.hidden_size,
                               hparams.hidden_size,
                               hparams.num_heads,
                               hparams.attention_dropout,
                               name="decoder_self_attention")
                    elif hparams.attention_moe_type == AttentionMoeType.LOCAL:
                        y, loss = dp(
                            common_attention.local_expert_attention,
                            x,
                            k=2,
                            loss_coef=1e-2,
                            attention_num_experts=hparams.
                            attention_num_experts,
                            train=hparams.mode ==
                            tf.contrib.learn.ModeKeys.TRAIN,
                            mask_right=True,
                            attention_kq_size=hparams.attention_kq_size,
                            attention_v_size=hparams.attention_v_size)
                        # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                        extra_loss += tf.add_n(loss) / dp.n
                    else:
                        raise ValueError("Only {} supported for now.".format(
                            AttentionMoeType.get_choices()))
                    x = postprocess(x, y)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers.split(","):
                        y, loss = expert_utils.distributed_moe(
                            dp,
                            self._ps_devices,
                            preprocess(x),
                            hparams.mode == tf.contrib.learn.ModeKeys.TRAIN,
                            input_size=hparams.hidden_size,
                            expert_fn=expert_fn,
                            num_experts=hparams.moe_num_experts,
                            k=hparams.moe_k,
                            loss_coef=hparams.moe_loss_coef)
                        extra_loss += loss
                    else:
                        y = dp(common_layers.conv_hidden_relu,
                               preprocess(x),
                               hparams.filter_size,
                               hparams.hidden_size,
                               dropout=hparams.relu_dropout)
                    x = postprocess(x, y)
        x = preprocess(x)
        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, extra_loss
Esempio n. 12
0
    def model_fn_body_sharded(self, sharded_features):
        # Remove dropout if not training
        hparams = self._hparams
        dp = self._data_parallelism
        if hparams.use_inputs:
            decoder_input = dp(tf.squeeze, sharded_features["inputs"], 2)
            decoder_self_attention_bias = None
        else:
            targets = sharded_features["targets"]
            targets = dp(tf.squeeze, targets, 2)
            (decoder_input, decoder_self_attention_bias,
             pad_remover) = dp(attention_lm_moe_prepare_decoder, targets,
                               hparams)

        def preprocess(x):
            return dp(common_layers.layer_preprocess, x, hparams)

        def postprocess(x, y):
            return dp(common_layers.layer_postprocess, x, y, hparams)

        x = dp(tf.nn.dropout, decoder_input,
               1.0 - hparams.layer_prepostprocess_dropout)
        extra_loss = 0.0
        moe_hidden_sizes = [
            int(s) for s in hparams.moe_hidden_sizes.split(",")
        ]
        if hparams.diet_experts:
            hsize, = moe_hidden_sizes

            def _diet_expert(x):
                return diet.diet_expert(x, hsize,
                                        diet.diet_adam_optimizer_params())

            expert_fn = _diet_expert
        else:
            expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size,
                                                   moe_hidden_sizes,
                                                   hparams.hidden_size)

        if (hparams.attention_type == AttentionType.LOCAL_EXPERTS
                and not hparams.use_inputs):
            # As preprocess and postprocess are called with batch of size one (all
            # batches concatenated), we just make sure that batch_norm is not use (
            # should not either way)
            assert hparams.norm_type != "batch"

            tf.logging.info(
                "Applying Padding Remover for the attention experts")

            dp_remove_pad = functools.partial(dp,
                                              remove_pad,
                                              pad_remover=pad_remover,
                                              mode=hparams.mode)
            dp_restore_pad = functools.partial(dp,
                                               restore_pad,
                                               ref_x=x,
                                               pad_remover=pad_remover,
                                               mode=hparams.mode)
        else:
            # Using identity function: No effect
            dp_remove_pad = lambda x: x
            dp_restore_pad = lambda x: x

        def print_shape(x, suffix, debug=False):
            # To help debugging, print the input/output shapes at inference and eval
            # Inference for long sequences can take a long time, so that's help to
            # see the progession of the generation
            if not debug and hparams.mode == ModeKeys.TRAIN:
                return x
            return tf.Print(x, [tf.shape(x)], "shape_x_{}".format(suffix))

        batch_coordinate = dp(get_batch_coordinate, x)
        batch_coordinate = dp_remove_pad(batch_coordinate)

        x = dp(print_shape, x, "in")
        x = dp_remove_pad(x)
        x = dp(print_shape, x, "in_flat")

        assert hparams.batch_size >= hparams.max_length

        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                with tf.variable_scope("attention_{}".format(
                        hparams.attention_type)):
                    if hparams.attention_type == AttentionType.MULTIHEAD:
                        y = dp(common_attention.multihead_attention,
                               preprocess(x),
                               None,
                               decoder_self_attention_bias,
                               hparams.attention_key_channels
                               or hparams.hidden_size,
                               hparams.attention_value_channels
                               or hparams.hidden_size,
                               hparams.hidden_size,
                               hparams.num_heads,
                               hparams.attention_dropout,
                               attention_type=("local_mask_right"
                                               if hparams.attention_local else
                                               "dot_product"),
                               name="decoder_self_attention")
                    elif hparams.attention_type == AttentionType.MEMORY_EFFICIENT:
                        assert hparams.layer_preprocess_sequence == "n"
                        y = dp(common_attention.
                               multihead_self_attention_memory_efficient,
                               x,
                               decoder_self_attention_bias,
                               hparams.num_heads,
                               name="decoder_self_attention")
                    elif hparams.attention_type == AttentionType.LOCAL_EXPERTS:
                        y, loss = dp(
                            common_attention.local_expert_attention,
                            preprocess(x),
                            k=hparams.attention_moe_k,
                            loss_coef=hparams.attention_load_balance,
                            attention_num_experts=hparams.
                            attention_num_experts,
                            train=hparams.mode == ModeKeys.TRAIN,
                            batch_coordinate=batch_coordinate,
                            mask_right=not hparams.use_inputs,
                            split_batch=bool(hparams.attention_split_batch),
                            attention_kq_size=hparams.attention_kq_size,
                            attention_v_size=hparams.attention_v_size)
                        # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
                        extra_loss += tf.add_n(loss) / dp.n
                    else:
                        raise ValueError("Only {} supported for now.".format(
                            AttentionType.get_choices()))
                    x = postprocess(x, y)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers.split(","):
                        y, loss = expert_utils.distributed_moe(
                            dp,
                            self._ps_devices,
                            preprocess(x),
                            hparams.mode == ModeKeys.TRAIN,
                            input_size=hparams.hidden_size,
                            expert_fn=expert_fn,
                            num_experts=hparams.moe_num_experts,
                            k=hparams.moe_k,
                            loss_coef=hparams.moe_loss_coef)
                        extra_loss += loss
                    elif hparams.memory_efficient_ffn:
                        assert hparams.layer_preprocess_sequence == "n"
                        y = dp(common_layers.conv_hidden_relu_memory_efficient,
                               x, hparams.filter_size)
                    else:
                        x_in = preprocess(x)
                        additional_conv_params = dict()
                        if hparams.use_sepconv:
                            # Restore padding so sequences don't attend to each others
                            # restore_pad will apply a reshape like x_ref, to restore the
                            # original shape. Here this works because the last dimension is
                            # constant between the output of attention and the original input
                            # but it shouldn't necessarily be the case.
                            x_in = dp_restore_pad(x_in)
                            additional_conv_params = dict(
                                padding="LEFT",
                                # Parameters copied from the transformer model
                                kernel_size=(3, 1),
                                second_kernel_size=(31, 1),
                            )
                        y = dp(common_layers.conv_hidden_relu,
                               x_in,
                               hparams.filter_size,
                               hparams.hidden_size,
                               dropout=hparams.relu_dropout,
                               **additional_conv_params)
                        if hparams.use_sepconv:
                            y = dp_remove_pad(y)
                    x = postprocess(x, y)
        x = preprocess(x)

        x = dp_restore_pad(x)

        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, extra_loss
Esempio n. 13
0
  def body_sharded(self, sharded_features):
    # Remove dropout if not training
    hparams = self._hparams
    dp = self._data_parallelism
    x = dp(tf.squeeze, sharded_features["inputs"], 2)

    def preprocess(x):
      return dp(common_layers.layer_preprocess, x, hparams)

    def postprocess(x, y):
      return dp(common_layers.layer_postprocess, x, y, hparams)

    x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
    extra_loss = 0.0
    ffn_hidden_sizes = [int(s) for s in hparams.ffn_hidden_sizes.split(",")]
    moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
    if hparams.mask_right:

      def _bias(x):
        return common_attention.attention_bias_lower_triangle(
            common_layers.shape_list(x)[1])

      bias = dp(_bias, x)
    else:
      bias = tf.zeros([1, 1, 1, 1])
    if hparams.diet_experts:
      hsize, = moe_hidden_sizes

      def _diet_expert(x):
        return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params())

      expert_fn = _diet_expert
    else:
      expert_fn = expert_utils.ffn_expert_fn(
          hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)

    batch_coordinate = dp(get_batch_coordinate, x)

    layers = hparams.layers.strip(",").split(",")
    for layer_num, layer_type in enumerate(layers):
      with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
        if _should_preprocess(layer_type):
          x = preprocess(x)
        if layer_type == "timing":
          y = dp(common_attention.add_timing_signal_nd, x)
        elif layer_type == "pos_emb":
          y = dp(
              common_attention.add_positional_embedding_nd,
              x,
              hparams.max_length,
              name="pos_emb")
        elif layer_type == "att":
          y = dp(
              common_attention.multihead_attention,
              x,
              None,
              bias,  # bias
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout)
        elif layer_type == "att_grouped":
          multiplicative_overhead = (
              hparams.multiplicative_overhead if hparams.mode == ModeKeys.TRAIN
              else hparams.multiplicative_overhead_eval)
          y, loss = dp(
              common_attention.grouped_attention_multihead,
              x,
              x,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              num_groups=hparams.attention_num_groups,
              memory_target_density=hparams.memory_target_density,
              multiplicative_overhead=multiplicative_overhead,
              make_image_summary=hparams.attention_image_summary,
              mask_right=hparams.mask_right,
          )
          extra_loss += tf.add_n(loss) / dp.n
        elif layer_type == "att_memory_efficient":
          assert hparams.layer_preprocess_sequence == "n"
          y = dp(common_attention.multihead_self_attention_memory_efficient, x,
                 bias, hparams.num_heads)
        elif layer_type == "att_local":
          y = dp(
              common_attention.multihead_attention,
              x,
              None,
              None,  # bias
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              attention_type=("local_mask_right"
                              if hparams.mask_right else "local_unmasked"),
              block_length=hparams.local_attention_window,
              block_width=hparams.local_attention_window)
        elif layer_type == "att_pseudolocal":
          # This is an inefficient implementation of local attention, for the
          # purpose of testing model quality.
          def _pseudolocal_bias(x):
            return common_attention.attention_bias_local(
                common_layers.shape_list(x)[1], hparams.local_attention_window,
                0 if hparams.mask_right else hparams.local_attention_window)

          pseudolocal_bias = dp(_pseudolocal_bias, x)
          y = dp(common_attention.multihead_attention, x, None,
                 pseudolocal_bias, hparams.attention_key_channels or
                 hparams.hidden_size, hparams.attention_value_channels or
                 hparams.hidden_size, hparams.hidden_size, hparams.num_heads,
                 hparams.attention_dropout)
        elif layer_type == "att_local_expert":
          y, loss = dp(
              common_attention.local_expert_attention,
              x,
              k=hparams.attention_moe_k,
              loss_coef=hparams.attention_load_balance,
              attention_num_experts=hparams.attention_num_experts,
              train=hparams.mode == ModeKeys.TRAIN,
              batch_coordinate=batch_coordinate,
              mask_right=hparams.mask_right,
              split_batch=bool(hparams.attention_split_batch),
              attention_kq_size=hparams.attention_kq_size,
              attention_v_size=hparams.attention_v_size)
          # TODO(avaswani, epot, noam): Do we need to divide by num shards ?
          extra_loss += tf.add_n(loss) / dp.n
        elif layer_type == "att_lsh":
          if hparams.lsh_truncated:
            attention_fn = common_attention.multihead_attention_sparse_truncated
          else:
            attention_fn = common_attention.multihead_attention_sparse_dot_prod
          y, loss = dp(
              attention_fn,
              x,
              None,
              None,  # Bias is computed inside
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,

              # Additional parameters
              bi=[
                  common_attention.BatchInfo(
                      coordinates=batch_coordinate[i],
                      order=None,  # No future mask
                  ) for i in range(dp.n)
              ],
              use_map_fn=False,
              experts_params=dict(nb_hyperplanes=4,))
          extra_loss += tf.add_n(loss) / dp.n
        elif layer_type == "moe":
          y, loss = expert_utils.distributed_moe(
              dp,
              self._ps_devices,
              x,
              hparams.mode == ModeKeys.TRAIN,
              input_size=hparams.hidden_size,
              expert_fn=expert_fn,
              num_experts=hparams.moe_num_experts,
              k=hparams.moe_k,
              loss_coef=hparams.moe_loss_coef)
          extra_loss += loss
        elif layer_type == "ffn":
          y = dp(
              expert_utils.ffn_expert_fn(hparams.hidden_size, ffn_hidden_sizes,
                                         hparams.hidden_size),
              dp(expert_utils.flatten_all_but_last, x))
          y = dp(expert_utils.reshape_like, y, x)
        elif layer_type == "conv":
          y = dp(
              common_layers.conv1d,
              x,
              hparams.hidden_size,
              hparams.kernel_height,
              activation=tf.nn.relu,
              padding="SAME",
          )
        else:
          assert False, "unknown sublayer %s" % layer_type
        if _should_postprocess(layer_type):
          x = postprocess(x, y)
        else:
          x = y
    x = preprocess(x)

    decoder_output = dp(tf.expand_dims, x, 2)
    return decoder_output, extra_loss