Esempio n. 1
0
 def logits_fn(step_num, ids, states):
   """Produce logits for this step, and new states."""
   ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
   x = (mtf.gather(targets_embedding_var, ids_this_step,
                   self.targets_vocab_dim) +
        mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
   with tf.variable_scope("decoder"):
     x, new_states = self._layer_stack(
         x,
         hparams.decoder_layers,
         encdec_attention_mask=encoder_attention_mask,
         step_num=step_num,
         encdec_tensors=encdec_tensors,
         states=states)
   logits = mtf.matmul(x, softmax_var)
   return logits, new_states
Esempio n. 2
0
 def logits_fn(step_num, ids, states):
   """Produce logits for this step, and new states."""
   ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
   x = (mtf.gather(targets_embedding_var, ids_this_step,
                   self.targets_vocab_dim) +
        mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
   with tf.variable_scope("decoder"):
     x, new_states = self._layer_stack(
         x,
         hparams.decoder_layers,
         encdec_attention_mask=encoder_attention_mask,
         step_num=step_num,
         encdec_tensors=encdec_tensors,
         states=states)
   logits = mtf.matmul(x, softmax_var)
   return logits, new_states
Esempio n. 3
0
 def logits_fn(step_num, ids, states):
   """Produce logits for this step, and new states."""
   self_attention_k = states[:hparams.num_decoder_layers]
   self_attention_v = states[hparams.num_decoder_layers:]
   ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
   x = (mtf.gather(targets_embedding_var, ids_this_step,
                   self.targets_vocab_dim) +
        mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
   with tf.variable_scope("decoder"):
     x, new_self_attention_k, new_self_attention_v = (
         self._decoder_layer_stack_incremental(
             x,
             step_num,
             encdec_tensors,
             self_attention_k,
             self_attention_v,
             encdec_attention_mask=encoder_attention_mask))
   logits = mtf.matmul(x, softmax_var)
   return logits, new_self_attention_k + new_self_attention_v
Esempio n. 4
0
    def _mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        hparams = self._hparams
        targets = tf.to_int32(features["targets"])
        if len(targets.get_shape()) > 2:
            tf.logging.info("targets = %s" % targets)
            targets = tf.squeeze(targets, [2, 3])
        # pad targets to max_length
        def pad_to_max_length(x):
            extra_length = hparams.max_length - tf.shape(x)[1]
            x = tf.pad(x, [[0, 0], [0, extra_length]])
            x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
            return x

        targets = pad_to_max_length(targets)
        for key in [
                "targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"
        ]:
            if key in features:
                features[key] = pad_to_max_length(features[key])
        shifted_targets = common_layers.shift_right_2d(targets)

        targets = self._import_to_batch_by_length(targets, "targets", mesh,
                                                  hparams)
        shifted_targets = self._import_to_batch_by_length(
            shifted_targets, "shifted_targets", mesh, hparams)

        if "targets_segmentation" in features:
            # "Packed" dataset - keep the examples from seeing each other.
            targets_segmentation = self._import_to_batch_by_length(
                features["targets_segmentation"], "targets_segmentation", mesh,
                hparams)
            targets_position = self._import_to_batch_by_length(
                features["targets_position"], "targets_position", mesh,
                hparams)
            decoder_self_attention_mask = (
                mtf.layers.attention_mask_autoregressive(
                    targets_position, dtype=self.activation_dtype) +
                mtf.layers.attention_mask_same_segment(
                    targets_segmentation, dtype=self.activation_dtype))
        else:
            targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
            decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
                targets_position, dtype=self.activation_dtype)

        def layer_prepostprocess_dropout(x):
            return mtf.dropout(
                x,
                keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
                noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

        extra_losses = []
        (inputs_embedding_var, targets_embedding_var, softmax_var,
         positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
        if hparams.transformer_type == "decoder":
            encoder_output = None
            encoder_decoder_attention_mask = None
        else:
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = pad_to_max_length(inputs)
            inputs = self._import_to_batch_by_length(inputs, "inputs", mesh,
                                                     hparams)
            if "inputs_segmentation" in features:
                # "Packed" dataset - keep the examples from seeing each other.
                inputs_segmentation = self._import_to_batch_by_length(
                    features["inputs_segmentation"], "inputs_segmentation",
                    mesh, hparams)
                inputs_position = self._import_to_batch_by_length(
                    features["inputs_position"], "inputs_position", mesh,
                    hparams)
                encoder_self_attention_mask = (
                    mtf.layers.attention_mask_same_segment(
                        inputs_segmentation, dtype=self.activation_dtype))
            else:
                inputs_position = mtf.range(mesh,
                                            self.length_dim,
                                            dtype=tf.int32)
                encoder_self_attention_mask = (
                    mtf.layers.attention_mask_ignore_padding(
                        inputs, dtype=self.activation_dtype))

            x = (mtf.gather(inputs_embedding_var, inputs,
                            self.inputs_vocab_dim) +
                 mtf.gather(positional_embedding_var, inputs_position,
                            self.max_length_dim))
            x = layer_prepostprocess_dropout(x)
            with tf.variable_scope("encoder"):
                x = self._layer_stack(
                    x,
                    hparams.encoder_layers,
                    self_attention_mask=encoder_self_attention_mask,
                    losses=extra_losses)

        if hparams.transformer_type == "encdec":
            if "inputs_segmentation" in features:
                encoder_decoder_attention_mask = (
                    mtf.layers.attention_mask_same_segment(
                        targets_segmentation,
                        inputs_segmentation,
                        dtype=self.activation_dtype))
            else:
                encoder_decoder_attention_mask = encoder_self_attention_mask
            encoder_output = mtf.rename_dimension(x, self.length_dim.name,
                                                  self.memory_length_dim.name)

        if hparams.transformer_type != "encoder":
            # DECODER
            x = (mtf.gather(targets_embedding_var, shifted_targets,
                            self.targets_vocab_dim) +
                 mtf.gather(positional_embedding_var, targets_position,
                            self.max_length_dim))
            x = layer_prepostprocess_dropout(x)
            with tf.variable_scope("decoder"):
                x = self._layer_stack(
                    x,
                    hparams.decoder_layers,
                    encoder_output=encoder_output,
                    self_attention_mask=decoder_self_attention_mask,
                    encdec_attention_mask=encoder_decoder_attention_mask,
                    losses=extra_losses)
        logits = mtf.matmul(x, softmax_var)
        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
        off_value = hparams.label_smoothing / self._targets_vocab_size
        on_value = 1.0 - hparams.label_smoothing + off_value
        soft_targets = mtf.one_hot(targets,
                                   self.targets_vocab_dim,
                                   on_value=on_value,
                                   off_value=off_value,
                                   dtype=self.activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, self.targets_vocab_dim)
        weights = mtf.layers.weights_nonzero(targets,
                                             dtype=self.activation_dtype)
        loss = mtf.reduce_mean(loss * weights)
        for l in extra_losses:
            loss += l
        logits = mtf.to_float(logits)
        # combine batch dims
        if len(self.batch_dims) > 1:
            combined_batch_dim = mtf.Dimension(self.batch_dims[0].name,
                                               mtf.Shape(self.batch_dims).size)
            logits = mtf.reshape(logits,
                                 [combined_batch_dim] + logits.shape.dims[-2:])
        return logits, loss
Esempio n. 5
0
  def _mtf_model_fn(self, features, mesh):
    self._original_features = features
    features = copy.copy(features)
    hparams = self._hparams
    extra_losses = []
    targets = tf.to_int32(features["targets"])
    mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
    is_training = mode == tf.estimator.ModeKeys.TRAIN
    if len(targets.get_shape()) > 2:
      tf.logging.info("targets = %s" % targets)
      targets = tf.squeeze(targets, [2, 3])
    # pad targets to max_length
    def pad_to_max_length(x):
      extra_length = hparams.max_length - tf.shape(x)[1]
      x = tf.pad(x, [[0, 0], [0, extra_length]])
      x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
      return x
    targets = pad_to_max_length(targets)
    targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams)
    for key in ["targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"]:
      if key in features:
        features[key] = pad_to_max_length(features[key])
    if hparams.decoder_type == "autoregressive":
      shifted_targets = mtf.shift(
          targets, offset=1, dim=self.length_dim, wrap=False)
    elif hparams.decoder_type == "denoising":
      shifted_targets = self._noisy_targets(targets, extra_losses)
    else:
      raise ValueError(
          "unknown hparams.decoder_type = %s" % hparams.decoder_type)

    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = self._import_to_batch_by_length(
          features["targets_segmentation"], "targets_segmentation",
          mesh, hparams)
      targets_position = self._import_to_batch_by_length(
          features["targets_position"], "targets_position",
          mesh, hparams)
      decoder_self_attention_mask = mtf.layers.attention_mask_same_segment(
          targets_segmentation, dtype=self.activation_dtype)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
    else:
      targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
      else:
        decoder_self_attention_mask = None

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if hparams.transformer_type == "decoder":
      encoder_output = None
      encoder_decoder_attention_mask = None
    else:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = pad_to_max_length(inputs)
      inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams)
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = self._import_to_batch_by_length(
            features["inputs_segmentation"], "inputs_segmentation",
            mesh, hparams)
        inputs_position = self._import_to_batch_by_length(
            features["inputs_position"], "inputs_position",
            mesh, hparams)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                inputs_segmentation, dtype=self.activation_dtype))
      else:
        inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))

      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.gather(positional_embedding_var, inputs_position,
                      self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.encoder_layers,
                              self_attention_mask=encoder_self_attention_mask,
                              losses=extra_losses)

    if hparams.transformer_type == "encdec":
      if "inputs_segmentation" in features:
        encoder_decoder_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                targets_segmentation, inputs_segmentation,
                dtype=self.activation_dtype))
      else:
        encoder_decoder_attention_mask = encoder_self_attention_mask
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)

    if hparams.transformer_type != "encoder":
      # DECODER
      x = (mtf.gather(
          targets_embedding_var, shifted_targets, self.targets_vocab_dim) +
           mtf.gather(
               positional_embedding_var, targets_position, self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("decoder"):
        x = self._layer_stack(
            x,
            hparams.decoder_layers,
            encoder_output=encoder_output,
            self_attention_mask=decoder_self_attention_mask,
            encdec_attention_mask=encoder_decoder_attention_mask,
            losses=extra_losses)
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      # For some reason, the logits computation is extremely slow on TPU
      # in some cases where the batch size per core is 1.  Reshape the logits
      # and the targets to double the batch size and halve the length.
      # TODO(noam): file a bug.
      old_dims = self.batch_dims + [self.length_dim]
      new_dims = self.batch_dims[:-1] + [
          mtf.Dimension(self.batch_dims[-1].name,
                        self.batch_dims[-1].size * 2),
          mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)]
      x = mtf.reshape(x, new_dims + [self.model_dim])
      targets = mtf.reshape(targets, new_dims)

    logits = mtf.matmul(x, softmax_var)
    if hparams.mode == tf.estimator.ModeKeys.TRAIN:
      logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
    off_value = hparams.label_smoothing / self._targets_vocab_size
    on_value = 1.0 - hparams.label_smoothing + off_value
    soft_targets = mtf.one_hot(
        targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value,
        dtype=self.activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.targets_vocab_dim)
    weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype)
    loss = mtf.reduce_mean(loss * weights)
    for l in extra_losses:
      loss += l
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim])
    logits = mtf.to_float(logits)
    return logits, loss
