Ejemplo n.º 1
0
def make_dense_layers(x, classes_dim):

    dense_dim1 = mtf.Dimension(name="dense_dim1", size=4096)
    dense_dim2 = mtf.Dimension(name="dense_dim2", size=4096)
    x = mtf.layers.dense(x, dense_dim1, name="dense-0")
    x = mtf.relu(x, name="relu-dense-0")
    x = mtf.layers.dense(x, dense_dim2, name="dense-1")
    x = mtf.relu(x, name="relu-dense-1")
    x = mtf.layers.dense(x, classes_dim, name="dense-2")
    return x
Ejemplo n.º 2
0
def BasicBlock(x, order, out_channels, strides):
    name = "BasicBlock"
    expansion = 1
    out_chls = out_channels // expansion
    identity = x

    x = mtf.layers.conv2d(x,
                          output_dim=mtf.Dimension(
                              name=name + '-' + str(order) + '-' + 'filters1',
                              size=out_chls),
                          filter_size=(3, 3),
                          strides=strides,
                          name="conv3x3_BB_1" + '-' + str(order),
                          variable_dtype=float16)
    print(x.name)
    print(x.dtype)
    x, _ = mtf.layers.batch_norm(x,
                                 is_training=True,
                                 momentum=0.99,
                                 epsilon=1e-5,
                                 name="batch_norm_BB_1" + '-' + str(order))
    x = mtf.relu(x, name="relu_BB_1" + '-' + str(order))

    x = mtf.layers.conv2d(x,
                          output_dim=mtf.Dimension(
                              name=name + '-' + str(order) + '-' + 'filters2',
                              size=out_channels),
                          filter_size=(3, 3),
                          strides=(1, 1),
                          name="conv3x3_BB_2" + '-' + str(order),
                          variable_dtype=float16)
    print(x.name)
    print(x.dtype)
    x, _ = mtf.layers.batch_norm(x,
                                 is_training=True,
                                 momentum=0.99,
                                 epsilon=1e-5,
                                 name="batch_norm_BB_2" + '-' + str(order))
    identity = mtf.reshape(identity,
                           new_shape=[
                               identity.shape.dims[0], identity.shape.dims[1],
                               identity.shape.dims[2], x.shape.dims[3]
                           ],
                           name="reshape_BB" + str(order))

    x = mtf.add(x,
                identity,
                output_shape=x.shape,
                name="add_BB_1" + '-' + str(order))
    x = mtf.relu(x, name="relu_BB_2" + '-' + str(order))
    print(x.name)
    print(x.dtype)
    return x
Ejemplo n.º 3
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a tf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    rows_dim = mtf.Dimension("rows", 28)
    cols_dim = mtf.Dimension("cols", 28)
    classes_dim = mtf.Dimension("classes", 10)

    x = mtf.import_tf_tensor(mesh, tf.reshape(image,
                                              [FLAGS.batch_size, 28, 28]),
                             [batch_dim, rows_dim, cols_dim])
    y = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]),
                             [batch_dim])

    w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, classes_dim])
    b1 = mtf.get_variable(mesh, "b1", [classes_dim])

    logits = mtf.relu(mtf.einsum([x, w1], [batch_dim, classes_dim]) + b1)

    if labels is None:
        loss = None
    else:
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(y, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
Ejemplo n.º 4
0
def make_conv_layers(x, mode, batch_norm=True):
    maxpool_count = 0
    conv2d_count = 0
    for size in mode:
        if size == "M":
            x = mtf.layers.max_pool2d(x,
                                      ksize=(2, 2),
                                      name='maxpool' + '-' +
                                      str(maxpool_count))
            maxpool_count += 1
        else:
            x = mtf.layers.conv2d(x,
                                  output_dim=mtf.Dimension(
                                      name='filters' + '-' + str(conv2d_count),
                                      size=size),
                                  filter_size=(3, 3),
                                  strides=(1, 1),
                                  name="conv3" + '-' + str(conv2d_count))

            if batch_norm:
                x, _ = mtf.layers.batch_norm(x,
                                             is_training=True,
                                             momentum=0.99,
                                             epsilon=1e-5,
                                             name="batch_norm" + '-' +
                                             str(conv2d_count))
            x = mtf.relu(x, name="relu-conv" + '-' + str(conv2d_count))
            conv2d_count += 1
    return x
Ejemplo n.º 5
0
def make_dense_layers(x, classes_dim, float16=None):
    
    
    dense_dim1 = mtf.Dimension(name="dense_dim1",size=4096)
    dense_dim2 = mtf.Dimension(name="dense_dim2",size=4096)
    x = mtf.layers.dense(x, dense_dim1, name="dense-0",reduced_dims=x.shape.dims[-3:],variable_dtype=float16)
    logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape))
    x = mtf.relu(x,name="relu-dense-0")
    logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape))
    x = mtf.layers.dense(x, dense_dim2, name="dense-1",variable_dtype=float16)
    logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape))
    x = mtf.relu(x,name="relu-dense-1")
    logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape))
    x = mtf.layers.dense(x, classes_dim, name="dense-2",variable_dtype=float16)
    logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape))
    return x
Ejemplo n.º 6
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """

    # tf_images is a tf.Tensor with shape [batch, 28, 28] and dtype tf.float32
    # tf_labels is a tf.Tensor with shape [batch] and dtype tf.int32
    batch_dim = mtf.Dimension("batch", 100)
    rows_dim = mtf.Dimension("rows", 28)
    cols_dim = mtf.Dimension("cols", 28)
    hidden_dim = mtf.Dimension("hidden", 1024)
    classes_dim = mtf.Dimension("classes", 10)
    images = mtf.import_tf_tensor(mesh,
                                  image,
                                  shape=[batch_dim, rows_dim, cols_dim])
    labels = mtf.import_tf_tensor(mesh, labels, [batch_dim])
    w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, hidden_dim])
    w2 = mtf.get_variable(mesh, "w2", [hidden_dim, classes_dim])
    # einsum is a generalization of matrix multiplication (see numpy.einsum)
    hidden = mtf.relu(
        mtf.einsum(images, w1, output_shape=[batch_dim, hidden_dim]))
    logits = mtf.einsum(hidden, w2, output_shape=[batch_dim, classes_dim])
    loss = mtf.reduce_mean(
        mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim))

    return logits, loss
Ejemplo n.º 7
0
def entmax_forward(x, alpha=1.3, dim=None, n_iter=50):
    assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2'

    _gp = lambda x, alpha: x ** (alpha - 1)
    _gp_inv = lambda x, alpha: mtf.pow(x, (1 / (alpha - 1)))
    _p = lambda x, alpha: _gp_inv(mtf.relu(x), alpha)

    dim = x.shape[-1] if dim is None else dim
    d = dim.size

    x = x * (alpha - 1)

    max_val = mtf.reduce_max(x, reduced_dim=dim)

    tau_lo = max_val - _gp(1, alpha)
    tau_hi = max_val - _gp(1 / d, alpha)

    f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1

    dm = tau_hi - tau_lo

    for _ in range(n_iter):
        dm = dm / 2
        tau_m = tau_lo + dm
        p_m = _p(x - tau_m, alpha)
        f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1

        mask = mtf.greater_equal((f_m * f_lo), 0)
        tau_lo = mtf.where(mask, tau_m, tau_lo)

    p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim)
    return p_m
Ejemplo n.º 8
0
def batch_norm_relu(inputs, is_training, relu=True):
    """Block of batch norm and relu."""
    inputs = mtf.layers.batch_norm(inputs,
                                   is_training,
                                   BATCH_NORM_DECAY,
                                   epsilon=BATCH_NORM_EPSILON)
    if relu:
        inputs = mtf.relu(inputs)
    return inputs
Ejemplo n.º 9
0
def batch_norm_relu(inputs, is_training, relu=True):
  """Block of batch norm and relu."""
  inputs = mtf.layers.batch_norm(
      inputs,
      is_training,
      BATCH_NORM_DECAY,
      epsilon=BATCH_NORM_EPSILON,
      init_zero=(not relu))
  if relu:
    inputs = mtf.relu(inputs)
  return inputs
Ejemplo n.º 10
0
def make_conv_layers(x, mode, batch_norm=True, float16=None):
    maxpool_count = 0
    conv2d_count = 0
    for size in mode:
        if size == "M":
            x = mtf.layers.max_pool2d(
                                        x,
                                        ksize=(2,2),
                                        name='maxpool'+'-'+str(maxpool_count)
                                        )
            logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape))
            maxpool_count += 1
        else:
            x = conv2d(
                            x,
                            output_dim=mtf.Dimension(
                                                        name='filters'+'-'+str(conv2d_count),
                                                        size=size
                                                        ),
                            filter_size=(3,3),
                            strides=(1,1),
                            name="conv3"+'-'+str(conv2d_count),
                            variable_dtype=float16
                            )
            logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape))
            if batch_norm:
                x,_ = mtf.layers.batch_norm(
                                x,
                                is_training=True,
                                momentum=0.99,
                                epsilon=1e-5,
                                name="batch_norm"+'-'+str(conv2d_count)
                                )
                logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape))
            x = mtf.relu(x,name="relu-conv"+'-'+str(conv2d_count))
            logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape))
            conv2d_count += 1
    return x
Ejemplo n.º 11
0
def cifar_model(features, labels, mesh):
  """The model.

  Args:
    image: tf.Tensor with shape [batch, 32*32]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
  features = copy.copy(features)
  batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
  row_blocks_dim = mtf.Dimension("row_blocks", 4)
  col_blocks_dim = mtf.Dimension("col_blocks", 4)
  rows_dim = mtf.Dimension("rows_size", 8)
  cols_dim = mtf.Dimension("cols_size", 8)

  classes_dim = mtf.Dimension("classes", 10)
  one_channel_dim = mtf.Dimension("one_channel", 3)


  # image = features['input']
  # with tf.device('/cpu:0'):
  image = features['image']
  labels = features['label']

  image = bnorm(image)

  x = mtf.import_tf_tensor(
      mesh, tf.reshape(image, [FLAGS.batch_size, 4, 8, 4, 8, 3]),
      mtf.Shape(
          [batch_dim, row_blocks_dim, rows_dim,
           col_blocks_dim, cols_dim, one_channel_dim]))
  x = mtf.transpose(x, [
      batch_dim, row_blocks_dim, col_blocks_dim,
      rows_dim, cols_dim, one_channel_dim])

  # add some convolutional layers to demonstrate that convolution works.
  fh_dim = mtf.Dimension("fh", 7)
  fw_dim = mtf.Dimension("fw", 7)
  filters1_dim = mtf.Dimension("filters1", 32)
  filters2_dim = mtf.Dimension("filters2", 32)

  kernel1 = mtf.get_variable(
      mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim])
  kernel2 = mtf.get_variable(
      mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim])


  f1 = mtf.relu(mtf.conv2d_with_blocks(
      x, kernel1, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))


  f2 = mtf.relu(mtf.conv2d_with_blocks(
      f1, kernel2, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters3_dim = mtf.Dimension("filters3", 64)
  kernel3 = mtf.get_variable(
      mesh, "kernel3", [fh_dim, fw_dim, filters2_dim, filters3_dim])  

  f3 = mtf.relu(mtf.conv2d_with_blocks(
      f2, kernel3, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters4_dim = mtf.Dimension("filters4", 64)
  kernel4 = mtf.get_variable(
      mesh, "kernel4", [fh_dim, fw_dim, filters3_dim, filters4_dim])  

  f4 = mtf.relu(mtf.conv2d_with_blocks(
      f3, kernel4, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters5_dim = mtf.Dimension("filters5", 128)
  kernel5 = mtf.get_variable(
      mesh, "kernel5", [fh_dim, fw_dim, filters4_dim, filters5_dim])  

  f5 = mtf.relu(mtf.conv2d_with_blocks(
      f4, kernel5, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))    

  filters6_dim = mtf.Dimension("filters6", 128)
  kernel6 = mtf.get_variable(
      mesh, "kernel6", [fh_dim, fw_dim, filters5_dim, filters6_dim])  

  f6 = mtf.relu(mtf.conv2d_with_blocks(
      f5, kernel6, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters7_dim = mtf.Dimension("filters7", 128)
  kernel7 = mtf.get_variable(
      mesh, "kernel7", [fh_dim, fw_dim, filters6_dim, filters7_dim])  

  f7 = mtf.relu(mtf.conv2d_with_blocks(
      f6, kernel7, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters8_dim = mtf.Dimension("filters8", 128)
  kernel8 = mtf.get_variable(
      mesh, "kernel8", [fh_dim, fw_dim, filters7_dim, filters8_dim])  

  f8 = mtf.relu(mtf.conv2d_with_blocks(
      f7, kernel8, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters9_dim = mtf.Dimension("filters9", 128)
  kernel9 = mtf.get_variable(
      mesh, "kernel9", [fh_dim, fw_dim, filters8_dim, filters9_dim])  

  f9 = mtf.relu(mtf.conv2d_with_blocks(
      f8, kernel9, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters10_dim = mtf.Dimension("filters10", 128)
  kernel10 = mtf.get_variable(
      mesh, "kernel10", [fh_dim, fw_dim, filters9_dim, filters10_dim])  

  f10 = mtf.relu(mtf.conv2d_with_blocks(
      f9, kernel10, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))                              
 

  filters11_dim = mtf.Dimension("filters11", 256)
  kernel11 = mtf.get_variable(
      mesh, "kernel11", [fh_dim, fw_dim, filters10_dim, filters11_dim])  

  f11 = mtf.relu(mtf.conv2d_with_blocks(
      f10, kernel11, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters12_dim = mtf.Dimension("filters12", 256)
  kernel12 = mtf.get_variable(
      mesh, "kernel12", [fh_dim, fw_dim, filters11_dim, filters12_dim])  

  f12 = mtf.relu(mtf.conv2d_with_blocks(
      f11, kernel12, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))                                            
 

  filters13_dim = mtf.Dimension("filters13", 256)
  kernel13 = mtf.get_variable(
      mesh, "kernel13", [fh_dim, fw_dim, filters12_dim, filters13_dim])  

  f13 = mtf.relu(mtf.conv2d_with_blocks(
      f12, kernel13, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))     

  filters14_dim = mtf.Dimension("filters14", 256)
  kernel14 = mtf.get_variable(
      mesh, "kernel14", [fh_dim, fw_dim, filters13_dim, filters14_dim])  

  f14 = mtf.relu(mtf.conv2d_with_blocks(
      f13, kernel14, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))   

  filters15_dim = mtf.Dimension("filters15", 256)
  kernel15 = mtf.get_variable(
      mesh, "kernel15", [fh_dim, fw_dim, filters14_dim, filters15_dim])  

  f15 = mtf.relu(mtf.conv2d_with_blocks(
      f14, kernel15, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters16_dim = mtf.Dimension("filters16", 256)
  kernel16 = mtf.get_variable(
      mesh, "kernel16", [fh_dim, fw_dim, filters15_dim, filters16_dim])  
  f16 = mtf.relu(mtf.conv2d_with_blocks(
      f15, kernel16, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))  

  filters17_dim = mtf.Dimension("filters17", 256)
  kernel17 = mtf.get_variable(
      mesh, "kernel17", [fh_dim, fw_dim, filters16_dim, filters17_dim])  

  f17 = mtf.relu(mtf.conv2d_with_blocks(
      f16, kernel17, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) 

  filters18_dim = mtf.Dimension("filters18", 256)
  kernel18 = mtf.get_variable(
      mesh, "kernel18", [fh_dim, fw_dim, filters17_dim, filters18_dim])  

  f18 = mtf.relu(mtf.conv2d_with_blocks(
      f17, kernel18, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))        

  x = mtf.reduce_mean(f18, reduced_dim=filters18_dim)

  # add some fully-connected dense layers.
  hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
  hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

  h1 = mtf.layers.dense(
      x, hidden_dim1,
      reduced_dims=x.shape.dims[-4:],
      activation=mtf.relu, name="hidden1")
  h2 = mtf.layers.dense(
      h1, hidden_dim2,
      activation=mtf.relu, name="hidden2")

  hidden_dim3 = mtf.Dimension("hidden3", FLAGS.hidden_size)
  hidden_dim4 = mtf.Dimension("hidden4", FLAGS.hidden_size)
  hidden_dim5 = mtf.Dimension("hidden5", FLAGS.hidden_size)
  hidden_dim6 = mtf.Dimension("hidden6", FLAGS.hidden_size)
  hidden_dim7 = mtf.Dimension("hidden7", FLAGS.hidden_size)
  hidden_dim8 = mtf.Dimension("hidden8", FLAGS.hidden_size)

  h3 = mtf.layers.dense(
      h2, hidden_dim3,
      activation=mtf.relu, name="hidden3")

  h4 = mtf.layers.dense(
      h3, hidden_dim4,
      activation=mtf.relu, name="hidden4")

  h5 = mtf.layers.dense(
    h4, hidden_dim5,
    activation=mtf.relu, name="hidden5")

  h6 = mtf.layers.dense(
    h5, hidden_dim6,
    activation=mtf.relu, name="hidden6")

  h7 = mtf.layers.dense(
    h6, hidden_dim7,
    activation=mtf.relu, name="hidden7") 

  h8 = mtf.layers.dense(
    h7, hidden_dim8,
    activation=mtf.relu, name="hidden8")                        

  logits = mtf.layers.dense(h8, classes_dim, name="logits")
  
  if labels is None:
    loss = None
  else:
    labels = mtf.import_tf_tensor(
        mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim]))
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, mtf.one_hot(labels, classes_dim), classes_dim)
    loss = mtf.reduce_mean(loss)
  return logits, loss
Ejemplo n.º 12
0
def synthetic_attention(q,
                        k,
                        v,
                        memory_length_dim,
                        key_dim,
                        value_dim,
                        bias=None,
                        dropout_rate=0.0,
                        dropout_broadcast_dims=None,
                        extra_logit=None,
                        synthesize=True,
                        synthesize_mode="random_plus_alpha",
                        factorized_dim=16,
                        max_length=512,
                        context=None):
  """Synthetic Attention from Synthesizers (https://arxiv.org/abs/2005.00743).

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor
    synthesize: flag to use synthetic attention or not
    synthesize_mode: which variant of synthesizer to use
    factorized_dim: factorized dim for synthesizers
    max_length: max length of input sequence
    context: context since we need context mode

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """

  if synthesize:
    num_heads = v.shape.get_dim_by_name("heads")
    tf.logging.info("Using synthesizer")
    if synthesize_mode == "random":
      tf.logging.info("Using Random Synthesizers")
      r_shape = mtf.Shape([mtf.Dimension("length", max_length),
                           mtf.Dimension("heads", num_heads.size),
                           mtf.Dimension("memory_length", max_length)])
      r = mtf.get_variable(context.mesh, "R", r_shape,
                           initializer=None,
                           dtype=context.variable_dtype)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length"))
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, "length")
      logits = r
      r_shape = logits.shape
    elif synthesize_mode == "factorized":
      tf.logging.info("Using Factorized Random Synthesizers")
      k = factorized_dim
      r1_shape = mtf.Shape([mtf.Dimension("tmp", k),
                            mtf.Dimension("heads", num_heads.size),
                            mtf.Dimension("memory_length", 512)])
      r2_shape = mtf.Shape([mtf.Dimension("tmp", k),
                            mtf.Dimension("heads", num_heads.size),
                            mtf.Dimension("memory_length", 512)])
      r_shape = mtf.Shape([mtf.Dimension("length", 512),
                           mtf.Dimension("heads", num_heads.size),
                           mtf.Dimension("memory_length", 512)])
      r1 = mtf.get_variable(context.mesh, "R1", r1_shape,
                            initializer=None,
                            dtype=context.variable_dtype)
      r2 = mtf.get_variable(context.mesh, "R2", r2_shape,
                            initializer=None,
                            dtype=context.variable_dtype)
      r = mtf.einsum([r1, r2], r_shape)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length"))
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, "length")
      logits = r
    elif synthesize_mode == "dense_minus":
      # Dense Synthesizer Model
      tmp_dim = mtf.Dimension("memory_length", max_length)
      logits = mtf.layers.dense(mtf.relu(q), [tmp_dim],
                                use_bias=False,
                                name="pi",
                                reduced_dims=[key_dim],
                                variable_dtype=None)
      logits = mtf.slice(logits, 0, memory_length_dim.size,
                         memory_length_dim.name)
      if context.mode == "incremental":
        pass
      else:
        length_dim = q.shape.get_dim_by_name("length")
        logits = mtf.slice(logits, 0, length_dim.size, "length")
    elif synthesize_mode == "random_plus_alpha" or \
        synthesize_mode == "random_plus":
      # Mixture Random Synthesizer with learnable Alpha
      tf.logging.info("Using Random Plus Alpha")
      logits = mtf.einsum([q, k], reduced_dims=[key_dim])
      num_heads = logits.shape.get_dim_by_name("heads")
      r_shape = mtf.Shape([mtf.Dimension("length", 512),
                           mtf.Dimension("heads", num_heads.size),
                           mtf.Dimension("memory_length", 512)])
      r = mtf.get_variable(context.mesh, "R", r_shape,
                           initializer=None,
                           dtype=context.variable_dtype)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length"))
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, length_dim.name)
      if "alpha" in synthesize_mode:
        alpha = mtf.get_variable(context.mesh,
                                 "alpha",
                                 mtf.Shape([mtf.Dimension("alpha", 1)]),
                                 initializer=tf.zeros_initializer(),
                                 dtype=context.variable_dtype)
        alpha = mtf.sigmoid(alpha)
        logits = ((1-alpha) * logits) + (alpha * r)
      else:
        logits = logits + r
    elif synthesize_mode == "dense_plus_alpha" or \
        synthesize_mode == "dense_plus":
      # Mixture Dense Synthesizer with learnable alpha
      tf.logging.info("Using Dense Plus Alpha Scaling")
      logits = mtf.einsum([q, k], reduced_dims=[key_dim])
      tmp_dim = mtf.Dimension("memory_length", 512)
      r = mtf.layers.dense(mtf.relu(q), [tmp_dim],
                           use_bias=False,
                           name="pi",
                           reduced_dims=[key_dim],
                           variable_dtype=None)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        pass
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, "length")
      if "alpha" in synthesize_mode:
        alpha = mtf.get_variable(context.mesh,
                                 "alpha",
                                 mtf.Shape([mtf.Dimension("alpha", 1)]),
                                 initializer=tf.zeros_initializer(),
                                 dtype=context.variable_dtype)
        alpha = mtf.sigmoid(alpha)
        logits = ((1-alpha) * logits) + (alpha * r)
      else:
        logits = logits + r
  if bias is not None:
    logits += bias

  weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit)
  weights = mtf.dropout(
      weights, context.train, 1.0 - dropout_rate,
      noise_shape=weights.shape - dropout_broadcast_dims)

  if synthesize and "plus" not in synthesize_mode:
    if synthesize_mode == "dense_minus":
      outputs_shape = mtf.Shape(q.shape.dims[:-1] + [value_dim])
    else:
      outputs_shape = mtf.Shape(q.shape.dims[:-1] + [num_heads, value_dim])
  else:
    outputs_shape = q.shape - [key_dim] + value_dim

  outputs = mtf.einsum([weights, v], outputs_shape)
  return outputs
Ejemplo n.º 13
0
def cifar_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    row_blocks_dim = mtf.Dimension("row_blocks", 1)
    col_blocks_dim = mtf.Dimension("col_blocks", 1)
    rows_dim = mtf.Dimension("rows_size", 32)
    cols_dim = mtf.Dimension("cols_size", 32)
    init = 60

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 3)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 1, 32, 1, 32, 3]),
        mtf.Shape([
            batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
            one_channel_dim
        ]))
    x = mtf.transpose(x, [
        batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
        one_channel_dim
    ])

    # add some convolutional layers to demonstrate that convolution works.
    filters1_dim = mtf.Dimension("filters1", init)
    filters2_dim = mtf.Dimension("filters2", init)
    f1 = mtf.relu(
        mtf.layers.conv2d_with_blocks(x,
                                      filters1_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv0"))
    #print("conv:, ", f1.shape)

    f2 = mtf.relu(
        mtf.layers.conv2d_with_blocks(f1,
                                      filters2_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv1"))

    x = mtf.layers.max_pool2d(f2, ksize=(2, 2), name="maxpool0")

    #print(x.shape)

    filters3_dim = mtf.Dimension("filters3", init * 2)
    filters4_dim = mtf.Dimension("filters4", init * 2)

    f3 = mtf.relu(
        mtf.layers.conv2d_with_blocks(x,
                                      filters3_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv2"))
    f4 = mtf.relu(
        mtf.layers.conv2d_with_blocks(f3,
                                      filters4_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv3"))

    x = mtf.layers.max_pool2d(f4, ksize=(2, 2), name="maxpool1")

    #print(x.shape)
    filters5_dim = mtf.Dimension("filters5", init * 4)
    filters6_dim = mtf.Dimension("filters6", init * 4)

    f5 = mtf.relu(
        mtf.layers.conv2d_with_blocks(x,
                                      filters5_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv4"))
    f6 = mtf.relu(
        mtf.layers.conv2d_with_blocks(f5,
                                      filters6_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv5"))

    x = mtf.layers.max_pool2d(f6, ksize=(2, 2), name="maxpool2")
    #print(x.shape)

    filters7_dim = mtf.Dimension("filters7", init * 8)
    filters8_dim = mtf.Dimension("filters8", init * 8)

    f7 = mtf.relu(
        mtf.layers.conv2d_with_blocks(x,
                                      filters7_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv6"))
    f8 = mtf.relu(
        mtf.layers.conv2d_with_blocks(f7,
                                      filters8_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv7"))

    x = mtf.layers.max_pool2d(f8, ksize=(2, 2), name="maxpool3")
    #  x = mtf.reduce_mean(f8, reduced_dim=filters8_dim)
    # add some fully-connected dense layers.
    #hidden_dim1 = mtf.Dimension("hidden1", init*8)
    hidden_dim1 = mtf.Dimension("hidden1", 256)
    hidden_dim2 = mtf.Dimension("hidden2", init * 8)

    h1 = mtf.layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-5:],
                          activation=mtf.relu,
                          name="hidden1")
    #h2 = mtf.layers.dense(
    #h1, hidden_dim2,
    #activation=mtf.relu, name="hidden2")
    logits = mtf.layers.dense(h1, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)

    all_filters = [[
        init, init, init * 2, init * 2, init * 4, init * 4, init * 8, init * 8
    ]]
    return logits, loss, all_filters
Ejemplo n.º 14
0
def bottleneck_block(inputs,
                     filters,
                     is_training,
                     strides,
                     projection_shortcut=None,
                     row_blocks_dim=None,
                     col_blocks_dim=None):
    """Bottleneck block variant for residual networks with BN after convolutions.

  Args:
    inputs: a `mtf.Tensor` of shape
        `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`.
    filters: `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
    is_training: `bool` for whether the model is in training mode.
    strides: `int` block stride. If greater than 1, this block will ultimately
        downsample the input.
    projection_shortcut: `function` to use for projection shortcuts (typically
        a 1x1 convolution to match the filter dimensions). If None, no
        projection is used and the input is passed as unchanged through the
        shortcut connection.
    row_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis
    col_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis

  Returns:
    The output `Tensor` of the block.
  """
    shortcut = inputs

    filter_h_dim = mtf.Dimension("filter_height", 3)
    filter_w_dim = mtf.Dimension("filter_width", 3)
    one_h_dim = mtf.Dimension("filter_height", 1)
    one_w_dim = mtf.Dimension("filter_width", 1)

    if projection_shortcut is not None:
        filters_dim = mtf.Dimension("filtersp", filters)
        kernel = mtf.get_variable(
            inputs.mesh, "kernel",
            mtf.Shape(
                [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim]))
        shortcut = projection_shortcut(inputs, kernel)

    # First conv block
    filters1_dim = mtf.Dimension("filters1", filters)
    kernel1 = mtf.get_variable(
        inputs.mesh, "kernel1",
        mtf.Shape([one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    kernel1,
                                    strides=[1, 1, 1, 1],
                                    padding="SAME",
                                    h_blocks_dim=None,
                                    w_blocks_dim=col_blocks_dim)

    # TODO(nikip): Add Dropout?
    inputs = batch_norm_relu(inputs, is_training)

    # Second conv block
    filters2_dim = mtf.Dimension("filters2", filters)
    kernel2 = mtf.get_variable(
        inputs.mesh, "kernel2",
        mtf.Shape([filter_h_dim, filter_w_dim, filters1_dim, filters2_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    kernel2,
                                    strides=[1, 1, 1, 1],
                                    padding="SAME",
                                    h_blocks_dim=row_blocks_dim,
                                    w_blocks_dim=col_blocks_dim)

    inputs = batch_norm_relu(inputs, is_training)

    # Third wide conv filter block
    filters3_dim = mtf.Dimension("filters3", filters)
    filters3_kernel = mtf.get_variable(
        inputs.mesh, "wide_kernel",
        mtf.Shape([one_h_dim, one_w_dim, filters2_dim, filters3_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    filters3_kernel,
                                    strides,
                                    padding="SAME",
                                    h_blocks_dim=None,
                                    w_blocks_dim=col_blocks_dim)

    inputs = batch_norm_relu(inputs, is_training, relu=False)

    # TODO(nikip): Maybe add residual with a projection?
    return mtf.relu(inputs + mtf.rename_dimension(
        shortcut, shortcut.shape.dims[-1].name, inputs.shape.dims[-1].name))
Ejemplo n.º 15
0
def mnist_model(image, labels, mesh):
	"""The model.
	Args:
		image: tf.Tensor with shape [batch, 28*28]
		labels: a tf.Tensor with shape [batch] and dtype tf.int32
		mesh: a mtf.Mesh
	Returns:
		logits: a mtf.Tensor with shape [batch, 10]
		loss: a mtf.Tensor with shape []
	"""
	batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
	row_blocks_dim = mtf.Dimension("row_blocks", 4)
	col_blocks_dim = mtf.Dimension("col_blocks", 4)
	rows_dim = mtf.Dimension("rows_size", 7)
	cols_dim = mtf.Dimension("cols_size", 7)

	classes_dim = mtf.Dimension("classes", 10)
	one_channel_dim = mtf.Dimension("one_channel", 1)

	x = mtf.import_tf_tensor(
		mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]),
		mtf.Shape(
			[batch_dim, row_blocks_dim, rows_dim,
			col_blocks_dim, cols_dim, one_channel_dim]))
	x = mtf.transpose(x, [
		batch_dim, row_blocks_dim, col_blocks_dim,
		rows_dim, cols_dim, one_channel_dim])
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(x.name,x.shape))
	# add some convolutional layers to demonstrate that convolution works.
	filters1_dim = mtf.Dimension("filters1", 16)
	filters2_dim = mtf.Dimension("filters2", 16)
	f1 = mtf.relu(mtf.layers.conv2d_with_blocks(
		x, filters1_dim, filter_size=[9, 9], strides=[1, 1], padding="SAME",
		h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv0"))
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(f1.name,f1.shape))
	f2 = mtf.relu(mtf.layers.conv2d_with_blocks(
		f1, filters2_dim, filter_size=[9, 9], strides=[1, 1], padding="SAME",
		h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv1"))
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(f2.name,f2.shape))
	x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(x.name,x.shape))
	# add some fully-connected dense layers.
	hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
	hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

	h1 = mtf.layers.dense(
		x, hidden_dim1,
		reduced_dims=x.shape.dims[-4:],
		activation=mtf.relu, name="hidden1")
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(h1.name,h1.shape))
	h2 = mtf.layers.dense(
		h1, hidden_dim2,
		activation=mtf.relu, name="hidden2")
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(h2.name,h2.shape))
	logits = mtf.layers.dense(h2, classes_dim, name="logits")
	if labels is None:
		loss = None
	else:
		labels = mtf.import_tf_tensor(
			mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim]))
		loss = mtf.layers.softmax_cross_entropy_with_logits(
			logits, mtf.one_hot(labels, classes_dim), classes_dim)
		loss = mtf.reduce_mean(loss)
	return logits, loss
Ejemplo n.º 16
0
def bottleneck_block(inputs,
                     filters,
                     is_training,
                     strides,
                     projection_shortcut=None,
                     row_blocks_dim=None,
                     col_blocks_dim=None):
    """Bottleneck block variant for residual networks with BN after convolutions.

  Args:
    inputs: a `mtf.Tensor` of shape
        `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`.
    filters: `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
    is_training: `bool` for whether the model is in training mode.
    strides: `int` block stride. If greater than 1, this block will ultimately
        downsample the input.
    projection_shortcut: `function` to use for projection shortcuts (typically
        a 1x1 convolution to match the filter dimensions). If None, no
        projection is used and the input is passed as unchanged through the
        shortcut connection.
    row_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis
    col_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis

  Returns:
    The output `Tensor` of the block.
  """
    shortcut = inputs

    if projection_shortcut is not None:
        filters_dim = mtf.Dimension("filtersp", filters)
        shortcut = projection_shortcut(inputs, filters_dim)

    # First conv block
    inputs = mtf.layers.conv2d_with_blocks(inputs,
                                           mtf.Dimension("filters1", filters),
                                           filter_size=[1, 1],
                                           strides=[1, 1],
                                           padding="SAME",
                                           h_blocks_dim=None,
                                           w_blocks_dim=col_blocks_dim,
                                           name="conv0")

    # TODO(nikip): Add Dropout?
    inputs = batch_norm_relu(inputs, is_training)

    # Second conv block
    inputs = mtf.layers.conv2d_with_blocks(inputs,
                                           mtf.Dimension(
                                               "filters2", 4 * filters),
                                           filter_size=[3, 3],
                                           strides=[1, 1],
                                           padding="SAME",
                                           h_blocks_dim=row_blocks_dim,
                                           w_blocks_dim=col_blocks_dim,
                                           name="conv1")

    inputs = batch_norm_relu(inputs, is_training)

    # Third wide conv filter block
    inputs = mtf.layers.conv2d_with_blocks(inputs,
                                           mtf.Dimension("filters3", filters),
                                           filter_size=[1, 1],
                                           strides=strides,
                                           padding="SAME",
                                           h_blocks_dim=None,
                                           w_blocks_dim=col_blocks_dim,
                                           name="conv2")

    # TODO(nikip): Althought the original resnet code has this batch norm, in our
    # setup this is causing no gradients to be passed. Investigate further.
    # inputs = batch_norm_relu(inputs, is_training, relu=True)

    # TODO(nikip): Maybe add residual with a projection?
    return mtf.relu(shortcut + mtf.rename_dimension(
        inputs, inputs.shape.dims[-1].name, shortcut.shape.dims[-1].name))
Ejemplo n.º 17
0
def ResidualBlockWithDown(x,
                          order,
                          out_channels,
                          strides,
                          float16=None,
                          batch_norm=False):
    name = "ResidualBlockWithDown"
    expansion = 4
    out_chls = out_channels // expansion
    identity = x

    x = conv2d(x,
               output_dim=mtf.Dimension(name=name + '-' + str(order) + '-' +
                                        'filters1',
                                        size=out_chls),
               filter_size=(1, 1),
               strides=(1, 1),
               name="conv1x1_RBW_1" + '-' + str(order),
               variable_dtype=float16)
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    if batch_norm:
        x, _ = mtf.layers.batch_norm(x,
                                     is_training=True,
                                     momentum=0.99,
                                     epsilon=1e-5,
                                     name="batch_norm_RBW_1" + '-' +
                                     str(order))
        logger.debug("[output tensor] (name,shape):({},{})".format(
            x.name, x.shape))
    x = mtf.relu(x, name="relu_RBW_1" + '-' + str(order))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = conv2d(x,
               output_dim=mtf.Dimension(name=name + '-' + str(order) + '-' +
                                        'filters2',
                                        size=out_chls),
               filter_size=(3, 3),
               strides=strides,
               name="conv3x3_RBW_1" + '-' + str(order),
               variable_dtype=float16)
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    if batch_norm:
        x, _ = mtf.layers.batch_norm(x,
                                     is_training=True,
                                     momentum=0.99,
                                     epsilon=1e-5,
                                     name="batch_norm_RBW_2" + '-' +
                                     str(order))

        logger.debug("[output tensor] (name,shape):({},{})".format(
            x.name, x.shape))
    x = mtf.relu(x, name="relu_RBW_2" + '-' + str(order))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = conv2d(x,
               output_dim=mtf.Dimension(name=name + '-' + str(order) + '-' +
                                        'filters3',
                                        size=out_channels),
               filter_size=(1, 1),
               strides=(1, 1),
               name="conv1x1-2_RBW_2" + '-' + str(order),
               variable_dtype=float16)
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    if batch_norm:
        x, _ = mtf.layers.batch_norm(x,
                                     is_training=True,
                                     momentum=0.99,
                                     epsilon=1e-5,
                                     name="batch_norm_RBW_3" + '-' +
                                     str(order))
        logger.debug("[output tensor] (name,shape):({},{})".format(
            x.name, x.shape))
    identity = conv2d(identity,
                      output_dim=mtf.Dimension(name=name + '-' + str(order) +
                                               '-' + 'filters3',
                                               size=out_channels),
                      filter_size=(1, 1),
                      strides=strides,
                      name="conv1x1_RBW_3" + '-' + str(order),
                      variable_dtype=float16)
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    if batch_norm:
        identity, _ = mtf.layers.batch_norm(identity,
                                            is_training=True,
                                            momentum=0.99,
                                            epsilon=1e-5,
                                            name="batch_norm_RBW_4" + '-' +
                                            str(order))
        logger.debug("[output tensor] (name,shape):({},{})".format(
            x.name, x.shape))
    identity = mtf.reshape(identity,
                           new_shape=[
                               identity.shape.dims[0], identity.shape.dims[1],
                               identity.shape.dims[2], x.shape.dims[3]
                           ],
                           name="reshape_RBW" + str(order))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.add(x,
                identity,
                output_shape=x.shape,
                name="add_RBW_1" + '-' + str(order))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.relu(x, name="relu_RBW_3" + '-' + str(order))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    return x
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    rows_dim = mtf.Dimension("rows_size", 28)
    cols_dim = mtf.Dimension("cols_size", 28)

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 28, 28, 1]),
        mtf.Shape([batch_dim, rows_dim, cols_dim, one_channel_dim]))

    fh_dim = mtf.Dimension("fh", 3)
    fw_dim = mtf.Dimension("fw", 3)
    filters1_dim = mtf.Dimension("filters1", FLAGS.num_filters)
    filters2_dim = mtf.Dimension("filters2", FLAGS.num_filters)
    filters3_dim = mtf.Dimension("filters3", FLAGS.num_filters)
    filters4_dim = mtf.Dimension("filters4", FLAGS.num_filters)
    filters5_dim = mtf.Dimension("filters5", FLAGS.num_filters)
    filters6_dim = mtf.Dimension("filters6", FLAGS.num_filters)

    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])
    kernel3 = mtf.get_variable(mesh, "kernel3",
                               [fh_dim, fw_dim, filters2_dim, filters3_dim])
    kernel4 = mtf.get_variable(mesh, "kernel4",
                               [fh_dim, fw_dim, filters3_dim, filters4_dim])
    kernel5 = mtf.get_variable(mesh, "kernel5",
                               [fh_dim, fw_dim, filters4_dim, filters5_dim])
    kernel6 = mtf.get_variable(mesh, "kernel6",
                               [fh_dim, fw_dim, filters5_dim, filters6_dim])

    x = mtf.relu(mtf.conv2d(x, kernel1, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel2, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel3, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel4, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel5, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel6, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.reduce_mean(x, reduced_dim=filters6_dim)

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)
    logits = mtf.Dimension("logits", 10)
    h1 = mtf.layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-2:],
                          activation=mtf.relu,
                          name="hidden1")
    h2 = mtf.layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2")
    logits = mtf.layers.dense(h2, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
Ejemplo n.º 19
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    row_blocks_dim = mtf.Dimension("row_blocks", 4)
    col_blocks_dim = mtf.Dimension("col_blocks", 4)
    rows_dim = mtf.Dimension("rows_size", 7)
    cols_dim = mtf.Dimension("cols_size", 7)

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]),
        mtf.Shape([
            batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
            one_channel_dim
        ]))
    x = mtf.transpose(x, [
        batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
        one_channel_dim
    ])

    # add some convolutional layers to demonstrate that convolution works.
    fh_dim = mtf.Dimension("fh", 9)
    fw_dim = mtf.Dimension("fw", 9)
    filters1_dim = mtf.Dimension("filters1", 16)
    filters2_dim = mtf.Dimension("filters2", 16)
    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])

    f1 = mtf.relu(
        mtf.conv2d_with_blocks(x,
                               kernel1,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))
    f2 = mtf.relu(
        mtf.conv2d_with_blocks(f1,
                               kernel2,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))
    x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    #hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)
    h1 = mtf.layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-4:],
                          activation=mtf.relu,
                          name="hidden1")
    #h2 = mtf.layers.dense(
    #    h1, hidden_dim2,
    #    activation=mtf.relu, name="hidden2")
    logits = mtf.layers.dense(h1, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
Ejemplo n.º 20
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    row_blocks_dim = mtf.Dimension("row_blocks", 1)
    col_blocks_dim = mtf.Dimension("col_blocks", 1)
    rows_dim = mtf.Dimension("rows_size", 28)
    cols_dim = mtf.Dimension("cols_size", 28)
    init = 60

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 1, 28, 1, 28, 1]),
        mtf.Shape([
            batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
            one_channel_dim
        ]))
    x = mtf.transpose(x, [
        batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
        one_channel_dim
    ])

    # add some convolutional layers to demonstrate that convolution works.
    filters1_dim = mtf.Dimension("filters1", 60)
    f1 = mtf.relu(
        mtf.layers.conv2d_with_blocks(x,
                                      filters1_dim,
                                      filter_size=[7, 7],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv0"))

    # f1 = mtf.reshape(f1, [FLAGS.batch_size, 1, 30, 3, 10, 1])
    filters2_dim = mtf.Dimension("filters2", 120)
    f2 = mtf.relu(
        mtf.layers.conv2d_with_blocks(f1,
                                      filters2_dim,
                                      filter_size=[5, 5],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv1"))

    filters3_dim = mtf.Dimension("filters3", 240)
    f3 = mtf.relu(
        mtf.layers.conv2d_with_blocks(f2,
                                      filters3_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv2"))

    x = mtf.layers.avg_pool2d(f3, ksize=(2, 2), name="averagePool")

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", 128)
    print(x.shape)
    h1 = mtf.layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-5:],
                          activation=mtf.relu,
                          name="hidden1")
    #  h1=x
    #  print(h1.shape)

    logits = mtf.layers.dense(h1, classes_dim, name="logits")
    #  logits = mtf.layers.dense(h1, classes_dim, reduced_dims=x.shape.dims[-5:], name="logits")

    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)

    all_filters = [[init, init * 2, init * 4]]
    return logits, loss, all_filters
Ejemplo n.º 21
0
def Alexnet(img, labels, num_nodes, num_gpus, args):
    num_classes = 1000
    keep_prob = 0.5
    learning_rate = 0.01
    graph, meshes, mesh_to_impl, mtf_img, mtf_labels = CreateMeshes(
        img, labels, num_nodes, num_gpus, args)
    RenameFC = lambda x: mt.rename_dimension(x, x.shape[-1].name,
                                             utils.RandName())

    strategy = args.strategy
    if strategy == 0:
        fc6_units = mtf.Dimension(utils.RandName(), 4096)
        fc7_units = mtf.Dimension(utils.RandName(), 4096)
        fc8_units = mtf.Dimension(utils.RandName(), num_classes)

    elif strategy == 1:
        fc6_units = mtf.Dimension('axis1', 4096)
        fc7_units = mtf.Dimension('axis0', 4096)
        fc8_units = mtf.Dimension('axis1', num_classes)

    elif strategy == 2:
        num_classes = utils.RoundUp(num_classes, num_gpus)
        fc6_units = mtf.Dimension('axis0', 4096)
        fc7_units = mtf.Dimension('axis0', 4096)
        fc8_units = mtf.Dimension('axis0', num_classes)

    elif strategy == 3:
        num_classes = utils.RoundUp(num_classes, num_gpus // 2)
        fc6_units = mtf.Dimension('axis1', 4096)
        fc7_units = mtf.Dimension('axis1', 4096)
        fc8_units = mtf.Dimension('axis1', num_classes)

    with tf.variable_scope('alexnet'):
        # Conv1 + ReLU + maxpool1
        conv1 = mt.Conv2d(mtf_img,
                          GetFilterShape(mtf_img, (11, 11, 3, 96)), (4, 4),
                          'VALID',
                          activation=mtf.relu,
                          name='conv1')
        pool1 = mt.MaxPool(conv1, (3, 3), (2, 2), 'VALID', name='pool1')

        # Conv2 + ReLU + maxpool2
        conv2 = mt.Conv2d(pool1,
                          GetFilterShape(pool1, (5, 5, 96, 256)), (1, 1),
                          'SAME',
                          activation=mtf.relu,
                          name='conv2')
        pool2 = mt.MaxPool(conv2, (3, 3), (2, 2), name='pool2')

        # Conv3 + ReLU
        conv3 = mt.Conv2d(pool2,
                          GetFilterShape(pool2, (3, 3, 256, 384)),
                          padding='SAME',
                          activation=mtf.relu,
                          name='conv3')

        # Conv4 + ReLU
        conv4 = mt.Conv2d(conv3,
                          GetFilterShape(conv3, (3, 3, 384, 384)),
                          padding='SAME',
                          activation=mtf.relu,
                          name='conv4')

        # Conv5 + ReLU + maxpool5
        conv5 = mt.Conv2d(conv4,
                          GetFilterShape(conv4, (3, 3, 384, 256)),
                          padding='SAME',
                          activation=mtf.relu,
                          name='conv5')
        pool5 = mt.MaxPool(conv5, (3, 3), (2, 2), name='pool5')

        # Rename dims
        if strategy == 1:
            k_dim = mtf.Dimension(utils.RandName(),
                                  utils.Prod(pool5.shape.to_integer_list[1:]))
            pool5 = mtf.reshape(pool5, mtf.Shape([pool5.shape[0], k_dim]))
            pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1],
                                                   (utils.RandName(), 'axis0'))

        elif strategy == 2:
            pool5 = mt.rename_dimension(pool5, pool5.shape[0].name,
                                        utils.RandName())

        elif strategy == 3:
            assert pool5.shape[0].name == 'axis0'
            #dim_names = pool5.shape.rename_dimension('axis0', utils.RandName())
            #pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1], dim_names)
            pool5 = ReplaceMeshWithConcatSplit(pool5, meshes[1])

        # FC + ReLU + dropout
        fc_activation = lambda x: mtf.dropout(mtf.relu(x), keep_prob)
        fc6 = mtf.layers.dense(pool5,
                               fc6_units,
                               activation=fc_activation,
                               reduced_dims=pool5.shape[1:],
                               name='fc6')
        if strategy == 2:
            fc6 = RenameFC(fc6)
        elif strategy == 3:
            fc6 = RenameFC(fc6)

        fc7 = mtf.layers.dense(fc6,
                               fc7_units,
                               activation=fc_activation,
                               reduced_dims=fc6.shape.dims[-1:],
                               name='fc7')
        if strategy == 2:
            fc7 = RenameFC(fc7)
        elif strategy == 3:
            fc7 = RenameFC(fc7)

        fc8 = mtf.layers.dense(fc7,
                               fc8_units,
                               reduced_dims=fc7.shape.dims[-1:],
                               name='fc8')
        fc8 = mtf.dropout(fc8, keep_prob)

        if strategy == 1:
            assert fc8.shape[-1].name == 'axis1'
            fc8 = ReplaceMeshWithDuplicates(fc8, meshes[2])

    with tf.variable_scope('loss'):
        if fc8.shape[0] != mtf_labels.shape[0]:
            fc8 = mt.rename_dimension(fc8, fc8.shape[0].name,
                                      mtf_labels.shape[0].name)
        one_hot_labels = mtf.one_hot(mtf_labels, fc8.shape[-1])
        mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits(
            fc8, one_hot_labels, fc8.shape[-1])
        mtf_loss = mtf.reduce_mean(mtf_cross_ent)

    return graph, mesh_to_impl, mtf_loss
Ejemplo n.º 22
0
def bottleneck_block(inputs,
                     filters,
                     is_training,
                     strides,
                     projection_shortcut=None,
                     row_blocks_dim=None,
                     col_blocks_dim=None):
  """Bottleneck block variant for residual networks with BN after convolutions.

  Args:
    inputs: a `mtf.Tensor` of shape
        `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`.
    filters: `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
    is_training: `bool` for whether the model is in training mode.
    strides: `int` block stride. If greater than 1, this block will ultimately
        downsample the input.
    projection_shortcut: `function` to use for projection shortcuts (typically
        a 1x1 convolution to match the filter dimensions). If None, no
        projection is used and the input is passed as unchanged through the
        shortcut connection.
    row_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis
    col_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis

  Returns:
    The output `Tensor` of the block.
  """
  shortcut = inputs

  filter_h_dim = mtf.Dimension("filter_height", 3)
  filter_w_dim = mtf.Dimension("filter_width", 3)
  one_h_dim = mtf.Dimension("filter_height", 1)
  one_w_dim = mtf.Dimension("filter_width", 1)

  if projection_shortcut is not None:
    filters_dim = mtf.Dimension("filtersp", filters)
    kernel = mtf.get_variable(
        inputs.mesh, "kernel", mtf.Shape(
            [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim]))
    shortcut = projection_shortcut(inputs, kernel)

  # First conv block
  filters1_dim = mtf.Dimension("filters1", filters)
  kernel1 = mtf.get_variable(
      inputs.mesh, "kernel1", mtf.Shape(
          [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim]))
  inputs = mtf.conv2d_with_blocks(
      inputs,
      kernel1,
      strides=[1, 1, 1, 1],
      padding="SAME",
      h_blocks_dim=None, w_blocks_dim=col_blocks_dim)

  # TODO(nikip): Add Dropout?
  inputs = batch_norm_relu(inputs, is_training)

  # Second conv block
  filters2_dim = mtf.Dimension("filters2", 4*filters)
  kernel2 = mtf.get_variable(
      inputs.mesh, "kernel2", mtf.Shape(
          [filter_h_dim, filter_w_dim, filters1_dim, filters2_dim]))
  inputs = mtf.conv2d_with_blocks(
      inputs,
      kernel2,
      strides=[1, 1, 1, 1],
      padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)

  inputs = batch_norm_relu(inputs, is_training)

  # Third wide conv filter block
  filters3_dim = mtf.Dimension("filters3", filters)
  filters3_kernel = mtf.get_variable(
      inputs.mesh, "wide_kernel", mtf.Shape(
          [one_h_dim, one_w_dim, filters2_dim, filters3_dim]))
  inputs = mtf.conv2d_with_blocks(
      inputs,
      filters3_kernel,
      strides,
      padding="SAME",
      h_blocks_dim=None, w_blocks_dim=col_blocks_dim)

  # TODO(nikip): Althought the original resnet code has this batch norm, in our
  # setup this is causing no gradients to be passed. Investigate further.
  # inputs = batch_norm_relu(inputs, is_training, relu=True)

  # TODO(nikip): Maybe add residual with a projection?
  return mtf.relu(
      shortcut + mtf.rename_dimension(
          inputs, inputs.shape.dims[-1].name, shortcut.shape.dims[-1].name))
Ejemplo n.º 23
0
def backbone(x,
             layerlist,
             chalist,
             strilist,
             classes_dim,
             blocklist,
             float16=None,
             batch_norm=False):
    name = "backbone"

    x = conv2d(
        x,
        output_dim=mtf.Dimension(name=name + '-' + 'filters', size=64),
        filter_size=(7, 7),
        strides=(2, 2),
        # padding="VALID",
        name="conv7x7_backbone",
        variable_dtype=float16)
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    if batch_norm:
        x, _ = mtf.layers.batch_norm(x,
                                     is_training=True,
                                     momentum=0.99,
                                     epsilon=1e-5,
                                     name="batch_norm_backbone")
        logger.debug("[output tensor] (name,shape):({},{})".format(
            x.name, x.shape))
    x = mtf.relu(x, name="relu_backbone")
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.layers.max_pool2d(x, ksize=(2, 2), name="maxpool_backbone")
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    shortcuttype1 = 0
    shortcuttype2 = 0
    for _, (layer, channel,
            strides) in enumerate(zip(layerlist, chalist, strilist)):
        x = blocklist[0](x,
                         order=shortcuttype1,
                         out_channels=channel,
                         strides=(strides, strides),
                         float16=float16,
                         batch_norm=batch_norm)
        shortcuttype1 += 1
        for _ in range(layer - 1):
            x = blocklist[1](x,
                             order=shortcuttype2,
                             out_channels=channel,
                             strides=(1, 1),
                             float16=float16,
                             batch_norm=batch_norm)
            shortcuttype2 += 1

    # x = mtf.einsum([x], output_shape=[list(x.shape.dims)[0],list(x.shape.dims)[3]], name="einsum_backbone")
    x = mtf.layers.avg_pool2d(x,
                              ksize=(x.shape.dims[1].size,
                                     x.shape.dims[2].size))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.reshape(
        x,
        new_shape=[
            x.shape.dims[0],
            mtf.Dimension(name="flatten",
                          size=x.shape.dims[1].size * x.shape.dims[2].size *
                          x.shape.dims[3].size)
        ],
        name="flatten")
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    logit = mtf.layers.dense(x,
                             classes_dim,
                             name="dense_backbone",
                             variable_dtype=float16)
    logger.debug("[output tensor] (name,shape):({},{})".format(
        logit.name, logit.shape))
    return logit