예제 #1
0
  def testDenseReluDense(self):
    batch = 2
    channels = 3
    hidden = 5
    inputs = tf.random_normal([batch, channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    channels_dim = mtf.Dimension("channels", channels)
    hidden_dim = mtf.Dimension("hidden", hidden)

    mtf_inputs = mtf.import_tf_tensor(
        mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim]))
    mtf_outputs = mtf_layers.dense_relu_dense(mtf_inputs,
                                              hidden_channels=hidden_dim)
    mesh_impl = placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init)
      sess.run(tf_group)
      actual = sess.run(actual_outputs)

    self.assertEqual(actual.shape, inputs.shape)
예제 #2
0
  def _feedforward_layer(self, x, losses=None):
    """Feed-forward layer.

    Args:
      x: a mtf.Tensor with shape [batch_dim, length_dim, model_dim]
      losses: a list to be appended-to
    Returns:
      a mtf.Tensor with shape [batch_dim, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
    hparams = self._hparams
    feedforward_layer = hparams.feedforward_layer
    if feedforward_layer == "dense_relu_dense":
      return mtf_layers.dense_relu_dense(
          x, self.feedforward_dim, dropout=hparams.relu_dropout,
          dropout_broadcast_dims=[self.length_dim])
    elif feedforward_layer == "moe":
      overhead = (
          hparams.moe_overhead_train
          if hparams.mode == tf.estimator.ModeKeys.TRAIN else
          hparams.moe_overhead_eval)
      output, loss = mtf_layers.moe_v0(
          x,
          self.feedforward_dim,
          self.model_dim,
          self.experts_dim,
          loss_coef=hparams.moe_loss_coef,
          overhead=overhead)
      if losses is not None:
        losses.append(loss)
        return output
    else:
      raise ValueError(
          "hparams.feedforward_layer not recognized %s" % feedforward_layer)
예제 #3
0
    def _feedforward_layer(self, x, losses=None):
        """Feed-forward layer.

    Args:
      x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
      losses: a list to be appended-to
    Returns:
      a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
        hparams = self._hparams
        feedforward_layer = hparams.feedforward_layer
        if feedforward_layer == "dense_relu_dense":
            return mtf_layers.dense_relu_dense(
                x,
                self.feedforward_dim,
                dropout=hparams.relu_dropout,
                dropout_broadcast_dims=[self.length_dim])
        elif feedforward_layer == "moe":
            output, loss = moe.transformer_moe_layer_v1(
                x, self.model_dim, hparams,
                hparams.mode == tf.estimator.ModeKeys.TRAIN)
        elif feedforward_layer == "hmoe":
            output, loss = moe.transformer_moe_layer_v2(
                x, self.model_dim, hparams,
                hparams.mode == tf.estimator.ModeKeys.TRAIN)
            if losses is not None:
                losses.append(loss)
            return output
        else:
            raise ValueError("hparams.feedforward_layer not recognized %s" %
                             feedforward_layer)
예제 #4
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.set_activation_type()

        # We assume fixed vocab size for targets
        targets_vocab_size = self._problem_hparams.target_modality._vocab_size  # pylint: disable=protected-access
        targets = tf.to_int32(features["targets"])

        # Image preprocessing, reshape into a 1D sequence and shift right.
        length = hparams.img_len * hparams.img_len * hparams.num_channels
        targets = tf.reshape(targets, [hparams.batch_size, length])
        shifted_targets = common_layers.shift_right_2d(targets)

        # Declare all the dimensions
        model_dim = mtf.Dimension("d_model", hparams.hidden_size)
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        length_dim = mtf.Dimension("length", length)
        max_length_dim = mtf.Dimension("max_length", hparams.max_length)
        filter_dim = mtf.Dimension("d_ff", hparams.d_ff)
        kv_channels = mtf.Dimension("kv_channels", hparams.d_kv)
        heads = mtf.Dimension("heads", hparams.num_heads)

        def import_to_batch_by_length(x, name):
            return mtf.import_tf_tensor(mesh,
                                        x,
                                        mtf.Shape([batch_dim, length_dim]),
                                        name=name)

        def layer_prepostprocess_dropout(x):
            return mtf.dropout(x,
                               keep_prob=1.0 -
                               hparams.layer_prepostprocess_dropout,
                               noise_shape=mtf.Shape([batch_dim, model_dim]))

        targets = import_to_batch_by_length(targets, "targets")
        shifted_targets = import_to_batch_by_length(shifted_targets,
                                                    "shifted_targets")

        extra_losses = []

        # Create targets content and position embeddings.
        targets_vocab_size = 256 * hparams.num_channels
        targets_vocab_dim = mtf.Dimension("vocab", targets_vocab_size)
        outputs_vocab_dim = mtf.Dimension("output_vocab", 256)

        # Create embedding var for targets and positions and do a gather.
        targets_embedding_var = mtf.get_variable(
            mesh,
            "targets_embedding",
            mtf.Shape([targets_vocab_dim, model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)

        x = mtf.gather(targets_embedding_var, shifted_targets,
                       targets_vocab_dim)
        # Add positional embeddings
        x += mtf.reshape(
            self.create_positional_emb_2d(targets, max_length_dim, model_dim),
            [length_dim, model_dim])

        # If conditional and input is given, add the input embedding to the target.
        # TODO(nikip): Verify conditional.
        if self.has_input and not hparams.unconditional:
            vocab_size = hparams.num_classes
            inputs_vocab_dim = mtf.Dimension("vocab", vocab_size)
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = import_to_batch_by_length(inputs, "inputs")

            # Input embeddings
            inputs_embedding_var = mtf_layers.embedding(
                mesh,
                "input_embedding",
                mtf.Shape([inputs_vocab_dim, model_dim]),
                activation_dtype=activation_dtype)
            inputs_emb = mtf.gather(inputs_embedding_var, inputs,
                                    inputs_vocab_dim)
            x += inputs_emb

        # Image Transformer Decoder
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_decoder_layers):
            layer_name = "decoder_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Self attention layer
                x += layer_prepostprocess_dropout(
                    mtf_layers.masked_local_attention_1d(
                        mtf_layers.layer_norm(x,
                                              model_dim,
                                              name="layer_norm_self_att"),
                        None,
                        kv_channels,
                        heads,
                        block_length=hparams.block_length,
                        name="self_att"))
                # ffn layer
                x += layer_prepostprocess_dropout(
                    mtf_layers.dense_relu_dense(
                        mtf_layers.layer_norm(x,
                                              model_dim,
                                              name="layer_norm_ffn"),
                        filter_dim,
                        hparams.dropout,
                        dropout_broadcast_dims=[length_dim]))

        x = mtf_layers.layer_norm(x,
                                  model_dim,
                                  name="decoder_final_layer_norm")

        # Calculate the logits and loss.
        logits = mtf_layers.dense(x, outputs_vocab_dim, name="logits")
        soft_targets = mtf.one_hot(targets,
                                   outputs_vocab_dim,
                                   dtype=activation_dtype)
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, outputs_vocab_dim)

        loss = mtf.reduce_mean(loss)
        for l in extra_losses:
            loss += l
        return logits, loss