Esempio n. 6
0
  def _mtf_model_fn(self, features, mesh):
    self._original_features = features
    features = copy.copy(features)
    hparams = self._hparams
    extra_losses = []
    targets = tf.to_int32(features["targets"])
    if len(targets.get_shape()) > 2:
      tf.logging.info("targets = %s" % targets)
      targets = tf.squeeze(targets, [2, 3])
    # pad targets to max_length
    def pad_to_max_length(x):
      extra_length = hparams.max_length - tf.shape(x)[1]
      x = tf.pad(x, [[0, 0], [0, extra_length]])
      x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
      return x
    targets = pad_to_max_length(targets)
    targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams)
    for key in ["targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"]:
      if key in features:
        features[key] = pad_to_max_length(features[key])
    if hparams.decoder_type == "autoregressive":
      shifted_targets = mtf.shift(
          targets, offset=1, dim=self.length_dim, wrap=False)
    elif hparams.decoder_type == "denoising":
      shifted_targets = self._noisy_targets(targets, extra_losses)
    else:
      raise ValueError(
          "unknown hparams.decoder_type = %s" % hparams.decoder_type)

    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = self._import_to_batch_by_length(
          features["targets_segmentation"], "targets_segmentation",
          mesh, hparams)
      targets_position = self._import_to_batch_by_length(
          features["targets_position"], "targets_position",
          mesh, hparams)
      decoder_self_attention_mask = mtf.layers.attention_mask_same_segment(
          targets_segmentation, dtype=self.activation_dtype)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
    else:
      targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
      else:
        decoder_self_attention_mask = None

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if hparams.transformer_type == "decoder":
      encoder_output = None
      encoder_decoder_attention_mask = None
    else:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = pad_to_max_length(inputs)
      inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams)
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = self._import_to_batch_by_length(
            features["inputs_segmentation"], "inputs_segmentation",
            mesh, hparams)
        inputs_position = self._import_to_batch_by_length(
            features["inputs_position"], "inputs_position",
            mesh, hparams)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                inputs_segmentation, dtype=self.activation_dtype))
      else:
        inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))

      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.gather(positional_embedding_var, inputs_position,
                      self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.encoder_layers,
                              self_attention_mask=encoder_self_attention_mask,
                              losses=extra_losses)

    if hparams.transformer_type == "encdec":
      if "inputs_segmentation" in features:
        encoder_decoder_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                targets_segmentation, inputs_segmentation,
                dtype=self.activation_dtype))
      else:
        encoder_decoder_attention_mask = encoder_self_attention_mask
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)

    if hparams.transformer_type != "encoder":
      # DECODER
      x = (mtf.gather(
          targets_embedding_var, shifted_targets, self.targets_vocab_dim) +
           mtf.gather(
               positional_embedding_var, targets_position, self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("decoder"):
        x = self._layer_stack(
            x,
            hparams.decoder_layers,
            encoder_output=encoder_output,
            self_attention_mask=decoder_self_attention_mask,
            encdec_attention_mask=encoder_decoder_attention_mask,
            losses=extra_losses)
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      # For some reason, the logits computation is extremely slow on TPU
      # in some cases where the batch size per core is 1.  Reshape the logits
      # and the targets to double the batch size and halve the length.
      # TODO(noam): file a bug.
      old_dims = self.batch_dims + [self.length_dim]
      new_dims = self.batch_dims[:-1] + [
          mtf.Dimension(self.batch_dims[-1].name,
                        self.batch_dims[-1].size * 2),
          mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)]
      x = mtf.reshape(x, new_dims + [self.model_dim])
      targets = mtf.reshape(targets, new_dims)

    logits = mtf.matmul(x, softmax_var)
    if hparams.mode == tf.estimator.ModeKeys.TRAIN:
      logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
    off_value = hparams.label_smoothing / self._targets_vocab_size
    on_value = 1.0 - hparams.label_smoothing + off_value
    soft_targets = mtf.one_hot(
        targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value,
        dtype=self.activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.targets_vocab_dim)
    weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype)
    loss = mtf.reduce_mean(loss * weights)
    for l in extra_losses:
      loss += l
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim])
    logits = mtf.to_float(logits)
    return logits, loss