def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a tf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows", 28) cols_dim = mtf.Dimension("cols", 28) classes_dim = mtf.Dimension("classes", 10) hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) x = mtf.import_tf_tensor(mesh, tf.reshape(image, [-1, 28, 28]), mtf.Shape([batch_dim, rows_dim, cols_dim])) h1 = mtf_layers.dense( x, hidden_dim1, reduced_dims=[rows_dim, cols_dim], activation=mtf.relu, name="hidden1") h2 = mtf_layers.dense( h1, hidden_dim2, activation=mtf.relu, name="hidden2") logits = mtf_layers.dense(h2, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, labels, mtf.Shape([batch_dim])) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a tf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows", 28) cols_dim = mtf.Dimension("cols", 28) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor(mesh, tf.reshape(image, [-1, 28, 28]), mtf.Shape([batch_dim, rows_dim, cols_dim])) x = mtf.reshape(x, [batch_dim, rows_dim, cols_dim, one_channel_dim]) # add some convolutional layers to demonstrate that convolution works. # TODO(noam): get spatially-partitioned convolution working. fh_dim = mtf.Dimension("fh", 3) fw_dim = mtf.Dimension("fw", 3) filters1_dim = mtf.Dimension("filters1", 32) filters2_dim = mtf.Dimension("filters2", 32) kernel1 = mtf.get_variable(mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable(mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) f1 = mtf.relu(mtf.conv2d(x, kernel1)) f2 = mtf.relu(mtf.conv2d(f1, kernel2)) x = mtf.reduce_mean(f2, reduced_dim=filters2_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf_layers.dense(x, hidden_dim1, reduced_dims=[rows_dim, cols_dim], activation=mtf.relu, name="hidden1") h2 = mtf_layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2") logits = mtf_layers.dense(h2, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, labels, mtf.Shape([batch_dim])) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def _mtf_model_fn(self, features, mesh): features = copy.copy(features) hparams = self._hparams targets = tf.to_int32(features["targets"]) if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) for key in ["targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position"]: if key in features: features[key] = pad_to_max_length(features[key]) shifted_targets = common_layers.shift_right_2d(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) shifted_targets = self._import_to_batch_by_length( shifted_targets, "shifted_targets", mesh, hparams) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = ( mtf_layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) + mtf_layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype)) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) decoder_self_attention_mask = mtf_layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([self.batch_dim, self.model_dim])) extra_losses = [] (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if self.has_input: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf_layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) encoder_decoder_attention_mask = ( mtf_layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf_layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) encoder_decoder_attention_mask = encoder_self_attention_mask x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.num_encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) else: encoder_output = None encoder_decoder_attention_mask = None # DECODER x = (mtf.gather( targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather( positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) # Decoder with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.num_decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) logits = mtf.matmul(x, softmax_var) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf_layers.weights_nonzero( targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l return logits, loss
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) hidden_dim = mtf.Dimension("hidden", hparams.hidden_size) filter_h_dim = mtf.Dimension("filter_height", 7) filter_w_dim = mtf.Dimension("filter_width", 7) filters = mtf.Dimension("filters", hparams.filter_sizes[0]) rows_dim = mtf.Dimension("rows_size", 32) cols_dim = mtf.Dimension("cols_size", 96) row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks) col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) inputs = features["inputs"] x = mtf.import_tf_tensor( mesh, tf.reshape(inputs, [ hparams.batch_size, hparams.row_blocks, hparams.rows_size // hparams.row_blocks, hparams.col_blocks, hparams.num_channels * hparams.cols_size // hparams.col_blocks, 1 ]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) x = mtf.to_float(x) initial_filters = mtf.get_variable( mesh, "init_filters", mtf.Shape([filter_h_dim, filter_w_dim, one_channel_dim, filters])) x = mtf.conv2d_with_blocks(x, initial_filters, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) x = batch_norm_relu(x, is_training) # Conv blocks # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_layers): layer_name = "block_layer_%d" % layer with tf.variable_scope(layer_name): # Residual block layer x = block_layer(inputs=x, filters=hparams.filter_sizes[0], blocks=hparams.layer_sizes[0], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer1", row_blocks_dim=None, col_blocks_dim=None) x = block_layer(inputs=x, filters=hparams.filter_sizes[1], blocks=hparams.layer_sizes[1], strides=[1, 2, 2, 1], is_training=is_training, name="block_layer2", row_blocks_dim=None, col_blocks_dim=None) x = block_layer(inputs=x, filters=hparams.filter_sizes[2], blocks=hparams.layer_sizes[2], strides=[1, 2, 2, 1], is_training=is_training, name="block_layer3", row_blocks_dim=None, col_blocks_dim=None) # Calculate the logits and loss. out = x outputs = mtf_layers.dense(out, hidden_dim, reduced_dims=out.shape.dims[-5:], activation=mtf.relu, name="dense") # We assume fixed vocab size for targets labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3]) labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim])) logits = mtf_layers.dense(outputs, classes_dim, name="logits") soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, soft_targets, classes_dim) # Reshape logits so it doesn't break inside t2t. logits = mtf.reshape( logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim])) loss = mtf.reduce_mean(loss) return logits, loss
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() # We assume fixed vocab size for targets targets_vocab_size = self._problem_hparams.target_modality._vocab_size # pylint: disable=protected-access targets = tf.to_int32(features["targets"]) # Image preprocessing, reshape into a 1D sequence and shift right. length = hparams.img_len * hparams.img_len * hparams.num_channels targets = tf.reshape(targets, [hparams.batch_size, length]) shifted_targets = common_layers.shift_right_2d(targets) # Declare all the dimensions model_dim = mtf.Dimension("d_model", hparams.hidden_size) batch_dim = mtf.Dimension("batch", hparams.batch_size) length_dim = mtf.Dimension("length", length) max_length_dim = mtf.Dimension("max_length", hparams.max_length) filter_dim = mtf.Dimension("d_ff", hparams.d_ff) kv_channels = mtf.Dimension("kv_channels", hparams.d_kv) heads = mtf.Dimension("heads", hparams.num_heads) def import_to_batch_by_length(x, name): return mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, length_dim]), name=name) def layer_prepostprocess_dropout(x): return mtf.dropout(x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([batch_dim, model_dim])) targets = import_to_batch_by_length(targets, "targets") shifted_targets = import_to_batch_by_length(shifted_targets, "shifted_targets") extra_losses = [] # Create targets content and position embeddings. targets_vocab_size = 256 * hparams.num_channels targets_vocab_dim = mtf.Dimension("vocab", targets_vocab_size) outputs_vocab_dim = mtf.Dimension("output_vocab", 256) # Create embedding var for targets and positions and do a gather. targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([targets_vocab_dim, model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) x = mtf.gather(targets_embedding_var, shifted_targets, targets_vocab_dim) # Add positional embeddings x += mtf.reshape( self.create_positional_emb_2d(targets, max_length_dim, model_dim), [length_dim, model_dim]) # If conditional and input is given, add the input embedding to the target. # TODO(nikip): Verify conditional. if self.has_input and not hparams.unconditional: vocab_size = hparams.num_classes inputs_vocab_dim = mtf.Dimension("vocab", vocab_size) inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = import_to_batch_by_length(inputs, "inputs") # Input embeddings inputs_embedding_var = mtf_layers.embedding( mesh, "input_embedding", mtf.Shape([inputs_vocab_dim, model_dim]), activation_dtype=activation_dtype) inputs_emb = mtf.gather(inputs_embedding_var, inputs, inputs_vocab_dim) x += inputs_emb # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_decoder_layers): layer_name = "decoder_layer_%d" % layer with tf.variable_scope(layer_name): # Self attention layer x += layer_prepostprocess_dropout( mtf_layers.masked_local_attention_1d( mtf_layers.layer_norm(x, model_dim, name="layer_norm_self_att"), None, kv_channels, heads, block_length=hparams.block_length, name="self_att")) # ffn layer x += layer_prepostprocess_dropout( mtf_layers.dense_relu_dense( mtf_layers.layer_norm(x, model_dim, name="layer_norm_ffn"), filter_dim, hparams.dropout, dropout_broadcast_dims=[length_dim])) x = mtf_layers.layer_norm(x, model_dim, name="decoder_final_layer_norm") # Calculate the logits and loss. logits = mtf_layers.dense(x, outputs_vocab_dim, name="logits") soft_targets = mtf.one_hot(targets, outputs_vocab_dim, dtype=activation_dtype) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, soft_targets, outputs_vocab_dim) loss = mtf.reduce_mean(loss) for l in extra_losses: loss += l return logits, loss
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a tf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 4) col_blocks_dim = mtf.Dimension("col_blocks", 4) rows_dim = mtf.Dimension("rows_size", 7) cols_dim = mtf.Dimension("cols_size", 7) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) # add some convolutional layers to demonstrate that convolution works. fh_dim = mtf.Dimension("fh", 9) fw_dim = mtf.Dimension("fw", 9) filters1_dim = mtf.Dimension("filters1", 16) filters2_dim = mtf.Dimension("filters2", 16) kernel1 = mtf.get_variable(mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable(mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) f1 = mtf.relu( mtf.conv2d_with_blocks(x, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) f2 = mtf.relu( mtf.conv2d_with_blocks(f1, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) x = mtf.reduce_mean(f2, reduced_dim=filters2_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf_layers.dense(x, hidden_dim1, reduced_dims=x.shape.dims[-4:], activation=mtf.relu, name="hidden1") h2 = mtf_layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2") logits = mtf_layers.dense(h2, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss