def attention(x, dim_head, dim_features_head, scope='attn', causal=False):
    with tf.variable_scope(scope):
        mesh, batch, seq, dim = x.mesh, *x.shape

        dim_heads = mtf.Dimension('dim_heads',
                                  dim_head.size * dim_features_head.size)
        dim_intermediate = mtf.Dimension('qkv_dimension', dim_heads.size * 3)
        qkv = linear(x, dim_intermediate, bias=False, scope='to_qkv')

        q, k, v = mtf.split(qkv, dim_intermediate, 3)
        q, k, v = map(
            lambda t: mtf.reshape(t, [batch, seq, dim_head, dim_features_head]
                                  ), (q, k, v))
        q, k, v = map(
            lambda t: mtf.transpose(
                t, [batch, dim_head, seq, dim_features_head]), (q, k, v))

        k, v = map(
            lambda t: mtf.rename_dimension(t, seq.name, 'memory_length'),
            (k, v))
        mem_len_dim = v.shape[-2]

        dots = mtf.layers.us_einsum([q, k],
                                    [batch, dim_head, seq, mem_len_dim])

        if causal:
            i = mtf.range(mesh, seq, tf.int32)
            j = mtf.range(mesh, mem_len_dim, tf.int32)
            i, j = map(lambda t: mtf.broadcast(t, [seq, mem_len_dim]), (i, j))
            mask = mtf.less(i + mem_len_dim.size - seq.size, j)
            mask = mtf.cast(mask, tf.float32) * -1e10
            dots += mask

        attn = mtf.softmax(dots, mem_len_dim)
        out = mtf.einsum([attn, v], [batch, dim_head, seq, dim_features_head])

        out = mtf.transpose(out, [batch, seq, dim_head, dim_features_head])
        out = mtf.reshape(out, [batch, seq, dim_heads])

        combined_out = linear(out, dim, scope='combine_output')
        return combined_out
