def testDenseReluDense(self): batch = 2 channels = 3 hidden = 5 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) hidden_dim = mtf.Dimension("hidden", hidden) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim])) mtf_outputs = mtf_layers.dense_relu_dense(mtf_inputs, hidden_channels=hidden_dim) mesh_impl = placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual = sess.run(actual_outputs) self.assertEqual(actual.shape, inputs.shape)
def _feedforward_layer(self, x, losses=None): """Feed-forward layer. Args: x: a mtf.Tensor with shape [batch_dim, length_dim, model_dim] losses: a list to be appended-to Returns: a mtf.Tensor with shape [batch_dim, length_dim, model_dim] Raises: ValueError: if hparams make no sense """ hparams = self._hparams feedforward_layer = hparams.feedforward_layer if feedforward_layer == "dense_relu_dense": return mtf_layers.dense_relu_dense( x, self.feedforward_dim, dropout=hparams.relu_dropout, dropout_broadcast_dims=[self.length_dim]) elif feedforward_layer == "moe": overhead = ( hparams.moe_overhead_train if hparams.mode == tf.estimator.ModeKeys.TRAIN else hparams.moe_overhead_eval) output, loss = mtf_layers.moe_v0( x, self.feedforward_dim, self.model_dim, self.experts_dim, loss_coef=hparams.moe_loss_coef, overhead=overhead) if losses is not None: losses.append(loss) return output else: raise ValueError( "hparams.feedforward_layer not recognized %s" % feedforward_layer)
def _feedforward_layer(self, x, losses=None): """Feed-forward layer. Args: x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim] losses: a list to be appended-to Returns: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim] Raises: ValueError: if hparams make no sense """ hparams = self._hparams feedforward_layer = hparams.feedforward_layer if feedforward_layer == "dense_relu_dense": return mtf_layers.dense_relu_dense( x, self.feedforward_dim, dropout=hparams.relu_dropout, dropout_broadcast_dims=[self.length_dim]) elif feedforward_layer == "moe": output, loss = moe.transformer_moe_layer_v1( x, self.model_dim, hparams, hparams.mode == tf.estimator.ModeKeys.TRAIN) elif feedforward_layer == "hmoe": output, loss = moe.transformer_moe_layer_v2( x, self.model_dim, hparams, hparams.mode == tf.estimator.ModeKeys.TRAIN) if losses is not None: losses.append(loss) return output else: raise ValueError("hparams.feedforward_layer not recognized %s" % feedforward_layer)
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() # We assume fixed vocab size for targets targets_vocab_size = self._problem_hparams.target_modality._vocab_size # pylint: disable=protected-access targets = tf.to_int32(features["targets"]) # Image preprocessing, reshape into a 1D sequence and shift right. length = hparams.img_len * hparams.img_len * hparams.num_channels targets = tf.reshape(targets, [hparams.batch_size, length]) shifted_targets = common_layers.shift_right_2d(targets) # Declare all the dimensions model_dim = mtf.Dimension("d_model", hparams.hidden_size) batch_dim = mtf.Dimension("batch", hparams.batch_size) length_dim = mtf.Dimension("length", length) max_length_dim = mtf.Dimension("max_length", hparams.max_length) filter_dim = mtf.Dimension("d_ff", hparams.d_ff) kv_channels = mtf.Dimension("kv_channels", hparams.d_kv) heads = mtf.Dimension("heads", hparams.num_heads) def import_to_batch_by_length(x, name): return mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, length_dim]), name=name) def layer_prepostprocess_dropout(x): return mtf.dropout(x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([batch_dim, model_dim])) targets = import_to_batch_by_length(targets, "targets") shifted_targets = import_to_batch_by_length(shifted_targets, "shifted_targets") extra_losses = [] # Create targets content and position embeddings. targets_vocab_size = 256 * hparams.num_channels targets_vocab_dim = mtf.Dimension("vocab", targets_vocab_size) outputs_vocab_dim = mtf.Dimension("output_vocab", 256) # Create embedding var for targets and positions and do a gather. targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([targets_vocab_dim, model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) x = mtf.gather(targets_embedding_var, shifted_targets, targets_vocab_dim) # Add positional embeddings x += mtf.reshape( self.create_positional_emb_2d(targets, max_length_dim, model_dim), [length_dim, model_dim]) # If conditional and input is given, add the input embedding to the target. # TODO(nikip): Verify conditional. if self.has_input and not hparams.unconditional: vocab_size = hparams.num_classes inputs_vocab_dim = mtf.Dimension("vocab", vocab_size) inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = import_to_batch_by_length(inputs, "inputs") # Input embeddings inputs_embedding_var = mtf_layers.embedding( mesh, "input_embedding", mtf.Shape([inputs_vocab_dim, model_dim]), activation_dtype=activation_dtype) inputs_emb = mtf.gather(inputs_embedding_var, inputs, inputs_vocab_dim) x += inputs_emb # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_decoder_layers): layer_name = "decoder_layer_%d" % layer with tf.variable_scope(layer_name): # Self attention layer x += layer_prepostprocess_dropout( mtf_layers.masked_local_attention_1d( mtf_layers.layer_norm(x, model_dim, name="layer_norm_self_att"), None, kv_channels, heads, block_length=hparams.block_length, name="self_att")) # ffn layer x += layer_prepostprocess_dropout( mtf_layers.dense_relu_dense( mtf_layers.layer_norm(x, model_dim, name="layer_norm_ffn"), filter_dim, hparams.dropout, dropout_broadcast_dims=[length_dim])) x = mtf_layers.layer_norm(x, model_dim, name="decoder_final_layer_norm") # Calculate the logits and loss. logits = mtf_layers.dense(x, outputs_vocab_dim, name="logits") soft_targets = mtf.one_hot(targets, outputs_vocab_dim, dtype=activation_dtype) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, soft_targets, outputs_vocab_dim) loss = mtf.reduce_mean(loss) for l in extra_losses: loss += l return logits, loss