def attention(x, dim_head, dim_features_head, scope='attn', causal=False): with tf.variable_scope(scope): mesh, batch, seq, dim = x.mesh, *x.shape dim_heads = mtf.Dimension('dim_heads', dim_head.size * dim_features_head.size) dim_intermediate = mtf.Dimension('qkv_dimension', dim_heads.size * 3) qkv = linear(x, dim_intermediate, bias=False, scope='to_qkv') q, k, v = mtf.split(qkv, dim_intermediate, 3) q, k, v = map( lambda t: mtf.reshape(t, [batch, seq, dim_head, dim_features_head] ), (q, k, v)) q, k, v = map( lambda t: mtf.transpose( t, [batch, dim_head, seq, dim_features_head]), (q, k, v)) k, v = map( lambda t: mtf.rename_dimension(t, seq.name, 'memory_length'), (k, v)) mem_len_dim = v.shape[-2] dots = mtf.layers.us_einsum([q, k], [batch, dim_head, seq, mem_len_dim]) if causal: i = mtf.range(mesh, seq, tf.int32) j = mtf.range(mesh, mem_len_dim, tf.int32) i, j = map(lambda t: mtf.broadcast(t, [seq, mem_len_dim]), (i, j)) mask = mtf.less(i + mem_len_dim.size - seq.size, j) mask = mtf.cast(mask, tf.float32) * -1e10 dots += mask attn = mtf.softmax(dots, mem_len_dim) out = mtf.einsum([attn, v], [batch, dim_head, seq, dim_features_head]) out = mtf.transpose(out, [batch, seq, dim_head, dim_features_head]) out = mtf.reshape(out, [batch, seq, dim_heads]) combined_out = linear(out, dim, scope='combine_output') return combined_out
def local_attention2d_spatial_decoder(x, kv_dim, heads_dim, feedforward_dim, hparams): """Image Transformer decoder with local2D spatial layers.""" batch_dim, length_dim, model_dim = x.shape.dims blocks_h_dim = mtf.Dimension("blocksh", hparams.block_height) blocks_w_dim = mtf.Dimension("blocksw", hparams.block_width) num_h_blocks_dim = mtf.Dimension("num_h_blocks", hparams.img_len // hparams.block_height) num_w_blocks_dim = mtf.Dimension( "num_w_blocks", hparams.img_len * hparams.num_channels // hparams.block_width) x = mtf.transpose( mtf.reshape( x, mtf.Shape([ batch_dim, num_h_blocks_dim, blocks_h_dim, num_w_blocks_dim, blocks_w_dim, model_dim ])), mtf.Shape([ batch_dim, num_h_blocks_dim, num_w_blocks_dim, blocks_h_dim, blocks_w_dim, model_dim ])) mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN) is_training = mode == tf_estimator.ModeKeys.TRAIN # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_decoder_layers): layer_name = "decoder_layer_%d" % layer with tf.variable_scope(layer_name): # Self attention layer x += layer_prepostprocess_dropout( mtf.layers.local_2d_self_attention_spatial_blocks( mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"), kv_dim, heads_dim, is_training, memory_h_dim=num_h_blocks_dim, memory_w_dim=num_w_blocks_dim, name="self_att"), hparams) # ffn layer x += layer_prepostprocess_dropout( mtf.layers.dense_relu_dense( mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"), feedforward_dim, hparams.dropout, dropout_broadcast_dims=[length_dim]), hparams) output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm") return output
def local_attention2d_spatial_decoder(x, kv_dim, heads_dim, feedforward_dim, hparams): """Image Transformer decoder with local2D spatial layers.""" batch_dim, length_dim, model_dim = x.shape.dims blocks_h_dim = mtf.Dimension("blocksh", hparams.block_height) blocks_w_dim = mtf.Dimension("blocksw", hparams.block_width) num_h_blocks_dim = mtf.Dimension("num_h_blocks", hparams.img_len // hparams.block_height) num_w_blocks_dim = mtf.Dimension( "num_w_blocks", hparams.img_len * hparams.num_channels // hparams.block_width) x = mtf.transpose( mtf.reshape( x, mtf.Shape([ batch_dim, num_h_blocks_dim, blocks_h_dim, num_w_blocks_dim, blocks_w_dim, model_dim ])), mtf.Shape([ batch_dim, num_h_blocks_dim, num_w_blocks_dim, blocks_h_dim, blocks_w_dim, model_dim ])) # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_decoder_layers): layer_name = "decoder_layer_%d" % layer with tf.variable_scope(layer_name): # Self attention layer x += layer_prepostprocess_dropout( mtf.layers.local_2d_self_attention_spatial_blocks( mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"), kv_dim, heads_dim, memory_h_dim=num_h_blocks_dim, memory_w_dim=num_w_blocks_dim, name="self_att"), hparams) # ffn layer x += layer_prepostprocess_dropout( mtf.layers.dense_relu_dense( mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"), feedforward_dim, hparams.dropout, dropout_broadcast_dims=[length_dim]), hparams) output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm") return output
def compute_output(self, o, output_shape=None): """Compute output of multihead attention. Args: o: a Tensor with dimensions query_heads_dims + {value_dim} + other_dims output_shape: an optional Shape Returns: a Tensor with shape: {output_dim} + other_dims """ if self.combine_dims: o = mtf.transpose(o, o.shape - self.o_dims + self.o_dims) o = mtf.replace_dimensions(o, self.o_dims, self.wo.shape.dims[-2]) reduced_dims = [self.wo.shape.dims[-2]] else: reduced_dims = self.o_dims return mtf.einsum( [o, self.wo], output_shape=output_shape, reduced_dims=reduced_dims)
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 4) col_blocks_dim = mtf.Dimension("col_blocks", 4) rows_dim = mtf.Dimension("rows_size", 7) cols_dim = mtf.Dimension("cols_size", 7) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) # add some convolutional layers to demonstrate that convolution works. fh_dim = mtf.Dimension("fh", 9) fw_dim = mtf.Dimension("fw", 9) filters1_dim = mtf.Dimension("filters1", 16) filters2_dim = mtf.Dimension("filters2", 16) kernel1 = mtf.get_variable(mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable(mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) f1 = mtf.relu( mtf.conv2d_with_blocks(x, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) f2 = mtf.relu( mtf.conv2d_with_blocks(f1, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) x = mtf.reduce_mean(f2, reduced_dim=filters2_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) #hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf.layers.dense(x, hidden_dim1, reduced_dims=x.shape.dims[-4:], activation=mtf.relu, name="hidden1") #h2 = mtf.layers.dense( # h1, hidden_dim2, # activation=mtf.relu, name="hidden2") logits = mtf.layers.dense(h1, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def self_attention(self, x, attention_bias): """Performs multi-headed self-attention with output projection. Args: x: output of previous layer attention_bias: optional float32 Tensor broadcastable to shape x.shape - self.model_dim + self.memory_seq_dim to be added to attention logits. This may used to mask out padding regions of the memory. Returns: float Tensor with the same shape as x """ queries = mtf.layers.dense( x, reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="query", use_bias=self.config.use_bias) keys = mtf.layers.dense( mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim), reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="key", use_bias=self.config.use_bias) values = mtf.layers.dense( mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim), reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="value", use_bias=self.config.use_bias) # Take the dot product between "query" and "key" to get the raw # attention scores. attention_scores = mtf.einsum( [queries, keys], reduced_dims=[self.size_per_head_dim]) attention_scores *= self.size_per_head_dim.size ** -0.5 if attention_bias is not None: attention_scores += attention_bias # Normalize the attention scores to probabilities. attention_probs = mtf.softmax(attention_scores, self.memory_seq_dim) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = mtf.dropout( attention_probs, is_training=(self.config.attention_probs_dropout_prob == 0.0), keep_prob=1.0 - self.config.attention_probs_dropout_prob) output = mtf.einsum([attention_probs, values], reduced_dims=[self.memory_seq_dim]) # linear transformation back to shape of query_antecedent output = mtf.layers.dense( output, reduced_dims=[self.num_heads_dim, self.size_per_head_dim], new_dims=[self.model_dim], kernel_initializer=self.dense_initializer, name="output", use_bias=self.config.use_bias) output = mtf.transpose(output, x.shape) return output
def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels): """Builds the UNet model graph, train op and eval metrics. Args: mesh: a MeshTensorflow.mesh object. mesh_impl: a mesh implementation, such as SimdMeshImpl and PlacementMeshImpl. dataset_str: a string of either train or eval. This is used for batch_norm. images: a laid out Tensor with shape [batch, x, y, num_channels] or [batch, x, y, z, num_channels]. labels: a laid out Tensor with shape [batch, x, y, num_classes] or [batch, x, y, z, num_classes]. Returns: Prediction and loss. """ is_training = (dataset_str == 'train') if dataset_str == 'train': batch_dim = mtf.Dimension('batch', FLAGS.batch_size_train) else: assert dataset_str == 'eval' batch_dim = mtf.Dimension('batch', FLAGS.batch_size_eval) image_nx_dim = mtf.Dimension('image_nx_block', FLAGS.image_nx_block) image_ny_dim = mtf.Dimension('image_ny_block', FLAGS.image_ny_block) image_sx_dim = mtf.Dimension('image_sx_block', FLAGS.ct_resolution // FLAGS.image_nx_block) image_sy_dim = mtf.Dimension('image_sy_block', FLAGS.ct_resolution // FLAGS.image_ny_block) image_sz_dim = mtf.Dimension('image_sz_block', FLAGS.ct_resolution) image_c_dim = mtf.Dimension('image_c', FLAGS.image_c) label_c_dim = mtf.Dimension('label_c', FLAGS.label_c) mtf_images_shape, mtf_labels_shape = get_input_mtf_shapes(dataset_str) mtf_dtype = tf.as_dtype(FLAGS.mtf_dtype) variable_dtype = mtf.VariableDType(mtf_dtype, mtf_dtype, mtf_dtype) # Import input features. x = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(images), mtf_images_shape) x = mtf.cast(x, mtf_dtype) # Import ground truth labels. t = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(labels), mtf_labels_shape) t = mtf.cast(t, mtf_dtype) # Transpose the blocks. if FLAGS.sampled_2d_slices: x = mtf.transpose(x, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_c_dim ]) t = mtf.transpose(t, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, label_c_dim ]) else: x = mtf.transpose(x, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_sz_dim, image_c_dim ]) t = mtf.transpose(t, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_sz_dim, label_c_dim ]) # Network. levels = [] all_bn_update_ops = [] # add levels with convolution or down-sampling for depth in range(FLAGS.network_depth): for n_conv in range(FLAGS.n_conv_per_block): if depth == 0 and n_conv == 0: # no dropout in 1st layer. dropout_keep_p = 1.0 else: dropout_keep_p = FLAGS.dropout_keep_p x, bn_update_ops = conv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), dropout_keep_p, FLAGS.with_batch_norm, is_training, 'conv_{}_{}'.format(depth, n_conv), variable_dtype, 'conv_down_{}_{}'.format(depth, n_conv)) all_bn_update_ops.extend(bn_update_ops) levels.append(x) if depth < FLAGS.network_depth - 1: if FLAGS.sampled_2d_slices: x = mtf.layers.max_pool2d(x, ksize=(2, 2)) else: x = mtf.layers.max_pool3d(x, ksize=(2, 2, 2)) # add levels with up-convolution or up-sampling for depth in range(FLAGS.network_depth - 1)[::-1]: x = deconv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p, 'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1), variable_dtype, 'deconv_{}_0'.format(depth)) x = mtf.concat([x, levels[depth]], concat_dim_name='conv_{}_{}'.format( depth, FLAGS.n_conv_per_block - 1)) for n_conv in range(FLAGS.n_conv_per_block): x, bn_update_ops = conv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p, FLAGS.with_batch_norm, is_training, 'conv_{}_{}'.format(depth, n_conv), variable_dtype, 'conv_up_{}_{}'.format(depth, n_conv)) all_bn_update_ops.extend(bn_update_ops) # no dropout in the final layer. if FLAGS.sampled_2d_slices: y = mtf.layers.conv2d_with_blocks( x, mtf.Dimension('label_c', FLAGS.label_c), filter_size=(1, 1), strides=(1, 1), padding='SAME', h_blocks_dim=image_nx_dim, w_blocks_dim=image_ny_dim, variable_dtype=variable_dtype, name='final_conv_{}'.format(FLAGS.label_c), ) else: y = mtf.layers.conv3d_with_blocks( x, mtf.Dimension('label_c', FLAGS.label_c), filter_size=(1, 1, 1), strides=(1, 1, 1), padding='SAME', d_blocks_dim=image_nx_dim, h_blocks_dim=image_ny_dim, variable_dtype=variable_dtype, name='final_conv_{}'.format(FLAGS.label_c), ) # use mtf.constant to make sure there is no CPU-side constants. def scalar(v, dtype): return mtf.constant(mesh, v, shape=[], dtype=dtype) argmax_t = mtf.argmax(t, label_c_dim) liver_t = mtf.cast(mtf.equal(argmax_t, scalar(1, tf.int32)), mtf_dtype) lesion_t = mtf.cast(mtf.equal(argmax_t, scalar(2, tf.int32)), mtf_dtype) argmax_y = mtf.argmax(y, label_c_dim) lesion_y = mtf.cast(mtf.equal(argmax_y, scalar(2, tf.int32)), mtf_dtype) # summary of class ratios. lesion_pred_ratio = mtf.reduce_mean(lesion_y) lesion_label_ratio = mtf.reduce_mean(lesion_t) # summary of accuracy. accuracy = mtf.reduce_mean( mtf.cast(mtf.equal(argmax_y, argmax_t), mtf_dtype)) # Cross-entropy loss. Up-weight the liver region. pixel_loss = mtf.layers.softmax_cross_entropy_with_logits( y, t, label_c_dim) pixel_weight = scalar(1, mtf_dtype) + \ liver_t * scalar(FLAGS.xen_liver_weight - 1, mtf_dtype) + \ lesion_t * scalar(FLAGS.xen_lesion_weight - FLAGS.xen_liver_weight, mtf_dtype) loss_xen = mtf.reduce_mean(pixel_loss * pixel_weight) # Dice loss y_prob = mtf.softmax(y, reduced_dim=label_c_dim) lesion_prob = mtf.reduce_sum(mtf.slice(y_prob, 2, 1, 'label_c'), reduced_dim=mtf.Dimension('label_c', 1)) prob_intersect = mtf.reduce_sum(lesion_prob * lesion_t, output_shape=mtf.Shape([batch_dim])) prob_area_sum = mtf.reduce_sum(lesion_prob + lesion_t, output_shape=mtf.Shape([batch_dim])) loss_dice_per_case = mtf.reduce_mean( scalar(-2, mtf_dtype) * prob_intersect / (prob_area_sum + scalar(FLAGS.dice_epsilon, mtf_dtype))) loss_dice_global = scalar(-2, mtf_dtype) * mtf.reduce_sum( prob_intersect) / (mtf.reduce_sum(prob_area_sum) + scalar(FLAGS.dice_epsilon, mtf_dtype)) loss_dice = (loss_dice_per_case + loss_dice_global) * scalar( 0.5, mtf_dtype) loss = scalar(FLAGS.dice_loss_weight, mtf_dtype) * loss_dice + scalar( 1 - FLAGS.dice_loss_weight, mtf_dtype) * loss_xen intersect = mtf.reduce_sum(lesion_y * lesion_t, output_shape=mtf.Shape([batch_dim])) area_sum = mtf.reduce_sum(lesion_y + lesion_t, output_shape=mtf.Shape([batch_dim])) # summary of dice. dice_per_case = mtf.reduce_mean( scalar(2, mtf_dtype) * intersect / (area_sum + scalar(0.000001, mtf_dtype))) dice_global = scalar(2, mtf_dtype) * mtf.reduce_sum(intersect) / ( mtf.reduce_sum(area_sum) + scalar(0.000001, mtf_dtype)) eval_metrics = { 'lesion_pred_ratio': lesion_pred_ratio, 'lesion_label_ratio': lesion_label_ratio, 'accuracy_of_all_classes': accuracy, 'lesion_dice_per_case': dice_per_case, 'lesion_dice_global': dice_global, 'loss_xen': loss_xen, 'loss_dice': loss_dice, 'loss_dice_per_case': loss_dice_per_case, 'loss_dice_global': loss_dice_global, } if FLAGS.sampled_2d_slices: y_prob_downsampled = mtf.layers.avg_pool2d( y_prob, ksize=(FLAGS.pred_downsample, ) * 2) if FLAGS.output_ground_truth: lesion_gt_downsampled = mtf.layers.avg_pool2d( mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample, ) * 2) else: y_prob_downsampled = mtf.layers.avg_pool3d( y_prob, ksize=(FLAGS.pred_downsample, ) * 3) if FLAGS.output_ground_truth: lesion_gt_downsampled = mtf.layers.avg_pool3d( mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample, ) * 3) liver_prob_downsampled = mtf.slice(y_prob_downsampled, 1, 1, 'label_c') lesion_prob_downsampled = mtf.slice(y_prob_downsampled, 2, 1, 'label_c') preds = [ mtf.reduce_sum(liver_prob_downsampled, reduced_dim=mtf.Dimension('label_c', 1)), mtf.reduce_sum(lesion_prob_downsampled, reduced_dim=mtf.Dimension('label_c', 1)) ] if FLAGS.output_ground_truth: preds.append( mtf.reduce_sum(lesion_gt_downsampled, reduced_dim=mtf.Dimension('label_c', 1))) preds.extend([intersect, area_sum]) return preds, loss, eval_metrics, all_bn_update_ops
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) hidden_dim = mtf.Dimension("hidden", hparams.hidden_size) filter_h_dim = mtf.Dimension("filter_height", 7) filter_w_dim = mtf.Dimension("filter_width", 7) filters = mtf.Dimension("filters", hparams.filter_sizes[0]) rows_dim = mtf.Dimension("rows_size", 32) cols_dim = mtf.Dimension("cols_size", 96) row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks) col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) inputs = features["inputs"] x = mtf.import_tf_tensor( mesh, tf.reshape(inputs, [ hparams.batch_size, hparams.row_blocks, hparams.rows_size // hparams.row_blocks, hparams.col_blocks, hparams.num_channels * hparams.cols_size // hparams.col_blocks, 1 ]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) x = mtf.to_float(x) initial_filters = mtf.get_variable( mesh, "init_filters", mtf.Shape([filter_h_dim, filter_w_dim, one_channel_dim, filters])) x = mtf.conv2d_with_blocks(x, initial_filters, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) x = batch_norm_relu(x, is_training) # Conv blocks # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_layers): layer_name = "block_layer_%d" % layer with tf.variable_scope(layer_name): # Residual block layer x = block_layer(inputs=x, filters=hparams.filter_sizes[0], blocks=hparams.layer_sizes[0], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer1", row_blocks_dim=None, col_blocks_dim=None) x = block_layer(inputs=x, filters=hparams.filter_sizes[1], blocks=hparams.layer_sizes[1], strides=[1, 2, 2, 1], is_training=is_training, name="block_layer2", row_blocks_dim=None, col_blocks_dim=None) x = block_layer(inputs=x, filters=hparams.filter_sizes[2], blocks=hparams.layer_sizes[2], strides=[1, 2, 2, 1], is_training=is_training, name="block_layer3", row_blocks_dim=None, col_blocks_dim=None) # Calculate the logits and loss. out = x outputs = mtf.layers.dense(out, hidden_dim, reduced_dims=out.shape.dims[-5:], activation=mtf.relu, name="dense") # We assume fixed vocab size for targets labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3]) labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim])) logits = mtf.layers.dense(outputs, classes_dim, name="logits") soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, classes_dim) # Reshape logits so it doesn't break inside t2t. logits = mtf.reshape( logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim])) loss = mtf.reduce_mean(loss) return logits, loss
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 4) col_blocks_dim = mtf.Dimension("col_blocks", 4) rows_dim = mtf.Dimension("rows_size", 7) cols_dim = mtf.Dimension("cols_size", 7) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]), mtf.Shape( [batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim]) tf.logging.info("[intra variable] (name, shape): ({},{})".format(x.name,x.shape)) # add some convolutional layers to demonstrate that convolution works. filters1_dim = mtf.Dimension("filters1", 16) filters2_dim = mtf.Dimension("filters2", 16) f1 = mtf.relu(mtf.layers.conv2d_with_blocks( x, filters1_dim, filter_size=[9, 9], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv0")) tf.logging.info("[intra variable] (name, shape): ({},{})".format(f1.name,f1.shape)) f2 = mtf.relu(mtf.layers.conv2d_with_blocks( f1, filters2_dim, filter_size=[9, 9], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv1")) tf.logging.info("[intra variable] (name, shape): ({},{})".format(f2.name,f2.shape)) x = mtf.reduce_mean(f2, reduced_dim=filters2_dim) tf.logging.info("[intra variable] (name, shape): ({},{})".format(x.name,x.shape)) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf.layers.dense( x, hidden_dim1, reduced_dims=x.shape.dims[-4:], activation=mtf.relu, name="hidden1") tf.logging.info("[intra variable] (name, shape): ({},{})".format(h1.name,h1.shape)) h2 = mtf.layers.dense( h1, hidden_dim2, activation=mtf.relu, name="hidden2") tf.logging.info("[intra variable] (name, shape): ({},{})".format(h2.name,h2.shape)) logits = mtf.layers.dense(h2, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) hidden_dim = mtf.Dimension("hidden", hparams.hidden_size) filter_h_dim = mtf.Dimension("filter_height", 7) filter_w_dim = mtf.Dimension("filter_width", 7) filters = mtf.Dimension("filters", hparams.filter_sizes[0]) rows_dim = mtf.Dimension("rows_size", hparams.rows_size) cols_dim = mtf.Dimension("cols_size", hparams.cols_size) row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks) col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks) classes_dim = mtf.Dimension("classes", 10) channels_dim = mtf.Dimension("channels", 3) one_channel_dim = mtf.Dimension("one_channel", 1) inputs = features["inputs"] x = mtf.import_tf_tensor( mesh, tf.reshape(inputs, [ hparams.batch_size, hparams.row_blocks, hparams.rows_size // hparams.row_blocks, hparams.col_blocks, hparams.num_channels*hparams.cols_size // hparams.col_blocks, hparams.num_channels]), mtf.Shape( [batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, channels_dim])) x = mtf.transpose(x, [batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, channels_dim]) x = mtf.to_float(x) initial_filters = mtf.get_variable( mesh, "init_filters", mtf.Shape([filter_h_dim, filter_w_dim, channels_dim, filters])) x = mtf.conv2d_with_blocks( x, initial_filters, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) x = batch_norm_relu(x, is_training) # Conv blocks # [block - strided block layer - strided block layer] x n for layer in range(hparams.num_layers): layer_name = "block_layer_%d" % layer with tf.variable_scope(layer_name): # Residual block layer x = block_layer( inputs=x, filters=hparams.filter_sizes[0], blocks=hparams.layer_sizes[0], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer1", row_blocks_dim=None, col_blocks_dim=None) x = block_layer( inputs=x, filters=hparams.filter_sizes[1], blocks=hparams.layer_sizes[1], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer2", row_blocks_dim=None, col_blocks_dim=None) x = block_layer( inputs=x, filters=hparams.filter_sizes[2], blocks=hparams.layer_sizes[2], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer3", row_blocks_dim=None, col_blocks_dim=None) # Calculate the logits and loss. out = x outputs = mtf.layers.dense( out, hidden_dim, reduced_dims=out.shape.dims[-5:], activation=mtf.relu, name="dense") # We assume fixed vocab size for targets labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3]) labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim])) logits = mtf.layers.dense(outputs, classes_dim, name="logits") soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, classes_dim) # Reshape logits so it doesn't break inside t2t. logits = mtf.reshape( logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim])) loss = mtf.reduce_mean(loss) return logits, loss
def cifar_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 1) col_blocks_dim = mtf.Dimension("col_blocks", 1) rows_dim = mtf.Dimension("rows_size", 32) cols_dim = mtf.Dimension("cols_size", 32) init = 60 classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 3) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 1, 32, 1, 32, 3]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) # add some convolutional layers to demonstrate that convolution works. filters1_dim = mtf.Dimension("filters1", init) filters2_dim = mtf.Dimension("filters2", init) f1 = mtf.relu( mtf.layers.conv2d_with_blocks(x, filters1_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv0")) #print("conv:, ", f1.shape) f2 = mtf.relu( mtf.layers.conv2d_with_blocks(f1, filters2_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv1")) x = mtf.layers.max_pool2d(f2, ksize=(2, 2), name="maxpool0") #print(x.shape) filters3_dim = mtf.Dimension("filters3", init * 2) filters4_dim = mtf.Dimension("filters4", init * 2) f3 = mtf.relu( mtf.layers.conv2d_with_blocks(x, filters3_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv2")) f4 = mtf.relu( mtf.layers.conv2d_with_blocks(f3, filters4_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv3")) x = mtf.layers.max_pool2d(f4, ksize=(2, 2), name="maxpool1") #print(x.shape) filters5_dim = mtf.Dimension("filters5", init * 4) filters6_dim = mtf.Dimension("filters6", init * 4) f5 = mtf.relu( mtf.layers.conv2d_with_blocks(x, filters5_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv4")) f6 = mtf.relu( mtf.layers.conv2d_with_blocks(f5, filters6_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv5")) x = mtf.layers.max_pool2d(f6, ksize=(2, 2), name="maxpool2") #print(x.shape) filters7_dim = mtf.Dimension("filters7", init * 8) filters8_dim = mtf.Dimension("filters8", init * 8) f7 = mtf.relu( mtf.layers.conv2d_with_blocks(x, filters7_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv6")) f8 = mtf.relu( mtf.layers.conv2d_with_blocks(f7, filters8_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv7")) x = mtf.layers.max_pool2d(f8, ksize=(2, 2), name="maxpool3") # x = mtf.reduce_mean(f8, reduced_dim=filters8_dim) # add some fully-connected dense layers. #hidden_dim1 = mtf.Dimension("hidden1", init*8) hidden_dim1 = mtf.Dimension("hidden1", 256) hidden_dim2 = mtf.Dimension("hidden2", init * 8) h1 = mtf.layers.dense(x, hidden_dim1, reduced_dims=x.shape.dims[-5:], activation=mtf.relu, name="hidden1") #h2 = mtf.layers.dense( #h1, hidden_dim2, #activation=mtf.relu, name="hidden2") logits = mtf.layers.dense(h1, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) all_filters = [[ init, init, init * 2, init * 2, init * 4, init * 4, init * 8, init * 8 ]] return logits, loss, all_filters
def cifar_model(features, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 32*32] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ features = copy.copy(features) batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 4) col_blocks_dim = mtf.Dimension("col_blocks", 4) rows_dim = mtf.Dimension("rows_size", 8) cols_dim = mtf.Dimension("cols_size", 8) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 3) # image = features['input'] # with tf.device('/cpu:0'): image = features['image'] labels = features['label'] image = bnorm(image) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 4, 8, 4, 8, 3]), mtf.Shape( [batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim]) # add some convolutional layers to demonstrate that convolution works. fh_dim = mtf.Dimension("fh", 7) fw_dim = mtf.Dimension("fw", 7) filters1_dim = mtf.Dimension("filters1", 32) filters2_dim = mtf.Dimension("filters2", 32) kernel1 = mtf.get_variable( mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable( mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) f1 = mtf.relu(mtf.conv2d_with_blocks( x, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) f2 = mtf.relu(mtf.conv2d_with_blocks( f1, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters3_dim = mtf.Dimension("filters3", 64) kernel3 = mtf.get_variable( mesh, "kernel3", [fh_dim, fw_dim, filters2_dim, filters3_dim]) f3 = mtf.relu(mtf.conv2d_with_blocks( f2, kernel3, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters4_dim = mtf.Dimension("filters4", 64) kernel4 = mtf.get_variable( mesh, "kernel4", [fh_dim, fw_dim, filters3_dim, filters4_dim]) f4 = mtf.relu(mtf.conv2d_with_blocks( f3, kernel4, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters5_dim = mtf.Dimension("filters5", 128) kernel5 = mtf.get_variable( mesh, "kernel5", [fh_dim, fw_dim, filters4_dim, filters5_dim]) f5 = mtf.relu(mtf.conv2d_with_blocks( f4, kernel5, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters6_dim = mtf.Dimension("filters6", 128) kernel6 = mtf.get_variable( mesh, "kernel6", [fh_dim, fw_dim, filters5_dim, filters6_dim]) f6 = mtf.relu(mtf.conv2d_with_blocks( f5, kernel6, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters7_dim = mtf.Dimension("filters7", 128) kernel7 = mtf.get_variable( mesh, "kernel7", [fh_dim, fw_dim, filters6_dim, filters7_dim]) f7 = mtf.relu(mtf.conv2d_with_blocks( f6, kernel7, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters8_dim = mtf.Dimension("filters8", 128) kernel8 = mtf.get_variable( mesh, "kernel8", [fh_dim, fw_dim, filters7_dim, filters8_dim]) f8 = mtf.relu(mtf.conv2d_with_blocks( f7, kernel8, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters9_dim = mtf.Dimension("filters9", 128) kernel9 = mtf.get_variable( mesh, "kernel9", [fh_dim, fw_dim, filters8_dim, filters9_dim]) f9 = mtf.relu(mtf.conv2d_with_blocks( f8, kernel9, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters10_dim = mtf.Dimension("filters10", 128) kernel10 = mtf.get_variable( mesh, "kernel10", [fh_dim, fw_dim, filters9_dim, filters10_dim]) f10 = mtf.relu(mtf.conv2d_with_blocks( f9, kernel10, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters11_dim = mtf.Dimension("filters11", 256) kernel11 = mtf.get_variable( mesh, "kernel11", [fh_dim, fw_dim, filters10_dim, filters11_dim]) f11 = mtf.relu(mtf.conv2d_with_blocks( f10, kernel11, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters12_dim = mtf.Dimension("filters12", 256) kernel12 = mtf.get_variable( mesh, "kernel12", [fh_dim, fw_dim, filters11_dim, filters12_dim]) f12 = mtf.relu(mtf.conv2d_with_blocks( f11, kernel12, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters13_dim = mtf.Dimension("filters13", 256) kernel13 = mtf.get_variable( mesh, "kernel13", [fh_dim, fw_dim, filters12_dim, filters13_dim]) f13 = mtf.relu(mtf.conv2d_with_blocks( f12, kernel13, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters14_dim = mtf.Dimension("filters14", 256) kernel14 = mtf.get_variable( mesh, "kernel14", [fh_dim, fw_dim, filters13_dim, filters14_dim]) f14 = mtf.relu(mtf.conv2d_with_blocks( f13, kernel14, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters15_dim = mtf.Dimension("filters15", 256) kernel15 = mtf.get_variable( mesh, "kernel15", [fh_dim, fw_dim, filters14_dim, filters15_dim]) f15 = mtf.relu(mtf.conv2d_with_blocks( f14, kernel15, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters16_dim = mtf.Dimension("filters16", 256) kernel16 = mtf.get_variable( mesh, "kernel16", [fh_dim, fw_dim, filters15_dim, filters16_dim]) f16 = mtf.relu(mtf.conv2d_with_blocks( f15, kernel16, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters17_dim = mtf.Dimension("filters17", 256) kernel17 = mtf.get_variable( mesh, "kernel17", [fh_dim, fw_dim, filters16_dim, filters17_dim]) f17 = mtf.relu(mtf.conv2d_with_blocks( f16, kernel17, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters18_dim = mtf.Dimension("filters18", 256) kernel18 = mtf.get_variable( mesh, "kernel18", [fh_dim, fw_dim, filters17_dim, filters18_dim]) f18 = mtf.relu(mtf.conv2d_with_blocks( f17, kernel18, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) x = mtf.reduce_mean(f18, reduced_dim=filters18_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf.layers.dense( x, hidden_dim1, reduced_dims=x.shape.dims[-4:], activation=mtf.relu, name="hidden1") h2 = mtf.layers.dense( h1, hidden_dim2, activation=mtf.relu, name="hidden2") hidden_dim3 = mtf.Dimension("hidden3", FLAGS.hidden_size) hidden_dim4 = mtf.Dimension("hidden4", FLAGS.hidden_size) hidden_dim5 = mtf.Dimension("hidden5", FLAGS.hidden_size) hidden_dim6 = mtf.Dimension("hidden6", FLAGS.hidden_size) hidden_dim7 = mtf.Dimension("hidden7", FLAGS.hidden_size) hidden_dim8 = mtf.Dimension("hidden8", FLAGS.hidden_size) h3 = mtf.layers.dense( h2, hidden_dim3, activation=mtf.relu, name="hidden3") h4 = mtf.layers.dense( h3, hidden_dim4, activation=mtf.relu, name="hidden4") h5 = mtf.layers.dense( h4, hidden_dim5, activation=mtf.relu, name="hidden5") h6 = mtf.layers.dense( h5, hidden_dim6, activation=mtf.relu, name="hidden6") h7 = mtf.layers.dense( h6, hidden_dim7, activation=mtf.relu, name="hidden7") h8 = mtf.layers.dense( h7, hidden_dim8, activation=mtf.relu, name="hidden8") logits = mtf.layers.dense(h8, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 1) col_blocks_dim = mtf.Dimension("col_blocks", 1) rows_dim = mtf.Dimension("rows_size", 28) cols_dim = mtf.Dimension("cols_size", 28) init = 60 classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 1, 28, 1, 28, 1]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) # add some convolutional layers to demonstrate that convolution works. filters1_dim = mtf.Dimension("filters1", 60) f1 = mtf.relu( mtf.layers.conv2d_with_blocks(x, filters1_dim, filter_size=[7, 7], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv0")) # f1 = mtf.reshape(f1, [FLAGS.batch_size, 1, 30, 3, 10, 1]) filters2_dim = mtf.Dimension("filters2", 120) f2 = mtf.relu( mtf.layers.conv2d_with_blocks(f1, filters2_dim, filter_size=[5, 5], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv1")) filters3_dim = mtf.Dimension("filters3", 240) f3 = mtf.relu( mtf.layers.conv2d_with_blocks(f2, filters3_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv2")) x = mtf.layers.avg_pool2d(f3, ksize=(2, 2), name="averagePool") # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", 128) print(x.shape) h1 = mtf.layers.dense(x, hidden_dim1, reduced_dims=x.shape.dims[-5:], activation=mtf.relu, name="hidden1") # h1=x # print(h1.shape) logits = mtf.layers.dense(h1, classes_dim, name="logits") # logits = mtf.layers.dense(h1, classes_dim, reduced_dims=x.shape.dims[-5:], name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) all_filters = [[init, init * 2, init * 4]] return logits, loss, all_filters