Пример #1
0
def mnist_model(image, labels, mesh):
  """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a tf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
  batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
  rows_dim = mtf.Dimension("rows", 28)
  cols_dim = mtf.Dimension("cols", 28)
  classes_dim = mtf.Dimension("classes", 10)
  hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
  hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

  x = mtf.import_tf_tensor(mesh, tf.reshape(image, [-1, 28, 28]),
                           mtf.Shape([batch_dim, rows_dim, cols_dim]))
  h1 = mtf_layers.dense(
      x, hidden_dim1, reduced_dims=[rows_dim, cols_dim],
      activation=mtf.relu, name="hidden1")
  h2 = mtf_layers.dense(
      h1, hidden_dim2, activation=mtf.relu, name="hidden2")
  logits = mtf_layers.dense(h2, classes_dim, name="logits")
  if labels is None:
    loss = None
  else:
    labels = mtf.import_tf_tensor(mesh, labels, mtf.Shape([batch_dim]))
    loss = mtf_layers.softmax_cross_entropy_with_logits(
        logits, mtf.one_hot(labels, classes_dim), classes_dim)
    loss = mtf.reduce_mean(loss)
  return logits, loss
Пример #2
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a tf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    rows_dim = mtf.Dimension("rows", 28)
    cols_dim = mtf.Dimension("cols", 28)
    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(mesh, tf.reshape(image, [-1, 28, 28]),
                             mtf.Shape([batch_dim, rows_dim, cols_dim]))
    x = mtf.reshape(x, [batch_dim, rows_dim, cols_dim, one_channel_dim])

    # add some convolutional layers to demonstrate that convolution works.
    # TODO(noam): get spatially-partitioned convolution working.
    fh_dim = mtf.Dimension("fh", 3)
    fw_dim = mtf.Dimension("fw", 3)
    filters1_dim = mtf.Dimension("filters1", 32)
    filters2_dim = mtf.Dimension("filters2", 32)
    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])

    f1 = mtf.relu(mtf.conv2d(x, kernel1))
    f2 = mtf.relu(mtf.conv2d(f1, kernel2))
    x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

    h1 = mtf_layers.dense(x,
                          hidden_dim1,
                          reduced_dims=[rows_dim, cols_dim],
                          activation=mtf.relu,
                          name="hidden1")
    h2 = mtf_layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2")
    logits = mtf_layers.dense(h2, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh, labels, mtf.Shape([batch_dim]))
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
Пример #3
0
  def _mtf_model_fn(self, features, mesh):
    features = copy.copy(features)
    hparams = self._hparams
    targets = tf.to_int32(features["targets"])
    if len(targets.get_shape()) > 2:
      tf.logging.info("targets = %s" % targets)
      targets = tf.squeeze(targets, [2, 3])
    # pad targets to max_length
    def pad_to_max_length(x):
      extra_length = hparams.max_length - tf.shape(x)[1]
      x = tf.pad(x, [[0, 0], [0, extra_length]])
      x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
      return x
    targets = pad_to_max_length(targets)
    for key in ["targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"]:
      if key in features:
        features[key] = pad_to_max_length(features[key])
    shifted_targets = common_layers.shift_right_2d(targets)

    targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams)
    shifted_targets = self._import_to_batch_by_length(
        shifted_targets, "shifted_targets", mesh, hparams)

    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = self._import_to_batch_by_length(
          features["targets_segmentation"], "targets_segmentation",
          mesh, hparams)
      targets_position = self._import_to_batch_by_length(
          features["targets_position"], "targets_position",
          mesh, hparams)
      decoder_self_attention_mask = (
          mtf_layers.attention_mask_autoregressive(
              targets_position, dtype=self.activation_dtype) +
          mtf_layers.attention_mask_same_segment(
              targets_segmentation, dtype=self.activation_dtype))
    else:
      targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
      decoder_self_attention_mask = mtf_layers.attention_mask_autoregressive(
          targets_position, dtype=self.activation_dtype)

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape([self.batch_dim, self.model_dim]))

    extra_losses = []
    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if self.has_input:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = pad_to_max_length(inputs)
      inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams)
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = self._import_to_batch_by_length(
            features["inputs_segmentation"], "inputs_segmentation",
            mesh, hparams)
        inputs_position = self._import_to_batch_by_length(
            features["inputs_position"], "inputs_position",
            mesh, hparams)
        encoder_self_attention_mask = (
            mtf_layers.attention_mask_same_segment(
                inputs_segmentation, dtype=self.activation_dtype))
        encoder_decoder_attention_mask = (
            mtf_layers.attention_mask_same_segment(
                targets_segmentation, inputs_segmentation,
                dtype=self.activation_dtype))
      else:
        inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
        encoder_self_attention_mask = (
            mtf_layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))
        encoder_decoder_attention_mask = encoder_self_attention_mask

      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.gather(positional_embedding_var, inputs_position,
                      self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.num_encoder_layers,
                              self_attention_mask=encoder_self_attention_mask,
                              losses=extra_losses)
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)
    else:
      encoder_output = None
      encoder_decoder_attention_mask = None

    # DECODER
    x = (mtf.gather(
        targets_embedding_var, shifted_targets, self.targets_vocab_dim) +
         mtf.gather(
             positional_embedding_var, targets_position, self.max_length_dim))
    x = layer_prepostprocess_dropout(x)

    # Decoder
    with tf.variable_scope("decoder"):
      x = self._layer_stack(
          x,
          hparams.num_decoder_layers,
          encoder_output=encoder_output,
          self_attention_mask=decoder_self_attention_mask,
          encdec_attention_mask=encoder_decoder_attention_mask,
          losses=extra_losses)
    logits = mtf.matmul(x, softmax_var)
    off_value = hparams.label_smoothing / self._targets_vocab_size
    on_value = 1.0 - hparams.label_smoothing + off_value
    soft_targets = mtf.one_hot(
        targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value,
        dtype=self.activation_dtype)
    loss = mtf_layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.targets_vocab_dim)
    weights = mtf_layers.weights_nonzero(
        targets, dtype=self.activation_dtype)
    loss = mtf.reduce_mean(loss * weights)
    for l in extra_losses:
      loss += l
    return logits, loss
Пример #4
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.set_activation_type()
        is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

        # Declare all the dimensions
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
        filter_h_dim = mtf.Dimension("filter_height", 7)
        filter_w_dim = mtf.Dimension("filter_width", 7)
        filters = mtf.Dimension("filters", hparams.filter_sizes[0])
        rows_dim = mtf.Dimension("rows_size", 32)
        cols_dim = mtf.Dimension("cols_size", 96)
        row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
        col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
        classes_dim = mtf.Dimension("classes", 10)
        one_channel_dim = mtf.Dimension("one_channel", 1)

        inputs = features["inputs"]
        x = mtf.import_tf_tensor(
            mesh,
            tf.reshape(inputs, [
                hparams.batch_size, hparams.row_blocks,
                hparams.rows_size // hparams.row_blocks, hparams.col_blocks,
                hparams.num_channels * hparams.cols_size // hparams.col_blocks,
                1
            ]),
            mtf.Shape([
                batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
                one_channel_dim
            ]))
        x = mtf.transpose(x, [
            batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
            one_channel_dim
        ])

        x = mtf.to_float(x)
        initial_filters = mtf.get_variable(
            mesh, "init_filters",
            mtf.Shape([filter_h_dim, filter_w_dim, one_channel_dim, filters]))
        x = mtf.conv2d_with_blocks(x,
                                   initial_filters,
                                   strides=[1, 1, 1, 1],
                                   padding="SAME",
                                   h_blocks_dim=None,
                                   w_blocks_dim=col_blocks_dim)

        x = batch_norm_relu(x, is_training)

        # Conv blocks
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_layers):
            layer_name = "block_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Residual block layer
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[0],
                                blocks=hparams.layer_sizes[0],
                                strides=[1, 1, 1, 1],
                                is_training=is_training,
                                name="block_layer1",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[1],
                                blocks=hparams.layer_sizes[1],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer2",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[2],
                                blocks=hparams.layer_sizes[2],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer3",
                                row_blocks_dim=None,
                                col_blocks_dim=None)

        # Calculate the logits and loss.
        out = x
        outputs = mtf_layers.dense(out,
                                   hidden_dim,
                                   reduced_dims=out.shape.dims[-5:],
                                   activation=mtf.relu,
                                   name="dense")

        # We assume fixed vocab size for targets
        labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [hparams.batch_size]),
                                      mtf.Shape([batch_dim]))

        logits = mtf_layers.dense(outputs, classes_dim, name="logits")
        soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, classes_dim)

        # Reshape logits so it doesn't break inside t2t.
        logits = mtf.reshape(
            logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
        loss = mtf.reduce_mean(loss)
        return logits, loss
Пример #5
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.set_activation_type()

        # We assume fixed vocab size for targets
        targets_vocab_size = self._problem_hparams.target_modality._vocab_size  # pylint: disable=protected-access
        targets = tf.to_int32(features["targets"])

        # Image preprocessing, reshape into a 1D sequence and shift right.
        length = hparams.img_len * hparams.img_len * hparams.num_channels
        targets = tf.reshape(targets, [hparams.batch_size, length])
        shifted_targets = common_layers.shift_right_2d(targets)

        # Declare all the dimensions
        model_dim = mtf.Dimension("d_model", hparams.hidden_size)
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        length_dim = mtf.Dimension("length", length)
        max_length_dim = mtf.Dimension("max_length", hparams.max_length)
        filter_dim = mtf.Dimension("d_ff", hparams.d_ff)
        kv_channels = mtf.Dimension("kv_channels", hparams.d_kv)
        heads = mtf.Dimension("heads", hparams.num_heads)

        def import_to_batch_by_length(x, name):
            return mtf.import_tf_tensor(mesh,
                                        x,
                                        mtf.Shape([batch_dim, length_dim]),
                                        name=name)

        def layer_prepostprocess_dropout(x):
            return mtf.dropout(x,
                               keep_prob=1.0 -
                               hparams.layer_prepostprocess_dropout,
                               noise_shape=mtf.Shape([batch_dim, model_dim]))

        targets = import_to_batch_by_length(targets, "targets")
        shifted_targets = import_to_batch_by_length(shifted_targets,
                                                    "shifted_targets")

        extra_losses = []

        # Create targets content and position embeddings.
        targets_vocab_size = 256 * hparams.num_channels
        targets_vocab_dim = mtf.Dimension("vocab", targets_vocab_size)
        outputs_vocab_dim = mtf.Dimension("output_vocab", 256)

        # Create embedding var for targets and positions and do a gather.
        targets_embedding_var = mtf.get_variable(
            mesh,
            "targets_embedding",
            mtf.Shape([targets_vocab_dim, model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)

        x = mtf.gather(targets_embedding_var, shifted_targets,
                       targets_vocab_dim)
        # Add positional embeddings
        x += mtf.reshape(
            self.create_positional_emb_2d(targets, max_length_dim, model_dim),
            [length_dim, model_dim])

        # If conditional and input is given, add the input embedding to the target.
        # TODO(nikip): Verify conditional.
        if self.has_input and not hparams.unconditional:
            vocab_size = hparams.num_classes
            inputs_vocab_dim = mtf.Dimension("vocab", vocab_size)
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = import_to_batch_by_length(inputs, "inputs")

            # Input embeddings
            inputs_embedding_var = mtf_layers.embedding(
                mesh,
                "input_embedding",
                mtf.Shape([inputs_vocab_dim, model_dim]),
                activation_dtype=activation_dtype)
            inputs_emb = mtf.gather(inputs_embedding_var, inputs,
                                    inputs_vocab_dim)
            x += inputs_emb

        # Image Transformer Decoder
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_decoder_layers):
            layer_name = "decoder_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Self attention layer
                x += layer_prepostprocess_dropout(
                    mtf_layers.masked_local_attention_1d(
                        mtf_layers.layer_norm(x,
                                              model_dim,
                                              name="layer_norm_self_att"),
                        None,
                        kv_channels,
                        heads,
                        block_length=hparams.block_length,
                        name="self_att"))
                # ffn layer
                x += layer_prepostprocess_dropout(
                    mtf_layers.dense_relu_dense(
                        mtf_layers.layer_norm(x,
                                              model_dim,
                                              name="layer_norm_ffn"),
                        filter_dim,
                        hparams.dropout,
                        dropout_broadcast_dims=[length_dim]))

        x = mtf_layers.layer_norm(x,
                                  model_dim,
                                  name="decoder_final_layer_norm")

        # Calculate the logits and loss.
        logits = mtf_layers.dense(x, outputs_vocab_dim, name="logits")
        soft_targets = mtf.one_hot(targets,
                                   outputs_vocab_dim,
                                   dtype=activation_dtype)
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, outputs_vocab_dim)

        loss = mtf.reduce_mean(loss)
        for l in extra_losses:
            loss += l
        return logits, loss
Пример #6
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a tf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    row_blocks_dim = mtf.Dimension("row_blocks", 4)
    col_blocks_dim = mtf.Dimension("col_blocks", 4)
    rows_dim = mtf.Dimension("rows_size", 7)
    cols_dim = mtf.Dimension("cols_size", 7)

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]),
        mtf.Shape([
            batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
            one_channel_dim
        ]))
    x = mtf.transpose(x, [
        batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
        one_channel_dim
    ])

    # add some convolutional layers to demonstrate that convolution works.
    fh_dim = mtf.Dimension("fh", 9)
    fw_dim = mtf.Dimension("fw", 9)
    filters1_dim = mtf.Dimension("filters1", 16)
    filters2_dim = mtf.Dimension("filters2", 16)
    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])

    f1 = mtf.relu(
        mtf.conv2d_with_blocks(x,
                               kernel1,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))
    f2 = mtf.relu(
        mtf.conv2d_with_blocks(f1,
                               kernel2,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))
    x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

    h1 = mtf_layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-4:],
                          activation=mtf.relu,
                          name="hidden1")
    h2 = mtf_layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2")
    logits = mtf_layers.dense(h2, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss