def projection_shortcut(inputs, kernel): """Project identity branch.""" inputs = mtf.conv2d_with_blocks(inputs, kernel, strides=strides, padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) return batch_norm_relu(inputs, is_training, relu=False)
def projection_shortcut(inputs, kernel): """Project identity branch.""" inputs = mtf.conv2d_with_blocks( inputs, kernel, strides=strides, padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) return batch_norm_relu( inputs, is_training, relu=False)
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 4) col_blocks_dim = mtf.Dimension("col_blocks", 4) rows_dim = mtf.Dimension("rows_size", 7) cols_dim = mtf.Dimension("cols_size", 7) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) # add some convolutional layers to demonstrate that convolution works. fh_dim = mtf.Dimension("fh", 9) fw_dim = mtf.Dimension("fw", 9) filters1_dim = mtf.Dimension("filters1", 16) filters2_dim = mtf.Dimension("filters2", 16) kernel1 = mtf.get_variable(mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable(mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) f1 = mtf.relu( mtf.conv2d_with_blocks(x, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) f2 = mtf.relu( mtf.conv2d_with_blocks(f1, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) x = mtf.reduce_mean(f2, reduced_dim=filters2_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) #hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf.layers.dense(x, hidden_dim1, reduced_dims=x.shape.dims[-4:], activation=mtf.relu, name="hidden1") #h2 = mtf.layers.dense( # h1, hidden_dim2, # activation=mtf.relu, name="hidden2") logits = mtf.layers.dense(h1, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def bottleneck_block(inputs, filters, is_training, strides, projection_shortcut=None, row_blocks_dim=None, col_blocks_dim=None): """Bottleneck block variant for residual networks with BN after convolutions. Args: inputs: a `mtf.Tensor` of shape `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`. filters: `int` number of filters for the first two convolutions. Note that the third and final convolution will use 4 times as many filters. is_training: `bool` for whether the model is in training mode. strides: `int` block stride. If greater than 1, this block will ultimately downsample the input. projection_shortcut: `function` to use for projection shortcuts (typically a 1x1 convolution to match the filter dimensions). If None, no projection is used and the input is passed as unchanged through the shortcut connection. row_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis col_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis Returns: The output `Tensor` of the block. """ shortcut = inputs filter_h_dim = mtf.Dimension("filter_height", 3) filter_w_dim = mtf.Dimension("filter_width", 3) one_h_dim = mtf.Dimension("filter_height", 1) one_w_dim = mtf.Dimension("filter_width", 1) if projection_shortcut is not None: filters_dim = mtf.Dimension("filtersp", filters) kernel = mtf.get_variable( inputs.mesh, "kernel", mtf.Shape( [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim])) shortcut = projection_shortcut(inputs, kernel) # First conv block filters1_dim = mtf.Dimension("filters1", filters) kernel1 = mtf.get_variable( inputs.mesh, "kernel1", mtf.Shape([one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim])) inputs = mtf.conv2d_with_blocks(inputs, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) # TODO(nikip): Add Dropout? inputs = batch_norm_relu(inputs, is_training) # Second conv block filters2_dim = mtf.Dimension("filters2", filters) kernel2 = mtf.get_variable( inputs.mesh, "kernel2", mtf.Shape([filter_h_dim, filter_w_dim, filters1_dim, filters2_dim])) inputs = mtf.conv2d_with_blocks(inputs, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim) inputs = batch_norm_relu(inputs, is_training) # Third wide conv filter block filters3_dim = mtf.Dimension("filters3", filters) filters3_kernel = mtf.get_variable( inputs.mesh, "wide_kernel", mtf.Shape([one_h_dim, one_w_dim, filters2_dim, filters3_dim])) inputs = mtf.conv2d_with_blocks(inputs, filters3_kernel, strides, padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) inputs = batch_norm_relu(inputs, is_training, relu=False) # TODO(nikip): Maybe add residual with a projection? return mtf.relu(inputs + mtf.rename_dimension( shortcut, shortcut.shape.dims[-1].name, inputs.shape.dims[-1].name))
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) hidden_dim = mtf.Dimension("hidden", hparams.hidden_size) filter_h_dim = mtf.Dimension("filter_height", 7) filter_w_dim = mtf.Dimension("filter_width", 7) filters = mtf.Dimension("filters", hparams.filter_sizes[0]) rows_dim = mtf.Dimension("rows_size", 32) cols_dim = mtf.Dimension("cols_size", 96) row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks) col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) inputs = features["inputs"] x = mtf.import_tf_tensor( mesh, tf.reshape(inputs, [ hparams.batch_size, hparams.row_blocks, hparams.rows_size // hparams.row_blocks, hparams.col_blocks, hparams.num_channels * hparams.cols_size // hparams.col_blocks, 1 ]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) x = mtf.to_float(x) initial_filters = mtf.get_variable( mesh, "init_filters", mtf.Shape([filter_h_dim, filter_w_dim, one_channel_dim, filters])) x = mtf.conv2d_with_blocks(x, initial_filters, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) x = batch_norm_relu(x, is_training) # Conv blocks # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_layers): layer_name = "block_layer_%d" % layer with tf.variable_scope(layer_name): # Residual block layer x = block_layer(inputs=x, filters=hparams.filter_sizes[0], blocks=hparams.layer_sizes[0], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer1", row_blocks_dim=None, col_blocks_dim=None) x = block_layer(inputs=x, filters=hparams.filter_sizes[1], blocks=hparams.layer_sizes[1], strides=[1, 2, 2, 1], is_training=is_training, name="block_layer2", row_blocks_dim=None, col_blocks_dim=None) x = block_layer(inputs=x, filters=hparams.filter_sizes[2], blocks=hparams.layer_sizes[2], strides=[1, 2, 2, 1], is_training=is_training, name="block_layer3", row_blocks_dim=None, col_blocks_dim=None) # Calculate the logits and loss. out = x outputs = mtf.layers.dense(out, hidden_dim, reduced_dims=out.shape.dims[-5:], activation=mtf.relu, name="dense") # We assume fixed vocab size for targets labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3]) labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim])) logits = mtf.layers.dense(outputs, classes_dim, name="logits") soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, classes_dim) # Reshape logits so it doesn't break inside t2t. logits = mtf.reshape( logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim])) loss = mtf.reduce_mean(loss) return logits, loss
def bottleneck_block(inputs, filters, is_training, strides, projection_shortcut=None, row_blocks_dim=None, col_blocks_dim=None): """Bottleneck block variant for residual networks with BN after convolutions. Args: inputs: a `mtf.Tensor` of shape `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`. filters: `int` number of filters for the first two convolutions. Note that the third and final convolution will use 4 times as many filters. is_training: `bool` for whether the model is in training mode. strides: `int` block stride. If greater than 1, this block will ultimately downsample the input. projection_shortcut: `function` to use for projection shortcuts (typically a 1x1 convolution to match the filter dimensions). If None, no projection is used and the input is passed as unchanged through the shortcut connection. row_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis col_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis Returns: The output `Tensor` of the block. """ shortcut = inputs filter_h_dim = mtf.Dimension("filter_height", 3) filter_w_dim = mtf.Dimension("filter_width", 3) one_h_dim = mtf.Dimension("filter_height", 1) one_w_dim = mtf.Dimension("filter_width", 1) if projection_shortcut is not None: filters_dim = mtf.Dimension("filtersp", filters) kernel = mtf.get_variable( inputs.mesh, "kernel", mtf.Shape( [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim])) shortcut = projection_shortcut(inputs, kernel) # First conv block filters1_dim = mtf.Dimension("filters1", filters) kernel1 = mtf.get_variable( inputs.mesh, "kernel1", mtf.Shape( [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim])) inputs = mtf.conv2d_with_blocks( inputs, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) # TODO(nikip): Add Dropout? inputs = batch_norm_relu(inputs, is_training) # Second conv block filters2_dim = mtf.Dimension("filters2", 4*filters) kernel2 = mtf.get_variable( inputs.mesh, "kernel2", mtf.Shape( [filter_h_dim, filter_w_dim, filters1_dim, filters2_dim])) inputs = mtf.conv2d_with_blocks( inputs, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim) inputs = batch_norm_relu(inputs, is_training) # Third wide conv filter block filters3_dim = mtf.Dimension("filters3", filters) filters3_kernel = mtf.get_variable( inputs.mesh, "wide_kernel", mtf.Shape( [one_h_dim, one_w_dim, filters2_dim, filters3_dim])) inputs = mtf.conv2d_with_blocks( inputs, filters3_kernel, strides, padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) # TODO(nikip): Althought the original resnet code has this batch norm, in our # setup this is causing no gradients to be passed. Investigate further. # inputs = batch_norm_relu(inputs, is_training, relu=True) # TODO(nikip): Maybe add residual with a projection? return mtf.relu( shortcut + mtf.rename_dimension( inputs, inputs.shape.dims[-1].name, shortcut.shape.dims[-1].name))
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) hidden_dim = mtf.Dimension("hidden", hparams.hidden_size) filter_h_dim = mtf.Dimension("filter_height", 7) filter_w_dim = mtf.Dimension("filter_width", 7) filters = mtf.Dimension("filters", hparams.filter_sizes[0]) rows_dim = mtf.Dimension("rows_size", hparams.rows_size) cols_dim = mtf.Dimension("cols_size", hparams.cols_size) row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks) col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks) classes_dim = mtf.Dimension("classes", 10) channels_dim = mtf.Dimension("channels", 3) one_channel_dim = mtf.Dimension("one_channel", 1) inputs = features["inputs"] x = mtf.import_tf_tensor( mesh, tf.reshape(inputs, [ hparams.batch_size, hparams.row_blocks, hparams.rows_size // hparams.row_blocks, hparams.col_blocks, hparams.num_channels*hparams.cols_size // hparams.col_blocks, hparams.num_channels]), mtf.Shape( [batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, channels_dim])) x = mtf.transpose(x, [batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, channels_dim]) x = mtf.to_float(x) initial_filters = mtf.get_variable( mesh, "init_filters", mtf.Shape([filter_h_dim, filter_w_dim, channels_dim, filters])) x = mtf.conv2d_with_blocks( x, initial_filters, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) x = batch_norm_relu(x, is_training) # Conv blocks # [block - strided block layer - strided block layer] x n for layer in range(hparams.num_layers): layer_name = "block_layer_%d" % layer with tf.variable_scope(layer_name): # Residual block layer x = block_layer( inputs=x, filters=hparams.filter_sizes[0], blocks=hparams.layer_sizes[0], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer1", row_blocks_dim=None, col_blocks_dim=None) x = block_layer( inputs=x, filters=hparams.filter_sizes[1], blocks=hparams.layer_sizes[1], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer2", row_blocks_dim=None, col_blocks_dim=None) x = block_layer( inputs=x, filters=hparams.filter_sizes[2], blocks=hparams.layer_sizes[2], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer3", row_blocks_dim=None, col_blocks_dim=None) # Calculate the logits and loss. out = x outputs = mtf.layers.dense( out, hidden_dim, reduced_dims=out.shape.dims[-5:], activation=mtf.relu, name="dense") # We assume fixed vocab size for targets labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3]) labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim])) logits = mtf.layers.dense(outputs, classes_dim, name="logits") soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, classes_dim) # Reshape logits so it doesn't break inside t2t. logits = mtf.reshape( logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim])) loss = mtf.reduce_mean(loss) return logits, loss
def cifar_model(features, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 32*32] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ features = copy.copy(features) batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 4) col_blocks_dim = mtf.Dimension("col_blocks", 4) rows_dim = mtf.Dimension("rows_size", 8) cols_dim = mtf.Dimension("cols_size", 8) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 3) # image = features['input'] # with tf.device('/cpu:0'): image = features['image'] labels = features['label'] image = bnorm(image) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 4, 8, 4, 8, 3]), mtf.Shape( [batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim]) # add some convolutional layers to demonstrate that convolution works. fh_dim = mtf.Dimension("fh", 7) fw_dim = mtf.Dimension("fw", 7) filters1_dim = mtf.Dimension("filters1", 32) filters2_dim = mtf.Dimension("filters2", 32) kernel1 = mtf.get_variable( mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable( mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) f1 = mtf.relu(mtf.conv2d_with_blocks( x, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) f2 = mtf.relu(mtf.conv2d_with_blocks( f1, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters3_dim = mtf.Dimension("filters3", 64) kernel3 = mtf.get_variable( mesh, "kernel3", [fh_dim, fw_dim, filters2_dim, filters3_dim]) f3 = mtf.relu(mtf.conv2d_with_blocks( f2, kernel3, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters4_dim = mtf.Dimension("filters4", 64) kernel4 = mtf.get_variable( mesh, "kernel4", [fh_dim, fw_dim, filters3_dim, filters4_dim]) f4 = mtf.relu(mtf.conv2d_with_blocks( f3, kernel4, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters5_dim = mtf.Dimension("filters5", 128) kernel5 = mtf.get_variable( mesh, "kernel5", [fh_dim, fw_dim, filters4_dim, filters5_dim]) f5 = mtf.relu(mtf.conv2d_with_blocks( f4, kernel5, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters6_dim = mtf.Dimension("filters6", 128) kernel6 = mtf.get_variable( mesh, "kernel6", [fh_dim, fw_dim, filters5_dim, filters6_dim]) f6 = mtf.relu(mtf.conv2d_with_blocks( f5, kernel6, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters7_dim = mtf.Dimension("filters7", 128) kernel7 = mtf.get_variable( mesh, "kernel7", [fh_dim, fw_dim, filters6_dim, filters7_dim]) f7 = mtf.relu(mtf.conv2d_with_blocks( f6, kernel7, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters8_dim = mtf.Dimension("filters8", 128) kernel8 = mtf.get_variable( mesh, "kernel8", [fh_dim, fw_dim, filters7_dim, filters8_dim]) f8 = mtf.relu(mtf.conv2d_with_blocks( f7, kernel8, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters9_dim = mtf.Dimension("filters9", 128) kernel9 = mtf.get_variable( mesh, "kernel9", [fh_dim, fw_dim, filters8_dim, filters9_dim]) f9 = mtf.relu(mtf.conv2d_with_blocks( f8, kernel9, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters10_dim = mtf.Dimension("filters10", 128) kernel10 = mtf.get_variable( mesh, "kernel10", [fh_dim, fw_dim, filters9_dim, filters10_dim]) f10 = mtf.relu(mtf.conv2d_with_blocks( f9, kernel10, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters11_dim = mtf.Dimension("filters11", 256) kernel11 = mtf.get_variable( mesh, "kernel11", [fh_dim, fw_dim, filters10_dim, filters11_dim]) f11 = mtf.relu(mtf.conv2d_with_blocks( f10, kernel11, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters12_dim = mtf.Dimension("filters12", 256) kernel12 = mtf.get_variable( mesh, "kernel12", [fh_dim, fw_dim, filters11_dim, filters12_dim]) f12 = mtf.relu(mtf.conv2d_with_blocks( f11, kernel12, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters13_dim = mtf.Dimension("filters13", 256) kernel13 = mtf.get_variable( mesh, "kernel13", [fh_dim, fw_dim, filters12_dim, filters13_dim]) f13 = mtf.relu(mtf.conv2d_with_blocks( f12, kernel13, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters14_dim = mtf.Dimension("filters14", 256) kernel14 = mtf.get_variable( mesh, "kernel14", [fh_dim, fw_dim, filters13_dim, filters14_dim]) f14 = mtf.relu(mtf.conv2d_with_blocks( f13, kernel14, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters15_dim = mtf.Dimension("filters15", 256) kernel15 = mtf.get_variable( mesh, "kernel15", [fh_dim, fw_dim, filters14_dim, filters15_dim]) f15 = mtf.relu(mtf.conv2d_with_blocks( f14, kernel15, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters16_dim = mtf.Dimension("filters16", 256) kernel16 = mtf.get_variable( mesh, "kernel16", [fh_dim, fw_dim, filters15_dim, filters16_dim]) f16 = mtf.relu(mtf.conv2d_with_blocks( f15, kernel16, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters17_dim = mtf.Dimension("filters17", 256) kernel17 = mtf.get_variable( mesh, "kernel17", [fh_dim, fw_dim, filters16_dim, filters17_dim]) f17 = mtf.relu(mtf.conv2d_with_blocks( f16, kernel17, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters18_dim = mtf.Dimension("filters18", 256) kernel18 = mtf.get_variable( mesh, "kernel18", [fh_dim, fw_dim, filters17_dim, filters18_dim]) f18 = mtf.relu(mtf.conv2d_with_blocks( f17, kernel18, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) x = mtf.reduce_mean(f18, reduced_dim=filters18_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf.layers.dense( x, hidden_dim1, reduced_dims=x.shape.dims[-4:], activation=mtf.relu, name="hidden1") h2 = mtf.layers.dense( h1, hidden_dim2, activation=mtf.relu, name="hidden2") hidden_dim3 = mtf.Dimension("hidden3", FLAGS.hidden_size) hidden_dim4 = mtf.Dimension("hidden4", FLAGS.hidden_size) hidden_dim5 = mtf.Dimension("hidden5", FLAGS.hidden_size) hidden_dim6 = mtf.Dimension("hidden6", FLAGS.hidden_size) hidden_dim7 = mtf.Dimension("hidden7", FLAGS.hidden_size) hidden_dim8 = mtf.Dimension("hidden8", FLAGS.hidden_size) h3 = mtf.layers.dense( h2, hidden_dim3, activation=mtf.relu, name="hidden3") h4 = mtf.layers.dense( h3, hidden_dim4, activation=mtf.relu, name="hidden4") h5 = mtf.layers.dense( h4, hidden_dim5, activation=mtf.relu, name="hidden5") h6 = mtf.layers.dense( h5, hidden_dim6, activation=mtf.relu, name="hidden6") h7 = mtf.layers.dense( h6, hidden_dim7, activation=mtf.relu, name="hidden7") h8 = mtf.layers.dense( h7, hidden_dim8, activation=mtf.relu, name="hidden8") logits = mtf.layers.dense(h8, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss