def testWeightsNonzero(self):
        inputs = tf.constant([[3, 1, 0], [1, 0, 0]])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0])
        channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1])

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))
        mtf_outputs = mtf_layers.weights_nonzero(mtf_inputs)
        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                          layout={},
                                                          devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        expected_outputs = common_layers.weights_nonzero(inputs)
        tf_group = lowering.copy_masters_to_slices()
        self.evaluate(tf_group)
        actual, expected = self.evaluate([actual_outputs, expected_outputs])

        self.assertAllEqual(actual, expected)
Exemple #2
0
  def _mtf_model_fn(self, features, mesh):
    features = copy.copy(features)
    hparams = self._hparams
    targets = tf.to_int32(features["targets"])
    if len(targets.get_shape()) > 2:
      tf.logging.info("targets = %s" % targets)
      targets = tf.squeeze(targets, [2, 3])
    # pad targets to max_length
    def pad_to_max_length(x):
      extra_length = hparams.max_length - tf.shape(x)[1]
      x = tf.pad(x, [[0, 0], [0, extra_length]])
      x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
      return x
    targets = pad_to_max_length(targets)
    for key in ["targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"]:
      if key in features:
        features[key] = pad_to_max_length(features[key])
    shifted_targets = common_layers.shift_right_2d(targets)

    targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams)
    shifted_targets = self._import_to_batch_by_length(
        shifted_targets, "shifted_targets", mesh, hparams)

    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = self._import_to_batch_by_length(
          features["targets_segmentation"], "targets_segmentation",
          mesh, hparams)
      targets_position = self._import_to_batch_by_length(
          features["targets_position"], "targets_position",
          mesh, hparams)
      decoder_self_attention_mask = (
          mtf_layers.attention_mask_autoregressive(
              targets_position, dtype=self.activation_dtype) +
          mtf_layers.attention_mask_same_segment(
              targets_segmentation, dtype=self.activation_dtype))
    else:
      targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
      decoder_self_attention_mask = mtf_layers.attention_mask_autoregressive(
          targets_position, dtype=self.activation_dtype)

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape([self.batch_dim, self.model_dim]))

    extra_losses = []
    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if self.has_input:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = pad_to_max_length(inputs)
      inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams)
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = self._import_to_batch_by_length(
            features["inputs_segmentation"], "inputs_segmentation",
            mesh, hparams)
        inputs_position = self._import_to_batch_by_length(
            features["inputs_position"], "inputs_position",
            mesh, hparams)
        encoder_self_attention_mask = (
            mtf_layers.attention_mask_same_segment(
                inputs_segmentation, dtype=self.activation_dtype))
        encoder_decoder_attention_mask = (
            mtf_layers.attention_mask_same_segment(
                targets_segmentation, inputs_segmentation,
                dtype=self.activation_dtype))
      else:
        inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
        encoder_self_attention_mask = (
            mtf_layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))
        encoder_decoder_attention_mask = encoder_self_attention_mask

      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.gather(positional_embedding_var, inputs_position,
                      self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.num_encoder_layers,
                              self_attention_mask=encoder_self_attention_mask,
                              losses=extra_losses)
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)
    else:
      encoder_output = None
      encoder_decoder_attention_mask = None

    # DECODER
    x = (mtf.gather(
        targets_embedding_var, shifted_targets, self.targets_vocab_dim) +
         mtf.gather(
             positional_embedding_var, targets_position, self.max_length_dim))
    x = layer_prepostprocess_dropout(x)

    # Decoder
    with tf.variable_scope("decoder"):
      x = self._layer_stack(
          x,
          hparams.num_decoder_layers,
          encoder_output=encoder_output,
          self_attention_mask=decoder_self_attention_mask,
          encdec_attention_mask=encoder_decoder_attention_mask,
          losses=extra_losses)
    logits = mtf.matmul(x, softmax_var)
    off_value = hparams.label_smoothing / self._targets_vocab_size
    on_value = 1.0 - hparams.label_smoothing + off_value
    soft_targets = mtf.one_hot(
        targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value,
        dtype=self.activation_dtype)
    loss = mtf_layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.targets_vocab_dim)
    weights = mtf_layers.weights_nonzero(
        targets, dtype=self.activation_dtype)
    loss = mtf.reduce_mean(loss * weights)
    for l in extra_losses:
      loss += l
    return logits, loss