def make_dense_layers(x, classes_dim): dense_dim1 = mtf.Dimension(name="dense_dim1", size=4096) dense_dim2 = mtf.Dimension(name="dense_dim2", size=4096) x = mtf.layers.dense(x, dense_dim1, name="dense-0") x = mtf.relu(x, name="relu-dense-0") x = mtf.layers.dense(x, dense_dim2, name="dense-1") x = mtf.relu(x, name="relu-dense-1") x = mtf.layers.dense(x, classes_dim, name="dense-2") return x
def BasicBlock(x, order, out_channels, strides): name = "BasicBlock" expansion = 1 out_chls = out_channels // expansion identity = x x = mtf.layers.conv2d(x, output_dim=mtf.Dimension( name=name + '-' + str(order) + '-' + 'filters1', size=out_chls), filter_size=(3, 3), strides=strides, name="conv3x3_BB_1" + '-' + str(order), variable_dtype=float16) print(x.name) print(x.dtype) x, _ = mtf.layers.batch_norm(x, is_training=True, momentum=0.99, epsilon=1e-5, name="batch_norm_BB_1" + '-' + str(order)) x = mtf.relu(x, name="relu_BB_1" + '-' + str(order)) x = mtf.layers.conv2d(x, output_dim=mtf.Dimension( name=name + '-' + str(order) + '-' + 'filters2', size=out_channels), filter_size=(3, 3), strides=(1, 1), name="conv3x3_BB_2" + '-' + str(order), variable_dtype=float16) print(x.name) print(x.dtype) x, _ = mtf.layers.batch_norm(x, is_training=True, momentum=0.99, epsilon=1e-5, name="batch_norm_BB_2" + '-' + str(order)) identity = mtf.reshape(identity, new_shape=[ identity.shape.dims[0], identity.shape.dims[1], identity.shape.dims[2], x.shape.dims[3] ], name="reshape_BB" + str(order)) x = mtf.add(x, identity, output_shape=x.shape, name="add_BB_1" + '-' + str(order)) x = mtf.relu(x, name="relu_BB_2" + '-' + str(order)) print(x.name) print(x.dtype) return x
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a tf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows", 28) cols_dim = mtf.Dimension("cols", 28) classes_dim = mtf.Dimension("classes", 10) x = mtf.import_tf_tensor(mesh, tf.reshape(image, [FLAGS.batch_size, 28, 28]), [batch_dim, rows_dim, cols_dim]) y = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), [batch_dim]) w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, classes_dim]) b1 = mtf.get_variable(mesh, "b1", [classes_dim]) logits = mtf.relu(mtf.einsum([x, w1], [batch_dim, classes_dim]) + b1) if labels is None: loss = None else: loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(y, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def make_conv_layers(x, mode, batch_norm=True): maxpool_count = 0 conv2d_count = 0 for size in mode: if size == "M": x = mtf.layers.max_pool2d(x, ksize=(2, 2), name='maxpool' + '-' + str(maxpool_count)) maxpool_count += 1 else: x = mtf.layers.conv2d(x, output_dim=mtf.Dimension( name='filters' + '-' + str(conv2d_count), size=size), filter_size=(3, 3), strides=(1, 1), name="conv3" + '-' + str(conv2d_count)) if batch_norm: x, _ = mtf.layers.batch_norm(x, is_training=True, momentum=0.99, epsilon=1e-5, name="batch_norm" + '-' + str(conv2d_count)) x = mtf.relu(x, name="relu-conv" + '-' + str(conv2d_count)) conv2d_count += 1 return x
def make_dense_layers(x, classes_dim, float16=None): dense_dim1 = mtf.Dimension(name="dense_dim1",size=4096) dense_dim2 = mtf.Dimension(name="dense_dim2",size=4096) x = mtf.layers.dense(x, dense_dim1, name="dense-0",reduced_dims=x.shape.dims[-3:],variable_dtype=float16) logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape)) x = mtf.relu(x,name="relu-dense-0") logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape)) x = mtf.layers.dense(x, dense_dim2, name="dense-1",variable_dtype=float16) logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape)) x = mtf.relu(x,name="relu-dense-1") logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape)) x = mtf.layers.dense(x, classes_dim, name="dense-2",variable_dtype=float16) logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape)) return x
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ # tf_images is a tf.Tensor with shape [batch, 28, 28] and dtype tf.float32 # tf_labels is a tf.Tensor with shape [batch] and dtype tf.int32 batch_dim = mtf.Dimension("batch", 100) rows_dim = mtf.Dimension("rows", 28) cols_dim = mtf.Dimension("cols", 28) hidden_dim = mtf.Dimension("hidden", 1024) classes_dim = mtf.Dimension("classes", 10) images = mtf.import_tf_tensor(mesh, image, shape=[batch_dim, rows_dim, cols_dim]) labels = mtf.import_tf_tensor(mesh, labels, [batch_dim]) w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, hidden_dim]) w2 = mtf.get_variable(mesh, "w2", [hidden_dim, classes_dim]) # einsum is a generalization of matrix multiplication (see numpy.einsum) hidden = mtf.relu( mtf.einsum(images, w1, output_shape=[batch_dim, hidden_dim])) logits = mtf.einsum(hidden, w2, output_shape=[batch_dim, classes_dim]) loss = mtf.reduce_mean( mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim)) return logits, loss
def entmax_forward(x, alpha=1.3, dim=None, n_iter=50): assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2' _gp = lambda x, alpha: x ** (alpha - 1) _gp_inv = lambda x, alpha: mtf.pow(x, (1 / (alpha - 1))) _p = lambda x, alpha: _gp_inv(mtf.relu(x), alpha) dim = x.shape[-1] if dim is None else dim d = dim.size x = x * (alpha - 1) max_val = mtf.reduce_max(x, reduced_dim=dim) tau_lo = max_val - _gp(1, alpha) tau_hi = max_val - _gp(1 / d, alpha) f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1 dm = tau_hi - tau_lo for _ in range(n_iter): dm = dm / 2 tau_m = tau_lo + dm p_m = _p(x - tau_m, alpha) f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1 mask = mtf.greater_equal((f_m * f_lo), 0) tau_lo = mtf.where(mask, tau_m, tau_lo) p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim) return p_m
def batch_norm_relu(inputs, is_training, relu=True): """Block of batch norm and relu.""" inputs = mtf.layers.batch_norm(inputs, is_training, BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON) if relu: inputs = mtf.relu(inputs) return inputs
def batch_norm_relu(inputs, is_training, relu=True): """Block of batch norm and relu.""" inputs = mtf.layers.batch_norm( inputs, is_training, BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, init_zero=(not relu)) if relu: inputs = mtf.relu(inputs) return inputs
def make_conv_layers(x, mode, batch_norm=True, float16=None): maxpool_count = 0 conv2d_count = 0 for size in mode: if size == "M": x = mtf.layers.max_pool2d( x, ksize=(2,2), name='maxpool'+'-'+str(maxpool_count) ) logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape)) maxpool_count += 1 else: x = conv2d( x, output_dim=mtf.Dimension( name='filters'+'-'+str(conv2d_count), size=size ), filter_size=(3,3), strides=(1,1), name="conv3"+'-'+str(conv2d_count), variable_dtype=float16 ) logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape)) if batch_norm: x,_ = mtf.layers.batch_norm( x, is_training=True, momentum=0.99, epsilon=1e-5, name="batch_norm"+'-'+str(conv2d_count) ) logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape)) x = mtf.relu(x,name="relu-conv"+'-'+str(conv2d_count)) logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape)) conv2d_count += 1 return x
def cifar_model(features, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 32*32] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ features = copy.copy(features) batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 4) col_blocks_dim = mtf.Dimension("col_blocks", 4) rows_dim = mtf.Dimension("rows_size", 8) cols_dim = mtf.Dimension("cols_size", 8) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 3) # image = features['input'] # with tf.device('/cpu:0'): image = features['image'] labels = features['label'] image = bnorm(image) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 4, 8, 4, 8, 3]), mtf.Shape( [batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim]) # add some convolutional layers to demonstrate that convolution works. fh_dim = mtf.Dimension("fh", 7) fw_dim = mtf.Dimension("fw", 7) filters1_dim = mtf.Dimension("filters1", 32) filters2_dim = mtf.Dimension("filters2", 32) kernel1 = mtf.get_variable( mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable( mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) f1 = mtf.relu(mtf.conv2d_with_blocks( x, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) f2 = mtf.relu(mtf.conv2d_with_blocks( f1, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters3_dim = mtf.Dimension("filters3", 64) kernel3 = mtf.get_variable( mesh, "kernel3", [fh_dim, fw_dim, filters2_dim, filters3_dim]) f3 = mtf.relu(mtf.conv2d_with_blocks( f2, kernel3, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters4_dim = mtf.Dimension("filters4", 64) kernel4 = mtf.get_variable( mesh, "kernel4", [fh_dim, fw_dim, filters3_dim, filters4_dim]) f4 = mtf.relu(mtf.conv2d_with_blocks( f3, kernel4, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters5_dim = mtf.Dimension("filters5", 128) kernel5 = mtf.get_variable( mesh, "kernel5", [fh_dim, fw_dim, filters4_dim, filters5_dim]) f5 = mtf.relu(mtf.conv2d_with_blocks( f4, kernel5, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters6_dim = mtf.Dimension("filters6", 128) kernel6 = mtf.get_variable( mesh, "kernel6", [fh_dim, fw_dim, filters5_dim, filters6_dim]) f6 = mtf.relu(mtf.conv2d_with_blocks( f5, kernel6, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters7_dim = mtf.Dimension("filters7", 128) kernel7 = mtf.get_variable( mesh, "kernel7", [fh_dim, fw_dim, filters6_dim, filters7_dim]) f7 = mtf.relu(mtf.conv2d_with_blocks( f6, kernel7, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters8_dim = mtf.Dimension("filters8", 128) kernel8 = mtf.get_variable( mesh, "kernel8", [fh_dim, fw_dim, filters7_dim, filters8_dim]) f8 = mtf.relu(mtf.conv2d_with_blocks( f7, kernel8, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters9_dim = mtf.Dimension("filters9", 128) kernel9 = mtf.get_variable( mesh, "kernel9", [fh_dim, fw_dim, filters8_dim, filters9_dim]) f9 = mtf.relu(mtf.conv2d_with_blocks( f8, kernel9, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters10_dim = mtf.Dimension("filters10", 128) kernel10 = mtf.get_variable( mesh, "kernel10", [fh_dim, fw_dim, filters9_dim, filters10_dim]) f10 = mtf.relu(mtf.conv2d_with_blocks( f9, kernel10, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters11_dim = mtf.Dimension("filters11", 256) kernel11 = mtf.get_variable( mesh, "kernel11", [fh_dim, fw_dim, filters10_dim, filters11_dim]) f11 = mtf.relu(mtf.conv2d_with_blocks( f10, kernel11, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters12_dim = mtf.Dimension("filters12", 256) kernel12 = mtf.get_variable( mesh, "kernel12", [fh_dim, fw_dim, filters11_dim, filters12_dim]) f12 = mtf.relu(mtf.conv2d_with_blocks( f11, kernel12, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters13_dim = mtf.Dimension("filters13", 256) kernel13 = mtf.get_variable( mesh, "kernel13", [fh_dim, fw_dim, filters12_dim, filters13_dim]) f13 = mtf.relu(mtf.conv2d_with_blocks( f12, kernel13, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters14_dim = mtf.Dimension("filters14", 256) kernel14 = mtf.get_variable( mesh, "kernel14", [fh_dim, fw_dim, filters13_dim, filters14_dim]) f14 = mtf.relu(mtf.conv2d_with_blocks( f13, kernel14, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters15_dim = mtf.Dimension("filters15", 256) kernel15 = mtf.get_variable( mesh, "kernel15", [fh_dim, fw_dim, filters14_dim, filters15_dim]) f15 = mtf.relu(mtf.conv2d_with_blocks( f14, kernel15, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters16_dim = mtf.Dimension("filters16", 256) kernel16 = mtf.get_variable( mesh, "kernel16", [fh_dim, fw_dim, filters15_dim, filters16_dim]) f16 = mtf.relu(mtf.conv2d_with_blocks( f15, kernel16, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters17_dim = mtf.Dimension("filters17", 256) kernel17 = mtf.get_variable( mesh, "kernel17", [fh_dim, fw_dim, filters16_dim, filters17_dim]) f17 = mtf.relu(mtf.conv2d_with_blocks( f16, kernel17, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters18_dim = mtf.Dimension("filters18", 256) kernel18 = mtf.get_variable( mesh, "kernel18", [fh_dim, fw_dim, filters17_dim, filters18_dim]) f18 = mtf.relu(mtf.conv2d_with_blocks( f17, kernel18, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) x = mtf.reduce_mean(f18, reduced_dim=filters18_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf.layers.dense( x, hidden_dim1, reduced_dims=x.shape.dims[-4:], activation=mtf.relu, name="hidden1") h2 = mtf.layers.dense( h1, hidden_dim2, activation=mtf.relu, name="hidden2") hidden_dim3 = mtf.Dimension("hidden3", FLAGS.hidden_size) hidden_dim4 = mtf.Dimension("hidden4", FLAGS.hidden_size) hidden_dim5 = mtf.Dimension("hidden5", FLAGS.hidden_size) hidden_dim6 = mtf.Dimension("hidden6", FLAGS.hidden_size) hidden_dim7 = mtf.Dimension("hidden7", FLAGS.hidden_size) hidden_dim8 = mtf.Dimension("hidden8", FLAGS.hidden_size) h3 = mtf.layers.dense( h2, hidden_dim3, activation=mtf.relu, name="hidden3") h4 = mtf.layers.dense( h3, hidden_dim4, activation=mtf.relu, name="hidden4") h5 = mtf.layers.dense( h4, hidden_dim5, activation=mtf.relu, name="hidden5") h6 = mtf.layers.dense( h5, hidden_dim6, activation=mtf.relu, name="hidden6") h7 = mtf.layers.dense( h6, hidden_dim7, activation=mtf.relu, name="hidden7") h8 = mtf.layers.dense( h7, hidden_dim8, activation=mtf.relu, name="hidden8") logits = mtf.layers.dense(h8, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def synthetic_attention(q, k, v, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None, synthesize=True, synthesize_mode="random_plus_alpha", factorized_dim=16, max_length=512, context=None): """Synthetic Attention from Synthesizers (https://arxiv.org/abs/2005.00743). key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor synthesize: flag to use synthetic attention or not synthesize_mode: which variant of synthesizer to use factorized_dim: factorized dim for synthesizers max_length: max length of input sequence context: context since we need context mode Returns: Tensor with shape q.shape - key_dim + value_dim """ if synthesize: num_heads = v.shape.get_dim_by_name("heads") tf.logging.info("Using synthesizer") if synthesize_mode == "random": tf.logging.info("Using Random Synthesizers") r_shape = mtf.Shape([mtf.Dimension("length", max_length), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", max_length)]) r = mtf.get_variable(context.mesh, "R", r_shape, initializer=None, dtype=context.variable_dtype) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") logits = r r_shape = logits.shape elif synthesize_mode == "factorized": tf.logging.info("Using Factorized Random Synthesizers") k = factorized_dim r1_shape = mtf.Shape([mtf.Dimension("tmp", k), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r2_shape = mtf.Shape([mtf.Dimension("tmp", k), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r_shape = mtf.Shape([mtf.Dimension("length", 512), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r1 = mtf.get_variable(context.mesh, "R1", r1_shape, initializer=None, dtype=context.variable_dtype) r2 = mtf.get_variable(context.mesh, "R2", r2_shape, initializer=None, dtype=context.variable_dtype) r = mtf.einsum([r1, r2], r_shape) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") logits = r elif synthesize_mode == "dense_minus": # Dense Synthesizer Model tmp_dim = mtf.Dimension("memory_length", max_length) logits = mtf.layers.dense(mtf.relu(q), [tmp_dim], use_bias=False, name="pi", reduced_dims=[key_dim], variable_dtype=None) logits = mtf.slice(logits, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": pass else: length_dim = q.shape.get_dim_by_name("length") logits = mtf.slice(logits, 0, length_dim.size, "length") elif synthesize_mode == "random_plus_alpha" or \ synthesize_mode == "random_plus": # Mixture Random Synthesizer with learnable Alpha tf.logging.info("Using Random Plus Alpha") logits = mtf.einsum([q, k], reduced_dims=[key_dim]) num_heads = logits.shape.get_dim_by_name("heads") r_shape = mtf.Shape([mtf.Dimension("length", 512), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r = mtf.get_variable(context.mesh, "R", r_shape, initializer=None, dtype=context.variable_dtype) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, length_dim.name) if "alpha" in synthesize_mode: alpha = mtf.get_variable(context.mesh, "alpha", mtf.Shape([mtf.Dimension("alpha", 1)]), initializer=tf.zeros_initializer(), dtype=context.variable_dtype) alpha = mtf.sigmoid(alpha) logits = ((1-alpha) * logits) + (alpha * r) else: logits = logits + r elif synthesize_mode == "dense_plus_alpha" or \ synthesize_mode == "dense_plus": # Mixture Dense Synthesizer with learnable alpha tf.logging.info("Using Dense Plus Alpha Scaling") logits = mtf.einsum([q, k], reduced_dims=[key_dim]) tmp_dim = mtf.Dimension("memory_length", 512) r = mtf.layers.dense(mtf.relu(q), [tmp_dim], use_bias=False, name="pi", reduced_dims=[key_dim], variable_dtype=None) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": pass else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") if "alpha" in synthesize_mode: alpha = mtf.get_variable(context.mesh, "alpha", mtf.Shape([mtf.Dimension("alpha", 1)]), initializer=tf.zeros_initializer(), dtype=context.variable_dtype) alpha = mtf.sigmoid(alpha) logits = ((1-alpha) * logits) + (alpha * r) else: logits = logits + r if bias is not None: logits += bias weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) weights = mtf.dropout( weights, context.train, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) if synthesize and "plus" not in synthesize_mode: if synthesize_mode == "dense_minus": outputs_shape = mtf.Shape(q.shape.dims[:-1] + [value_dim]) else: outputs_shape = mtf.Shape(q.shape.dims[:-1] + [num_heads, value_dim]) else: outputs_shape = q.shape - [key_dim] + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def cifar_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 1) col_blocks_dim = mtf.Dimension("col_blocks", 1) rows_dim = mtf.Dimension("rows_size", 32) cols_dim = mtf.Dimension("cols_size", 32) init = 60 classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 3) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 1, 32, 1, 32, 3]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) # add some convolutional layers to demonstrate that convolution works. filters1_dim = mtf.Dimension("filters1", init) filters2_dim = mtf.Dimension("filters2", init) f1 = mtf.relu( mtf.layers.conv2d_with_blocks(x, filters1_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv0")) #print("conv:, ", f1.shape) f2 = mtf.relu( mtf.layers.conv2d_with_blocks(f1, filters2_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv1")) x = mtf.layers.max_pool2d(f2, ksize=(2, 2), name="maxpool0") #print(x.shape) filters3_dim = mtf.Dimension("filters3", init * 2) filters4_dim = mtf.Dimension("filters4", init * 2) f3 = mtf.relu( mtf.layers.conv2d_with_blocks(x, filters3_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv2")) f4 = mtf.relu( mtf.layers.conv2d_with_blocks(f3, filters4_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv3")) x = mtf.layers.max_pool2d(f4, ksize=(2, 2), name="maxpool1") #print(x.shape) filters5_dim = mtf.Dimension("filters5", init * 4) filters6_dim = mtf.Dimension("filters6", init * 4) f5 = mtf.relu( mtf.layers.conv2d_with_blocks(x, filters5_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv4")) f6 = mtf.relu( mtf.layers.conv2d_with_blocks(f5, filters6_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv5")) x = mtf.layers.max_pool2d(f6, ksize=(2, 2), name="maxpool2") #print(x.shape) filters7_dim = mtf.Dimension("filters7", init * 8) filters8_dim = mtf.Dimension("filters8", init * 8) f7 = mtf.relu( mtf.layers.conv2d_with_blocks(x, filters7_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv6")) f8 = mtf.relu( mtf.layers.conv2d_with_blocks(f7, filters8_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv7")) x = mtf.layers.max_pool2d(f8, ksize=(2, 2), name="maxpool3") # x = mtf.reduce_mean(f8, reduced_dim=filters8_dim) # add some fully-connected dense layers. #hidden_dim1 = mtf.Dimension("hidden1", init*8) hidden_dim1 = mtf.Dimension("hidden1", 256) hidden_dim2 = mtf.Dimension("hidden2", init * 8) h1 = mtf.layers.dense(x, hidden_dim1, reduced_dims=x.shape.dims[-5:], activation=mtf.relu, name="hidden1") #h2 = mtf.layers.dense( #h1, hidden_dim2, #activation=mtf.relu, name="hidden2") logits = mtf.layers.dense(h1, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) all_filters = [[ init, init, init * 2, init * 2, init * 4, init * 4, init * 8, init * 8 ]] return logits, loss, all_filters
def bottleneck_block(inputs, filters, is_training, strides, projection_shortcut=None, row_blocks_dim=None, col_blocks_dim=None): """Bottleneck block variant for residual networks with BN after convolutions. Args: inputs: a `mtf.Tensor` of shape `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`. filters: `int` number of filters for the first two convolutions. Note that the third and final convolution will use 4 times as many filters. is_training: `bool` for whether the model is in training mode. strides: `int` block stride. If greater than 1, this block will ultimately downsample the input. projection_shortcut: `function` to use for projection shortcuts (typically a 1x1 convolution to match the filter dimensions). If None, no projection is used and the input is passed as unchanged through the shortcut connection. row_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis col_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis Returns: The output `Tensor` of the block. """ shortcut = inputs filter_h_dim = mtf.Dimension("filter_height", 3) filter_w_dim = mtf.Dimension("filter_width", 3) one_h_dim = mtf.Dimension("filter_height", 1) one_w_dim = mtf.Dimension("filter_width", 1) if projection_shortcut is not None: filters_dim = mtf.Dimension("filtersp", filters) kernel = mtf.get_variable( inputs.mesh, "kernel", mtf.Shape( [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim])) shortcut = projection_shortcut(inputs, kernel) # First conv block filters1_dim = mtf.Dimension("filters1", filters) kernel1 = mtf.get_variable( inputs.mesh, "kernel1", mtf.Shape([one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim])) inputs = mtf.conv2d_with_blocks(inputs, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) # TODO(nikip): Add Dropout? inputs = batch_norm_relu(inputs, is_training) # Second conv block filters2_dim = mtf.Dimension("filters2", filters) kernel2 = mtf.get_variable( inputs.mesh, "kernel2", mtf.Shape([filter_h_dim, filter_w_dim, filters1_dim, filters2_dim])) inputs = mtf.conv2d_with_blocks(inputs, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim) inputs = batch_norm_relu(inputs, is_training) # Third wide conv filter block filters3_dim = mtf.Dimension("filters3", filters) filters3_kernel = mtf.get_variable( inputs.mesh, "wide_kernel", mtf.Shape([one_h_dim, one_w_dim, filters2_dim, filters3_dim])) inputs = mtf.conv2d_with_blocks(inputs, filters3_kernel, strides, padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) inputs = batch_norm_relu(inputs, is_training, relu=False) # TODO(nikip): Maybe add residual with a projection? return mtf.relu(inputs + mtf.rename_dimension( shortcut, shortcut.shape.dims[-1].name, inputs.shape.dims[-1].name))
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 4) col_blocks_dim = mtf.Dimension("col_blocks", 4) rows_dim = mtf.Dimension("rows_size", 7) cols_dim = mtf.Dimension("cols_size", 7) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]), mtf.Shape( [batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim]) tf.logging.info("[intra variable] (name, shape): ({},{})".format(x.name,x.shape)) # add some convolutional layers to demonstrate that convolution works. filters1_dim = mtf.Dimension("filters1", 16) filters2_dim = mtf.Dimension("filters2", 16) f1 = mtf.relu(mtf.layers.conv2d_with_blocks( x, filters1_dim, filter_size=[9, 9], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv0")) tf.logging.info("[intra variable] (name, shape): ({},{})".format(f1.name,f1.shape)) f2 = mtf.relu(mtf.layers.conv2d_with_blocks( f1, filters2_dim, filter_size=[9, 9], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv1")) tf.logging.info("[intra variable] (name, shape): ({},{})".format(f2.name,f2.shape)) x = mtf.reduce_mean(f2, reduced_dim=filters2_dim) tf.logging.info("[intra variable] (name, shape): ({},{})".format(x.name,x.shape)) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf.layers.dense( x, hidden_dim1, reduced_dims=x.shape.dims[-4:], activation=mtf.relu, name="hidden1") tf.logging.info("[intra variable] (name, shape): ({},{})".format(h1.name,h1.shape)) h2 = mtf.layers.dense( h1, hidden_dim2, activation=mtf.relu, name="hidden2") tf.logging.info("[intra variable] (name, shape): ({},{})".format(h2.name,h2.shape)) logits = mtf.layers.dense(h2, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def bottleneck_block(inputs, filters, is_training, strides, projection_shortcut=None, row_blocks_dim=None, col_blocks_dim=None): """Bottleneck block variant for residual networks with BN after convolutions. Args: inputs: a `mtf.Tensor` of shape `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`. filters: `int` number of filters for the first two convolutions. Note that the third and final convolution will use 4 times as many filters. is_training: `bool` for whether the model is in training mode. strides: `int` block stride. If greater than 1, this block will ultimately downsample the input. projection_shortcut: `function` to use for projection shortcuts (typically a 1x1 convolution to match the filter dimensions). If None, no projection is used and the input is passed as unchanged through the shortcut connection. row_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis col_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis Returns: The output `Tensor` of the block. """ shortcut = inputs if projection_shortcut is not None: filters_dim = mtf.Dimension("filtersp", filters) shortcut = projection_shortcut(inputs, filters_dim) # First conv block inputs = mtf.layers.conv2d_with_blocks(inputs, mtf.Dimension("filters1", filters), filter_size=[1, 1], strides=[1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim, name="conv0") # TODO(nikip): Add Dropout? inputs = batch_norm_relu(inputs, is_training) # Second conv block inputs = mtf.layers.conv2d_with_blocks(inputs, mtf.Dimension( "filters2", 4 * filters), filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv1") inputs = batch_norm_relu(inputs, is_training) # Third wide conv filter block inputs = mtf.layers.conv2d_with_blocks(inputs, mtf.Dimension("filters3", filters), filter_size=[1, 1], strides=strides, padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim, name="conv2") # TODO(nikip): Althought the original resnet code has this batch norm, in our # setup this is causing no gradients to be passed. Investigate further. # inputs = batch_norm_relu(inputs, is_training, relu=True) # TODO(nikip): Maybe add residual with a projection? return mtf.relu(shortcut + mtf.rename_dimension( inputs, inputs.shape.dims[-1].name, shortcut.shape.dims[-1].name))
def ResidualBlockWithDown(x, order, out_channels, strides, float16=None, batch_norm=False): name = "ResidualBlockWithDown" expansion = 4 out_chls = out_channels // expansion identity = x x = conv2d(x, output_dim=mtf.Dimension(name=name + '-' + str(order) + '-' + 'filters1', size=out_chls), filter_size=(1, 1), strides=(1, 1), name="conv1x1_RBW_1" + '-' + str(order), variable_dtype=float16) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) if batch_norm: x, _ = mtf.layers.batch_norm(x, is_training=True, momentum=0.99, epsilon=1e-5, name="batch_norm_RBW_1" + '-' + str(order)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.relu(x, name="relu_RBW_1" + '-' + str(order)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = conv2d(x, output_dim=mtf.Dimension(name=name + '-' + str(order) + '-' + 'filters2', size=out_chls), filter_size=(3, 3), strides=strides, name="conv3x3_RBW_1" + '-' + str(order), variable_dtype=float16) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) if batch_norm: x, _ = mtf.layers.batch_norm(x, is_training=True, momentum=0.99, epsilon=1e-5, name="batch_norm_RBW_2" + '-' + str(order)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.relu(x, name="relu_RBW_2" + '-' + str(order)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = conv2d(x, output_dim=mtf.Dimension(name=name + '-' + str(order) + '-' + 'filters3', size=out_channels), filter_size=(1, 1), strides=(1, 1), name="conv1x1-2_RBW_2" + '-' + str(order), variable_dtype=float16) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) if batch_norm: x, _ = mtf.layers.batch_norm(x, is_training=True, momentum=0.99, epsilon=1e-5, name="batch_norm_RBW_3" + '-' + str(order)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) identity = conv2d(identity, output_dim=mtf.Dimension(name=name + '-' + str(order) + '-' + 'filters3', size=out_channels), filter_size=(1, 1), strides=strides, name="conv1x1_RBW_3" + '-' + str(order), variable_dtype=float16) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) if batch_norm: identity, _ = mtf.layers.batch_norm(identity, is_training=True, momentum=0.99, epsilon=1e-5, name="batch_norm_RBW_4" + '-' + str(order)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) identity = mtf.reshape(identity, new_shape=[ identity.shape.dims[0], identity.shape.dims[1], identity.shape.dims[2], x.shape.dims[3] ], name="reshape_RBW" + str(order)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.add(x, identity, output_shape=x.shape, name="add_RBW_1" + '-' + str(order)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.relu(x, name="relu_RBW_3" + '-' + str(order)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) return x
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows_size", 28) cols_dim = mtf.Dimension("cols_size", 28) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 28, 28, 1]), mtf.Shape([batch_dim, rows_dim, cols_dim, one_channel_dim])) fh_dim = mtf.Dimension("fh", 3) fw_dim = mtf.Dimension("fw", 3) filters1_dim = mtf.Dimension("filters1", FLAGS.num_filters) filters2_dim = mtf.Dimension("filters2", FLAGS.num_filters) filters3_dim = mtf.Dimension("filters3", FLAGS.num_filters) filters4_dim = mtf.Dimension("filters4", FLAGS.num_filters) filters5_dim = mtf.Dimension("filters5", FLAGS.num_filters) filters6_dim = mtf.Dimension("filters6", FLAGS.num_filters) kernel1 = mtf.get_variable(mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable(mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) kernel3 = mtf.get_variable(mesh, "kernel3", [fh_dim, fw_dim, filters2_dim, filters3_dim]) kernel4 = mtf.get_variable(mesh, "kernel4", [fh_dim, fw_dim, filters3_dim, filters4_dim]) kernel5 = mtf.get_variable(mesh, "kernel5", [fh_dim, fw_dim, filters4_dim, filters5_dim]) kernel6 = mtf.get_variable(mesh, "kernel6", [fh_dim, fw_dim, filters5_dim, filters6_dim]) x = mtf.relu(mtf.conv2d(x, kernel1, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel2, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel3, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel4, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel5, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel6, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.reduce_mean(x, reduced_dim=filters6_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) logits = mtf.Dimension("logits", 10) h1 = mtf.layers.dense(x, hidden_dim1, reduced_dims=x.shape.dims[-2:], activation=mtf.relu, name="hidden1") h2 = mtf.layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2") logits = mtf.layers.dense(h2, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 4) col_blocks_dim = mtf.Dimension("col_blocks", 4) rows_dim = mtf.Dimension("rows_size", 7) cols_dim = mtf.Dimension("cols_size", 7) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) # add some convolutional layers to demonstrate that convolution works. fh_dim = mtf.Dimension("fh", 9) fw_dim = mtf.Dimension("fw", 9) filters1_dim = mtf.Dimension("filters1", 16) filters2_dim = mtf.Dimension("filters2", 16) kernel1 = mtf.get_variable(mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable(mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) f1 = mtf.relu( mtf.conv2d_with_blocks(x, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) f2 = mtf.relu( mtf.conv2d_with_blocks(f1, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) x = mtf.reduce_mean(f2, reduced_dim=filters2_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) #hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf.layers.dense(x, hidden_dim1, reduced_dims=x.shape.dims[-4:], activation=mtf.relu, name="hidden1") #h2 = mtf.layers.dense( # h1, hidden_dim2, # activation=mtf.relu, name="hidden2") logits = mtf.layers.dense(h1, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 1) col_blocks_dim = mtf.Dimension("col_blocks", 1) rows_dim = mtf.Dimension("rows_size", 28) cols_dim = mtf.Dimension("cols_size", 28) init = 60 classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 1, 28, 1, 28, 1]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) # add some convolutional layers to demonstrate that convolution works. filters1_dim = mtf.Dimension("filters1", 60) f1 = mtf.relu( mtf.layers.conv2d_with_blocks(x, filters1_dim, filter_size=[7, 7], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv0")) # f1 = mtf.reshape(f1, [FLAGS.batch_size, 1, 30, 3, 10, 1]) filters2_dim = mtf.Dimension("filters2", 120) f2 = mtf.relu( mtf.layers.conv2d_with_blocks(f1, filters2_dim, filter_size=[5, 5], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv1")) filters3_dim = mtf.Dimension("filters3", 240) f3 = mtf.relu( mtf.layers.conv2d_with_blocks(f2, filters3_dim, filter_size=[3, 3], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv2")) x = mtf.layers.avg_pool2d(f3, ksize=(2, 2), name="averagePool") # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", 128) print(x.shape) h1 = mtf.layers.dense(x, hidden_dim1, reduced_dims=x.shape.dims[-5:], activation=mtf.relu, name="hidden1") # h1=x # print(h1.shape) logits = mtf.layers.dense(h1, classes_dim, name="logits") # logits = mtf.layers.dense(h1, classes_dim, reduced_dims=x.shape.dims[-5:], name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) all_filters = [[init, init * 2, init * 4]] return logits, loss, all_filters
def Alexnet(img, labels, num_nodes, num_gpus, args): num_classes = 1000 keep_prob = 0.5 learning_rate = 0.01 graph, meshes, mesh_to_impl, mtf_img, mtf_labels = CreateMeshes( img, labels, num_nodes, num_gpus, args) RenameFC = lambda x: mt.rename_dimension(x, x.shape[-1].name, utils.RandName()) strategy = args.strategy if strategy == 0: fc6_units = mtf.Dimension(utils.RandName(), 4096) fc7_units = mtf.Dimension(utils.RandName(), 4096) fc8_units = mtf.Dimension(utils.RandName(), num_classes) elif strategy == 1: fc6_units = mtf.Dimension('axis1', 4096) fc7_units = mtf.Dimension('axis0', 4096) fc8_units = mtf.Dimension('axis1', num_classes) elif strategy == 2: num_classes = utils.RoundUp(num_classes, num_gpus) fc6_units = mtf.Dimension('axis0', 4096) fc7_units = mtf.Dimension('axis0', 4096) fc8_units = mtf.Dimension('axis0', num_classes) elif strategy == 3: num_classes = utils.RoundUp(num_classes, num_gpus // 2) fc6_units = mtf.Dimension('axis1', 4096) fc7_units = mtf.Dimension('axis1', 4096) fc8_units = mtf.Dimension('axis1', num_classes) with tf.variable_scope('alexnet'): # Conv1 + ReLU + maxpool1 conv1 = mt.Conv2d(mtf_img, GetFilterShape(mtf_img, (11, 11, 3, 96)), (4, 4), 'VALID', activation=mtf.relu, name='conv1') pool1 = mt.MaxPool(conv1, (3, 3), (2, 2), 'VALID', name='pool1') # Conv2 + ReLU + maxpool2 conv2 = mt.Conv2d(pool1, GetFilterShape(pool1, (5, 5, 96, 256)), (1, 1), 'SAME', activation=mtf.relu, name='conv2') pool2 = mt.MaxPool(conv2, (3, 3), (2, 2), name='pool2') # Conv3 + ReLU conv3 = mt.Conv2d(pool2, GetFilterShape(pool2, (3, 3, 256, 384)), padding='SAME', activation=mtf.relu, name='conv3') # Conv4 + ReLU conv4 = mt.Conv2d(conv3, GetFilterShape(conv3, (3, 3, 384, 384)), padding='SAME', activation=mtf.relu, name='conv4') # Conv5 + ReLU + maxpool5 conv5 = mt.Conv2d(conv4, GetFilterShape(conv4, (3, 3, 384, 256)), padding='SAME', activation=mtf.relu, name='conv5') pool5 = mt.MaxPool(conv5, (3, 3), (2, 2), name='pool5') # Rename dims if strategy == 1: k_dim = mtf.Dimension(utils.RandName(), utils.Prod(pool5.shape.to_integer_list[1:])) pool5 = mtf.reshape(pool5, mtf.Shape([pool5.shape[0], k_dim])) pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1], (utils.RandName(), 'axis0')) elif strategy == 2: pool5 = mt.rename_dimension(pool5, pool5.shape[0].name, utils.RandName()) elif strategy == 3: assert pool5.shape[0].name == 'axis0' #dim_names = pool5.shape.rename_dimension('axis0', utils.RandName()) #pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1], dim_names) pool5 = ReplaceMeshWithConcatSplit(pool5, meshes[1]) # FC + ReLU + dropout fc_activation = lambda x: mtf.dropout(mtf.relu(x), keep_prob) fc6 = mtf.layers.dense(pool5, fc6_units, activation=fc_activation, reduced_dims=pool5.shape[1:], name='fc6') if strategy == 2: fc6 = RenameFC(fc6) elif strategy == 3: fc6 = RenameFC(fc6) fc7 = mtf.layers.dense(fc6, fc7_units, activation=fc_activation, reduced_dims=fc6.shape.dims[-1:], name='fc7') if strategy == 2: fc7 = RenameFC(fc7) elif strategy == 3: fc7 = RenameFC(fc7) fc8 = mtf.layers.dense(fc7, fc8_units, reduced_dims=fc7.shape.dims[-1:], name='fc8') fc8 = mtf.dropout(fc8, keep_prob) if strategy == 1: assert fc8.shape[-1].name == 'axis1' fc8 = ReplaceMeshWithDuplicates(fc8, meshes[2]) with tf.variable_scope('loss'): if fc8.shape[0] != mtf_labels.shape[0]: fc8 = mt.rename_dimension(fc8, fc8.shape[0].name, mtf_labels.shape[0].name) one_hot_labels = mtf.one_hot(mtf_labels, fc8.shape[-1]) mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits( fc8, one_hot_labels, fc8.shape[-1]) mtf_loss = mtf.reduce_mean(mtf_cross_ent) return graph, mesh_to_impl, mtf_loss
def bottleneck_block(inputs, filters, is_training, strides, projection_shortcut=None, row_blocks_dim=None, col_blocks_dim=None): """Bottleneck block variant for residual networks with BN after convolutions. Args: inputs: a `mtf.Tensor` of shape `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`. filters: `int` number of filters for the first two convolutions. Note that the third and final convolution will use 4 times as many filters. is_training: `bool` for whether the model is in training mode. strides: `int` block stride. If greater than 1, this block will ultimately downsample the input. projection_shortcut: `function` to use for projection shortcuts (typically a 1x1 convolution to match the filter dimensions). If None, no projection is used and the input is passed as unchanged through the shortcut connection. row_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis col_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis Returns: The output `Tensor` of the block. """ shortcut = inputs filter_h_dim = mtf.Dimension("filter_height", 3) filter_w_dim = mtf.Dimension("filter_width", 3) one_h_dim = mtf.Dimension("filter_height", 1) one_w_dim = mtf.Dimension("filter_width", 1) if projection_shortcut is not None: filters_dim = mtf.Dimension("filtersp", filters) kernel = mtf.get_variable( inputs.mesh, "kernel", mtf.Shape( [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim])) shortcut = projection_shortcut(inputs, kernel) # First conv block filters1_dim = mtf.Dimension("filters1", filters) kernel1 = mtf.get_variable( inputs.mesh, "kernel1", mtf.Shape( [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim])) inputs = mtf.conv2d_with_blocks( inputs, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) # TODO(nikip): Add Dropout? inputs = batch_norm_relu(inputs, is_training) # Second conv block filters2_dim = mtf.Dimension("filters2", 4*filters) kernel2 = mtf.get_variable( inputs.mesh, "kernel2", mtf.Shape( [filter_h_dim, filter_w_dim, filters1_dim, filters2_dim])) inputs = mtf.conv2d_with_blocks( inputs, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim) inputs = batch_norm_relu(inputs, is_training) # Third wide conv filter block filters3_dim = mtf.Dimension("filters3", filters) filters3_kernel = mtf.get_variable( inputs.mesh, "wide_kernel", mtf.Shape( [one_h_dim, one_w_dim, filters2_dim, filters3_dim])) inputs = mtf.conv2d_with_blocks( inputs, filters3_kernel, strides, padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) # TODO(nikip): Althought the original resnet code has this batch norm, in our # setup this is causing no gradients to be passed. Investigate further. # inputs = batch_norm_relu(inputs, is_training, relu=True) # TODO(nikip): Maybe add residual with a projection? return mtf.relu( shortcut + mtf.rename_dimension( inputs, inputs.shape.dims[-1].name, shortcut.shape.dims[-1].name))
def backbone(x, layerlist, chalist, strilist, classes_dim, blocklist, float16=None, batch_norm=False): name = "backbone" x = conv2d( x, output_dim=mtf.Dimension(name=name + '-' + 'filters', size=64), filter_size=(7, 7), strides=(2, 2), # padding="VALID", name="conv7x7_backbone", variable_dtype=float16) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) if batch_norm: x, _ = mtf.layers.batch_norm(x, is_training=True, momentum=0.99, epsilon=1e-5, name="batch_norm_backbone") logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.relu(x, name="relu_backbone") logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.layers.max_pool2d(x, ksize=(2, 2), name="maxpool_backbone") logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) shortcuttype1 = 0 shortcuttype2 = 0 for _, (layer, channel, strides) in enumerate(zip(layerlist, chalist, strilist)): x = blocklist[0](x, order=shortcuttype1, out_channels=channel, strides=(strides, strides), float16=float16, batch_norm=batch_norm) shortcuttype1 += 1 for _ in range(layer - 1): x = blocklist[1](x, order=shortcuttype2, out_channels=channel, strides=(1, 1), float16=float16, batch_norm=batch_norm) shortcuttype2 += 1 # x = mtf.einsum([x], output_shape=[list(x.shape.dims)[0],list(x.shape.dims)[3]], name="einsum_backbone") x = mtf.layers.avg_pool2d(x, ksize=(x.shape.dims[1].size, x.shape.dims[2].size)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.reshape( x, new_shape=[ x.shape.dims[0], mtf.Dimension(name="flatten", size=x.shape.dims[1].size * x.shape.dims[2].size * x.shape.dims[3].size) ], name="flatten") logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) logit = mtf.layers.dense(x, classes_dim, name="dense_backbone", variable_dtype=float16) logger.debug("[output tensor] (name,shape):({},{})".format( logit.name, logit.shape)) return logit