예제 #1
0
def _super_stack(inputs,
                 attention_bias,
                 hparams,
                 mp,
                 padding="LEFT"):
  """A stack of super_lm layers.

  Args:
    inputs: a list of Tensors
    attention_bias: list of bias Tensor for self-attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    mp: a Parallelism object
    padding: a string

  Returns:
    y: a list of Tensors
    extra_loss: an optional scalar
  """
  layers = hparams.layers.strip(",").split(",")
  moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
  if hparams.diet_experts:
    hsize, = moe_hidden_sizes
    def _diet_expert(x):
      return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params())
    expert_fn = _diet_expert
  else:
    expert_fn = expert_utils.ffn_expert_fn(
        hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)
  # scaled_dot_product_attention_with_projections uses a 3d attention bias
  # (no heads), where multihead_attention uses 4d attention bias.
  attention_bias_3d = mp(tf.squeeze, attention_bias, 1)
  mix_size = int(hparams.mix_fraction * hparams.hidden_size)
  accumulator = inputs
  x = inputs
  extra_losses = []
  for layer_num, layer_type in enumerate(layers):
    with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
      tf.logging.info("%s_%d" % (layer_type, layer_num))
      if layer_type == "a":
        # accumulate
        accumulator = mp(tf.add, x, accumulator)
        x = accumulator
      elif layer_type == "n":
        # normalize
        x = mp(common_layers.apply_norm,
               x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon)
      elif layer_type == "d":
        # dropout
        x = mp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
      elif layer_type == "m":
        # mix across shards
        def _split(t):
          return tuple(tf.split(
              t, [mix_size, hparams.hidden_size - mix_size], 2))
        to_mix, to_keep = mp(_split, x)
        mixed = common_layers.all_reduce_ring(to_mix, mp)
        mixed = mp(tf.multiply, mixed, mp.n ** -0.5)
        x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep)
      elif layer_type == "att":
        # single-head attention
        q = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False,
               name="q_transform")
        x = mp(
            common_attention.scaled_dot_product_attention_simple,
            q, x, x, attention_bias_3d)
        x = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False,
               name="o_transform")
      elif layer_type == "multihead-att":
        # multi-head attention
        x = mp(
            common_attention.multihead_attention,
            x,
            None,
            attention_bias,  # bias
            hparams.multihead_attention_key_channels or hparams.hidden_size,
            hparams.multihead_attention_value_channels or hparams.hidden_size,
            hparams.hidden_size,
            hparams.multihead_attention_num_heads,
            hparams.attention_dropout)
      elif layer_type == "ffn":
        x = mp(
            common_layers.dense_relu_dense, x,
            hparams.filter_size, hparams.hidden_size)
      elif layer_type == "conv":
        # convolution
        x = mp(
            common_layers.conv1d,
            x,
            hparams.hidden_size,
            hparams.kernel_height,
            activation=tf.nn.relu,
            padding=padding,
        )
      elif layer_type == "moe":
        # mixture of experts - each model shard has its own local MoE.
        x, loss = mp(
            expert_utils.local_moe,
            x,
            train=hparams.mode == tf.estimator.ModeKeys.TRAIN,
            expert_fn=expert_fn,
            num_experts=hparams.moe_num_experts,
            k=hparams.moe_k,
            loss_coef=hparams.moe_loss_coef)
        extra_losses.extend(loss)
      else:
        assert False, "unknown sublayer %s" % layer_type
  if extra_losses:
    extra_loss = tf.add_n(extra_losses)
  else:
    extra_loss = None
  return x, extra_loss
