Example #1
0
    def body(self, features):
        # Remove dropout if not training
        hparams = self._hparams
        ps_devices = self._ps_devices
        assert hparams.num_model_shards % len(ps_devices) == 0
        shards_per_device = hparams.num_model_shards // len(ps_devices)
        model_devices = [
            ps_devices[i // shards_per_device]
            for i in range(hparams.num_model_shards)
        ]
        print("model_devices = %s" % model_devices)
        mp = expert_utils.Parallelism(model_devices, reuse=False)
        vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size
        # squeeze out channels, heights
        targets = features["targets_raw"]
        targets = tf.squeeze(targets, 3)
        targets = tf.squeeze(targets, 2)
        shifted_targets = common_layers.shift_right_2d(targets)
        # Bypass the symbol modality and use a different embedding on each shard.
        decoder_input = mp(common_layers.embedding,
                           shifted_targets,
                           vocab_size,
                           hparams.hidden_size,
                           multiplier=hparams.hidden_size**0.5,
                           symbol_dropout_rate=hparams.symbol_dropout)
        decoder_self_attention_bias = mp(
            common_attention.attention_bias_lower_triangle,
            tf.shape(targets)[1])
        if "targets_segmentation" in features:
            # "Packed" dataset - keep the examples from seeing each other.
            targets_segmentation = features["targets_segmentation"]
            targets_position = features["targets_position"]
            decoder_self_attention_bias = mp(
                tf.add, decoder_self_attention_bias,
                mp(common_attention.attention_bias_same_segment,
                   targets_segmentation, targets_segmentation))
        else:
            targets_position = None

        if hparams.pos == "timing":
            if targets_position is None:
                decoder_input = mp(common_attention.add_timing_signal_1d,
                                   decoder_input)
            else:
                decoder_input = mp(
                    common_attention.add_timing_signal_1d_given_position,
                    decoder_input, targets_position)

        decoder_input = mp(tf.nn.dropout, decoder_input,
                           1.0 - hparams.layer_prepostprocess_dropout)
        decoder_output, extra_loss = _super_stack(decoder_input,
                                                  decoder_self_attention_bias,
                                                  hparams, mp)
        # Bypass the symbol modality and compute logits directly.
        # We compute a different set of logits on each shard, and sum them.
        logits = mp(tf.layers.dense, decoder_output, vocab_size, name="logits")
        logits = expert_utils.all_reduce_ring(logits, mp)
        logits = mp(tf.multiply, logits, mp.n**-0.5)
        # We now have identical logits on all shards.
        # Shard 0 gets returned to the estimator.
        logits_shard_0 = logits[0]
        logits_shard_0 = tf.expand_dims(logits_shard_0, 2)
        logits_shard_0 = tf.expand_dims(logits_shard_0, 3)
        # On each device, we compute the loss for a part of the batch.
        # This is faster than computing the whole loss on one shard.
        mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0])

        def _loss_for_shard(logits, targets, shard):
            if mp.n > 1:
                logits = common_layers.approximate_split(logits, mp.n,
                                                         0)[shard]
                targets = common_layers.approximate_split(targets, mp.n,
                                                          0)[shard]
            return common_layers.padded_cross_entropy(logits, targets,
                                                      hparams.label_smoothing)

        num, denom = mp(_loss_for_shard, logits, targets, range(mp.n))
        # override training loss so that it is not computed externally.
        losses = {"training": tf.add_n(num) / tf.add_n(denom)}
        if extra_loss is not None:
            losses["extra"] = extra_loss
        return logits_shard_0, losses
  def body(self, features):
    hparams = self._hparams
    ps_devices = self._ps_devices
    single_device = (len(ps_devices) == 1)
    assert hparams.num_model_shards % len(ps_devices) == 0
    shards_per_device = hparams.num_model_shards // len(ps_devices)
    model_devices = [ps_devices[i // shards_per_device]
                     for i in range(hparams.num_model_shards)]
    print("model_devices = %s" % model_devices)
    mp = expert_utils.Parallelism(model_devices, reuse=False)
    targets_vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size
    # squeeze out channels, heights
    targets = tf.squeeze(features["targets_raw"], [2, 3])
    targets_embedding_var = mp(
        tf.get_variable, "embedding",
        [[targets_vocab_size, hparams.hidden_size]] * mp.n,
        initializer=tf.random_normal_initializer(
            0.0, hparams.hidden_size**-0.5))
    shifted_targets = common_layers.shift_right_2d(targets)
    # Bypass the symbol modality and use a different embedding on each shard.
    if single_device:
      targets_embedding_var_combined = tf.concat(targets_embedding_var, 1)
      decoder_input_combined = common_layers.embedding(
          shifted_targets, targets_vocab_size,
          hparams.hidden_size * mp.n,
          multiplier=hparams.hidden_size**0.5,
          embedding_var=targets_embedding_var_combined,
      )
      decoder_input = tf.split(decoder_input_combined, mp.n, axis=2)
    else:
      targets_embedding_var_combined = None
      decoder_input = mp(
          common_layers.embedding, shifted_targets, targets_vocab_size,
          hparams.hidden_size,
          multiplier=hparams.hidden_size**0.5,
          embedding_var=targets_embedding_var,
      )
    decoder_self_attention_bias = mp(
        common_attention.attention_bias_lower_triangle,
        tf.shape(targets)[1])
    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = features["targets_segmentation"]
      targets_position = features["targets_position"]
      decoder_self_attention_bias = mp(
          tf.add, decoder_self_attention_bias,
          mp(common_attention.attention_bias_same_segment,
             targets_segmentation, targets_segmentation))
      decoder_input = mp(
          common_attention.add_timing_signal_1d_given_position,
          decoder_input, targets_position)
    else:
      targets_position = None
      decoder_self_attention_bias = mp(
          common_attention.attention_bias_lower_triangle,
          tf.shape(targets)[1])
      decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input)

    if self.has_input:
      inputs = tf.squeeze(features["inputs_raw"], [2, 3])
      inputs_vocab_size = self._problem_hparams.vocabulary["inputs"].vocab_size
      # share everything for now
      share_inputs_and_targets_embedding = True
      if share_inputs_and_targets_embedding:
        assert inputs_vocab_size == targets_vocab_size
        inputs_embedding_var = targets_embedding_var
        inputs_embedding_var_combined = targets_embedding_var_combined
      if single_device:
        encoder_input_combined = common_layers.embedding(
            inputs, inputs_vocab_size,
            hparams.hidden_size * mp.n,
            multiplier=hparams.hidden_size**0.5,
            embedding_var=inputs_embedding_var_combined,
        )
        encoder_input = tf.split(encoder_input_combined, mp.n, axis=2)
      else:
        encoder_input = mp(
            common_layers.embedding, inputs, inputs_vocab_size,
            hparams.hidden_size,
            multiplier=hparams.hidden_size**0.5,
            embedding_var=inputs_embedding_var,
        )
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = features["inputs_segmentation"]
        inputs_position = features["inputs_position"]
        encoder_self_attention_bias = mp(
            common_attention.attention_bias_same_segment,
            inputs_segmentation, inputs_segmentation)
        encoder_decoder_attention_bias = mp(
            common_attention.attention_bias_same_segment,
            targets_segmentation, inputs_segmentation)
        encoder_input = mp(
            common_attention.add_timing_signal_1d_given_position,
            encoder_input, inputs_position)
      else:
        encoder_padding = tf.to_float(tf.equal(inputs, 0))
        ignore_padding = common_attention.attention_bias_ignore_padding(
            encoder_padding)
        encoder_self_attention_bias = ignore_padding
        encoder_decoder_attention_bias = ignore_padding
        inputs_position = None
        encoder_input = mp(common_attention.add_timing_signal_1d, encoder_input)

      # encoder stack here
      with tf.variable_scope("encoder"):
        encoder_input = mp(
            tf.nn.dropout, encoder_input,
            1.0 - hparams.layer_prepostprocess_dropout)
        encoder_output = _layer_stack(
            mp,
            encoder_input,
            encoder_self_attention_bias,
            hparams.encoder_layers,
            hparams)
    else:
      encoder_decoder_attention_bias = None
      encoder_output = None

    with tf.variable_scope("decoder"):
      decoder_input = mp(
          tf.nn.dropout, decoder_input,
          1.0 - hparams.layer_prepostprocess_dropout)
      decoder_output = _layer_stack(
          mp,
          decoder_input,
          decoder_self_attention_bias,
          layers=hparams.decoder_layers,
          hparams=hparams,
          encoder_output=encoder_output,
          encoder_decoder_attention_bias=encoder_decoder_attention_bias)

    # Bypass the symbol modality and compute logits directly.
    # We compute a different set of logits on each shard, and sum them.
    # Share the weights with the target embedding.
    output_var = targets_embedding_var
    output_var_combined = targets_embedding_var_combined
    if single_device:
      decoder_output = tf.concat(decoder_output, 2)
      logits = tf.tensordot(decoder_output, output_var_combined, [[2], [1]])
      num, denom = common_layers.padded_cross_entropy(
          logits, targets, hparams.label_smoothing)
      training_loss = num / denom
    else:
      logits = mp(
          tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n)
      logits = expert_utils.all_reduce_ring(logits, mp)
      # On each device, we compute the loss for a part of the batch.
      # This is faster than computing the whole loss on one shard.
      mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0])
      def _loss_for_shard(logits, targets, shard):
        logits = common_layers.approximate_split(logits, mp.n, 0)[shard]
        targets = common_layers.approximate_split(targets, mp.n, 0)[shard]
        return common_layers.padded_cross_entropy(
            logits, targets, hparams.label_smoothing)
      num, denom = mp(_loss_for_shard, logits, targets, range(mp.n))
      training_loss = tf.add_n(num) / tf.add_n(denom)
      logits = logits[0]
    logits = tf.expand_dims(tf.expand_dims(logits, 2), 3)
    # override training loss so that it is not computed externally.
    losses = {"training": training_loss}
    return logits, losses
  def body(self, features):
    hparams = self._hparams
    ps_devices = self._ps_devices
    single_device = (len(ps_devices) == 1)
    assert hparams.num_model_shards % len(ps_devices) == 0
    shards_per_device = hparams.num_model_shards // len(ps_devices)
    model_devices = [ps_devices[i // shards_per_device]
                     for i in range(hparams.num_model_shards)]
    print("model_devices = %s" % model_devices)
    mp = expert_utils.Parallelism(model_devices, reuse=False)
    targets_vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size
    # squeeze out channels, heights
    targets = tf.squeeze(features["targets_raw"], [2, 3])
    targets_embedding_var = mp(
        tf.get_variable, "embedding",
        [[targets_vocab_size, hparams.hidden_size]] * mp.n,
        initializer=tf.random_normal_initializer(
            0.0, hparams.hidden_size**-0.5))
    shifted_targets = common_layers.shift_right_2d(targets)
    # Bypass the symbol modality and use a different embedding on each shard.
    if single_device:
      targets_embedding_var_combined = tf.concat(targets_embedding_var, 1)
      decoder_input_combined = common_layers.embedding(
          shifted_targets, targets_vocab_size,
          hparams.hidden_size * mp.n,
          multiplier=hparams.hidden_size**0.5,
          embedding_var=targets_embedding_var_combined,
      )
      decoder_input = tf.split(decoder_input_combined, mp.n, axis=2)
    else:
      targets_embedding_var_combined = None
      decoder_input = mp(
          common_layers.embedding, shifted_targets, targets_vocab_size,
          hparams.hidden_size,
          multiplier=hparams.hidden_size**0.5,
          embedding_var=targets_embedding_var,
      )
    decoder_self_attention_bias = mp(
        common_attention.attention_bias_lower_triangle,
        tf.shape(targets)[1])
    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = features["targets_segmentation"]
      targets_position = features["targets_position"]
      decoder_self_attention_bias = mp(
          tf.add, decoder_self_attention_bias,
          mp(common_attention.attention_bias_same_segment,
             targets_segmentation, targets_segmentation))
      decoder_input = mp(
          common_attention.add_timing_signal_1d_given_position,
          decoder_input, targets_position)
    else:
      targets_position = None
      decoder_self_attention_bias = mp(
          common_attention.attention_bias_lower_triangle,
          tf.shape(targets)[1])
      decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input)

    if self.has_input:
      inputs = tf.squeeze(features["inputs_raw"], [2, 3])
      inputs_vocab_size = self._problem_hparams.vocabulary["inputs"].vocab_size
      # share everything for now
      share_inputs_and_targets_embedding = True
      if share_inputs_and_targets_embedding:
        assert inputs_vocab_size == targets_vocab_size
        inputs_embedding_var = targets_embedding_var
        inputs_embedding_var_combined = targets_embedding_var_combined
      if single_device:
        encoder_input_combined = common_layers.embedding(
            inputs, inputs_vocab_size,
            hparams.hidden_size * mp.n,
            multiplier=hparams.hidden_size**0.5,
            embedding_var=inputs_embedding_var_combined,
        )
        encoder_input = tf.split(encoder_input_combined, mp.n, axis=2)
      else:
        encoder_input = mp(
            common_layers.embedding, inputs, inputs_vocab_size,
            hparams.hidden_size,
            multiplier=hparams.hidden_size**0.5,
            embedding_var=inputs_embedding_var,
        )
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = features["inputs_segmentation"]
        inputs_position = features["inputs_position"]
        encoder_self_attention_bias = mp(
            common_attention.attention_bias_same_segment,
            inputs_segmentation, inputs_segmentation)
        encoder_decoder_attention_bias = mp(
            common_attention.attention_bias_same_segment,
            targets_segmentation, inputs_segmentation)
        encoder_input = mp(
            common_attention.add_timing_signal_1d_given_position,
            encoder_input, inputs_position)
      else:
        encoder_padding = tf.to_float(tf.equal(inputs, 0))
        ignore_padding = common_attention.attention_bias_ignore_padding(
            encoder_padding)
        encoder_self_attention_bias = ignore_padding
        encoder_decoder_attention_bias = ignore_padding
        inputs_position = None
        encoder_input = mp(common_attention.add_timing_signal_1d, encoder_input)

      # encoder stack here
      with tf.variable_scope("encoder"):
        encoder_input = mp(
            tf.nn.dropout, encoder_input,
            1.0 - hparams.layer_prepostprocess_dropout)
        encoder_output = _layer_stack(
            mp,
            encoder_input,
            encoder_self_attention_bias,
            hparams.encoder_layers,
            hparams)
    else:
      encoder_decoder_attention_bias = None
      encoder_output = None

    with tf.variable_scope("decoder"):
      decoder_input = mp(
          tf.nn.dropout, decoder_input,
          1.0 - hparams.layer_prepostprocess_dropout)
      decoder_output = _layer_stack(
          mp,
          decoder_input,
          decoder_self_attention_bias,
          layers=hparams.decoder_layers,
          hparams=hparams,
          encoder_output=encoder_output,
          encoder_decoder_attention_bias=encoder_decoder_attention_bias)

    # Bypass the symbol modality and compute logits directly.
    # We compute a different set of logits on each shard, and sum them.
    # Share the weights with the target embedding.
    output_var = targets_embedding_var
    output_var_combined = targets_embedding_var_combined
    if single_device:
      decoder_output = tf.concat(decoder_output, 2)
      logits = tf.tensordot(decoder_output, output_var_combined, [[2], [1]])
      num, denom = common_layers.padded_cross_entropy(
          logits, targets, hparams.label_smoothing)
      training_loss = num / denom
    else:
      logits = mp(
          tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n)
      logits = expert_utils.all_reduce_ring(logits, mp)
      # On each device, we compute the loss for a part of the batch.
      # This is faster than computing the whole loss on one shard.
      mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0])
      def _loss_for_shard(logits, targets, shard):
        logits = common_layers.approximate_split(logits, mp.n, 0)[shard]
        targets = common_layers.approximate_split(targets, mp.n, 0)[shard]
        return common_layers.padded_cross_entropy(
            logits, targets, hparams.label_smoothing)
      num, denom = mp(_loss_for_shard, logits, targets, range(mp.n))
      training_loss = tf.add_n(num) / tf.add_n(denom)
      logits = logits[0]
    logits = tf.expand_dims(tf.expand_dims(logits, 2), 3)
    # override training loss so that it is not computed externally.
    losses = {"training": training_loss}
    return logits, losses
