def model_fn_body_sharded(self, sharded_features, train):
        # Remove dropout if not training
        hparams = copy.copy(self._hparams)
        if not train:
            hparams.attention_dropout = 0.
            hparams.relu_dropout = 0.
            hparams.residual_dropout = 0.
        dp = self._data_parallelism
        targets = sharded_features["targets"]
        targets = dp(tf.squeeze, targets, 2)

        (decoder_input,
         decoder_self_attention_bias) = dp(attention_lm_moe_prepare_decoder,
                                           targets, hparams)

        def residual_fn(x, y):
            return common_layers.layer_norm(
                x + tf.nn.dropout(y, 1.0 - hparams.residual_dropout))

        x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.residual_dropout)
        extra_loss = 0.0
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                with tf.variable_scope("attention"):
                    y = dp(common_attention.multihead_attention,
                           x,
                           None,
                           decoder_self_attention_bias,
                           hparams.attention_key_channels
                           or hparams.hidden_size,
                           hparams.attention_value_channels
                           or hparams.hidden_size,
                           hparams.hidden_size,
                           hparams.num_heads,
                           hparams.attention_dropout,
                           summaries=True,
                           name="decoder_self_attention")
                    x = dp(residual_fn, x, y)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers.split(","):
                        y, loss = common_layers.moe_layer(
                            dp, self._ps_devices, x, train,
                            hparams.hidden_size, hparams.moe_hidden_size,
                            hparams.moe_n1, hparams.moe_n2,
                            hparams.moe_loss_coef)
                        extra_loss += loss
                    else:
                        y = dp(common_layers.conv_hidden_relu,
                               x,
                               hparams.filter_size,
                               hparams.hidden_size,
                               dropout=hparams.relu_dropout)
                    x = dp(residual_fn, x, y)
        decoder_output = dp(tf.expand_dims, x, 2)
        return decoder_output, extra_loss
Ejemplo n.º 2
0
    def model_fn_body_sharded(self, sharded_features):
        # Remove dropout if not training
        hparams = self._hparams
        dp = self._data_parallelism
        targets = sharded_features["targets"]
        targets = dp(tf.squeeze, targets, 2)
        inputs = sharded_features["inputs"]
        inputs = dp(tf.squeeze, inputs, 2)

        decoder_input = dp(long_answer_prepare_decoder, inputs, targets,
                           hparams)

        def residual_fn(x, y):
            return common_layers.layer_norm(
                x + tf.nn.dropout(y, 1.0 - hparams.residual_dropout))

        x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.residual_dropout)
        extra_loss = 0.0
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                with tf.variable_scope("attention"):
                    y = dp(common_attention.multihead_attention,
                           x,
                           None,
                           None,
                           hparams.attention_key_channels
                           or hparams.hidden_size,
                           hparams.attention_value_channels
                           or hparams.hidden_size,
                           hparams.hidden_size,
                           hparams.num_heads,
                           hparams.attention_dropout,
                           attention_type="local_mask_right",
                           block_length=hparams.block_length,
                           name="decoder_self_attention")
                    x = dp(residual_fn, x, y)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers.split(","):
                        y, loss = common_layers.moe_layer(
                            dp, self._ps_devices, x,
                            hparams.mode == tf.contrib.learn.ModeKeys.TRAIN,
                            hparams.hidden_size, hparams.moe_hidden_size,
                            hparams.moe_n1, hparams.moe_n2,
                            hparams.moe_loss_coef)
                        extra_loss += loss
                    else:
                        y = dp(common_layers.conv_hidden_relu,
                               x,
                               hparams.filter_size,
                               hparams.hidden_size,
                               dropout=hparams.relu_dropout)
                    x = dp(residual_fn, x, y)
        x = dp(long_answer_output, x, inputs)
        return x, extra_loss