예제 #2
0
  def body(self, features):
    # Remove dropout if not training
    hparams = self._hparams
    ps_devices = self._ps_devices
    assert hparams.num_model_shards % len(ps_devices) == 0
    shards_per_device = hparams.num_model_shards // len(ps_devices)
    model_devices = [ps_devices[i // shards_per_device]
                     for i in xrange(hparams.num_model_shards)]
    print("model_devices = %s" % model_devices)
    mp = expert_utils.Parallelism(model_devices, reuse=False)
    vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size
    # squeeze out channels, heights
    targets = features["targets_raw"]
    targets = tf.squeeze(targets, 3)
    targets = tf.squeeze(targets, 2)
    shifted_targets = common_layers.shift_right_2d(targets)
    # Bypass the symbol modality and use a different embedding on each shard.
    decoder_input = mp(
        common_layers.embedding, shifted_targets, vocab_size,
        hparams.hidden_size,
        multiplier=hparams.hidden_size**0.5,
        symbol_dropout_rate=hparams.symbol_dropout)
    decoder_self_attention_bias = mp(
        common_attention.attention_bias_lower_triangle,
        tf.shape(targets)[1])
    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = features["targets_segmentation"]
      targets_position = features["targets_position"]
      decoder_self_attention_bias = mp(
          tf.add, decoder_self_attention_bias,
          mp(common_attention.attention_bias_same_segment,
             targets_segmentation, targets_segmentation))
    else:
      targets_position = None

    if hparams.pos == "timing":
      if targets_position is None:
        decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input)
      else:
        decoder_input = mp(
            common_attention.add_timing_signal_1d_given_position,
            decoder_input, targets_position)

    decoder_input = mp(
        tf.nn.dropout, decoder_input,
        1.0 - hparams.layer_prepostprocess_dropout)
    decoder_output, extra_loss = _super_stack(
        decoder_input, decoder_self_attention_bias, hparams, mp)
    # Bypass the symbol modality and compute logits directly.
    # We compute a different set of logits on each shard, and sum them.
    logits = mp(tf.layers.dense, decoder_output, vocab_size, name="logits")
    logits = common_layers.all_reduce_ring(logits, mp)
    logits = mp(tf.multiply, logits, mp.n ** -0.5)
    # We now have identical logits on all shards.
    # Shard 0 gets returned to the estimator.
    logits_shard_0 = logits[0]
    logits_shard_0 = tf.expand_dims(logits_shard_0, 2)
    logits_shard_0 = tf.expand_dims(logits_shard_0, 3)
    # On each device, we compute the loss for a part of the batch.
    # This is faster than computing the whole loss on one shard.
    mp, logits = common_layers.reduce_by_device(mp, logits, lambda l: l[0])
    def _loss_for_shard(logits, targets, shard):
      if mp.n > 1:
        logits = common_layers.approximate_split(logits, mp.n, 0)[shard]
        targets = common_layers.approximate_split(targets, mp.n, 0)[shard]
      return common_layers.padded_cross_entropy(
          logits, targets, hparams.label_smoothing)
    num, denom = mp(_loss_for_shard, logits, targets, range(mp.n))
    # override training loss so that it is not computed externally.
    losses = {"training": tf.add_n(num) / tf.add_n(denom)}
    if extra_loss is not None:
      losses["extra"] = extra_loss
    return logits_shard_0, losses
예제 #3
0
def _layer_stack(mp,
                 inputs,
                 self_attention_bias,
                 layers,
                 hparams,
                 encoder_output=None,
                 encoder_decoder_attention_bias=None):
    """A stack of layers.

  Args:
    mp: a Parallelism object
    inputs: a list of Tensors
    self_attention_bias: list of bias Tensor for self-attention
      (see common_attention.attention_bias())
    layers: a string
    hparams: hyperparameters for model
    encoder_output: optional list of tensors
    encoder_decoder_attention_bias: optional list of tensors

  Returns:
    y: a list of Tensors
  """
    layers = layers.strip(",").split(",")

    # scaled_dot_product_attention_with_projections uses a 3d attention bias
    # (no heads), where multihead_attention uses 4d attention bias.
    self_attention_bias_3d = mp(tf.squeeze, self_attention_bias, 1)
    if encoder_decoder_attention_bias is not None:
        encoder_decoder_attention_bias_3d = mp(tf.squeeze,
                                               encoder_decoder_attention_bias,
                                               1)
    relu_dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(hparams, "relu_dropout_broadcast_dims", "")))
    mix_size = int(hparams.mix_fraction * hparams.hidden_size)
    accumulator = inputs
    x = inputs
    for layer_num, layer_type in enumerate(layers):
        with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
            tf.logging.info("%s_%d" % (layer_type, layer_num))
            if layer_type == "a":
                # accumulate
                accumulator = mp(tf.add, x, accumulator)
                x = accumulator
            elif layer_type == "n":
                # normalize
                x = mp(common_layers.apply_norm, x, hparams.norm_type,
                       hparams.hidden_size, hparams.norm_epsilon)
            elif layer_type == "d":
                # dropout
                x = mp(tf.nn.dropout, x,
                       1.0 - hparams.layer_prepostprocess_dropout)
            elif layer_type == "m":
                if mix_size > 0:
                    # mix across shards
                    def _split(t):
                        return tuple(
                            tf.split(
                                t, [mix_size, hparams.hidden_size - mix_size],
                                2))

                    to_mix, to_keep = mp(_split, x)
                    mixed = common_layers.all_reduce_ring(to_mix, mp)
                    mixed = mp(tf.multiply, mixed, mp.n**-0.5)
                    x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep)
            elif layer_type == "att":
                # single-head attention
                q = mp(tf.layers.dense,
                       x,
                       hparams.hidden_size,
                       use_bias=False,
                       name="q_transform")
                x = mp(common_attention.scaled_dot_product_attention_simple, q,
                       x, x, self_attention_bias_3d)
                x = mp(tf.layers.dense,
                       x,
                       hparams.hidden_size,
                       use_bias=False,
                       name="o_transform")
            elif layer_type == "enc-att":
                # single-head attention over encoder
                q = mp(tf.layers.dense,
                       x,
                       hparams.hidden_size,
                       use_bias=False,
                       name="q_transform")
                assert encoder_output is not None
                x = mp(common_attention.scaled_dot_product_attention_simple, q,
                       encoder_output, encoder_output,
                       encoder_decoder_attention_bias_3d)
                x = mp(tf.layers.dense,
                       x,
                       hparams.hidden_size,
                       use_bias=False,
                       name="o_transform")
            elif layer_type == "multihead-att":
                # multi-head attention
                x = mp(
                    common_attention.multihead_attention,
                    x,
                    None,
                    self_attention_bias,  # bias
                    hparams.multihead_attention_key_channels
                    or hparams.hidden_size,
                    hparams.multihead_attention_value_channels
                    or hparams.hidden_size,
                    hparams.hidden_size,
                    hparams.multihead_attention_num_heads,
                    hparams.attention_dropout)
            elif layer_type == "enc-multihead-att":
                # multi-head attention
                x = mp(
                    common_attention.multihead_attention,
                    x,
                    encoder_output,
                    encoder_decoder_attention_bias,  # bias
                    hparams.multihead_attention_key_channels
                    or hparams.hidden_size,
                    hparams.multihead_attention_value_channels
                    or hparams.hidden_size,
                    hparams.hidden_size,
                    hparams.multihead_attention_num_heads,
                    hparams.attention_dropout)
            elif layer_type == "ffn":
                x = mp(common_layers.dense_relu_dense,
                       x,
                       hparams.filter_size,
                       hparams.hidden_size,
                       dropout=hparams.relu_dropout,
                       dropout_broadcast_dims=[relu_dropout_broadcast_dims] *
                       mp.n)
            else:
                assert False, "unknown sublayer %s" % layer_type
    return x
예제 #4
0
    def body(self, features):
        hparams = self._hparams
        ps_devices = self._ps_devices
        single_device = (len(ps_devices) == 1)
        assert hparams.num_model_shards % len(ps_devices) == 0
        shards_per_device = hparams.num_model_shards // len(ps_devices)
        model_devices = [
            ps_devices[i // shards_per_device]
            for i in range(hparams.num_model_shards)
        ]
        print("model_devices = %s" % model_devices)
        mp = expert_utils.Parallelism(model_devices, reuse=False)
        targets_vocab_size = self._problem_hparams.vocabulary[
            "targets"].vocab_size
        # squeeze out channels, heights
        targets = tf.squeeze(features["targets_raw"], [2, 3])
        targets_embedding_var = mp(
            tf.get_variable,
            "embedding", [[targets_vocab_size, hparams.hidden_size]] * mp.n,
            initializer=tf.random_normal_initializer(
                0.0, hparams.hidden_size**-0.5))
        shifted_targets = common_layers.shift_right_2d(targets)
        # Bypass the symbol modality and use a different embedding on each shard.
        if single_device:
            targets_embedding_var_combined = tf.concat(targets_embedding_var,
                                                       1)
            decoder_input_combined = common_layers.embedding(
                shifted_targets,
                targets_vocab_size,
                hparams.hidden_size * mp.n,
                multiplier=hparams.hidden_size**0.5,
                embedding_var=targets_embedding_var_combined,
            )
            decoder_input = tf.split(decoder_input_combined, mp.n, axis=2)
        else:
            targets_embedding_var_combined = None
            decoder_input = mp(
                common_layers.embedding,
                shifted_targets,
                targets_vocab_size,
                hparams.hidden_size,
                multiplier=hparams.hidden_size**0.5,
                embedding_var=targets_embedding_var,
            )
        decoder_self_attention_bias = mp(
            common_attention.attention_bias_lower_triangle,
            tf.shape(targets)[1])
        if "targets_segmentation" in features:
            # "Packed" dataset - keep the examples from seeing each other.
            targets_segmentation = features["targets_segmentation"]
            targets_position = features["targets_position"]
            decoder_self_attention_bias = mp(
                tf.add, decoder_self_attention_bias,
                mp(common_attention.attention_bias_same_segment,
                   targets_segmentation, targets_segmentation))
            decoder_input = mp(
                common_attention.add_timing_signal_1d_given_position,
                decoder_input, targets_position)
        else:
            targets_position = None
            decoder_self_attention_bias = mp(
                common_attention.attention_bias_lower_triangle,
                tf.shape(targets)[1])
            decoder_input = mp(common_attention.add_timing_signal_1d,
                               decoder_input)

        if self.has_input:
            inputs = tf.squeeze(features["inputs_raw"], [2, 3])
            inputs_vocab_size = self._problem_hparams.vocabulary[
                "inputs"].vocab_size
            # share everything for now
            share_inputs_and_targets_embedding = True
            if share_inputs_and_targets_embedding:
                assert inputs_vocab_size == targets_vocab_size
                inputs_embedding_var = targets_embedding_var
                inputs_embedding_var_combined = targets_embedding_var_combined
            if single_device:
                encoder_input_combined = common_layers.embedding(
                    inputs,
                    inputs_vocab_size,
                    hparams.hidden_size * mp.n,
                    multiplier=hparams.hidden_size**0.5,
                    embedding_var=inputs_embedding_var_combined,
                )
                encoder_input = tf.split(encoder_input_combined, mp.n, axis=2)
            else:
                encoder_input = mp(
                    common_layers.embedding,
                    inputs,
                    inputs_vocab_size,
                    hparams.hidden_size,
                    multiplier=hparams.hidden_size**0.5,
                    embedding_var=inputs_embedding_var,
                )
            if "inputs_segmentation" in features:
                # "Packed" dataset - keep the examples from seeing each other.
                inputs_segmentation = features["inputs_segmentation"]
                inputs_position = features["inputs_position"]
                encoder_self_attention_bias = mp(
                    common_attention.attention_bias_same_segment,
                    inputs_segmentation, inputs_segmentation)
                encoder_decoder_attention_bias = mp(
                    common_attention.attention_bias_same_segment,
                    targets_segmentation, inputs_segmentation)
                encoder_input = mp(
                    common_attention.add_timing_signal_1d_given_position,
                    encoder_input, inputs_position)
            else:
                encoder_padding = tf.to_float(tf.equal(inputs, 0))
                ignore_padding = common_attention.attention_bias_ignore_padding(
                    encoder_padding)
                encoder_self_attention_bias = ignore_padding
                encoder_decoder_attention_bias = ignore_padding
                inputs_position = None
                encoder_input = mp(common_attention.add_timing_signal_1d,
                                   encoder_input)

            # encoder stack here
            with tf.variable_scope("encoder"):
                encoder_input = mp(tf.nn.dropout, encoder_input,
                                   1.0 - hparams.layer_prepostprocess_dropout)
                encoder_output = _layer_stack(mp, encoder_input,
                                              encoder_self_attention_bias,
                                              hparams.encoder_layers, hparams)
        else:
            encoder_decoder_attention_bias = None
            encoder_output = None

        with tf.variable_scope("decoder"):
            decoder_input = mp(tf.nn.dropout, decoder_input,
                               1.0 - hparams.layer_prepostprocess_dropout)
            decoder_output = _layer_stack(
                mp,
                decoder_input,
                decoder_self_attention_bias,
                layers=hparams.decoder_layers,
                hparams=hparams,
                encoder_output=encoder_output,
                encoder_decoder_attention_bias=encoder_decoder_attention_bias)

        # Bypass the symbol modality and compute logits directly.
        # We compute a different set of logits on each shard, and sum them.
        # Share the weights with the target embedding.
        output_var = targets_embedding_var
        output_var_combined = targets_embedding_var_combined
        if single_device:
            decoder_output = tf.concat(decoder_output, 2)
            logits = tf.tensordot(decoder_output, output_var_combined,
                                  [[2], [1]])
            num, denom = common_layers.padded_cross_entropy(
                logits, targets, hparams.label_smoothing)
            training_loss = num / denom
        else:
            logits = mp(tf.tensordot, decoder_output, output_var,
                        [[[2], [1]]] * mp.n)
            logits = common_layers.all_reduce_ring(logits, mp)
            # On each device, we compute the loss for a part of the batch.
            # This is faster than computing the whole loss on one shard.
            mp, logits = common_layers.reduce_by_device(
                mp, logits, lambda l: l[0])

            def _loss_for_shard(logits, targets, shard):
                logits = common_layers.approximate_split(logits, mp.n,
                                                         0)[shard]
                targets = common_layers.approximate_split(targets, mp.n,
                                                          0)[shard]
                return common_layers.padded_cross_entropy(
                    logits, targets, hparams.label_smoothing)

            num, denom = mp(_loss_for_shard, logits, targets, range(mp.n))
            training_loss = tf.add_n(num) / tf.add_n(denom)
            logits = logits[0]
        logits = tf.expand_dims(tf.expand_dims(logits, 2), 3)
        # override training loss so that it is not computed externally.
        losses = {"training": training_loss}
        return logits, losses
예제 #5
0
  def body(self, features):
    # Remove dropout if not training
    hparams = self._hparams
    ps_devices = self._ps_devices
    assert hparams.num_model_shards % len(ps_devices) == 0
    shards_per_device = hparams.num_model_shards // len(ps_devices)
    model_devices = [ps_devices[i // shards_per_device]
                     for i in xrange(hparams.num_model_shards)]
    print("model_devices = %s" % model_devices)
    mp = expert_utils.Parallelism(model_devices, reuse=False)
    vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size
    # squeeze out channels, heights
    targets = features["targets_raw"]
    targets = tf.squeeze(targets, 3)
    targets = tf.squeeze(targets, 2)
    shifted_targets = common_layers.shift_right_2d(targets)
    # Bypass the symbol modality and use a different embedding on each shard.
    decoder_input = mp(
        common_layers.embedding, shifted_targets, vocab_size,
        hparams.hidden_size,
        multiplier=hparams.hidden_size**0.5,
        symbol_dropout_rate=hparams.symbol_dropout)
    decoder_self_attention_bias = mp(
        common_attention.attention_bias_lower_triangle,
        tf.shape(targets)[1])
    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = features["targets_segmentation"]
      targets_position = features["targets_position"]
      decoder_self_attention_bias = mp(
          tf.add, decoder_self_attention_bias,
          mp(common_attention.attention_bias_same_segment,
             targets_segmentation, targets_segmentation))
    else:
      targets_position = None

    if hparams.pos == "timing":
      if targets_position is None:
        decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input)
      else:
        decoder_input = mp(
            common_attention.add_timing_signal_1d_given_position,
            decoder_input, targets_position)

    decoder_input = mp(
        tf.nn.dropout, decoder_input,
        1.0 - hparams.layer_prepostprocess_dropout)
    decoder_output, extra_loss = _super_stack(
        decoder_input, decoder_self_attention_bias, hparams, mp)
    # Bypass the symbol modality and compute logits directly.
    # We compute a different set of logits on each shard, and sum them.
    logits = mp(tf.layers.dense, decoder_output, vocab_size, name="logits")
    logits = common_layers.all_reduce_ring(logits, mp)
    logits = mp(tf.multiply, logits, mp.n ** -0.5)
    # We now have identical logits on all shards.
    # Shard 0 gets returned to the estimator.
    logits_shard_0 = logits[0]
    logits_shard_0 = tf.expand_dims(logits_shard_0, 2)
    logits_shard_0 = tf.expand_dims(logits_shard_0, 3)
    # On each device, we compute the loss for a part of the batch.
    # This is faster than computing the whole loss on one shard.
    mp, logits = common_layers.reduce_by_device(mp, logits, lambda l: l[0])
    def _loss_for_shard(logits, targets, shard):
      if mp.n > 1:
        logits = common_layers.approximate_split(logits, mp.n, 0)[shard]
        targets = common_layers.approximate_split(targets, mp.n, 0)[shard]
      return common_layers.padded_cross_entropy(
          logits, targets, hparams.label_smoothing)
    num, denom = mp(_loss_for_shard, logits, targets, range(mp.n))
    # override training loss so that it is not computed externally.
    losses = {"training": tf.add_n(num) / tf.add_n(denom)}
    if extra_loss is not None:
      losses["extra"] = extra_loss
    return logits_shard_0, losses
예제 #6
0
def _super_stack(inputs,
                 attention_bias,
                 hparams,
                 mp,
                 padding="LEFT"):
  """A stack of super_lm layers.

  Args:
    inputs: a list of Tensors
    attention_bias: list of bias Tensor for self-attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    mp: a Parallelism object
    padding: a string

  Returns:
    y: a list of Tensors
    extra_loss: an optional scalar
  """
  layers = hparams.layers.strip(",").split(",")
  moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
  if hparams.diet_experts:
    hsize, = moe_hidden_sizes
    def _diet_expert(x):
      return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params())
    expert_fn = _diet_expert
  else:
    expert_fn = expert_utils.ffn_expert_fn(
        hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)
  # scaled_dot_product_attention_with_projections uses a 3d attention bias
  # (no heads), where multihead_attention uses 4d attention bias.
  attention_bias_3d = mp(tf.squeeze, attention_bias, 1)
  mix_size = int(hparams.mix_fraction * hparams.hidden_size)
  accumulator = inputs
  x = inputs
  extra_losses = []
  for layer_num, layer_type in enumerate(layers):
    with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
      tf.logging.info("%s_%d" % (layer_type, layer_num))
      if layer_type == "a":
        # accumulate
        accumulator = mp(tf.add, x, accumulator)
        x = accumulator
      elif layer_type == "n":
        # normalize
        x = mp(common_layers.apply_norm,
               x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon)
      elif layer_type == "d":
        # dropout
        x = mp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
      elif layer_type == "m":
        # mix across shards
        def _split(t):
          return tuple(tf.split(
              t, [mix_size, hparams.hidden_size - mix_size], 2))
        to_mix, to_keep = mp(_split, x)
        mixed = common_layers.all_reduce_ring(to_mix, mp)
        mixed = mp(tf.multiply, mixed, mp.n ** -0.5)
        x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep)
      elif layer_type == "att":
        # single-head attention
        q = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False,
               name="q_transform")
        x = mp(
            common_attention.scaled_dot_product_attention_simple,
            q, x, x, attention_bias_3d)
        x = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False,
               name="o_transform")
      elif layer_type == "multihead-att":
        # multi-head attention
        x = mp(
            common_attention.multihead_attention,
            x,
            None,
            attention_bias,  # bias
            hparams.multihead_attention_key_channels or hparams.hidden_size,
            hparams.multihead_attention_value_channels or hparams.hidden_size,
            hparams.hidden_size,
            hparams.multihead_attention_num_heads,
            hparams.attention_dropout)
      elif layer_type == "ffn":
        x = mp(
            common_layers.dense_relu_dense, x,
            hparams.filter_size, hparams.hidden_size)
      elif layer_type == "conv":
        # convolution
        x = mp(
            common_layers.conv1d,
            x,
            hparams.hidden_size,
            hparams.kernel_height,
            activation=tf.nn.relu,
            padding=padding,
        )
      elif layer_type == "moe":
        # mixture of experts - each model shard has its own local MoE.
        x, loss = mp(
            expert_utils.local_moe,
            x,
            train=hparams.mode == tf.estimator.ModeKeys.TRAIN,
            expert_fn=expert_fn,
            num_experts=hparams.moe_num_experts,
            k=hparams.moe_k,
            loss_coef=hparams.moe_loss_coef)
        extra_losses.extend(loss)
      else:
        assert False, "unknown sublayer %s" % layer_type
  if extra_losses:
    extra_loss = tf.add_n(extra_losses)
  else:
    extra_loss = None
  return x, extra_loss