def local_attention2d_spatial_decoder(x, kv_dim, heads_dim, feedforward_dim,
                                      hparams):
    """Image Transformer decoder with local2D spatial layers."""
    batch_dim, length_dim, model_dim = x.shape.dims
    blocks_h_dim = mtf.Dimension("blocksh", hparams.block_height)
    blocks_w_dim = mtf.Dimension("blocksw", hparams.block_width)
    num_h_blocks_dim = mtf.Dimension("num_h_blocks",
                                     hparams.img_len // hparams.block_height)
    num_w_blocks_dim = mtf.Dimension(
        "num_w_blocks",
        hparams.img_len * hparams.num_channels // hparams.block_width)
    x = mtf.transpose(
        mtf.reshape(
            x,
            mtf.Shape([
                batch_dim, num_h_blocks_dim, blocks_h_dim, num_w_blocks_dim,
                blocks_w_dim, model_dim
            ])),
        mtf.Shape([
            batch_dim, num_h_blocks_dim, num_w_blocks_dim, blocks_h_dim,
            blocks_w_dim, model_dim
        ]))
    mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN)
    is_training = mode == tf_estimator.ModeKeys.TRAIN
    # Image Transformer Decoder
    # [ self attention - ffn - residual + dropout] x n
    for layer in range(hparams.num_decoder_layers):
        layer_name = "decoder_layer_%d" % layer
        with tf.variable_scope(layer_name):
            # Self attention layer
            x += layer_prepostprocess_dropout(
                mtf.layers.local_2d_self_attention_spatial_blocks(
                    mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
                    kv_dim,
                    heads_dim,
                    is_training,
                    memory_h_dim=num_h_blocks_dim,
                    memory_w_dim=num_w_blocks_dim,
                    name="self_att"), hparams)
            # ffn layer
            x += layer_prepostprocess_dropout(
                mtf.layers.dense_relu_dense(
                    mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"),
                    feedforward_dim,
                    hparams.dropout,
                    dropout_broadcast_dims=[length_dim]), hparams)

    output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm")
    return output
def local_attention2d_spatial_decoder(x, kv_dim, heads_dim,
                                      feedforward_dim, hparams):
  """Image Transformer decoder with local2D spatial layers."""
  batch_dim, length_dim, model_dim = x.shape.dims
  blocks_h_dim = mtf.Dimension("blocksh", hparams.block_height)
  blocks_w_dim = mtf.Dimension("blocksw", hparams.block_width)
  num_h_blocks_dim = mtf.Dimension("num_h_blocks",
                                   hparams.img_len // hparams.block_height)
  num_w_blocks_dim = mtf.Dimension(
      "num_w_blocks",
      hparams.img_len * hparams.num_channels // hparams.block_width)
  x = mtf.transpose(
      mtf.reshape(
          x,
          mtf.Shape([
              batch_dim, num_h_blocks_dim, blocks_h_dim,
              num_w_blocks_dim, blocks_w_dim, model_dim
          ])),
      mtf.Shape([
          batch_dim, num_h_blocks_dim, num_w_blocks_dim,
          blocks_h_dim, blocks_w_dim, model_dim
      ]))
  # Image Transformer Decoder
  # [ self attention - ffn - residual + dropout] x n
  for layer in range(hparams.num_decoder_layers):
    layer_name = "decoder_layer_%d" % layer
    with tf.variable_scope(layer_name):
      # Self attention layer
      x += layer_prepostprocess_dropout(
          mtf.layers.local_2d_self_attention_spatial_blocks(
              mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
              kv_dim,
              heads_dim,
              memory_h_dim=num_h_blocks_dim,
              memory_w_dim=num_w_blocks_dim,
              name="self_att"), hparams)
      # ffn layer
      x += layer_prepostprocess_dropout(
          mtf.layers.dense_relu_dense(
              mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"),
              feedforward_dim,
              hparams.dropout,
              dropout_broadcast_dims=[length_dim]), hparams)

  output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm")
  return output
Exemple #4
0
  def compute_output(self, o, output_shape=None):
    """Compute output of multihead attention.

    Args:
      o: a Tensor with dimensions
         query_heads_dims + {value_dim} + other_dims
      output_shape: an optional Shape
    Returns:
      a Tensor with shape:
         {output_dim} + other_dims
    """
    if self.combine_dims:
      o = mtf.transpose(o, o.shape - self.o_dims + self.o_dims)
      o = mtf.replace_dimensions(o, self.o_dims, self.wo.shape.dims[-2])
      reduced_dims = [self.wo.shape.dims[-2]]
    else:
      reduced_dims = self.o_dims
    return mtf.einsum(
        [o, self.wo], output_shape=output_shape, reduced_dims=reduced_dims)
Exemple #5
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    row_blocks_dim = mtf.Dimension("row_blocks", 4)
    col_blocks_dim = mtf.Dimension("col_blocks", 4)
    rows_dim = mtf.Dimension("rows_size", 7)
    cols_dim = mtf.Dimension("cols_size", 7)

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]),
        mtf.Shape([
            batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
            one_channel_dim
        ]))
    x = mtf.transpose(x, [
        batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
        one_channel_dim
    ])

    # add some convolutional layers to demonstrate that convolution works.
    fh_dim = mtf.Dimension("fh", 9)
    fw_dim = mtf.Dimension("fw", 9)
    filters1_dim = mtf.Dimension("filters1", 16)
    filters2_dim = mtf.Dimension("filters2", 16)
    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])

    f1 = mtf.relu(
        mtf.conv2d_with_blocks(x,
                               kernel1,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))
    f2 = mtf.relu(
        mtf.conv2d_with_blocks(f1,
                               kernel2,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))
    x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    #hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)
    h1 = mtf.layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-4:],
                          activation=mtf.relu,
                          name="hidden1")
    #h2 = mtf.layers.dense(
    #    h1, hidden_dim2,
    #    activation=mtf.relu, name="hidden2")
    logits = mtf.layers.dense(h1, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
  def self_attention(self, x, attention_bias):
    """Performs multi-headed self-attention with output projection.

    Args:
      x: output of previous layer
      attention_bias: optional float32 Tensor broadcastable to shape
        x.shape - self.model_dim + self.memory_seq_dim
        to be added to attention logits.
        This may used to mask out padding regions of the memory.

    Returns:
      float Tensor with the same shape as x
    """

    queries = mtf.layers.dense(
        x,
        reduced_dims=[self.model_dim],
        new_dims=[self.num_heads_dim, self.size_per_head_dim],
        kernel_initializer=self.dense_initializer,
        name="query",
        use_bias=self.config.use_bias)
    keys = mtf.layers.dense(
        mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim),
        reduced_dims=[self.model_dim],
        new_dims=[self.num_heads_dim, self.size_per_head_dim],
        kernel_initializer=self.dense_initializer,
        name="key",
        use_bias=self.config.use_bias)
    values = mtf.layers.dense(
        mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim),
        reduced_dims=[self.model_dim],
        new_dims=[self.num_heads_dim, self.size_per_head_dim],
        kernel_initializer=self.dense_initializer,
        name="value",
        use_bias=self.config.use_bias)

    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
    attention_scores = mtf.einsum(
        [queries, keys], reduced_dims=[self.size_per_head_dim])
    attention_scores *= self.size_per_head_dim.size ** -0.5

    if attention_bias is not None:
      attention_scores += attention_bias

    # Normalize the attention scores to probabilities.
    attention_probs = mtf.softmax(attention_scores, self.memory_seq_dim)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = mtf.dropout(
        attention_probs,
        is_training=(self.config.attention_probs_dropout_prob == 0.0),
        keep_prob=1.0 - self.config.attention_probs_dropout_prob)

    output = mtf.einsum([attention_probs, values],
                        reduced_dims=[self.memory_seq_dim])

    # linear transformation back to shape of query_antecedent
    output = mtf.layers.dense(
        output,
        reduced_dims=[self.num_heads_dim, self.size_per_head_dim],
        new_dims=[self.model_dim],
        kernel_initializer=self.dense_initializer,
        name="output",
        use_bias=self.config.use_bias)
    output = mtf.transpose(output, x.shape)
    return output
Exemple #7
0
def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels):
    """Builds the UNet model graph, train op and eval metrics.

  Args:
    mesh: a MeshTensorflow.mesh object.
    mesh_impl: a mesh implementation, such as SimdMeshImpl and
      PlacementMeshImpl.
    dataset_str: a string of either train or eval. This is used for batch_norm.
    images: a laid out Tensor with shape [batch, x, y, num_channels]
      or [batch, x, y, z, num_channels].
    labels: a laid out Tensor with shape [batch, x, y, num_classes]
      or [batch, x, y, z, num_classes].

  Returns:
    Prediction and loss.
  """

    is_training = (dataset_str == 'train')
    if dataset_str == 'train':
        batch_dim = mtf.Dimension('batch', FLAGS.batch_size_train)
    else:
        assert dataset_str == 'eval'
        batch_dim = mtf.Dimension('batch', FLAGS.batch_size_eval)
    image_nx_dim = mtf.Dimension('image_nx_block', FLAGS.image_nx_block)
    image_ny_dim = mtf.Dimension('image_ny_block', FLAGS.image_ny_block)
    image_sx_dim = mtf.Dimension('image_sx_block',
                                 FLAGS.ct_resolution // FLAGS.image_nx_block)
    image_sy_dim = mtf.Dimension('image_sy_block',
                                 FLAGS.ct_resolution // FLAGS.image_ny_block)
    image_sz_dim = mtf.Dimension('image_sz_block', FLAGS.ct_resolution)
    image_c_dim = mtf.Dimension('image_c', FLAGS.image_c)
    label_c_dim = mtf.Dimension('label_c', FLAGS.label_c)
    mtf_images_shape, mtf_labels_shape = get_input_mtf_shapes(dataset_str)

    mtf_dtype = tf.as_dtype(FLAGS.mtf_dtype)
    variable_dtype = mtf.VariableDType(mtf_dtype, mtf_dtype, mtf_dtype)

    # Import input features.
    x = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(images),
                                   mtf_images_shape)
    x = mtf.cast(x, mtf_dtype)

    # Import ground truth labels.
    t = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(labels),
                                   mtf_labels_shape)
    t = mtf.cast(t, mtf_dtype)

    # Transpose the blocks.
    if FLAGS.sampled_2d_slices:
        x = mtf.transpose(x, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_c_dim
        ])

        t = mtf.transpose(t, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            label_c_dim
        ])
    else:
        x = mtf.transpose(x, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_sz_dim, image_c_dim
        ])

        t = mtf.transpose(t, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_sz_dim, label_c_dim
        ])

    # Network.
    levels = []
    all_bn_update_ops = []
    # add levels with convolution or down-sampling
    for depth in range(FLAGS.network_depth):
        for n_conv in range(FLAGS.n_conv_per_block):
            if depth == 0 and n_conv == 0:
                # no dropout in 1st layer.
                dropout_keep_p = 1.0
            else:
                dropout_keep_p = FLAGS.dropout_keep_p
            x, bn_update_ops = conv_with_spatial_partition(
                x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
                FLAGS.n_base_filters * (2**depth), dropout_keep_p,
                FLAGS.with_batch_norm, is_training,
                'conv_{}_{}'.format(depth, n_conv), variable_dtype,
                'conv_down_{}_{}'.format(depth, n_conv))
            all_bn_update_ops.extend(bn_update_ops)
        levels.append(x)

        if depth < FLAGS.network_depth - 1:
            if FLAGS.sampled_2d_slices:
                x = mtf.layers.max_pool2d(x, ksize=(2, 2))
            else:
                x = mtf.layers.max_pool3d(x, ksize=(2, 2, 2))

    # add levels with up-convolution or up-sampling
    for depth in range(FLAGS.network_depth - 1)[::-1]:
        x = deconv_with_spatial_partition(
            x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
            FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p,
            'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1),
            variable_dtype, 'deconv_{}_0'.format(depth))
        x = mtf.concat([x, levels[depth]],
                       concat_dim_name='conv_{}_{}'.format(
                           depth, FLAGS.n_conv_per_block - 1))

        for n_conv in range(FLAGS.n_conv_per_block):
            x, bn_update_ops = conv_with_spatial_partition(
                x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
                FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p,
                FLAGS.with_batch_norm, is_training,
                'conv_{}_{}'.format(depth, n_conv), variable_dtype,
                'conv_up_{}_{}'.format(depth, n_conv))
            all_bn_update_ops.extend(bn_update_ops)

    # no dropout in the final layer.
    if FLAGS.sampled_2d_slices:
        y = mtf.layers.conv2d_with_blocks(
            x,
            mtf.Dimension('label_c', FLAGS.label_c),
            filter_size=(1, 1),
            strides=(1, 1),
            padding='SAME',
            h_blocks_dim=image_nx_dim,
            w_blocks_dim=image_ny_dim,
            variable_dtype=variable_dtype,
            name='final_conv_{}'.format(FLAGS.label_c),
        )
    else:
        y = mtf.layers.conv3d_with_blocks(
            x,
            mtf.Dimension('label_c', FLAGS.label_c),
            filter_size=(1, 1, 1),
            strides=(1, 1, 1),
            padding='SAME',
            d_blocks_dim=image_nx_dim,
            h_blocks_dim=image_ny_dim,
            variable_dtype=variable_dtype,
            name='final_conv_{}'.format(FLAGS.label_c),
        )

    # use mtf.constant to make sure there is no CPU-side constants.
    def scalar(v, dtype):
        return mtf.constant(mesh, v, shape=[], dtype=dtype)

    argmax_t = mtf.argmax(t, label_c_dim)
    liver_t = mtf.cast(mtf.equal(argmax_t, scalar(1, tf.int32)), mtf_dtype)
    lesion_t = mtf.cast(mtf.equal(argmax_t, scalar(2, tf.int32)), mtf_dtype)

    argmax_y = mtf.argmax(y, label_c_dim)
    lesion_y = mtf.cast(mtf.equal(argmax_y, scalar(2, tf.int32)), mtf_dtype)

    # summary of class ratios.
    lesion_pred_ratio = mtf.reduce_mean(lesion_y)
    lesion_label_ratio = mtf.reduce_mean(lesion_t)

    # summary of accuracy.
    accuracy = mtf.reduce_mean(
        mtf.cast(mtf.equal(argmax_y, argmax_t), mtf_dtype))

    # Cross-entropy loss. Up-weight the liver region.
    pixel_loss = mtf.layers.softmax_cross_entropy_with_logits(
        y, t, label_c_dim)
    pixel_weight = scalar(1, mtf_dtype) + \
        liver_t * scalar(FLAGS.xen_liver_weight - 1, mtf_dtype) + \
        lesion_t * scalar(FLAGS.xen_lesion_weight - FLAGS.xen_liver_weight,
                          mtf_dtype)
    loss_xen = mtf.reduce_mean(pixel_loss * pixel_weight)

    # Dice loss
    y_prob = mtf.softmax(y, reduced_dim=label_c_dim)
    lesion_prob = mtf.reduce_sum(mtf.slice(y_prob, 2, 1, 'label_c'),
                                 reduced_dim=mtf.Dimension('label_c', 1))
    prob_intersect = mtf.reduce_sum(lesion_prob * lesion_t,
                                    output_shape=mtf.Shape([batch_dim]))
    prob_area_sum = mtf.reduce_sum(lesion_prob + lesion_t,
                                   output_shape=mtf.Shape([batch_dim]))
    loss_dice_per_case = mtf.reduce_mean(
        scalar(-2, mtf_dtype) * prob_intersect /
        (prob_area_sum + scalar(FLAGS.dice_epsilon, mtf_dtype)))
    loss_dice_global = scalar(-2, mtf_dtype) * mtf.reduce_sum(
        prob_intersect) / (mtf.reduce_sum(prob_area_sum) +
                           scalar(FLAGS.dice_epsilon, mtf_dtype))

    loss_dice = (loss_dice_per_case + loss_dice_global) * scalar(
        0.5, mtf_dtype)

    loss = scalar(FLAGS.dice_loss_weight, mtf_dtype) * loss_dice + scalar(
        1 - FLAGS.dice_loss_weight, mtf_dtype) * loss_xen

    intersect = mtf.reduce_sum(lesion_y * lesion_t,
                               output_shape=mtf.Shape([batch_dim]))
    area_sum = mtf.reduce_sum(lesion_y + lesion_t,
                              output_shape=mtf.Shape([batch_dim]))
    # summary of dice.
    dice_per_case = mtf.reduce_mean(
        scalar(2, mtf_dtype) * intersect /
        (area_sum + scalar(0.000001, mtf_dtype)))
    dice_global = scalar(2, mtf_dtype) * mtf.reduce_sum(intersect) / (
        mtf.reduce_sum(area_sum) + scalar(0.000001, mtf_dtype))

    eval_metrics = {
        'lesion_pred_ratio': lesion_pred_ratio,
        'lesion_label_ratio': lesion_label_ratio,
        'accuracy_of_all_classes': accuracy,
        'lesion_dice_per_case': dice_per_case,
        'lesion_dice_global': dice_global,
        'loss_xen': loss_xen,
        'loss_dice': loss_dice,
        'loss_dice_per_case': loss_dice_per_case,
        'loss_dice_global': loss_dice_global,
    }

    if FLAGS.sampled_2d_slices:
        y_prob_downsampled = mtf.layers.avg_pool2d(
            y_prob, ksize=(FLAGS.pred_downsample, ) * 2)
        if FLAGS.output_ground_truth:
            lesion_gt_downsampled = mtf.layers.avg_pool2d(
                mtf.slice(t, 2, 1, 'label_c'),
                ksize=(FLAGS.pred_downsample, ) * 2)
    else:
        y_prob_downsampled = mtf.layers.avg_pool3d(
            y_prob, ksize=(FLAGS.pred_downsample, ) * 3)
        if FLAGS.output_ground_truth:
            lesion_gt_downsampled = mtf.layers.avg_pool3d(
                mtf.slice(t, 2, 1, 'label_c'),
                ksize=(FLAGS.pred_downsample, ) * 3)

    liver_prob_downsampled = mtf.slice(y_prob_downsampled, 1, 1, 'label_c')
    lesion_prob_downsampled = mtf.slice(y_prob_downsampled, 2, 1, 'label_c')
    preds = [
        mtf.reduce_sum(liver_prob_downsampled,
                       reduced_dim=mtf.Dimension('label_c', 1)),
        mtf.reduce_sum(lesion_prob_downsampled,
                       reduced_dim=mtf.Dimension('label_c', 1))
    ]

    if FLAGS.output_ground_truth:
        preds.append(
            mtf.reduce_sum(lesion_gt_downsampled,
                           reduced_dim=mtf.Dimension('label_c', 1)))

    preds.extend([intersect, area_sum])

    return preds, loss, eval_metrics, all_bn_update_ops