Ejemplo n.º 3
0
def conv_experts(xs, hparams, dp, ps, padding, mask, layer_id):
    """Convolutions + Mixture-of-Experts layer."""
    del layer_id  # Unused.
    train = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN,
    conv_out = dp(conv_res_step, xs, hparams, padding, mask)
    loss = 0.0
    moe_out, loss = common_layers.moe_layer(dp, ps, xs, train,
                                            hparams.hidden_size,
                                            hparams.filter_size,
                                            hparams.moe_n1, hparams.moe_n2,
                                            1.0)
    return dp(residual_fn3, xs, moe_out, conv_out, hparams), loss
  def model_fn_body_sharded(self, sharded_features, train):
    # Remove dropout if not training
    hparams = copy.copy(self._hparams)
    if not train:
      hparams.attention_dropout = 0.
      hparams.relu_dropout = 0.
      hparams.residual_dropout = 0.
    dp = self._data_parallelism
    targets = sharded_features["targets"]
    targets = dp(tf.squeeze, targets, 2)

    (decoder_input, decoder_self_attention_bias) = dp(
        attention_lm_moe_prepare_decoder, targets, hparams)

    def residual_fn(x, y):
      return common_layers.layer_norm(x + tf.nn.dropout(
          y, 1.0 - hparams.residual_dropout))

    x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.residual_dropout)
    extra_loss = 0.0
    for layer in xrange(hparams.num_hidden_layers):
      with tf.variable_scope("layer_%d" % layer):
        with tf.variable_scope("attention"):
          y = dp(common_attention.multihead_attention,
                 x,
                 None,
                 decoder_self_attention_bias,
                 hparams.attention_key_channels or hparams.hidden_size,
                 hparams.attention_value_channels or hparams.hidden_size,
                 hparams.hidden_size,
                 hparams.num_heads,
                 hparams.attention_dropout,
                 summaries=True,
                 name="decoder_self_attention")
          x = dp(residual_fn, x, y)
        with tf.variable_scope("ffn"):
          if str(layer) in hparams.moe_layers.split(","):
            y, loss = common_layers.moe_layer(
                dp, self._ps_devices, x, train, hparams.hidden_size,
                hparams.moe_hidden_size, hparams.moe_n1, hparams.moe_n2,
                hparams.moe_loss_coef)
            extra_loss += loss
          else:
            y = dp(common_layers.conv_hidden_relu,
                   x,
                   hparams.filter_size,
                   hparams.hidden_size,
                   dropout=hparams.relu_dropout)
          x = dp(residual_fn, x, y)
    decoder_output = dp(tf.expand_dims, x, 2)
    return decoder_output, extra_loss
Ejemplo n.º 5
0
    def model_fn_body_sharded(self, sharded_features):
        train = self._hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
        dp = self._data_parallelism
        hparams = self._hparams

        def flatten(inputs):
            return tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)

        inputs = dp(flatten, sharded_features["inputs"])
        inputs_pad = dp(slicenet.embedding_to_padding, inputs)
        inputs_mask = dp(lambda x: 1.0 - x, inputs_pad)
        inputs_encoded = dp(common_layers.add_timing_signal, inputs)
        expert_loss = 0.0
        for i in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("enc_layer_%d" % i):
                inputs_encoded, moe_loss = conv_experts(
                    inputs_encoded, hparams, dp, self._ps_devices, "SAME",
                    inputs_mask, i)
                expert_loss += tf.reduce_mean(moe_loss) * hparams.moe_loss_coef

        # If we're just predicing a class, there is no use for a decoder, return.
        if isinstance(hparams.problems[self._problem_idx].target_modality,
                      modalities.ClassLabelModality):
            return inputs_encoded, tf.reduce_mean(expert_loss)

        # Decoder.
        inputs3d = dp(tf.squeeze, inputs, 2)
        inputs_encoded3d = dp(tf.squeeze, inputs_encoded, 2)
        encoder_padding = dp(common_attention.embedding_to_padding, inputs3d)
        encoder_attention_bias = dp(
            common_attention.attention_bias_ignore_padding, encoder_padding)
        targets = dp(common_layers.flatten4d3d, sharded_features["targets"])
        target_space_emb = dp(slicenet.embed_target_space,
                              sharded_features["target_space_id"],
                              hparams.hidden_size)

        (decoder_input,
         decoder_self_attention_bias) = dp(prepare_decoder, targets,
                                           target_space_emb)

        x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.dropout)
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("dec_layer_%d" % layer):
                with tf.variable_scope("attention"):
                    y = dp(common_attention.multihead_attention,
                           x,
                           None,
                           decoder_self_attention_bias,
                           hparams.hidden_size,
                           hparams.hidden_size,
                           hparams.hidden_size,
                           hparams.num_heads,
                           hparams.attention_dropout,
                           name="decoder_self_attention")
                    z = dp(common_attention.multihead_attention,
                           y,
                           inputs_encoded3d,
                           encoder_attention_bias,
                           hparams.hidden_size,
                           hparams.hidden_size,
                           hparams.hidden_size,
                           hparams.num_heads,
                           hparams.attention_dropout,
                           name="encdec_attention")
                    x = dp(residual_fn3, x, y, z, hparams)
                with tf.variable_scope("ffn"):
                    if str(layer) in hparams.moe_layers.split(","):
                        y, moe_loss = common_layers.moe_layer(
                            dp, self._ps_devices, x, train,
                            hparams.hidden_size, hparams.filter_size,
                            hparams.moe_n1, hparams.moe_n2,
                            hparams.moe_loss_coef)
                        expert_loss += tf.reduce_mean(moe_loss)
                    else:
                        y = dp(common_layers.conv_hidden_relu,
                               x,
                               hparams.filter_size,
                               hparams.hidden_size,
                               dropout=hparams.dropout)
                    x = dp(residual_fn2, x, y, hparams)

        x = dp(tf.expand_dims, x, 2)
        return x, tf.reduce_mean(expert_loss)