예제 #7
0
def _super_stack(inputs, attention_bias, hparams, mp, padding="LEFT"):
    """A stack of super_lm layers.

  Args:
    inputs: a list of Tensors
    attention_bias: list of bias Tensor for self-attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    mp: a Parallelism object
    padding: a string

  Returns:
    y: a Tensors
  """
    layers = hparams.layers.strip(",").split(",")
    ffn_hidden_sizes = [int(s) for s in hparams.ffn_hidden_sizes.split(",")]
    # scaled_dot_product_attention_with_projections uses a 3d attention bias
    # (no heads), where multihead_attention uses 4d attention bias.
    mix_size = int(hparams.mix_fraction * hparams.hidden_size)
    attention_bias_3d = mp(tf.squeeze, attention_bias, 1)
    accumulator = inputs
    x = inputs
    for layer_num, layer_type in enumerate(layers):
        with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
            tf.logging.info("%s_%d" % (layer_type, layer_num))
            if layer_type == "a":
                # accumulate
                accumulator = mp(tf.add, x, accumulator)
                x = accumulator
            elif layer_type == "n":
                # normalize
                x = mp(common_layers.apply_norm, x, hparams.norm_type,
                       hparams.hidden_size, hparams.norm_epsilon)
            elif layer_type == "d":
                # dropout
                x = mp(tf.nn.dropout, x,
                       1.0 - hparams.layer_prepostprocess_dropout)
            elif layer_type == "m":
                # mix across shards
                def _split(t):
                    return tuple(
                        tf.split(t, [mix_size, hparams.hidden_size - mix_size],
                                 2))

                to_mix, to_keep = mp(_split, x)
                mixed = common_layers.all_reduce_ring(to_mix, mp)
                mixed = mp(tf.multiply, mixed, mp.n**-0.5)
                x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep)
            elif layer_type == "att":
                # single-head attention
                q = mp(tf.layers.dense,
                       x,
                       hparams.hidden_size,
                       use_bias=False,
                       name="q_transform")
                x = mp(common_attention.scaled_dot_product_attention_simple, q,
                       x, x, attention_bias_3d)
                x = mp(tf.layers.dense,
                       x,
                       hparams.hidden_size,
                       use_bias=False,
                       name="o_transform")
            elif layer_type == "multihead-att":
                # multi-head attention
                x = mp(
                    common_attention.multihead_attention,
                    x,
                    None,
                    attention_bias,  # bias
                    hparams.attention_key_channels or hparams.hidden_size,
                    hparams.attention_value_channels or hparams.hidden_size,
                    hparams.hidden_size,
                    hparams.num_heads,
                    hparams.attention_dropout)
            elif layer_type == "ffn":
                y = mp(
                    expert_utils.ffn_expert_fn(hparams.hidden_size,
                                               ffn_hidden_sizes,
                                               hparams.hidden_size),
                    mp(expert_utils.flatten_all_but_last, x))
                x = mp(expert_utils.reshape_like, y, x)
            elif layer_type == "conv":
                # convolution
                x = mp(
                    common_layers.conv1d,
                    x,
                    hparams.hidden_size,
                    hparams.kernel_height,
                    activation=tf.nn.relu,
                    padding=padding,
                )
            else:
                assert False, "unknown sublayer %s" % layer_type
    return x
  def body(self, features):
    hparams = self._hparams
    ps_devices = self._ps_devices
    single_device = (len(ps_devices) == 1)
    assert hparams.num_model_shards % len(ps_devices) == 0
    shards_per_device = hparams.num_model_shards // len(ps_devices)
    model_devices = [ps_devices[i // shards_per_device]
                     for i in xrange(hparams.num_model_shards)]
    print("model_devices = %s" % model_devices)
    mp = expert_utils.Parallelism(model_devices, reuse=False)
    targets_vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size
    # squeeze out channels, heights
    targets = tf.squeeze(features["targets_raw"], [2, 3])
    targets_embedding_var = mp(
        tf.get_variable, "embedding",
        [[targets_vocab_size, hparams.hidden_size]] * mp.n,
        initializer=tf.random_normal_initializer(
            0.0, hparams.hidden_size**-0.5))
    shifted_targets = common_layers.shift_right_2d(targets)
    # Bypass the symbol modality and use a different embedding on each shard.
    if single_device:
      targets_embedding_var_combined = tf.concat(targets_embedding_var, 1)
      decoder_input_combined = common_layers.embedding(
          shifted_targets, targets_vocab_size,
          hparams.hidden_size * mp.n,
          multiplier=hparams.hidden_size**0.5,
          embedding_var=targets_embedding_var_combined,
      )
      decoder_input = tf.split(decoder_input_combined, mp.n, axis=2)
    else:
      targets_embedding_var_combined = None
      decoder_input = mp(
          common_layers.embedding, shifted_targets, targets_vocab_size,
          hparams.hidden_size,
          multiplier=hparams.hidden_size**0.5,
          embedding_var=targets_embedding_var,
      )
    decoder_self_attention_bias = mp(
        common_attention.attention_bias_lower_triangle,
        tf.shape(targets)[1])
    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = features["targets_segmentation"]
      targets_position = features["targets_position"]
      decoder_self_attention_bias = mp(
          tf.add, decoder_self_attention_bias,
          mp(common_attention.attention_bias_same_segment,
             targets_segmentation, targets_segmentation))
      decoder_input = mp(
          common_attention.add_timing_signal_1d_given_position,
          decoder_input, targets_position)
    else:
      targets_position = None
      decoder_self_attention_bias = mp(
          common_attention.attention_bias_lower_triangle,
          tf.shape(targets)[1])
      decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input)

    if self.has_input:
      inputs = tf.squeeze(features["inputs_raw"], [2, 3])
      inputs_vocab_size = self._problem_hparams.vocabulary["inputs"].vocab_size
      # share everything for now
      share_inputs_and_targets_embedding = True
      if share_inputs_and_targets_embedding:
        assert inputs_vocab_size == targets_vocab_size
        inputs_embedding_var = targets_embedding_var
        inputs_embedding_var_combined = targets_embedding_var_combined
      if single_device:
        encoder_input_combined = common_layers.embedding(
            inputs, inputs_vocab_size,
            hparams.hidden_size * mp.n,
            multiplier=hparams.hidden_size**0.5,
            embedding_var=inputs_embedding_var_combined,
        )
        encoder_input = tf.split(encoder_input_combined, mp.n, axis=2)
      else:
        encoder_input = mp(
            common_layers.embedding, inputs, inputs_vocab_size,
            hparams.hidden_size,
            multiplier=hparams.hidden_size**0.5,
            embedding_var=inputs_embedding_var,
        )
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = features["inputs_segmentation"]
        inputs_position = features["inputs_position"]
        encoder_self_attention_bias = mp(
            common_attention.attention_bias_same_segment,
            inputs_segmentation, inputs_segmentation)
        encoder_decoder_attention_bias = mp(
            common_attention.attention_bias_same_segment,
            targets_segmentation, inputs_segmentation)
        encoder_input = mp(
            common_attention.add_timing_signal_1d_given_position,
            encoder_input, inputs_position)
      else:
        encoder_padding = tf.to_float(tf.equal(inputs, 0))
        ignore_padding = common_attention.attention_bias_ignore_padding(
            encoder_padding)
        encoder_self_attention_bias = ignore_padding
        encoder_decoder_attention_bias = ignore_padding
        inputs_position = None
        encoder_input = mp(common_attention.add_timing_signal_1d, encoder_input)

      # encoder stack here
      with tf.variable_scope("encoder"):
        encoder_input = mp(
            tf.nn.dropout, encoder_input,
            1.0 - hparams.layer_prepostprocess_dropout)
        encoder_output = _layer_stack(
            mp,
            encoder_input,
            encoder_self_attention_bias,
            hparams.encoder_layers,
            hparams)
    else:
      encoder_decoder_attention_bias = None
      encoder_output = None

    with tf.variable_scope("decoder"):
      decoder_input = mp(
          tf.nn.dropout, decoder_input,
          1.0 - hparams.layer_prepostprocess_dropout)
      decoder_output = _layer_stack(
          mp,
          decoder_input,
          decoder_self_attention_bias,
          layers=hparams.decoder_layers,
          hparams=hparams,
          encoder_output=encoder_output,
          encoder_decoder_attention_bias=encoder_decoder_attention_bias)

    # Bypass the symbol modality and compute logits directly.
    # We compute a different set of logits on each shard, and sum them.
    # Share the weights with the target embedding.
    output_var = targets_embedding_var
    output_var_combined = targets_embedding_var_combined
    if single_device:
      decoder_output = tf.concat(decoder_output, 2)
      logits = tf.tensordot(decoder_output, output_var_combined, [[2], [1]])
      num, denom = common_layers.padded_cross_entropy(
          logits, targets, hparams.label_smoothing)
      training_loss = num / denom
    else:
      logits = mp(
          tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n)
      logits = common_layers.all_reduce_ring(logits, mp)
      # On each device, we compute the loss for a part of the batch.
      # This is faster than computing the whole loss on one shard.
      mp, logits = common_layers.reduce_by_device(mp, logits, lambda l: l[0])
      def _loss_for_shard(logits, targets, shard):
        logits = common_layers.approximate_split(logits, mp.n, 0)[shard]
        targets = common_layers.approximate_split(targets, mp.n, 0)[shard]
        return common_layers.padded_cross_entropy(
            logits, targets, hparams.label_smoothing)
      num, denom = mp(_loss_for_shard, logits, targets, range(mp.n))
      training_loss = tf.add_n(num) / tf.add_n(denom)
      logits = logits[0]
    logits = tf.expand_dims(tf.expand_dims(logits, 2), 3)
    # override training loss so that it is not computed externally.
    losses = {"training": training_loss}
    return logits, losses