Exemple #8
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.set_activation_type()
        is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

        # Declare all the dimensions
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
        filter_h_dim = mtf.Dimension("filter_height", 7)
        filter_w_dim = mtf.Dimension("filter_width", 7)
        filters = mtf.Dimension("filters", hparams.filter_sizes[0])
        rows_dim = mtf.Dimension("rows_size", 32)
        cols_dim = mtf.Dimension("cols_size", 96)
        row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
        col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
        classes_dim = mtf.Dimension("classes", 10)
        one_channel_dim = mtf.Dimension("one_channel", 1)

        inputs = features["inputs"]
        x = mtf.import_tf_tensor(
            mesh,
            tf.reshape(inputs, [
                hparams.batch_size, hparams.row_blocks,
                hparams.rows_size // hparams.row_blocks, hparams.col_blocks,
                hparams.num_channels * hparams.cols_size // hparams.col_blocks,
                1
            ]),
            mtf.Shape([
                batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
                one_channel_dim
            ]))
        x = mtf.transpose(x, [
            batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
            one_channel_dim
        ])

        x = mtf.to_float(x)
        initial_filters = mtf.get_variable(
            mesh, "init_filters",
            mtf.Shape([filter_h_dim, filter_w_dim, one_channel_dim, filters]))
        x = mtf.conv2d_with_blocks(x,
                                   initial_filters,
                                   strides=[1, 1, 1, 1],
                                   padding="SAME",
                                   h_blocks_dim=None,
                                   w_blocks_dim=col_blocks_dim)

        x = batch_norm_relu(x, is_training)

        # Conv blocks
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_layers):
            layer_name = "block_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Residual block layer
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[0],
                                blocks=hparams.layer_sizes[0],
                                strides=[1, 1, 1, 1],
                                is_training=is_training,
                                name="block_layer1",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[1],
                                blocks=hparams.layer_sizes[1],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer2",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[2],
                                blocks=hparams.layer_sizes[2],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer3",
                                row_blocks_dim=None,
                                col_blocks_dim=None)

        # Calculate the logits and loss.
        out = x
        outputs = mtf.layers.dense(out,
                                   hidden_dim,
                                   reduced_dims=out.shape.dims[-5:],
                                   activation=mtf.relu,
                                   name="dense")

        # We assume fixed vocab size for targets
        labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [hparams.batch_size]),
                                      mtf.Shape([batch_dim]))

        logits = mtf.layers.dense(outputs, classes_dim, name="logits")
        soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, classes_dim)

        # Reshape logits so it doesn't break inside t2t.
        logits = mtf.reshape(
            logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
        loss = mtf.reduce_mean(loss)
        return logits, loss
Exemple #9
0
def mnist_model(image, labels, mesh):
	"""The model.
	Args:
		image: tf.Tensor with shape [batch, 28*28]
		labels: a tf.Tensor with shape [batch] and dtype tf.int32
		mesh: a mtf.Mesh
	Returns:
		logits: a mtf.Tensor with shape [batch, 10]
		loss: a mtf.Tensor with shape []
	"""
	batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
	row_blocks_dim = mtf.Dimension("row_blocks", 4)
	col_blocks_dim = mtf.Dimension("col_blocks", 4)
	rows_dim = mtf.Dimension("rows_size", 7)
	cols_dim = mtf.Dimension("cols_size", 7)

	classes_dim = mtf.Dimension("classes", 10)
	one_channel_dim = mtf.Dimension("one_channel", 1)

	x = mtf.import_tf_tensor(
		mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]),
		mtf.Shape(
			[batch_dim, row_blocks_dim, rows_dim,
			col_blocks_dim, cols_dim, one_channel_dim]))
	x = mtf.transpose(x, [
		batch_dim, row_blocks_dim, col_blocks_dim,
		rows_dim, cols_dim, one_channel_dim])
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(x.name,x.shape))
	# add some convolutional layers to demonstrate that convolution works.
	filters1_dim = mtf.Dimension("filters1", 16)
	filters2_dim = mtf.Dimension("filters2", 16)
	f1 = mtf.relu(mtf.layers.conv2d_with_blocks(
		x, filters1_dim, filter_size=[9, 9], strides=[1, 1], padding="SAME",
		h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv0"))
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(f1.name,f1.shape))
	f2 = mtf.relu(mtf.layers.conv2d_with_blocks(
		f1, filters2_dim, filter_size=[9, 9], strides=[1, 1], padding="SAME",
		h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv1"))
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(f2.name,f2.shape))
	x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(x.name,x.shape))
	# add some fully-connected dense layers.
	hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
	hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

	h1 = mtf.layers.dense(
		x, hidden_dim1,
		reduced_dims=x.shape.dims[-4:],
		activation=mtf.relu, name="hidden1")
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(h1.name,h1.shape))
	h2 = mtf.layers.dense(
		h1, hidden_dim2,
		activation=mtf.relu, name="hidden2")
	tf.logging.info("[intra variable] (name, shape): ({},{})".format(h2.name,h2.shape))
	logits = mtf.layers.dense(h2, classes_dim, name="logits")
	if labels is None:
		loss = None
	else:
		labels = mtf.import_tf_tensor(
			mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim]))
		loss = mtf.layers.softmax_cross_entropy_with_logits(
			logits, mtf.one_hot(labels, classes_dim), classes_dim)
		loss = mtf.reduce_mean(loss)
	return logits, loss
