Example #1
0
    def _mtf_model_fn(self, features, mesh):
        self._original_features = features
        hparams = self._hparams

        def import_feature(key):
            return self._import_feature(features, mesh, key)

        targets = import_feature("targets")
        if self.autoregressive:
            inputs = mtf.shift(targets,
                               offset=1,
                               dim=self.length_dim,
                               wrap=False)
        else:
            inputs = import_feature("inputs")
            # TODO(noam): options for bert-style masking here?
        sequence_id = import_feature("targets_segmentation")
        model = self.model()
        logits, loss = model.call_simple(inputs=inputs,
                                         targets=targets,
                                         compute_loss=True,
                                         mode=hparams.mode,
                                         variable_dtype=self.variable_dtype,
                                         sequence_id=sequence_id)
        return logits, loss
Example #2
0
    def _mtf_model_fn(self, features, mesh):
        self._original_features = features
        hparams = self._hparams

        def import_feature(key):
            return self._import_feature(features, mesh, key)

        targets = import_feature("targets")
        sequence_id = import_feature("targets_segmentation")
        position = import_feature("targets_position")
        if self.autoregressive:
            inputs = mtf.shift(targets,
                               offset=1,
                               dim=self.length_dim,
                               wrap=False)
            if position is not None:
                # first input in later sequences should be 0
                inputs *= mtf.to_int32(mtf.not_equal(position, 0))
        else:
            inputs = import_feature("inputs")
            # TODO(noam): options for bert-style masking here?
        model = self.model()
        logits, loss = model.call_simple(inputs=inputs,
                                         targets=targets,
                                         compute_loss=True,
                                         mode=hparams.mode,
                                         variable_dtype=self.variable_dtype,
                                         sequence_id=sequence_id,
                                         position=position)
        return logits, loss
Example #3
0
    def _mtf_model_fn(self, features, mesh):
        self._original_features = features
        hparams = self._hparams

        def import_feature(key):
            return self._import_feature(features, mesh, key)

        targets = import_feature("targets")
        sequence_id = import_feature("targets_segmentation")
        if hparams.use_global_position_in_packed_sequence:
            position = None
        else:
            position = import_feature("targets_position")
        if self.autoregressive:
            inputs = mtf.shift(targets,
                               offset=1,
                               dim=self.length_dim,
                               wrap=False)
            # We should have a 0 at the beginning of each sequence rather than the
            # shifted EOS (1) from the previous sequence.
            inputs -= mtf.to_int32(mtf.equal(inputs, 1))
        else:
            inputs = import_feature("inputs")
            # TODO(noam): options for bert-style masking here?
        model = self.model()
        logits, loss = model.call_simple(inputs=inputs,
                                         targets=targets,
                                         compute_loss=True,
                                         mode=hparams.mode,
                                         variable_dtype=self.variable_dtype,
                                         sequence_id=sequence_id,
                                         position=position)
        return logits, loss
Example #4
0
    def call_simple(self,
                    inputs,
                    targets,
                    compute_loss,
                    mode=tf.estimator.ModeKeys.TRAIN,
                    variable_dtype=mtf.VariableDType(tf.float32),
                    encoder_sequence_id=None,
                    decoder_sequence_id=None):
        """Compute logits based on inputs (all positions in parallel).

    This is called during training and evaluation.

    Args:
      inputs: an int32 Tensor with shape [<batch_dims>, length_dim]
      targets: an optional int32 Tensor with shape [<batch_dims>, length_dim]
      compute_loss: a boolean
      mode: a tf.estimator.ModeKeys
      variable_dtype: a mtf.VariableDType
      encoder_sequence_id: an optional Tensor
      decoder_sequence_id: an optional Tensor

    Returns:
      logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
      loss: an optional Scalar (if compute_loss=True)
    """
        shared_params = self._shared_params(inputs.mesh, variable_dtype)
        encoder_output, encoder_loss = self.encoder.call_simple(
            inputs,
            None,
            compute_loss,
            mode=mode,
            variable_dtype=variable_dtype,
            sequence_id=encoder_sequence_id,
            shared_params=shared_params)
        encoder_output = mtf.layers.rename_length_to_memory_length(
            encoder_output)
        if encoder_sequence_id is not None:
            encoder_sequence_id = mtf.layers.rename_length_to_memory_length(
                encoder_sequence_id)
        length_dim = targets.shape.dims[-1]
        shifted_targets = mtf.shift(targets,
                                    offset=1,
                                    dim=length_dim,
                                    wrap=False)
        logits, loss = self.decoder.call_simple(
            shifted_targets,
            targets,
            compute_loss,
            mode=mode,
            variable_dtype=variable_dtype,
            sequence_id=decoder_sequence_id,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            shared_params=shared_params)
        if loss is not None and encoder_loss is not None:
            loss += encoder_loss
        return logits, loss
Example #5
0
def halo_reduce(x, blocks_dim, block_size_dim, halo_size, wrap=True):
    """Reduce each block with the margins of adjacent blocks.

  Get left and right blocks_dim and sum overlap along block_size_dim.
  Only supports halo size smaller than block_size/2

  Args:
    x: a Tensor.
    blocks_dim: a Dimension in x.shape
    block_size_dim: a Dimension in x.shape
    halo_size: an integer
    wrap: a boolean

  Returns:
    a Tensor with the same shape as x, other than in block_size_dim, whose
    size is increased by 2*halo_size.
  """
    if halo_size == 0:
        return x
    block_size = block_size_dim.size
    assert halo_size <= block_size // 2

    left_margin = mtf.slice(x, 0, 2 * halo_size, block_size_dim.name)
    right_margin = mtf.slice(x, block_size_dim.size - 2 * halo_size,
                             2 * halo_size, block_size_dim.name)
    center = mtf.slice(x, 2 * halo_size, block_size_dim.size - 4 * halo_size,
                       block_size_dim.name)

    # Perform halo exchange sum margins
    left = mtf.shift(right_margin, 1, blocks_dim, wrap) + left_margin
    right = mtf.shift(left_margin, -1, blocks_dim, wrap) + right_margin

    # Recompose block
    left = mtf.pad(left, [0, block_size_dim.size - 2 * halo_size],
                   block_size_dim.name)
    right = mtf.pad(right, [block_size_dim.size - 2 * halo_size, 0],
                    block_size_dim.name)
    center = mtf.pad(center, [2 * halo_size, 2 * halo_size],
                     block_size_dim.name)
    x = left + center + right
    return x
Example #6
0
            def model_fn(mtf_features):
                """The kind of function we need for mtf.serialize_training_step.

        Args:
          mtf_features: a dictionary
        Returns:
          a dictionary
        """
                targets = mtf_features["targets"]
                if model_type == "lm":
                    _, _, length_dim = targets.shape
                    inputs = mtf.shift(targets,
                                       offset=1,
                                       dim=length_dim,
                                       wrap=False)
                else:
                    inputs = mtf_features["inputs"]

                if isinstance(transformer_model, transformer.Unitransformer):
                    position_kwargs = dict(
                        sequence_id=mtf_features.get("targets_segmentation",
                                                     None),
                        position=mtf_features.get("targets_position", None),
                    )
                elif isinstance(transformer_model, transformer.Bitransformer
                                ) or model_type == "bi_student_teacher":
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "targets_segmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "targets_position", None),
                    )
                else:
                    raise ValueError("unrecognized class")

                logits, loss = transformer_model.call_simple(
                    inputs=inputs,
                    targets=targets,
                    compute_loss=True,
                    mode=mode,
                    variable_dtype=get_variable_dtype(),
                    **position_kwargs)
                if num_microbatches > 1:
                    loss /= float(num_microbatches)
                del logits
                return {"loss": loss}
