Пример #1
0
 def projection_shortcut(inputs, kernel):
     """Project identity branch."""
     inputs = mtf.conv2d_with_blocks(inputs,
                                     kernel,
                                     strides=strides,
                                     padding="SAME",
                                     h_blocks_dim=None,
                                     w_blocks_dim=col_blocks_dim)
     return batch_norm_relu(inputs, is_training, relu=False)
Пример #2
0
 def projection_shortcut(inputs, kernel):
   """Project identity branch."""
   inputs = mtf.conv2d_with_blocks(
       inputs,
       kernel,
       strides=strides,
       padding="SAME",
       h_blocks_dim=None, w_blocks_dim=col_blocks_dim)
   return batch_norm_relu(
       inputs, is_training, relu=False)
Пример #3
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    row_blocks_dim = mtf.Dimension("row_blocks", 4)
    col_blocks_dim = mtf.Dimension("col_blocks", 4)
    rows_dim = mtf.Dimension("rows_size", 7)
    cols_dim = mtf.Dimension("cols_size", 7)

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]),
        mtf.Shape([
            batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
            one_channel_dim
        ]))
    x = mtf.transpose(x, [
        batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
        one_channel_dim
    ])

    # add some convolutional layers to demonstrate that convolution works.
    fh_dim = mtf.Dimension("fh", 9)
    fw_dim = mtf.Dimension("fw", 9)
    filters1_dim = mtf.Dimension("filters1", 16)
    filters2_dim = mtf.Dimension("filters2", 16)
    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])

    f1 = mtf.relu(
        mtf.conv2d_with_blocks(x,
                               kernel1,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))
    f2 = mtf.relu(
        mtf.conv2d_with_blocks(f1,
                               kernel2,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))
    x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    #hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)
    h1 = mtf.layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-4:],
                          activation=mtf.relu,
                          name="hidden1")
    #h2 = mtf.layers.dense(
    #    h1, hidden_dim2,
    #    activation=mtf.relu, name="hidden2")
    logits = mtf.layers.dense(h1, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
Пример #4
0
def bottleneck_block(inputs,
                     filters,
                     is_training,
                     strides,
                     projection_shortcut=None,
                     row_blocks_dim=None,
                     col_blocks_dim=None):
    """Bottleneck block variant for residual networks with BN after convolutions.

  Args:
    inputs: a `mtf.Tensor` of shape
        `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`.
    filters: `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
    is_training: `bool` for whether the model is in training mode.
    strides: `int` block stride. If greater than 1, this block will ultimately
        downsample the input.
    projection_shortcut: `function` to use for projection shortcuts (typically
        a 1x1 convolution to match the filter dimensions). If None, no
        projection is used and the input is passed as unchanged through the
        shortcut connection.
    row_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis
    col_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis

  Returns:
    The output `Tensor` of the block.
  """
    shortcut = inputs

    filter_h_dim = mtf.Dimension("filter_height", 3)
    filter_w_dim = mtf.Dimension("filter_width", 3)
    one_h_dim = mtf.Dimension("filter_height", 1)
    one_w_dim = mtf.Dimension("filter_width", 1)

    if projection_shortcut is not None:
        filters_dim = mtf.Dimension("filtersp", filters)
        kernel = mtf.get_variable(
            inputs.mesh, "kernel",
            mtf.Shape(
                [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim]))
        shortcut = projection_shortcut(inputs, kernel)

    # First conv block
    filters1_dim = mtf.Dimension("filters1", filters)
    kernel1 = mtf.get_variable(
        inputs.mesh, "kernel1",
        mtf.Shape([one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    kernel1,
                                    strides=[1, 1, 1, 1],
                                    padding="SAME",
                                    h_blocks_dim=None,
                                    w_blocks_dim=col_blocks_dim)

    # TODO(nikip): Add Dropout?
    inputs = batch_norm_relu(inputs, is_training)

    # Second conv block
    filters2_dim = mtf.Dimension("filters2", filters)
    kernel2 = mtf.get_variable(
        inputs.mesh, "kernel2",
        mtf.Shape([filter_h_dim, filter_w_dim, filters1_dim, filters2_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    kernel2,
                                    strides=[1, 1, 1, 1],
                                    padding="SAME",
                                    h_blocks_dim=row_blocks_dim,
                                    w_blocks_dim=col_blocks_dim)

    inputs = batch_norm_relu(inputs, is_training)

    # Third wide conv filter block
    filters3_dim = mtf.Dimension("filters3", filters)
    filters3_kernel = mtf.get_variable(
        inputs.mesh, "wide_kernel",
        mtf.Shape([one_h_dim, one_w_dim, filters2_dim, filters3_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    filters3_kernel,
                                    strides,
                                    padding="SAME",
                                    h_blocks_dim=None,
                                    w_blocks_dim=col_blocks_dim)

    inputs = batch_norm_relu(inputs, is_training, relu=False)

    # TODO(nikip): Maybe add residual with a projection?
    return mtf.relu(inputs + mtf.rename_dimension(
        shortcut, shortcut.shape.dims[-1].name, inputs.shape.dims[-1].name))
Пример #5
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.set_activation_type()
        is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

        # Declare all the dimensions
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
        filter_h_dim = mtf.Dimension("filter_height", 7)
        filter_w_dim = mtf.Dimension("filter_width", 7)
        filters = mtf.Dimension("filters", hparams.filter_sizes[0])
        rows_dim = mtf.Dimension("rows_size", 32)
        cols_dim = mtf.Dimension("cols_size", 96)
        row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
        col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
        classes_dim = mtf.Dimension("classes", 10)
        one_channel_dim = mtf.Dimension("one_channel", 1)

        inputs = features["inputs"]
        x = mtf.import_tf_tensor(
            mesh,
            tf.reshape(inputs, [
                hparams.batch_size, hparams.row_blocks,
                hparams.rows_size // hparams.row_blocks, hparams.col_blocks,
                hparams.num_channels * hparams.cols_size // hparams.col_blocks,
                1
            ]),
            mtf.Shape([
                batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
                one_channel_dim
            ]))
        x = mtf.transpose(x, [
            batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
            one_channel_dim
        ])

        x = mtf.to_float(x)
        initial_filters = mtf.get_variable(
            mesh, "init_filters",
            mtf.Shape([filter_h_dim, filter_w_dim, one_channel_dim, filters]))
        x = mtf.conv2d_with_blocks(x,
                                   initial_filters,
                                   strides=[1, 1, 1, 1],
                                   padding="SAME",
                                   h_blocks_dim=None,
                                   w_blocks_dim=col_blocks_dim)

        x = batch_norm_relu(x, is_training)

        # Conv blocks
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_layers):
            layer_name = "block_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Residual block layer
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[0],
                                blocks=hparams.layer_sizes[0],
                                strides=[1, 1, 1, 1],
                                is_training=is_training,
                                name="block_layer1",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[1],
                                blocks=hparams.layer_sizes[1],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer2",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[2],
                                blocks=hparams.layer_sizes[2],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer3",
                                row_blocks_dim=None,
                                col_blocks_dim=None)

        # Calculate the logits and loss.
        out = x
        outputs = mtf.layers.dense(out,
                                   hidden_dim,
                                   reduced_dims=out.shape.dims[-5:],
                                   activation=mtf.relu,
                                   name="dense")

        # We assume fixed vocab size for targets
        labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [hparams.batch_size]),
                                      mtf.Shape([batch_dim]))

        logits = mtf.layers.dense(outputs, classes_dim, name="logits")
        soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, classes_dim)

        # Reshape logits so it doesn't break inside t2t.
        logits = mtf.reshape(
            logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
        loss = mtf.reduce_mean(loss)
        return logits, loss
Пример #6
0
def bottleneck_block(inputs,
                     filters,
                     is_training,
                     strides,
                     projection_shortcut=None,
                     row_blocks_dim=None,
                     col_blocks_dim=None):
  """Bottleneck block variant for residual networks with BN after convolutions.

  Args:
    inputs: a `mtf.Tensor` of shape
        `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`.
    filters: `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
    is_training: `bool` for whether the model is in training mode.
    strides: `int` block stride. If greater than 1, this block will ultimately
        downsample the input.
    projection_shortcut: `function` to use for projection shortcuts (typically
        a 1x1 convolution to match the filter dimensions). If None, no
        projection is used and the input is passed as unchanged through the
        shortcut connection.
    row_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis
    col_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis

  Returns:
    The output `Tensor` of the block.
  """
  shortcut = inputs

  filter_h_dim = mtf.Dimension("filter_height", 3)
  filter_w_dim = mtf.Dimension("filter_width", 3)
  one_h_dim = mtf.Dimension("filter_height", 1)
  one_w_dim = mtf.Dimension("filter_width", 1)

  if projection_shortcut is not None:
    filters_dim = mtf.Dimension("filtersp", filters)
    kernel = mtf.get_variable(
        inputs.mesh, "kernel", mtf.Shape(
            [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim]))
    shortcut = projection_shortcut(inputs, kernel)

  # First conv block
  filters1_dim = mtf.Dimension("filters1", filters)
  kernel1 = mtf.get_variable(
      inputs.mesh, "kernel1", mtf.Shape(
          [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim]))
  inputs = mtf.conv2d_with_blocks(
      inputs,
      kernel1,
      strides=[1, 1, 1, 1],
      padding="SAME",
      h_blocks_dim=None, w_blocks_dim=col_blocks_dim)

  # TODO(nikip): Add Dropout?
  inputs = batch_norm_relu(inputs, is_training)

  # Second conv block
  filters2_dim = mtf.Dimension("filters2", 4*filters)
  kernel2 = mtf.get_variable(
      inputs.mesh, "kernel2", mtf.Shape(
          [filter_h_dim, filter_w_dim, filters1_dim, filters2_dim]))
  inputs = mtf.conv2d_with_blocks(
      inputs,
      kernel2,
      strides=[1, 1, 1, 1],
      padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)

  inputs = batch_norm_relu(inputs, is_training)

  # Third wide conv filter block
  filters3_dim = mtf.Dimension("filters3", filters)
  filters3_kernel = mtf.get_variable(
      inputs.mesh, "wide_kernel", mtf.Shape(
          [one_h_dim, one_w_dim, filters2_dim, filters3_dim]))
  inputs = mtf.conv2d_with_blocks(
      inputs,
      filters3_kernel,
      strides,
      padding="SAME",
      h_blocks_dim=None, w_blocks_dim=col_blocks_dim)

  # TODO(nikip): Althought the original resnet code has this batch norm, in our
  # setup this is causing no gradients to be passed. Investigate further.
  # inputs = batch_norm_relu(inputs, is_training, relu=True)

  # TODO(nikip): Maybe add residual with a projection?
  return mtf.relu(
      shortcut + mtf.rename_dimension(
          inputs, inputs.shape.dims[-1].name, shortcut.shape.dims[-1].name))
Пример #7
0
  def mtf_model_fn(self, features, mesh):
    features = copy.copy(features)
    tf.logging.info("features = %s" % features)
    hparams = self._hparams
    activation_dtype = self.set_activation_type()
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

    # Declare all the dimensions
    batch_dim = mtf.Dimension("batch", hparams.batch_size)
    hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
    filter_h_dim = mtf.Dimension("filter_height", 7)
    filter_w_dim = mtf.Dimension("filter_width", 7)
    filters = mtf.Dimension("filters", hparams.filter_sizes[0])
    rows_dim = mtf.Dimension("rows_size", hparams.rows_size)
    cols_dim = mtf.Dimension("cols_size", hparams.cols_size)
    row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
    col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
    classes_dim = mtf.Dimension("classes", 10)
    channels_dim = mtf.Dimension("channels", 3)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    inputs = features["inputs"]
    x = mtf.import_tf_tensor(
        mesh, tf.reshape(inputs, [
            hparams.batch_size,
            hparams.row_blocks,
            hparams.rows_size // hparams.row_blocks,
            hparams.col_blocks,
            hparams.num_channels*hparams.cols_size // hparams.col_blocks,
            hparams.num_channels]),
        mtf.Shape(
            [batch_dim, row_blocks_dim, rows_dim,
             col_blocks_dim, cols_dim, channels_dim]))
    x = mtf.transpose(x, [batch_dim, row_blocks_dim, col_blocks_dim,
                          rows_dim, cols_dim, channels_dim])

    x = mtf.to_float(x)
    initial_filters = mtf.get_variable(
        mesh, "init_filters",
        mtf.Shape([filter_h_dim, filter_w_dim, channels_dim, filters]))
    x = mtf.conv2d_with_blocks(
        x,
        initial_filters,
        strides=[1, 1, 1, 1],
        padding="SAME",
        h_blocks_dim=None, w_blocks_dim=col_blocks_dim)

    x = batch_norm_relu(x, is_training)

    # Conv blocks
    # [block - strided block layer - strided block layer] x n
    for layer in range(hparams.num_layers):
      layer_name = "block_layer_%d" % layer
      with tf.variable_scope(layer_name):
        # Residual block layer
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[0],
            blocks=hparams.layer_sizes[0],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer1",
            row_blocks_dim=None,
            col_blocks_dim=None)
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[1],
            blocks=hparams.layer_sizes[1],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer2",
            row_blocks_dim=None,
            col_blocks_dim=None)
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[2],
            blocks=hparams.layer_sizes[2],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer3",
            row_blocks_dim=None,
            col_blocks_dim=None)

    # Calculate the logits and loss.
    out = x
    outputs = mtf.layers.dense(
        out, hidden_dim,
        reduced_dims=out.shape.dims[-5:],
        activation=mtf.relu, name="dense")

    # We assume fixed vocab size for targets
    labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
    labels = mtf.import_tf_tensor(
        mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim]))

    logits = mtf.layers.dense(outputs, classes_dim, name="logits")
    soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, classes_dim)

    # Reshape logits so it doesn't break inside t2t.
    logits = mtf.reshape(
        logits,
        mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
    loss = mtf.reduce_mean(loss)
    return logits, loss
Пример #8
0
def cifar_model(features, labels, mesh):
  """The model.

  Args:
    image: tf.Tensor with shape [batch, 32*32]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
  features = copy.copy(features)
  batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
  row_blocks_dim = mtf.Dimension("row_blocks", 4)
  col_blocks_dim = mtf.Dimension("col_blocks", 4)
  rows_dim = mtf.Dimension("rows_size", 8)
  cols_dim = mtf.Dimension("cols_size", 8)

  classes_dim = mtf.Dimension("classes", 10)
  one_channel_dim = mtf.Dimension("one_channel", 3)


  # image = features['input']
  # with tf.device('/cpu:0'):
  image = features['image']
  labels = features['label']

  image = bnorm(image)

  x = mtf.import_tf_tensor(
      mesh, tf.reshape(image, [FLAGS.batch_size, 4, 8, 4, 8, 3]),
      mtf.Shape(
          [batch_dim, row_blocks_dim, rows_dim,
           col_blocks_dim, cols_dim, one_channel_dim]))
  x = mtf.transpose(x, [
      batch_dim, row_blocks_dim, col_blocks_dim,
      rows_dim, cols_dim, one_channel_dim])

  # add some convolutional layers to demonstrate that convolution works.
  fh_dim = mtf.Dimension("fh", 7)
  fw_dim = mtf.Dimension("fw", 7)
  filters1_dim = mtf.Dimension("filters1", 32)
  filters2_dim = mtf.Dimension("filters2", 32)

  kernel1 = mtf.get_variable(
      mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim])
  kernel2 = mtf.get_variable(
      mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim])


  f1 = mtf.relu(mtf.conv2d_with_blocks(
      x, kernel1, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))


  f2 = mtf.relu(mtf.conv2d_with_blocks(
      f1, kernel2, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters3_dim = mtf.Dimension("filters3", 64)
  kernel3 = mtf.get_variable(
      mesh, "kernel3", [fh_dim, fw_dim, filters2_dim, filters3_dim])  

  f3 = mtf.relu(mtf.conv2d_with_blocks(
      f2, kernel3, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters4_dim = mtf.Dimension("filters4", 64)
  kernel4 = mtf.get_variable(
      mesh, "kernel4", [fh_dim, fw_dim, filters3_dim, filters4_dim])  

  f4 = mtf.relu(mtf.conv2d_with_blocks(
      f3, kernel4, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters5_dim = mtf.Dimension("filters5", 128)
  kernel5 = mtf.get_variable(
      mesh, "kernel5", [fh_dim, fw_dim, filters4_dim, filters5_dim])  

  f5 = mtf.relu(mtf.conv2d_with_blocks(
      f4, kernel5, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))    

  filters6_dim = mtf.Dimension("filters6", 128)
  kernel6 = mtf.get_variable(
      mesh, "kernel6", [fh_dim, fw_dim, filters5_dim, filters6_dim])  

  f6 = mtf.relu(mtf.conv2d_with_blocks(
      f5, kernel6, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters7_dim = mtf.Dimension("filters7", 128)
  kernel7 = mtf.get_variable(
      mesh, "kernel7", [fh_dim, fw_dim, filters6_dim, filters7_dim])  

  f7 = mtf.relu(mtf.conv2d_with_blocks(
      f6, kernel7, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters8_dim = mtf.Dimension("filters8", 128)
  kernel8 = mtf.get_variable(
      mesh, "kernel8", [fh_dim, fw_dim, filters7_dim, filters8_dim])  

  f8 = mtf.relu(mtf.conv2d_with_blocks(
      f7, kernel8, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters9_dim = mtf.Dimension("filters9", 128)
  kernel9 = mtf.get_variable(
      mesh, "kernel9", [fh_dim, fw_dim, filters8_dim, filters9_dim])  

  f9 = mtf.relu(mtf.conv2d_with_blocks(
      f8, kernel9, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters10_dim = mtf.Dimension("filters10", 128)
  kernel10 = mtf.get_variable(
      mesh, "kernel10", [fh_dim, fw_dim, filters9_dim, filters10_dim])  

  f10 = mtf.relu(mtf.conv2d_with_blocks(
      f9, kernel10, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))                              
 

  filters11_dim = mtf.Dimension("filters11", 256)
  kernel11 = mtf.get_variable(
      mesh, "kernel11", [fh_dim, fw_dim, filters10_dim, filters11_dim])  

  f11 = mtf.relu(mtf.conv2d_with_blocks(
      f10, kernel11, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters12_dim = mtf.Dimension("filters12", 256)
  kernel12 = mtf.get_variable(
      mesh, "kernel12", [fh_dim, fw_dim, filters11_dim, filters12_dim])  

  f12 = mtf.relu(mtf.conv2d_with_blocks(
      f11, kernel12, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))                                            
 

  filters13_dim = mtf.Dimension("filters13", 256)
  kernel13 = mtf.get_variable(
      mesh, "kernel13", [fh_dim, fw_dim, filters12_dim, filters13_dim])  

  f13 = mtf.relu(mtf.conv2d_with_blocks(
      f12, kernel13, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))     

  filters14_dim = mtf.Dimension("filters14", 256)
  kernel14 = mtf.get_variable(
      mesh, "kernel14", [fh_dim, fw_dim, filters13_dim, filters14_dim])  

  f14 = mtf.relu(mtf.conv2d_with_blocks(
      f13, kernel14, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))   

  filters15_dim = mtf.Dimension("filters15", 256)
  kernel15 = mtf.get_variable(
      mesh, "kernel15", [fh_dim, fw_dim, filters14_dim, filters15_dim])  

  f15 = mtf.relu(mtf.conv2d_with_blocks(
      f14, kernel15, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters16_dim = mtf.Dimension("filters16", 256)
  kernel16 = mtf.get_variable(
      mesh, "kernel16", [fh_dim, fw_dim, filters15_dim, filters16_dim])  
  f16 = mtf.relu(mtf.conv2d_with_blocks(
      f15, kernel16, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))  

  filters17_dim = mtf.Dimension("filters17", 256)
  kernel17 = mtf.get_variable(
      mesh, "kernel17", [fh_dim, fw_dim, filters16_dim, filters17_dim])  

  f17 = mtf.relu(mtf.conv2d_with_blocks(
      f16, kernel17, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) 

  filters18_dim = mtf.Dimension("filters18", 256)
  kernel18 = mtf.get_variable(
      mesh, "kernel18", [fh_dim, fw_dim, filters17_dim, filters18_dim])  

  f18 = mtf.relu(mtf.conv2d_with_blocks(
      f17, kernel18, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))        

  x = mtf.reduce_mean(f18, reduced_dim=filters18_dim)

  # add some fully-connected dense layers.
  hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
  hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

  h1 = mtf.layers.dense(
      x, hidden_dim1,
      reduced_dims=x.shape.dims[-4:],
      activation=mtf.relu, name="hidden1")
  h2 = mtf.layers.dense(
      h1, hidden_dim2,
      activation=mtf.relu, name="hidden2")

  hidden_dim3 = mtf.Dimension("hidden3", FLAGS.hidden_size)
  hidden_dim4 = mtf.Dimension("hidden4", FLAGS.hidden_size)
  hidden_dim5 = mtf.Dimension("hidden5", FLAGS.hidden_size)
  hidden_dim6 = mtf.Dimension("hidden6", FLAGS.hidden_size)
  hidden_dim7 = mtf.Dimension("hidden7", FLAGS.hidden_size)
  hidden_dim8 = mtf.Dimension("hidden8", FLAGS.hidden_size)

  h3 = mtf.layers.dense(
      h2, hidden_dim3,
      activation=mtf.relu, name="hidden3")

  h4 = mtf.layers.dense(
      h3, hidden_dim4,
      activation=mtf.relu, name="hidden4")

  h5 = mtf.layers.dense(
    h4, hidden_dim5,
    activation=mtf.relu, name="hidden5")

  h6 = mtf.layers.dense(
    h5, hidden_dim6,
    activation=mtf.relu, name="hidden6")

  h7 = mtf.layers.dense(
    h6, hidden_dim7,
    activation=mtf.relu, name="hidden7") 

  h8 = mtf.layers.dense(
    h7, hidden_dim8,
    activation=mtf.relu, name="hidden8")                        

  logits = mtf.layers.dense(h8, classes_dim, name="logits")
  
  if labels is None:
    loss = None
  else:
    labels = mtf.import_tf_tensor(
        mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim]))
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, mtf.one_hot(labels, classes_dim), classes_dim)
    loss = mtf.reduce_mean(loss)
  return logits, loss