Exemple #10
0
  def mtf_model_fn(self, features, mesh):
    features = copy.copy(features)
    tf.logging.info("features = %s" % features)
    hparams = self._hparams
    activation_dtype = self.set_activation_type()
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

    # Declare all the dimensions
    batch_dim = mtf.Dimension("batch", hparams.batch_size)
    hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
    filter_h_dim = mtf.Dimension("filter_height", 7)
    filter_w_dim = mtf.Dimension("filter_width", 7)
    filters = mtf.Dimension("filters", hparams.filter_sizes[0])
    rows_dim = mtf.Dimension("rows_size", hparams.rows_size)
    cols_dim = mtf.Dimension("cols_size", hparams.cols_size)
    row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
    col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
    classes_dim = mtf.Dimension("classes", 10)
    channels_dim = mtf.Dimension("channels", 3)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    inputs = features["inputs"]
    x = mtf.import_tf_tensor(
        mesh, tf.reshape(inputs, [
            hparams.batch_size,
            hparams.row_blocks,
            hparams.rows_size // hparams.row_blocks,
            hparams.col_blocks,
            hparams.num_channels*hparams.cols_size // hparams.col_blocks,
            hparams.num_channels]),
        mtf.Shape(
            [batch_dim, row_blocks_dim, rows_dim,
             col_blocks_dim, cols_dim, channels_dim]))
    x = mtf.transpose(x, [batch_dim, row_blocks_dim, col_blocks_dim,
                          rows_dim, cols_dim, channels_dim])

    x = mtf.to_float(x)
    initial_filters = mtf.get_variable(
        mesh, "init_filters",
        mtf.Shape([filter_h_dim, filter_w_dim, channels_dim, filters]))
    x = mtf.conv2d_with_blocks(
        x,
        initial_filters,
        strides=[1, 1, 1, 1],
        padding="SAME",
        h_blocks_dim=None, w_blocks_dim=col_blocks_dim)

    x = batch_norm_relu(x, is_training)

    # Conv blocks
    # [block - strided block layer - strided block layer] x n
    for layer in range(hparams.num_layers):
      layer_name = "block_layer_%d" % layer
      with tf.variable_scope(layer_name):
        # Residual block layer
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[0],
            blocks=hparams.layer_sizes[0],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer1",
            row_blocks_dim=None,
            col_blocks_dim=None)
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[1],
            blocks=hparams.layer_sizes[1],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer2",
            row_blocks_dim=None,
            col_blocks_dim=None)
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[2],
            blocks=hparams.layer_sizes[2],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer3",
            row_blocks_dim=None,
            col_blocks_dim=None)

    # Calculate the logits and loss.
    out = x
    outputs = mtf.layers.dense(
        out, hidden_dim,
        reduced_dims=out.shape.dims[-5:],
        activation=mtf.relu, name="dense")

    # We assume fixed vocab size for targets
    labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
    labels = mtf.import_tf_tensor(
        mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim]))

    logits = mtf.layers.dense(outputs, classes_dim, name="logits")
    soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, classes_dim)

    # Reshape logits so it doesn't break inside t2t.
    logits = mtf.reshape(
        logits,
        mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
    loss = mtf.reduce_mean(loss)
    return logits, loss
Exemple #11
0
def cifar_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    row_blocks_dim = mtf.Dimension("row_blocks", 1)
    col_blocks_dim = mtf.Dimension("col_blocks", 1)
    rows_dim = mtf.Dimension("rows_size", 32)
    cols_dim = mtf.Dimension("cols_size", 32)
    init = 60

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 3)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 1, 32, 1, 32, 3]),
        mtf.Shape([
            batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
            one_channel_dim
        ]))
    x = mtf.transpose(x, [
        batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
        one_channel_dim
    ])

    # add some convolutional layers to demonstrate that convolution works.
    filters1_dim = mtf.Dimension("filters1", init)
    filters2_dim = mtf.Dimension("filters2", init)
    f1 = mtf.relu(
        mtf.layers.conv2d_with_blocks(x,
                                      filters1_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv0"))
    #print("conv:, ", f1.shape)

    f2 = mtf.relu(
        mtf.layers.conv2d_with_blocks(f1,
                                      filters2_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv1"))

    x = mtf.layers.max_pool2d(f2, ksize=(2, 2), name="maxpool0")

    #print(x.shape)

    filters3_dim = mtf.Dimension("filters3", init * 2)
    filters4_dim = mtf.Dimension("filters4", init * 2)

    f3 = mtf.relu(
        mtf.layers.conv2d_with_blocks(x,
                                      filters3_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv2"))
    f4 = mtf.relu(
        mtf.layers.conv2d_with_blocks(f3,
                                      filters4_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv3"))

    x = mtf.layers.max_pool2d(f4, ksize=(2, 2), name="maxpool1")

    #print(x.shape)
    filters5_dim = mtf.Dimension("filters5", init * 4)
    filters6_dim = mtf.Dimension("filters6", init * 4)

    f5 = mtf.relu(
        mtf.layers.conv2d_with_blocks(x,
                                      filters5_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv4"))
    f6 = mtf.relu(
        mtf.layers.conv2d_with_blocks(f5,
                                      filters6_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv5"))

    x = mtf.layers.max_pool2d(f6, ksize=(2, 2), name="maxpool2")
    #print(x.shape)

    filters7_dim = mtf.Dimension("filters7", init * 8)
    filters8_dim = mtf.Dimension("filters8", init * 8)

    f7 = mtf.relu(
        mtf.layers.conv2d_with_blocks(x,
                                      filters7_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv6"))
    f8 = mtf.relu(
        mtf.layers.conv2d_with_blocks(f7,
                                      filters8_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv7"))

    x = mtf.layers.max_pool2d(f8, ksize=(2, 2), name="maxpool3")
    #  x = mtf.reduce_mean(f8, reduced_dim=filters8_dim)
    # add some fully-connected dense layers.
    #hidden_dim1 = mtf.Dimension("hidden1", init*8)
    hidden_dim1 = mtf.Dimension("hidden1", 256)
    hidden_dim2 = mtf.Dimension("hidden2", init * 8)

    h1 = mtf.layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-5:],
                          activation=mtf.relu,
                          name="hidden1")
    #h2 = mtf.layers.dense(
    #h1, hidden_dim2,
    #activation=mtf.relu, name="hidden2")
    logits = mtf.layers.dense(h1, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)

    all_filters = [[
        init, init, init * 2, init * 2, init * 4, init * 4, init * 8, init * 8
    ]]
    return logits, loss, all_filters
Exemple #12
0
def cifar_model(features, labels, mesh):
  """The model.

  Args:
    image: tf.Tensor with shape [batch, 32*32]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
  features = copy.copy(features)
  batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
  row_blocks_dim = mtf.Dimension("row_blocks", 4)
  col_blocks_dim = mtf.Dimension("col_blocks", 4)
  rows_dim = mtf.Dimension("rows_size", 8)
  cols_dim = mtf.Dimension("cols_size", 8)

  classes_dim = mtf.Dimension("classes", 10)
  one_channel_dim = mtf.Dimension("one_channel", 3)


  # image = features['input']
  # with tf.device('/cpu:0'):
  image = features['image']
  labels = features['label']

  image = bnorm(image)

  x = mtf.import_tf_tensor(
      mesh, tf.reshape(image, [FLAGS.batch_size, 4, 8, 4, 8, 3]),
      mtf.Shape(
          [batch_dim, row_blocks_dim, rows_dim,
           col_blocks_dim, cols_dim, one_channel_dim]))
  x = mtf.transpose(x, [
      batch_dim, row_blocks_dim, col_blocks_dim,
      rows_dim, cols_dim, one_channel_dim])

  # add some convolutional layers to demonstrate that convolution works.
  fh_dim = mtf.Dimension("fh", 7)
  fw_dim = mtf.Dimension("fw", 7)
  filters1_dim = mtf.Dimension("filters1", 32)
  filters2_dim = mtf.Dimension("filters2", 32)

  kernel1 = mtf.get_variable(
      mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim])
  kernel2 = mtf.get_variable(
      mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim])


  f1 = mtf.relu(mtf.conv2d_with_blocks(
      x, kernel1, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))


  f2 = mtf.relu(mtf.conv2d_with_blocks(
      f1, kernel2, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters3_dim = mtf.Dimension("filters3", 64)
  kernel3 = mtf.get_variable(
      mesh, "kernel3", [fh_dim, fw_dim, filters2_dim, filters3_dim])  

  f3 = mtf.relu(mtf.conv2d_with_blocks(
      f2, kernel3, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters4_dim = mtf.Dimension("filters4", 64)
  kernel4 = mtf.get_variable(
      mesh, "kernel4", [fh_dim, fw_dim, filters3_dim, filters4_dim])  

  f4 = mtf.relu(mtf.conv2d_with_blocks(
      f3, kernel4, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters5_dim = mtf.Dimension("filters5", 128)
  kernel5 = mtf.get_variable(
      mesh, "kernel5", [fh_dim, fw_dim, filters4_dim, filters5_dim])  

  f5 = mtf.relu(mtf.conv2d_with_blocks(
      f4, kernel5, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))    

  filters6_dim = mtf.Dimension("filters6", 128)
  kernel6 = mtf.get_variable(
      mesh, "kernel6", [fh_dim, fw_dim, filters5_dim, filters6_dim])  

  f6 = mtf.relu(mtf.conv2d_with_blocks(
      f5, kernel6, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters7_dim = mtf.Dimension("filters7", 128)
  kernel7 = mtf.get_variable(
      mesh, "kernel7", [fh_dim, fw_dim, filters6_dim, filters7_dim])  

  f7 = mtf.relu(mtf.conv2d_with_blocks(
      f6, kernel7, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters8_dim = mtf.Dimension("filters8", 128)
  kernel8 = mtf.get_variable(
      mesh, "kernel8", [fh_dim, fw_dim, filters7_dim, filters8_dim])  

  f8 = mtf.relu(mtf.conv2d_with_blocks(
      f7, kernel8, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters9_dim = mtf.Dimension("filters9", 128)
  kernel9 = mtf.get_variable(
      mesh, "kernel9", [fh_dim, fw_dim, filters8_dim, filters9_dim])  

  f9 = mtf.relu(mtf.conv2d_with_blocks(
      f8, kernel9, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters10_dim = mtf.Dimension("filters10", 128)
  kernel10 = mtf.get_variable(
      mesh, "kernel10", [fh_dim, fw_dim, filters9_dim, filters10_dim])  

  f10 = mtf.relu(mtf.conv2d_with_blocks(
      f9, kernel10, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))                              
 

  filters11_dim = mtf.Dimension("filters11", 256)
  kernel11 = mtf.get_variable(
      mesh, "kernel11", [fh_dim, fw_dim, filters10_dim, filters11_dim])  

  f11 = mtf.relu(mtf.conv2d_with_blocks(
      f10, kernel11, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters12_dim = mtf.Dimension("filters12", 256)
  kernel12 = mtf.get_variable(
      mesh, "kernel12", [fh_dim, fw_dim, filters11_dim, filters12_dim])  

  f12 = mtf.relu(mtf.conv2d_with_blocks(
      f11, kernel12, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))                                            
 

  filters13_dim = mtf.Dimension("filters13", 256)
  kernel13 = mtf.get_variable(
      mesh, "kernel13", [fh_dim, fw_dim, filters12_dim, filters13_dim])  

  f13 = mtf.relu(mtf.conv2d_with_blocks(
      f12, kernel13, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))     

  filters14_dim = mtf.Dimension("filters14", 256)
  kernel14 = mtf.get_variable(
      mesh, "kernel14", [fh_dim, fw_dim, filters13_dim, filters14_dim])  

  f14 = mtf.relu(mtf.conv2d_with_blocks(
      f13, kernel14, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))   

  filters15_dim = mtf.Dimension("filters15", 256)
  kernel15 = mtf.get_variable(
      mesh, "kernel15", [fh_dim, fw_dim, filters14_dim, filters15_dim])  

  f15 = mtf.relu(mtf.conv2d_with_blocks(
      f14, kernel15, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters16_dim = mtf.Dimension("filters16", 256)
  kernel16 = mtf.get_variable(
      mesh, "kernel16", [fh_dim, fw_dim, filters15_dim, filters16_dim])  
  f16 = mtf.relu(mtf.conv2d_with_blocks(
      f15, kernel16, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))  

  filters17_dim = mtf.Dimension("filters17", 256)
  kernel17 = mtf.get_variable(
      mesh, "kernel17", [fh_dim, fw_dim, filters16_dim, filters17_dim])  

  f17 = mtf.relu(mtf.conv2d_with_blocks(
      f16, kernel17, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) 

  filters18_dim = mtf.Dimension("filters18", 256)
  kernel18 = mtf.get_variable(
      mesh, "kernel18", [fh_dim, fw_dim, filters17_dim, filters18_dim])  

  f18 = mtf.relu(mtf.conv2d_with_blocks(
      f17, kernel18, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))        

  x = mtf.reduce_mean(f18, reduced_dim=filters18_dim)

  # add some fully-connected dense layers.
  hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
  hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

  h1 = mtf.layers.dense(
      x, hidden_dim1,
      reduced_dims=x.shape.dims[-4:],
      activation=mtf.relu, name="hidden1")
  h2 = mtf.layers.dense(
      h1, hidden_dim2,
      activation=mtf.relu, name="hidden2")

  hidden_dim3 = mtf.Dimension("hidden3", FLAGS.hidden_size)
  hidden_dim4 = mtf.Dimension("hidden4", FLAGS.hidden_size)
  hidden_dim5 = mtf.Dimension("hidden5", FLAGS.hidden_size)
  hidden_dim6 = mtf.Dimension("hidden6", FLAGS.hidden_size)
  hidden_dim7 = mtf.Dimension("hidden7", FLAGS.hidden_size)
  hidden_dim8 = mtf.Dimension("hidden8", FLAGS.hidden_size)

  h3 = mtf.layers.dense(
      h2, hidden_dim3,
      activation=mtf.relu, name="hidden3")

  h4 = mtf.layers.dense(
      h3, hidden_dim4,
      activation=mtf.relu, name="hidden4")

  h5 = mtf.layers.dense(
    h4, hidden_dim5,
    activation=mtf.relu, name="hidden5")

  h6 = mtf.layers.dense(
    h5, hidden_dim6,
    activation=mtf.relu, name="hidden6")

  h7 = mtf.layers.dense(
    h6, hidden_dim7,
    activation=mtf.relu, name="hidden7") 

  h8 = mtf.layers.dense(
    h7, hidden_dim8,
    activation=mtf.relu, name="hidden8")                        

  logits = mtf.layers.dense(h8, classes_dim, name="logits")
  
  if labels is None:
    loss = None
  else:
    labels = mtf.import_tf_tensor(
        mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim]))
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, mtf.one_hot(labels, classes_dim), classes_dim)
    loss = mtf.reduce_mean(loss)
  return logits, loss
Exemple #13
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    row_blocks_dim = mtf.Dimension("row_blocks", 1)
    col_blocks_dim = mtf.Dimension("col_blocks", 1)
    rows_dim = mtf.Dimension("rows_size", 28)
    cols_dim = mtf.Dimension("cols_size", 28)
    init = 60

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 1, 28, 1, 28, 1]),
        mtf.Shape([
            batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
            one_channel_dim
        ]))
    x = mtf.transpose(x, [
        batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
        one_channel_dim
    ])

    # add some convolutional layers to demonstrate that convolution works.
    filters1_dim = mtf.Dimension("filters1", 60)
    f1 = mtf.relu(
        mtf.layers.conv2d_with_blocks(x,
                                      filters1_dim,
                                      filter_size=[7, 7],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv0"))

    # f1 = mtf.reshape(f1, [FLAGS.batch_size, 1, 30, 3, 10, 1])
    filters2_dim = mtf.Dimension("filters2", 120)
    f2 = mtf.relu(
        mtf.layers.conv2d_with_blocks(f1,
                                      filters2_dim,
                                      filter_size=[5, 5],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv1"))

    filters3_dim = mtf.Dimension("filters3", 240)
    f3 = mtf.relu(
        mtf.layers.conv2d_with_blocks(f2,
                                      filters3_dim,
                                      filter_size=[3, 3],
                                      strides=[1, 1],
                                      padding="SAME",
                                      h_blocks_dim=row_blocks_dim,
                                      w_blocks_dim=col_blocks_dim,
                                      name="conv2"))

    x = mtf.layers.avg_pool2d(f3, ksize=(2, 2), name="averagePool")

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", 128)
    print(x.shape)
    h1 = mtf.layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-5:],
                          activation=mtf.relu,
                          name="hidden1")
    #  h1=x
    #  print(h1.shape)

    logits = mtf.layers.dense(h1, classes_dim, name="logits")
    #  logits = mtf.layers.dense(h1, classes_dim, reduced_dims=x.shape.dims[-5:], name="logits")

    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)

    all_filters = [[init, init * 2, init * 4]]
    return logits, loss, all_filters