Example #7
0
def causal_depthwise_conv(x, context, kernel_size=3):
    """Causal depthwise convolution."""
    def scale_var(shift_distance):
        return mtf.get_variable(
            context.mesh,
            "conv_%s" % shift_distance,
            mtf.Shape(context.model.ensemble_dims + x.shape.dims[-1:]),
            initializer=tf.constant_initializer(0.5 if shift_distance ==
                                                0 else 0.5 / kernel_size),
            dtype=context.variable_dtype)

    ret = x * scale_var(0)
    for shift_distance in range(1, kernel_size):
        x = mtf.shift(x, 1, context.length_dim, wrap=False)
        ret += x * scale_var(shift_distance)
    return ret
    def _mtf_model_fn(self, features, mesh):
        self._original_features = features
        features = copy.copy(features)
        hparams = self._hparams
        targets = tf.to_int32(features["targets"])
        if len(targets.get_shape()) > 2:
            tf.logging.info("targets = %s" % targets)
            targets = tf.squeeze(targets, [2, 3])
        # pad targets to max_length
        def pad_to_length(x):
            extra_length = self.length_dim.size - tf.shape(x)[1]
            x = tf.pad(x, [[0, 0], [0, extra_length]])
            x = tf.reshape(x, [hparams.batch_size, self.length_dim.size])
            return x

        targets = pad_to_length(targets)
        targets = self._import_to_batch_by_length(targets, "targets", mesh)
        for key in [
                "targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"
        ]:
            if key in features:
                features[key] = pad_to_length(features[key])
        if hparams.decoder_type == "autoregressive":
            shifted_targets = mtf.shift(targets,
                                        offset=1,
                                        dim=self.length_dim,
                                        wrap=False)
        else:
            raise ValueError("unknown hparams.decoder_type = %s" %
                             hparams.decoder_type)
        model = self.model()
        logits, loss = model.call_simple(inputs=shifted_targets,
                                         targets=targets,
                                         compute_loss=True,
                                         mode=hparams.mode,
                                         variable_dtype=self.variable_dtype)
        # mesh_shape=hparams.mesh_shape,
        # layout=hparams.layout,
        return logits, loss
Example #9
0
    def body_fn(position, ids, *states):
        """One step in the decode loop."""
        nonlocal sampling_keep_top_k

        context = mtf_transformer.transformer.Context(
            model=None,
            mesh=inputs.mesh,
            batch_dims=batch_dims,
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="incremental",
            position=position,
            position_is_default=True,
            states=states,
            new_states=[],
            initial_position=position,
            sequence_id=None,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=ids,
            encoder_inputs=encoder_inputs) if not slow_sampling else None

        with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE):
            logits, _, _ = gpt2.model({"inputs": ids},
                                      other_features,
                                      params,
                                      inputs.mesh,
                                      variable_dtype=variable_dtype,
                                      context=context)

        # By default, do top_k sampling of 0.9
        if sampling_keep_top_k == -2:
            sampling_keep_top_k = int(logits.shape[-1].size * 0.1)

        if sampling_keep_top_k != -1:
            if sampling_keep_top_k <= 0:
                raise ValueError(
                    "sampling_keep_top_k must either be -1 or positive.")
            k_largest = mtf.nth_largest_element(
                logits,
                n=sampling_keep_top_k,
                reduced_dim=other_features["vocab_dim"])
            logits = mtf.where(mtf.less_equal(logits, k_largest),
                               mtf.ones_like(logits) * -1e6, logits)

        ids_this_step = mtf.sample_with_temperature(
            logits, other_features["vocab_dim"], temperature)

        if slow_sampling:
            ids_this_step = mtf.shift(ids_this_step,
                                      offset=1,
                                      dim=length_dim,
                                      wrap=False)
        else:
            ids_this_step = mtf.reshape(ids_this_step, (batch_dims))

        one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32)
        one_new_id = ids_this_step * one_hot
        new_ids = (1 - one_hot) * ids + one_new_id
        new_position = position + 1

        ret = [new_position, new_ids]
        if context is not None:
            ret += context.new_states
        return ret
 def call_simple(self,
                 inputs,
                 targets,
                 compute_loss,
                 attributes=None,
                 mode=tf.estimator.ModeKeys.TRAIN,
                 variable_dtype=mtf.VariableDType(tf.float32),
                 sequence_id=None,
                 subsequence_id=None,
                 position=None,
                 encoder_output=None,
                 encoder_sequence_id=None,
                 encoder_inputs=None,
                 shared_params=None,
                 layer_outputs=None,
                 encoder_layer_outputs=None,
                 z=None):
     """Compute logits based on inputs (all positions in parallel).
     This is called during training and evaluation.
     Args:
       inputs: an int32 Tensor with shape [<batch_dims>, length_dim] For training
         autoregressive models this should be equal to mtf.shift(targets,
         offset=1, dim=length_dim, wrap=False)
       targets: an optional int32 Tensor with shape [<batch_dims>, length_dim]
       compute_loss: a boolean
       attributes: an (optional?) int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>])
       mode: a tf.estimator.ModeKeys
       variable_dtype: a mtf.VariableDType
       sequence_id: an optional Tensor
       subsequence_id: an optional Tensor
       position: an optional Tensor
       encoder_output: an optional Tensor
       encoder_sequence_id: an optional Tensor
       encoder_inputs: an optional Tensor
       shared_params: an optional dictionary
       layer_outputs: an optional list to append Tensor layer activations to
       encoder_layer_outputs: optional - readonly list of tensor activations when
         decoding, one per each input layer + the embedding layer
     Returns:
       logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
       loss: an optional Scalar (if compute_loss=True)
     """
     batch_dims = inputs.shape.dims[:-1]
     length_dim = inputs.shape.dims[-1]
     length_range = mtf.range(inputs.mesh, length_dim, dtype=tf.int32)
     if not self.positional_embedding:
         # To make relative attention faster, we drop the information about the
         #   position in the subsequence.  The relative attention code then
         #   assumes that the positions are given by index in the tensor,
         #   which still leads to the correct computation of relative position.
         position = None
     if position is None:
         position_is_default = True
         position = length_range
     else:
         position_is_default = False
     if self.input_full_attention:
         # The inputs part of each sequence can fully attend within itself.
         full_attention_region = delimited_lm_inputs_mask(targets)
         # We can include one additional position to the right - the position
         #   where the final EOS of the inputs is read and the first target token
         #   is predicted.
         full_attention_region = mtf.logical_or(
             full_attention_region,
             mtf.shift(full_attention_region,
                       offset=1,
                       dim=length_dim,
                       wrap=False))
         # We set read_priority and write_priority to 0 in the full-attention
         #   region and equal to the position elsewhere.
         read_priority = write_priority = length_range * mtf.cast(
             mtf.logical_not(full_attention_region), tf.int32)
     elif self.autoregressive:
         # Vanilla autoregressive model - each position can see previous positions.
         read_priority = write_priority = length_range
     else:
         read_priority = write_priority = None
     context = Context(model=self,
                       mesh=inputs.mesh,
                       batch_dims=batch_dims,
                       length_dim=length_dim,
                       variable_dtype=variable_dtype,
                       mode=mode,
                       losses=[] if compute_loss else None,
                       sequence_id=sequence_id,
                       subsequence_id=subsequence_id,
                       position=position,
                       position_is_default=position_is_default,
                       encoder_output=encoder_output,
                       encoder_sequence_id=encoder_sequence_id,
                       shared_params=shared_params,
                       layer_outputs=layer_outputs,
                       encoder_layer_outputs=encoder_layer_outputs,
                       write_priority=write_priority,
                       read_priority=read_priority,
                       inputs=inputs,
                       encoder_inputs=encoder_inputs)
     with tf.variable_scope(self.name):
         logits = self._call_internal(context,
                                      inputs,
                                      targets,
                                      attributes,
                                      z=z)
     if compute_loss:
         loss = mtf.add_n(context.losses)
     else:
         loss = None
     return logits, loss
  def _mtf_model_fn(self, features, mesh):
    self._original_features = features
    features = copy.copy(features)
    hparams = self._hparams
    extra_losses = []
    targets = tf.to_int32(features["targets"])
    mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
    is_training = mode == tf.estimator.ModeKeys.TRAIN
    if len(targets.get_shape()) > 2:
      tf.logging.info("targets = %s" % targets)
      targets = tf.squeeze(targets, [2, 3])
    # pad targets to max_length
    def pad_to_max_length(x):
      extra_length = hparams.max_length - tf.shape(x)[1]
      x = tf.pad(x, [[0, 0], [0, extra_length]])
      x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
      return x
    targets = pad_to_max_length(targets)
    targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams)
    for key in ["targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"]:
      if key in features:
        features[key] = pad_to_max_length(features[key])
    if hparams.decoder_type == "autoregressive":
      shifted_targets = mtf.shift(
          targets, offset=1, dim=self.length_dim, wrap=False)
    elif hparams.decoder_type == "denoising":
      shifted_targets = self._noisy_targets(targets, extra_losses)
    else:
      raise ValueError(
          "unknown hparams.decoder_type = %s" % hparams.decoder_type)

    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = self._import_to_batch_by_length(
          features["targets_segmentation"], "targets_segmentation",
          mesh, hparams)
      targets_position = self._import_to_batch_by_length(
          features["targets_position"], "targets_position",
          mesh, hparams)
      decoder_self_attention_mask = mtf.layers.attention_mask_same_segment(
          targets_segmentation, dtype=self.activation_dtype)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
    else:
      targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
      else:
        decoder_self_attention_mask = None

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if hparams.transformer_type == "decoder":
      encoder_output = None
      encoder_decoder_attention_mask = None
    else:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = pad_to_max_length(inputs)
      inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams)
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = self._import_to_batch_by_length(
            features["inputs_segmentation"], "inputs_segmentation",
            mesh, hparams)
        inputs_position = self._import_to_batch_by_length(
            features["inputs_position"], "inputs_position",
            mesh, hparams)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                inputs_segmentation, dtype=self.activation_dtype))
      else:
        inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))

      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.gather(positional_embedding_var, inputs_position,
                      self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.encoder_layers,
                              self_attention_mask=encoder_self_attention_mask,
                              losses=extra_losses)

    if hparams.transformer_type == "encdec":
      if "inputs_segmentation" in features:
        encoder_decoder_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                targets_segmentation, inputs_segmentation,
                dtype=self.activation_dtype))
      else:
        encoder_decoder_attention_mask = encoder_self_attention_mask
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)

    if hparams.transformer_type != "encoder":
      # DECODER
      x = (mtf.gather(
          targets_embedding_var, shifted_targets, self.targets_vocab_dim) +
           mtf.gather(
               positional_embedding_var, targets_position, self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("decoder"):
        x = self._layer_stack(
            x,
            hparams.decoder_layers,
            encoder_output=encoder_output,
            self_attention_mask=decoder_self_attention_mask,
            encdec_attention_mask=encoder_decoder_attention_mask,
            losses=extra_losses)
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      # For some reason, the logits computation is extremely slow on TPU
      # in some cases where the batch size per core is 1.  Reshape the logits
      # and the targets to double the batch size and halve the length.
      # TODO(noam): file a bug.
      old_dims = self.batch_dims + [self.length_dim]
      new_dims = self.batch_dims[:-1] + [
          mtf.Dimension(self.batch_dims[-1].name,
                        self.batch_dims[-1].size * 2),
          mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)]
      x = mtf.reshape(x, new_dims + [self.model_dim])
      targets = mtf.reshape(targets, new_dims)

    logits = mtf.matmul(x, softmax_var)
    if hparams.mode == tf.estimator.ModeKeys.TRAIN:
      logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
    off_value = hparams.label_smoothing / self._targets_vocab_size
    on_value = 1.0 - hparams.label_smoothing + off_value
    soft_targets = mtf.one_hot(
        targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value,
        dtype=self.activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.targets_vocab_dim)
    weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype)
    loss = mtf.reduce_mean(loss * weights)
    for l in extra_losses:
      loss += l
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim])
    logits = mtf.to_float(logits)
    return logits, loss
Example #12
0
        def logits_and_loss(mtf_features):
            """Compute logits and loss.
            Args:
              mtf_features: a dictionary
            Returns:
              logits: a mtf.Tensor
              loss: a mtf.Tensor
            """
            if model_type == "lm":  # TOTRY Adapt that to our case
                if "inputs" in mtf_features:
                    mtf_features = _dynamic_text2self(mtf_features)
                _, _, length_dim = mtf_features["targets"].shape
                inputs = mtf.shift(mtf_features["targets"],
                                   offset=1,
                                   dim=length_dim,
                                   wrap=False)
            else:
                inputs = mtf_features["inputs"]

            if attribute_embedding:
                attributes = mtf_features["attribute"]
            else:
                attributes = None

            if control_codes:
                codeprefixedtargets = mtf_features["codeprefixedtargets"]
            else:
                codeprefixedtargets = None

            if isinstance(transformer_model, transformer.Unitransformer):
                position_kwargs = dict(
                    sequence_id=mtf_features.get("targets_segmentation", None),
                    position=mtf_features.get("targets_position", None),
                )
            elif isinstance(transformer_model, transformer.Bitransformer
                            ) or model_type == "bi_student_teacher":
                if control_codes:
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "codeprefixedtargets_segmentation", None),
                        decoder_subsequence_id=mtf_features.get(
                            "codeprefixedtargets_subsegmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "codeprefixedtargets_position", None),
                    )
                else:
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "targets_segmentation", None),
                        decoder_subsequence_id=mtf_features.get(
                            "targets_subsegmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "targets_position", None),
                    )
            else:
                raise ValueError("unrecognized class")

            if isinstance(transformer_model, Bitransformer_ll):
                if cycle_consistency_loss:
                    logits_ae, l_ae = transformer_model.call_simple(
                        inputs=inputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)

                    if has_partial_sequences:
                        controlcodes = mtf_features["controlcode"]
                    else:
                        controlcodes = None

                    with gin.config_scope('training'):
                        mtf_samples = transformer_model.decode(
                            inputs,
                            attributes=attributes,
                            controlcodes=controlcodes,
                            has_partial_sequences=has_partial_sequences,
                            remove_partial_sequences=remove_partial_sequences,
                            variable_dtype=get_variable_dtype())
                        # mtf_samples = mtf.anonymize(mtf_samples)
                    outputs = mtf_samples

                    logits_cycle, l_cycle = transformer_model.call_simple(
                        inputs=outputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)

                    loss_ae_cycle = lambda_ae * l_ae + lambda_cycle * l_cycle
                    return logits_cycle, loss_ae_cycle
                else:
                    return transformer_model.call_simple(
                        inputs=inputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)
            else:
                return transformer_model.call_simple(
                    inputs=inputs,
                    targets=mtf_features["targets"],
                    compute_loss=True,
                    mode=mode,
                    variable_dtype=get_variable_dtype(),
                    num_microbatches=num_microbatches,
                    **position_kwargs)
    def sample_autoregressive(self,
                              partial_sequences,
                              dst_attributes=None,
                              stop_at_token=1,
                              max_steps=None,
                              temperature=0.0,
                              variable_dtype=mtf.VariableDType(tf.float32),
                              encoder_output=None,
                              encoder_sequence_id=None,
                              encoder_inputs=None,
                              shared_params=None,
                              has_partial_sequences=True,
                              encoder_layer_outputs=None,
                              never_end=False,
                              remove_partial_sequences=False,
                              sampling_keep_top_k=-1,
                              z=None):
        """Sample randomly one token at a time.
        The partial_sequences represent partial sequences to be continued.  The
        first tokens of each sequence are nonzero representing the given partial
        sequences and the last tokens of each sequence are zeros, representing what
        needs to be filled in.
        If there are no partial sequences (you want to sample from the beginning),
        then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
        has_partial_sequences=False (so we can skip computation).
        The dst_attributes represents the destination attributes in which we want to generate sequences.
        Args:
          partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
          dst_attribute: an int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>])
          stop_at_token: an optional integer eos id.  Stop when we produce it.
          max_steps: an optional integer, the max number of steps to decode.
          temperature: an optional floating point value between 0.0 and 1.0 0.0
            means argmax, 1.0 means sample according to predicted distribution.
          variable_dtype: a mtf.VariableDType
          encoder_output: an optional Tensor
          encoder_sequence_id: an optional Tensor
          encoder_inputs: an optional Tensor
          shared_params: an optional dictionary
          has_partial_sequences: a boolean
          encoder_layer_outputs: optional - readonly list of tensor activations when
            decoding, one per each input layer + the embedding layer
          never_end: a boolean - if set, then avoid generating stop_at_token
          remove_partial_sequences: a boolean - whether to remove the partial
            sequences from the output
          sampling_keep_top_k: an integer - if not -1, only sample from the top k
            logits.
        Returns:
          a Tensor with shape [<batch_dims>, length_dim]
        """
        if not self.autoregressive:
            raise ValueError("must be autoregressive")

        inputs = partial_sequences
        attributes = dst_attributes
        batch_dims = inputs.shape.dims[:-1]
        length_dim = inputs.shape.dims[-1]
        initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal(
            inputs, 0)),
                                          reduced_dim=length_dim)
        sequence_id = 1 if encoder_sequence_id is not None else None

        length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
        if self.input_full_attention:
            read_priority = write_priority = length_range * mtf.to_int32(
                mtf.greater(length_range, initial_position))
        else:
            read_priority = write_priority = length_range

        context_first_part = Context(
            model=self,
            mesh=inputs.mesh,
            batch_dims=batch_dims,
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="first_part",
            position=length_range,
            position_is_default=True,
            new_states=[],
            initial_position=initial_position,
            sequence_id=sequence_id,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            constant_states=[],
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=inputs,
            encoder_inputs=encoder_inputs)

        shifted_inputs = mtf.shift(inputs,
                                   offset=1,
                                   dim=length_dim,
                                   wrap=False)
        with tf.variable_scope(self.name):
            logits = self._call_internal(context_first_part,
                                         shifted_inputs,
                                         attributes=attributes,
                                         z=z)
        del logits
        constant_states = context_first_part.constant_states
        if not has_partial_sequences:
            initial_states = [
                mtf.zeros_like(t) for t in context_first_part.new_states
            ]
            partial_sequences_eos_count = 0
        else:
            initial_states = context_first_part.new_states
            partial_sequences_eos_count = mtf.reduce_sum(
                mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)),
                reduced_dim=length_dim)

        def cond_fn(position, ids, *unused_states):
            """Should we run another loop iteration."""
            past_end = mtf.greater_equal(position, length_dim.size)
            if max_steps:
                past_end = mtf.logical_or(
                    past_end,
                    mtf.greater_equal(position - initial_position, max_steps))

            is_done = past_end
            if stop_at_token is not None:
                eos_count = mtf.reduce_sum(mtf.to_int32(
                    mtf.equal(ids, stop_at_token)),
                                           reduced_dim=length_dim)
                has_additional_eos = mtf.greater(eos_count,
                                                 partial_sequences_eos_count)
                is_done = mtf.logical_or(is_done, has_additional_eos)
            all_done = mtf.reduce_all(is_done)
            return mtf.logical_not(all_done)

        def body_fn(position, ids, *states):
            """One step in the decode loop."""
            inputs_this_step = mtf.gather(ids, position - 1, length_dim)
            if self.attribute_embedding:
                attributes_this_step = mtf.gather(attributes, position - 1,
                                                  length_dim)
            else:
                attributes_this_step = None
            # raise ValueError("inputs_this_step shape=%s , ids shape=%s, position - 1 shape=%s, length_dim=%s" % (inputs_this_step.shape, ids.shape, (position - 1).shape, length_dim))
            context_incremental = Context(
                model=self,
                mesh=inputs.mesh,
                batch_dims=batch_dims,
                length_dim=length_dim,
                variable_dtype=variable_dtype,
                mode="incremental",
                position=position,
                states=states,
                new_states=[],
                sequence_id=sequence_id,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                constant_states=constant_states,
                shared_params=shared_params,
                encoder_layer_outputs=encoder_layer_outputs,
                write_priority=write_priority,
                read_priority=position,
                inputs=inputs_this_step,
                encoder_inputs=encoder_inputs)

            with tf.variable_scope(self.name, reuse=True):
                logits = self._call_internal(context_incremental,
                                             inputs_this_step,
                                             attributes=attributes_this_step,
                                             z=z)
                if never_end:
                    logits += mtf.one_hot(mtf.constant(logits.mesh,
                                                       stop_at_token,
                                                       dtype=tf.int32),
                                          self.output_vocab_dim,
                                          on_value=-1e9,
                                          off_value=0.0,
                                          dtype=logits.dtype)

            # TBD whether this should be before or after never_end:
            # Note for adding top_p sampling in the future, in other code bases, the
            # option to apply temperature is done before the top-k truncation. This
            # implementation does this in the opposite order. For top-k this doesn't
            # matter, but for top_p it will.
            if sampling_keep_top_k != -1:
                if sampling_keep_top_k <= 0:
                    raise ValueError(
                        "sampling_keep_top_k must either be -1 or positive.")
                k_largest = mtf.nth_largest_element(
                    logits,
                    n=sampling_keep_top_k,
                    reduced_dim=self.output_vocab_dim)
                logits = mtf.where(mtf.less_equal(logits, k_largest),
                                   mtf.ones_like(logits) * -1e6, logits)

            ids_this_step = mtf.sample_with_temperature(
                logits, self.output_vocab_dim, temperature)
            new_position = position + 1
            new_ids = ids + ids_this_step * mtf.one_hot(
                position, length_dim, dtype=tf.int32)
            return [new_position, new_ids] + context_incremental.new_states

        while_loop_inputs = [initial_position, inputs] + initial_states
        final_position, outputs = mtf.while_loop(cond_fn, body_fn,
                                                 while_loop_inputs)[:2]
        del final_position
        if has_partial_sequences and remove_partial_sequences:
            # remove partial sequences from outputs
            partial_length = mtf.reduce_sum(mtf.to_int32(
                mtf.not_equal(partial_sequences, 0)),
                                            reduced_dim=length_dim)
            outputs = mtf.dynamic_shift(outputs,
                                        -partial_length,
                                        length_dim,
                                        wrap=False)
        return outputs
def gradient_based_subword_tokenization(x,
                                        length_dim,
                                        max_subword_length=4,
                                        downsample=None,
                                        use_offsets=False,
                                        consider_chars_as_blocks=False,
                                        use_block_pos_embedding=False,
                                        share_block_kernel=False,
                                        memory_embeddings=0,
                                        context=None,
                                        block_mixing_mode=None,
                                        activation="softmax",
                                        downsample_function="mean"):
    """Implements GBSWT from Charformer.

  Args:
    x: a Tensor containing length_dim
    length_dim: a Dimension
    max_subword_length: integer
    downsample: integer.
    use_offsets: boolean.
    consider_chars_as_blocks: boolean.
    use_block_pos_embedding: boolean.
    share_block_kernel: boolean.
    memory_embeddings: integer.
    context: Context.
    block_mixing_mode: Str for block mixing.
    activation: Str for block ranking.
    downsample_function: Str, supports mean/linformer for now.

  Returns:
    a Tensor with the same shape as x.

  Raises:
    ValueError: if channels or depth don't match.
  """
    # don't use this for now.
    del max_subword_length
    del memory_embeddings
    all_blocks = []
    all_scores = []
    tf.logging.info("GSW block layer")

    def _tile(x, n, tile_dim):
        # Simple tile function in MTF.
        return mtf.concat([x] * n, tile_dim.name)

    def _repeat(x, n, repeat_dim):
        # repeat function in MTF
        tmp_dim = mtf.Dimension("tmp", 1)
        expand_shape = mtf.Shape(x.shape.dims + [tmp_dim])
        x = mtf.reshape(x, expand_shape)
        x = _tile(x, n, tmp_dim)
        output_shape = []
        for dim in x.shape.dims:
            if dim.name == "tmp":
                continue
            if dim.name == repeat_dim.name:
                dim = mtf.Dimension(dim.name, dim.size * n)
            output_shape.append(dim)
        output_shape = mtf.Shape(output_shape)
        x = mtf.reshape(x, output_shape)
        return x

    def _combined_dim(dims):
        return mtf.Dimension(dims[0].name, mtf.Shape(dims).size)

    # compute all subword blocks
    # TODO(yitay): handle offsets to get all blocks
    if activation == "sigtanh":
        # one score for sigmoid
        tmp_dim = mtf.Dimension("block_score", 2)
    else:
        tmp_dim = mtf.Dimension("block_score", 1)

    model_dim = x.shape[-1]
    subword_blocks_width = [2, 3, 4]

    if consider_chars_as_blocks:
        subword_blocks_width += [1]

    if share_block_kernel:
        block_kernel_shape = mtf.Shape([model_dim, tmp_dim])
        block_kernel = mtf.get_variable(x.mesh,
                                        "block_kernel",
                                        block_kernel_shape,
                                        initializer=None,
                                        dtype=context.variable_dtype)
    else:
        block_kernel = None

    for subword_len in subword_blocks_width:
        if use_block_pos_embedding:
            # this is turn off by default. It is meant to support cases like
            # parameterized pooling or other features.
            block_len_dim = mtf.Dimension(length_dim.name, subword_len)
            # TODO(vqtran): Consider other positional embeddings.
            block_pos_emb = sinusoid_positional_embedding_weights(
                context.mesh, block_len_dim, x.shape[-1],
                context.variable_dtype.activation_dtype)
            block_pos_emb = _repeat(
                block_pos_emb, math.ceil(length_dim.size / float(subword_len)),
                block_len_dim)
        if use_offsets:
            offset_space = subword_len
        else:
            offset_space = 1
        for offsets in range(offset_space):
            if offsets > 0:
                xoff = mtf.shift(x, offsets, length_dim, wrap=False)
                if use_block_pos_embedding:
                    block_pos_emb = mtf.shift(block_pos_emb,
                                              offsets,
                                              block_pos_emb.shape[-2],
                                              wrap=False)
            else:
                xoff = x
            tf.logging.info("SW len=%d offset=%d", subword_len, offsets)
            if length_dim.size % subword_len != 0:
                tf.logging.info("Not divisible by length")
                # add extra padding tokens
                pad_amt = int(subword_len) - int(length_dim.size % subword_len)
                kp = mtf.pad(xoff, [0, pad_amt], length_dim.name)
            else:
                kp = xoff

            if use_block_pos_embedding:
                kp += block_pos_emb

            bx = mtf.pool_tensor_1d(
                kp,
                pool_dim=kp.shape.get_dim_by_name("length"),
                reduce_fn=mtf.reduce_mean,
                pool_size=int(subword_len))
            block_score = mtf.layers.dense(bx, [tmp_dim],
                                           use_bias=False,
                                           name="bx",
                                           reduced_dims=[model_dim],
                                           variable_dtype=None,
                                           kernel_weights=block_kernel)

            expand_bx = _repeat(bx, subword_len, length_dim)
            expand_scores = _repeat(block_score, subword_len, length_dim)
            if offsets > 0:
                # add offset.
                expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name)
                expand_scores = mtf.pad(expand_scores, [offsets, 0],
                                        length_dim.name)
            new_len = expand_bx.shape.get_dim_by_name(length_dim.name)
            if new_len.size < length_dim.size:
                pad_amt = new_len.size - length_dim.size
                expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name)
                expand_scores = mtf.pad(expand_scores, [0, pad_amt],
                                        length_dim.name)
            elif new_len.size > length_dim.size:
                expand_bx = mtf.slice(expand_bx, 0, length_dim.size,
                                      length_dim.name)
                expand_scores = mtf.slice(expand_scores, 0, length_dim.size,
                                          length_dim.name)

            new_tmp_dim = mtf.Dimension("extra_dim", 1)
            expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim])
            expand_scores_shape = mtf.Shape(expand_scores.shape.dims +
                                            [new_tmp_dim])
            expand_bx = mtf.reshape(expand_bx, expand_shape)
            expand_scores = mtf.reshape(expand_scores, expand_scores_shape)
            all_blocks.append(expand_bx)
            all_scores.append(expand_scores)

    all_blocks = mtf.concat(all_blocks, new_tmp_dim.name)
    all_scores = mtf.concat(all_scores, new_tmp_dim.name)
    tf.logging.info(all_blocks)
    new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim")
    combined_dim = _combined_dim([new_tmp_dim, tmp_dim])
    block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim
    block_net = mtf.reshape(all_scores, block_net_shape)

    if block_mixing_mode == "score_attention":
        tf.logging.info("Using score attention")
        att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim])
        tf.logging.info(block_net)
        att = mtf.softmax(att, reduced_dim=att.shape[-1])
        block_net = mtf.einsum([att, block_net], output_shape=block_net.shape)
        tf.logging.info(block_net)

    if activation == "softmax":
        block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim)
    elif activation == "tanh":
        tf.logging.info("Using tanh")
        block_net = mtf.tanh(block_net)

    all_blocks = block_net * all_blocks
    all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim)
    output = all_blocks

    if downsample:
        output_length = output.shape.get_dim_by_name("length")
        if output_length.size % int(downsample) != 0:
            pad_amt = int(downsample) - int(
                output_length.size % int(downsample))
            output = mtf.pad(output, [0, pad_amt], output_length.name)
        if downsample_function == "mean":
            output = mtf.pool_tensor_1d(
                output,
                pool_dim=output.shape.get_dim_by_name("length"),
                reduce_fn=mtf.reduce_mean,
                pool_size=int(downsample))
        else:
            raise ValueError("Downsampling function not implemeneted.")

    return output
    def beam_search(self,
                    inputs,
                    decode_length,
                    dst_attributes=None,
                    variable_dtype=mtf.VariableDType(tf.float32),
                    encoder_output=None,
                    encoder_sequence_id=None,
                    encoder_inputs=None,
                    alpha=0.6,
                    shared_params=None,
                    encoder_layer_outputs=None,
                    z=None):
        """Beam search.
        Args:
          inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim,
            length_dim].#
          decode_length: an int32 mtf scalar.  Maximum decode length.
          attributes: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim]
                                          ([<batch_dims>]
                                           [<batch_dims>, beam_dim]).
          variable_dtype: a mtf.VariableDType
          encoder_output: an optional Tensor
          encoder_sequence_id: an optional Tensor
          encoder_inputs: an optional Tensor
          alpha: a floating point value (length bonus)
          shared_params: an optional dictionary
          encoder_layer_outputs: optional - readonly list of tensor activations when
            decoding, one per each input layer + the embedding layer
        Returns:
          a Tensor with shape [<batch_dims>, beam_dim, length_dim]
        """
        attributes = dst_attributes
        if not self.autoregressive:
            raise ValueError("must be autoregressive")

        batch_dims = inputs.shape.dims[:-2]
        if len(batch_dims) != 1:
            raise NotImplementedError(
                "beam search supports exactly one batch dimension.")
        beam_dim = inputs.shape.dims[-2]
        length_dim = inputs.shape.dims[-1]
        length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
        initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal(
            inputs, 0)),
                                          reduced_dim=length_dim)
        sequence_id = 1 if encoder_sequence_id is not None else None

        if self.input_full_attention:
            # This only makes sense in the case of beam search with given partial
            # sequences, which is not yet implemented.
            # TODO(noam): implement
            raise NotImplementedError(
                "Beam search for language models not yet implemented")
        else:
            read_priority = write_priority = length_range

        context_first_part = Context(
            model=self,
            mesh=inputs.mesh,
            batch_dims=batch_dims + [beam_dim],
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="first_part",
            position=length_range,
            position_is_default=True,
            new_states=[],
            initial_position=initial_position,
            sequence_id=sequence_id,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            constant_states=[],
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=inputs,
            encoder_inputs=encoder_inputs)

        shifted_inputs = mtf.shift(inputs,
                                   offset=1,
                                   dim=length_dim,
                                   wrap=False)
        with tf.variable_scope(self.name):
            logits = self._call_internal(context_first_part,
                                         shifted_inputs,
                                         attributes=attributes,
                                         z=z)
        del logits
        # There are no partial targets.
        # Replace initial states by zeros to avoid computing them.
        initial_states = [
            mtf.zeros_like(t) for t in context_first_part.new_states
        ]
        constant_states = context_first_part.constant_states

        def logits_fn(step_num, ids, states):
            """logits_fn for mtf.beam_search.beam_search()."""
            inputs_this_step = mtf.gather(ids, step_num - 1, length_dim)

            if self.attribute_embedding:
                attributes_this_step = mtf.gather(attributes, step_num - 1,
                                                  length_dim)
            else:
                attributes_this_step = None

            context_incremental = Context(
                model=self,
                mesh=inputs.mesh,
                batch_dims=batch_dims + [beam_dim],
                length_dim=length_dim,
                variable_dtype=variable_dtype,
                mode="incremental",
                position=step_num,
                states=states,
                new_states=[],
                sequence_id=sequence_id,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                constant_states=constant_states,
                shared_params=shared_params,
                encoder_layer_outputs=encoder_layer_outputs,
                write_priority=write_priority,
                read_priority=step_num,
                inputs=inputs_this_step,
                encoder_inputs=encoder_inputs)
            with tf.variable_scope(self.name, reuse=True):
                logits = self._call_internal(context_incremental,
                                             inputs_this_step,
                                             attributes=attributes_this_step,
                                             z=z)
            return mtf.to_float(logits), context_incremental.new_states

        beams, unused_scores = mtf.beam_search.beam_search(
            logits_fn,
            inputs,
            alpha,
            states=initial_states,
            decode_length=decode_length,
            use_tpu=True,
            dtype=tf.float32,
            mesh_shape=self.mesh_shape,
            layout=self.layout)
        return mtf.gather(beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32),
                          beam_dim)
Example #16
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        def _import_feature(key):
            """Import a feature from the features dictionary into a mtf.Tensor.

      Args:
        key: a string

      Returns:
        a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim]
      """
            batch_dim = mtf.Dimension("batch", batch_size)
            length_dim = mtf.Dimension("length", length)
            mtf_shape = mtf.Shape([batch_dim, length_dim])
            if key not in features:
                return None
            x = tf.to_int32(features[key])
            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key)

        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = _import_feature("inputs")
            if text2self:
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs,
                    variable_dtype=variable_dtype,
                    temperature=temperature)
            else:
                mtf_samples = transformer_model.decode(
                    inputs,
                    variable_dtype=variable_dtype,
                    beam_size=beam_size,
                    alpha=alpha,
                    temperature=temperature)
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        targets = _import_feature("targets")
        anon_targets = mtf.anonymize(targets)
        if text2self:
            _, length_dim = targets.shape
            inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False)
        else:
            inputs = _import_feature("inputs")

        if text2self:
            position_kwargs = dict(
                sequence_id=_import_feature("targets_segmentation"),
                position=_import_feature("targets_position"),
            )
        else:
            position_kwargs = dict(
                encoder_sequence_id=_import_feature("inputs_segmentation"),
                decoder_sequence_id=_import_feature("targets_segmentation"),
                encoder_position=_import_feature("inputs_position"),
                decoder_position=_import_feature("targets_position"),
            )

        logits, loss = transformer_model.call_simple(
            inputs=inputs,
            targets=targets,
            compute_loss=True,
            mode=mode,
            variable_dtype=variable_dtype,
            **position_kwargs)

        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            optimizer = mtf.optimize.AdafactorOptimizer()
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if not use_tpu:
            tf_loss = tf.Print(tf_loss,
                               [tf_loss, tf.train.get_global_step()],
                               "step, tf_loss")
        if logits and mode != tf.estimator.ModeKeys.TRAIN:
            tf_logits = lowering.export_to_tf_tensor(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

            if mode == tf.estimator.ModeKeys.TRAIN:
                if use_tpu:
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_hooks=[restore_hook, saver_hook])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[restore_hook, saver_hook])
            elif mode == tf.estimator.ModeKeys.EVAL:

                def padded_neg_log_perplexity(logits, labels):
                    weights = tf.to_float(tf.not_equal(labels, 0))
                    xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
                        labels=labels, logits=logits)
                    return {
                        "neg_log_perplexity": tf.metrics.mean(-xent, weights)
                    }

                labels = lowering.export_to_tf_tensor(anon_targets)
                eval_metrics = (padded_neg_log_perplexity, [tf_logits, labels])
                return tpu_estimator.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.EVAL,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics)
Example #17
0
  def sample_autoregressive(self,
                            partial_sequences,
                            stop_at_token=1,
                            max_steps=None,
                            temperature=1.0,
                            variable_dtype=mtf.VariableDType(tf.float32),
                            encoder_output=None,
                            encoder_sequence_id=None,
                            shared_params=None,
                            has_partial_sequences=True,
                            encoder_layer_outputs=None):
    """Sample randomly one token at a time.

    The partial_sequences represent partial sequences to be continued.  The
    first tokens of each sequence are nonzero representing the given partial
    sequences and the last tokens of each sequence are zeros, representing what
    needs to be filled in.

    If there are no partial sequences (you want to sample from the beginning),
    then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
    has_partial_sequences=False (so we can skip computation).

    Args:
      partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
      stop_at_token: an optional integer eos id.  Stop when we produce it.
      max_steps: an optional integer
      temperature: an optional floating point value between 0.0 and 1.0 0.0
        means argmax, 1.0 means sample according to predicted distribution.
      variable_dtype: a mtf.VariableDType
      encoder_output: an optional Tensor
      encoder_sequence_id: an optional Tensor
      shared_params: an optional dictionary
      has_partial_sequences: a boolean
      encoder_layer_outputs: optional - readonly list of tensor activations when
        decoding, one per each input layer + the embedding layer

    Returns:
      a Tensor with shape [<batch_dims>, length_dim]
    """
    del max_steps  # TODO(noam): implement
    if not self.autoregressive:
      raise ValueError("must be autoregressive")

    inputs = partial_sequences
    batch_dims = inputs.shape.dims[:-1]
    length_dim = inputs.shape.dims[-1]
    initial_position = mtf.reduce_sum(
        mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim)
    sequence_id = 1 if encoder_sequence_id is not None else None

    context_first_part = Context(
        mesh=inputs.mesh,
        batch_dims=batch_dims,
        length_dim=length_dim,
        model_dim=self.model_dim,
        variable_dtype=variable_dtype,
        mode="first_part",
        autoregressive=self.autoregressive,
        new_states=[],
        initial_position=initial_position,
        sequence_id=sequence_id,
        encoder_output=encoder_output,
        encoder_sequence_id=encoder_sequence_id,
        constant_states=[],
        shared_params=shared_params,
        layout=self.layout,
        mesh_shape=self.mesh_shape,
        encoder_layer_outputs=encoder_layer_outputs)

    shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False)
    with tf.variable_scope(self.name):
      logits = self._call_internal(context_first_part, shifted_inputs)
    del logits
    constant_states = context_first_part.constant_states
    if not has_partial_sequences:
      initial_states = [
          mtf.zeros_like(t) for t in context_first_part.new_states]
    else:
      initial_states = context_first_part.new_states

    def cond_fn(position, ids, *unused_states):
      """Should we run another loop iteration."""
      past_end = mtf.greater_equal(position, length_dim.size)
      is_done = past_end
      if stop_at_token is not None:
        has_eos = mtf.reduce_any(
            mtf.equal(ids, stop_at_token), reduced_dim=length_dim)
        is_done = mtf.logical_or(is_done, has_eos)
      all_done = mtf.reduce_all(is_done)
      return mtf.logical_not(all_done)

    def body_fn(position, ids, *states):
      """One step in the decode loop."""
      context_incremental = Context(
          mesh=inputs.mesh,
          batch_dims=batch_dims,
          length_dim=length_dim,
          model_dim=self.model_dim,
          variable_dtype=variable_dtype,
          mode="incremental",
          autoregressive=self.autoregressive,
          position=position,
          states=states,
          new_states=[],
          sequence_id=sequence_id,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          constant_states=constant_states,
          shared_params=shared_params,
          layout=self.layout,
          mesh_shape=self.mesh_shape,
          encoder_layer_outputs=encoder_layer_outputs)
      inputs_this_step = mtf.gather(ids, position - 1, length_dim)
      with tf.variable_scope(self.name, reuse=True):
        logits = self._call_internal(context_incremental, inputs_this_step)
      ids_this_step = mtf.sample_with_temperature(
          logits, self.output_vocab_dim, temperature)
      new_position = position + 1
      new_ids = ids + ids_this_step * mtf.one_hot(
          position, length_dim, dtype=tf.int32)
      return [new_position, new_ids] + context_incremental.new_states
    while_loop_inputs = [initial_position, inputs] + initial_states
    final_position, outputs = mtf.while_loop(
        cond_fn, body_fn, while_loop_inputs)[:2]
    del final_position
    return outputs