def _layer_stack(mp,
                 inputs,
                 self_attention_bias,
                 layers,
                 hparams,
                 encoder_output=None,
                 encoder_decoder_attention_bias=None):
  """A stack of layers.

  Args:
    mp: a Parallelism object
    inputs: a list of Tensors
    self_attention_bias: list of bias Tensor for self-attention
      (see common_attention.attention_bias())
    layers: a string
    hparams: hyperparameters for model
    encoder_output: optional list of tensors
    encoder_decoder_attention_bias: optional list of tensors

  Returns:
    y: a list of Tensors
  """
  layers = layers.strip(",").split(",")

  # scaled_dot_product_attention_with_projections uses a 3d attention bias
  # (no heads), where multihead_attention uses 4d attention bias.
  self_attention_bias_3d = mp(tf.squeeze, self_attention_bias, 1)
  if encoder_decoder_attention_bias is not None:
    encoder_decoder_attention_bias_3d = mp(
        tf.squeeze, encoder_decoder_attention_bias, 1)
  relu_dropout_broadcast_dims = (
      common_layers.comma_separated_string_to_integer_list(
          getattr(hparams, "relu_dropout_broadcast_dims", "")))
  mix_size = int(hparams.mix_fraction * hparams.hidden_size)
  accumulator = inputs
  x = inputs
  for layer_num, layer_type in enumerate(layers):
    with tf.variable_scope("%s_%d" % (layer_type, layer_num)):
      tf.logging.info("%s_%d" % (layer_type, layer_num))
      if layer_type == "a":
        # accumulate
        accumulator = mp(tf.add, x, accumulator)
        x = accumulator
      elif layer_type == "n":
        # normalize
        x = mp(common_layers.apply_norm,
               x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon)
      elif layer_type == "d":
        # dropout
        x = mp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
      elif layer_type == "m":
        if mix_size > 0:
          # mix across shards
          def _split(t):
            return tuple(tf.split(
                t, [mix_size, hparams.hidden_size - mix_size], 2))
          to_mix, to_keep = mp(_split, x)
          mixed = common_layers.all_reduce_ring(to_mix, mp)
          mixed = mp(tf.multiply, mixed, mp.n ** -0.5)
          x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep)
      elif layer_type == "att":
        # single-head attention
        q = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False,
               name="q_transform")
        x = mp(
            common_attention.scaled_dot_product_attention_simple,
            q, x, x, self_attention_bias_3d)
        x = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False,
               name="o_transform")
      elif layer_type == "enc-att":
        # single-head attention over encoder
        q = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False,
               name="q_transform")
        assert encoder_output is not None
        x = mp(
            common_attention.scaled_dot_product_attention_simple,
            q, encoder_output, encoder_output,
            encoder_decoder_attention_bias_3d)
        x = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False,
               name="o_transform")
      elif layer_type == "multihead-att":
        # multi-head attention
        x = mp(
            common_attention.multihead_attention,
            x,
            None,
            self_attention_bias,  # bias
            hparams.multihead_attention_key_channels or hparams.hidden_size,
            hparams.multihead_attention_value_channels or hparams.hidden_size,
            hparams.hidden_size,
            hparams.multihead_attention_num_heads,
            hparams.attention_dropout)
      elif layer_type == "enc-multihead-att":
        # multi-head attention
        x = mp(
            common_attention.multihead_attention,
            x,
            encoder_output,
            encoder_decoder_attention_bias,  # bias
            hparams.multihead_attention_key_channels or hparams.hidden_size,
            hparams.multihead_attention_value_channels or hparams.hidden_size,
            hparams.hidden_size,
            hparams.multihead_attention_num_heads,
            hparams.attention_dropout)
      elif layer_type == "ffn":
        x = mp(
            common_layers.dense_relu_dense, x,
            hparams.filter_size, hparams.hidden_size,
            dropout=hparams.relu_dropout,
            dropout_broadcast_dims=[relu_dropout_broadcast_dims] * mp.n)
      else:
        assert False, "unknown sublayer %s" % layer_type
  return x