Example #4
0
  def body(self, features):
    # Remove dropout if not training
    hparams = self._hparams
    ps_devices = self._ps_devices
    assert hparams.num_model_shards % len(ps_devices) == 0
    shards_per_device = hparams.num_model_shards // len(ps_devices)
    model_devices = [ps_devices[i // shards_per_device]
                     for i in range(hparams.num_model_shards)]
    print("model_devices = %s" % model_devices)
    mp = expert_utils.Parallelism(model_devices, reuse=False)
    vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size
    # squeeze out channels, heights
    targets = features["targets_raw"]
    targets = tf.squeeze(targets, 3)
    targets = tf.squeeze(targets, 2)
    shifted_targets = common_layers.shift_right_2d(targets)
    # Bypass the symbol modality and use a different embedding on each shard.
    decoder_input = mp(
        common_layers.embedding, shifted_targets, vocab_size,
        hparams.hidden_size,
        multiplier=hparams.hidden_size**0.5,
        symbol_dropout_rate=hparams.symbol_dropout)
    decoder_self_attention_bias = mp(
        common_attention.attention_bias_lower_triangle,
        tf.shape(targets)[1])
    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = features["targets_segmentation"]
      targets_position = features["targets_position"]
      decoder_self_attention_bias = mp(
          tf.add, decoder_self_attention_bias,
          mp(common_attention.attention_bias_same_segment,
             targets_segmentation, targets_segmentation))
    else:
      targets_position = None

    if hparams.pos == "timing":
      if targets_position is None:
        decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input)
      else:
        decoder_input = mp(
            common_attention.add_timing_signal_1d_given_position,
            decoder_input, targets_position)

    decoder_input = mp(
        tf.nn.dropout, decoder_input,
        1.0 - hparams.layer_prepostprocess_dropout)
    decoder_output, extra_loss = _super_stack(
        decoder_input, decoder_self_attention_bias, hparams, mp)
    # Bypass the symbol modality and compute logits directly.
    # We compute a different set of logits on each shard, and sum them.
    logits = mp(tf.layers.dense, decoder_output, vocab_size, name="logits")
    logits = expert_utils.all_reduce_ring(logits, mp)
    logits = mp(tf.multiply, logits, mp.n ** -0.5)
    # We now have identical logits on all shards.
    # Shard 0 gets returned to the estimator.
    logits_shard_0 = logits[0]
    logits_shard_0 = tf.expand_dims(logits_shard_0, 2)
    logits_shard_0 = tf.expand_dims(logits_shard_0, 3)
    # On each device, we compute the loss for a part of the batch.
    # This is faster than computing the whole loss on one shard.
    mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0])
    def _loss_for_shard(logits, targets, shard):
      if mp.n > 1:
        logits = common_layers.approximate_split(logits, mp.n, 0)[shard]
        targets = common_layers.approximate_split(targets, mp.n, 0)[shard]
      return common_layers.padded_cross_entropy(
          logits, targets, hparams.label_smoothing)
    num, denom = mp(_loss_for_shard, logits, targets, range(mp.n))
    # override training loss so that it is not computed externally.
    losses = {"training": tf.add_n(num) / tf.add_n(denom)}
    if extra_loss is not None:
      losses["extra"] = extra_loss
    return logits_shard_0, losses