Example #18
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = _logical_to_physical(physical_shape,
                                                       mesh_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        def _import_feature(key, allow_missing=False):
            """Import a feature from the features dictionary into a mtf.Tensor.

      Args:
        key: a string
        allow_missing: a boolean

      Returns:
        a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim]
      """
            outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
            batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
            length_dim = mtf.Dimension("length", sequence_length)

            mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim])
            if key not in features:
                if allow_missing:
                    return None
                else:
                    raise ValueError("feature not found %s - features %s = " %
                                     (key, features))
            tf.logging.info("Import feature %s: %s" % (key, features[key]))

            x = tf.to_int32(features[key])
            x = tf.reshape(
                x, [outer_batch_size, batch_size // outer_batch_size, -1])

            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key)

        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = _import_feature("inputs")
            inputs = mtf.reshape(
                inputs,
                mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", sequence_length)
                ]))
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        targets = _import_feature("targets")
        anon_targets = mtf.anonymize(targets)
        if model_type == "lm":
            _, length_dim = targets.shape
            inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False)
        else:
            inputs = _import_feature("inputs")

        if mode == tf.estimator.ModeKeys.EVAL:
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            labels = lowering.export_to_tf_tensor(anon_targets)
            restore_hook = mtf.MtfRestoreHook(lowering)

            # metric_names becomes locally scoped if we simply assign
            # ["padded_neg_log_perplexity"] to it conditioned on if it's None.
            local_metric_names = metric_names or ["token_accuracy"]

            def metric_fn(labels, outputs):
                return get_metric_fns(local_metric_names, labels, outputs)

            eval_metrics = (metric_fn, [labels, outputs])
            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                # Unfortunately TPUEstimatorSpec requires us to provide a value for
                # loss when in EVAL mode. Since we are sampling or decoding from the
                # model, we don't have a loss to report.
                loss=tf.constant(0.),
                evaluation_hooks=[restore_hook],
                eval_metrics=eval_metrics)

        if isinstance(transformer_model, transformer.Unitransformer):
            position_kwargs = dict(
                sequence_id=_import_feature("targets_segmentation", True),
                position=_import_feature("targets_position", True),
            )
        elif isinstance(transformer_model, transformer.Bitransformer):
            position_kwargs = dict(
                encoder_sequence_id=_import_feature("inputs_segmentation",
                                                    True),
                decoder_sequence_id=_import_feature("targets_segmentation",
                                                    True),
                encoder_position=_import_feature("inputs_position", True),
                decoder_position=_import_feature("targets_position", True),
            )
        else:
            raise ValueError("unrecognized class")

        logits, loss = transformer_model.call_simple(
            inputs=inputs,
            targets=targets,
            compute_loss=True,
            mode=mode,
            variable_dtype=get_variable_dtype(),
            **position_kwargs)

        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            optimizer = mtf.optimize.AdafactorOptimizer(
                learning_rate=learning_rate)
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if not use_tpu:
            tf_loss = tf.Print(tf_loss,
                               [tf_loss, tf.train.get_global_step()],
                               "step, tf_loss")

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=checkpoints_to_keep,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                model_dir,
                save_steps=save_steps,
                saver=saver,
                listeners=[saver_listener])
            gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                model_dir, summarize_config=True)

            if mode == tf.estimator.ModeKeys.TRAIN:
                if use_tpu:
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
Example #19
0
  def beam_search(self,
                  inputs,
                  decode_length,
                  variable_dtype=mtf.VariableDType(tf.float32),
                  encoder_output=None,
                  encoder_sequence_id=None,
                  alpha=0.6,
                  shared_params=None,
                  encoder_layer_outputs=None):
    """Beam search.

    Args:
      inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim,
        length_dim].
      decode_length: an int32 mtf scalar.  Maximum decode length.
      variable_dtype: a mtf.VariableDType
      encoder_output: an optional Tensor
      encoder_sequence_id: an optional Tensor
      alpha: a floating point value (length bonus)
      shared_params: an optional dictionary
      encoder_layer_outputs: optional - readonly list of tensor activations when
        decoding, one per each input layer + the embedding layer

    Returns:
      a Tensor with shape [<batch_dims>, beam_dim, length_dim]
    """
    if not self.autoregressive:
      raise ValueError("must be autoregressive")

    batch_dims = inputs.shape.dims[:-2]
    if len(batch_dims) != 1:
      raise NotImplementedError(
          "beam search supports exactly one batch dimension.")
    beam_dim = inputs.shape.dims[-2]
    length_dim = inputs.shape.dims[-1]
    initial_position = mtf.reduce_sum(
        mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim)
    sequence_id = 1 if encoder_sequence_id is not None else None

    context_first_part = Context(
        mesh=inputs.mesh,
        batch_dims=batch_dims + [beam_dim],
        length_dim=length_dim,
        model_dim=self.model_dim,
        variable_dtype=variable_dtype,
        mode="first_part",
        autoregressive=self.autoregressive,
        new_states=[],
        initial_position=initial_position,
        sequence_id=sequence_id,
        encoder_output=encoder_output,
        encoder_sequence_id=encoder_sequence_id,
        constant_states=[],
        shared_params=shared_params,
        layout=self.layout,
        mesh_shape=self.mesh_shape,
        encoder_layer_outputs=encoder_layer_outputs)

    shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False)
    with tf.variable_scope(self.name):
      logits = self._call_internal(context_first_part, shifted_inputs)
    del logits
    # There are no partial targets.
    # Replace initial states by zeros to avoid computing them.
    initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states]
    constant_states = context_first_part.constant_states

    def logits_fn(step_num, ids, states):
      """logits_fn for mtf.beam_search.beam_search()."""
      context_incremental = Context(
          mesh=inputs.mesh,
          batch_dims=batch_dims + [beam_dim],
          length_dim=length_dim,
          model_dim=self.model_dim,
          variable_dtype=variable_dtype,
          mode="incremental",
          autoregressive=self.autoregressive,
          position=step_num,
          states=states,
          new_states=[],
          sequence_id=sequence_id,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          constant_states=constant_states,
          shared_params=shared_params,
          layout=self.layout,
          mesh_shape=self.mesh_shape,
          encoder_layer_outputs=encoder_layer_outputs)
      inputs_this_step = mtf.gather(ids, step_num - 1, length_dim)
      with tf.variable_scope(self.name, reuse=True):
        logits = self._call_internal(context_incremental, inputs_this_step)
      return mtf.to_float(logits), context_incremental.new_states

    beams, unused_scores = mtf.beam_search.beam_search(
        logits_fn,
        inputs,
        alpha,
        states=initial_states,
        decode_length=decode_length,
        use_tpu=True,
        dtype=tf.float32,
        mesh_shape=self.mesh_shape,
        layout=self.layout)
    return mtf.gather(
        beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
Example #20
0
  def _mtf_model_fn(self, features, mesh):
    self._original_features = features
    features = copy.copy(features)
    hparams = self._hparams
    extra_losses = []
    targets = tf.to_int32(features["targets"])
    if len(targets.get_shape()) > 2:
      tf.logging.info("targets = %s" % targets)
      targets = tf.squeeze(targets, [2, 3])
    # pad targets to max_length
    def pad_to_max_length(x):
      extra_length = hparams.max_length - tf.shape(x)[1]
      x = tf.pad(x, [[0, 0], [0, extra_length]])
      x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
      return x
    targets = pad_to_max_length(targets)
    targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams)
    for key in ["targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"]:
      if key in features:
        features[key] = pad_to_max_length(features[key])
    if hparams.decoder_type == "autoregressive":
      shifted_targets = mtf.shift(
          targets, offset=1, dim=self.length_dim, wrap=False)
    elif hparams.decoder_type == "denoising":
      shifted_targets = self._noisy_targets(targets, extra_losses)
    else:
      raise ValueError(
          "unknown hparams.decoder_type = %s" % hparams.decoder_type)

    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = self._import_to_batch_by_length(
          features["targets_segmentation"], "targets_segmentation",
          mesh, hparams)
      targets_position = self._import_to_batch_by_length(
          features["targets_position"], "targets_position",
          mesh, hparams)
      decoder_self_attention_mask = mtf.layers.attention_mask_same_segment(
          targets_segmentation, dtype=self.activation_dtype)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
    else:
      targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
      else:
        decoder_self_attention_mask = None

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if hparams.transformer_type == "decoder":
      encoder_output = None
      encoder_decoder_attention_mask = None
    else:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = pad_to_max_length(inputs)
      inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams)
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = self._import_to_batch_by_length(
            features["inputs_segmentation"], "inputs_segmentation",
            mesh, hparams)
        inputs_position = self._import_to_batch_by_length(
            features["inputs_position"], "inputs_position",
            mesh, hparams)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                inputs_segmentation, dtype=self.activation_dtype))
      else:
        inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))

      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.gather(positional_embedding_var, inputs_position,
                      self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.encoder_layers,
                              self_attention_mask=encoder_self_attention_mask,
                              losses=extra_losses)

    if hparams.transformer_type == "encdec":
      if "inputs_segmentation" in features:
        encoder_decoder_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                targets_segmentation, inputs_segmentation,
                dtype=self.activation_dtype))
      else:
        encoder_decoder_attention_mask = encoder_self_attention_mask
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)

    if hparams.transformer_type != "encoder":
      # DECODER
      x = (mtf.gather(
          targets_embedding_var, shifted_targets, self.targets_vocab_dim) +
           mtf.gather(
               positional_embedding_var, targets_position, self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("decoder"):
        x = self._layer_stack(
            x,
            hparams.decoder_layers,
            encoder_output=encoder_output,
            self_attention_mask=decoder_self_attention_mask,
            encdec_attention_mask=encoder_decoder_attention_mask,
            losses=extra_losses)
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      # For some reason, the logits computation is extremely slow on TPU
      # in some cases where the batch size per core is 1.  Reshape the logits
      # and the targets to double the batch size and halve the length.
      # TODO(noam): file a bug.
      old_dims = self.batch_dims + [self.length_dim]
      new_dims = self.batch_dims[:-1] + [
          mtf.Dimension(self.batch_dims[-1].name,
                        self.batch_dims[-1].size * 2),
          mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)]
      x = mtf.reshape(x, new_dims + [self.model_dim])
      targets = mtf.reshape(targets, new_dims)

    logits = mtf.matmul(x, softmax_var)
    if hparams.mode == tf.estimator.ModeKeys.TRAIN:
      logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
    off_value = hparams.label_smoothing / self._targets_vocab_size
    on_value = 1.0 - hparams.label_smoothing + off_value
    soft_targets = mtf.one_hot(
        targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value,
        dtype=self.activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.targets_vocab_dim)
    weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype)
    loss = mtf.reduce_mean(loss * weights)
    for l in extra_losses:
      loss += l
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim])
    logits = mtf.to_float(logits)
    return logits, loss