def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows_size", image_height) cols_dim = mtf.Dimension("cols_size", image_width) channel_dim = mtf.Dimension("image_channel", num_channels) classes_dim = mtf.Dimension(name='classesnum',size=classesnum) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, image_height, image_width, num_channels]), mtf.Shape( [batch_dim, rows_dim, cols_dim, channel_dim])) # x = mtf.transpose(x, [batch_dim, rows_dim, cols_dim, channel_dim]) # print(x.shape) logits = VGG(x, classes_dim=classes_dim,depth=depth) logits = mtf.cast(logits,dtype=tf.float32) if labels is None: loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ # tf_images is a tf.Tensor with shape [batch, 28, 28] and dtype tf.float32 # tf_labels is a tf.Tensor with shape [batch] and dtype tf.int32 batch_dim = mtf.Dimension("batch", 100) rows_dim = mtf.Dimension("rows", 28) cols_dim = mtf.Dimension("cols", 28) hidden_dim = mtf.Dimension("hidden", 1024) classes_dim = mtf.Dimension("classes", 10) images = mtf.import_tf_tensor(mesh, image, shape=[batch_dim, rows_dim, cols_dim]) labels = mtf.import_tf_tensor(mesh, labels, [batch_dim]) w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, hidden_dim]) w2 = mtf.get_variable(mesh, "w2", [hidden_dim, classes_dim]) # einsum is a generalization of matrix multiplication (see numpy.einsum) hidden = mtf.relu( mtf.einsum(images, w1, output_shape=[batch_dim, hidden_dim])) logits = mtf.einsum(hidden, w2, output_shape=[batch_dim, classes_dim]) loss = mtf.reduce_mean( mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim)) return logits, loss
def call(self, context, x, losses=None): """Call the layer.""" memory_length = self.memory_length(context) q = self.compute_q(context, x) if context.mode == "incremental": m = x else: m = mtf.replace_dimensions(x, context.length_dim, memory_length) k = self.compute_k(context, m) v = self.compute_v(context, m) if context.mode == "incremental": one_hot = mtf.one_hot( context.position, memory_length, dtype=context.activation_dtype) inv_one_hot = 1.0 - one_hot old_k, old_v = context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot memory_position = mtf.range(context.mesh, memory_length, tf.int32) else: memory_position = self.rename_length_to_memory_length( context.position, context) if context.mode == "incremental" or context.mode == "first_part": context.record_new_states([k, v]) bias = self.compute_bias(context, memory_position, x, self.softmax_heads_dims, q) return self.attention_internal(context, x, m, q, k, v, memory_length, bias)
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a tf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows", 28) cols_dim = mtf.Dimension("cols", 28) classes_dim = mtf.Dimension("classes", 10) x = mtf.import_tf_tensor(mesh, tf.reshape(image, [FLAGS.batch_size, 28, 28]), [batch_dim, rows_dim, cols_dim]) y = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), [batch_dim]) w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, classes_dim]) b1 = mtf.get_variable(mesh, "b1", [classes_dim]) logits = mtf.relu(mtf.einsum([x, w1], [batch_dim, classes_dim]) + b1) if labels is None: loss = None else: loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(y, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def body_fn(position, ids, *states): """One step in the decode loop.""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, position - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states
def get_project_to_cluster_length(self, cluster_mask, dtype): """Returns projection from length dim to the shorter cluster length dim.""" seq_length_dim = cluster_mask.shape.get_dim_by_name("length") cluster_length_dim = self.get_cluster_length_dim(seq_length_dim) return mtf.cast(cluster_mask, dtype) * mtf.one_hot( mtf.cumsum(mtf.cast(cluster_mask, tf.int32), seq_length_dim) - 1, output_dim=cluster_length_dim, dtype=dtype)
def get_log_softmax_prefix(self, log_softmax, end_index): """Returns first end_index entries in log_softmax along the vocab dim.""" prefix_dim = mtf.Dimension(self._vocab_dim.name, end_index) indices = mtf.mtf_range( log_softmax.mesh, dim=self._vocab_dim, dtype=tf.int32) prefix_indices = mtf.where(mtf.less(indices, end_index), indices, -1) projection = mtf.one_hot( prefix_indices, prefix_dim, dtype=log_softmax.dtype) return mtf.einsum([log_softmax, projection], reduced_dims=[self._vocab_dim])
def call(self, context, x, losses=None): """Call the layer.""" params = self.make_params(context) q = params.compute_q(x) memory_length = self.memory_length(context) if context.mode == "incremental": m = x else: m = mtf.replace_dimensions(x, context.length_dim, memory_length) if self.shared_kv: kv = params.compute_kv(m) else: k = params.compute_k(m) v = params.compute_v(m) if context.mode == "incremental": one_hot = mtf.one_hot( context.position, memory_length, dtype=context.activation_dtype) inv_one_hot = 1.0 - one_hot if self.shared_kv: old_kv = context.get_states(1) kv = old_kv * inv_one_hot + kv * one_hot else: old_k, old_v = context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot memory_position = mtf.range(context.mesh, memory_length, tf.int32) else: memory_position = self.rename_length_to_memory_length( context.position, context) if context.mode == "incremental" or context.mode == "first_part": context.record_new_states([kv] if self.shared_kv else [k, v]) if self.shared_kv: k = kv v = kv if self.attention_func == "hybrid": o = attention.hybrid_attention( q, k, v, context, memory_length, self.kv_dim, self.kv_dim, self.compute_bias( context, memory_position, x, params.query_heads_dims), **self.attention_kwargs_from_context(context)) else: o = attention.attention( q, k, v, memory_length, self.kv_dim, self.kv_dim, self.compute_bias( context, memory_position, x, params.query_heads_dims), **self.attention_kwargs_from_context(context)) return params.compute_output(o, output_shape=x.shape)
def mnist_model(image, labels, mesh, hs_t): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh hs_t: a mtf.Tensor with shape [batch, hidden_1] Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] hs_t: an updated mtf.Tensor """ input_num = 28 timesteps_num = 28 classes_num = 10 batch_dim = mtf.Dimension("batch", FLAGS.batch_size) input_dim = mtf.Dimension("input", input_num) timesteps_dim = mtf.Dimension("timesteps", timesteps_num) classes_dim = mtf.Dimension("classes", classes_num) hidden_dim_1 = mtf.Dimension("hidden_1", FLAGS.hidden_size) hidden_dim_2 = mtf.Dimension("hidden_2", FLAGS.hidden_size) x = mtf.import_tf_tensor(mesh, tf.reshape(image, [FLAGS.batch_size, 28, 28]), [batch_dim, timesteps_dim, input_dim]) y = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), [batch_dim]) hs_t = mtf.import_tf_tensor(mesh, hs_t, [batch_dim, hidden_dim_1]) Wxh = mtf.get_variable(mesh, "Wxh", [input_dim, hidden_dim_2]) Whh = mtf.get_variable(mesh, "Whh", [hidden_dim_1, hidden_dim_2]) Why = mtf.get_variable(mesh, "Why", [hidden_dim_2, classes_dim]) bh = mtf.get_variable(mesh, "bh", [hidden_dim_2]) by = mtf.get_variable(mesh, "by", [classes_dim]) x_list = mtf.unstack(x, timesteps_dim) for xs_t in x_list: hs_t = mtf.tanh( mtf.einsum([xs_t, Wxh], [batch_dim, hidden_dim_2]) + mtf.einsum([hs_t, Whh], [batch_dim, hidden_dim_2]) + bh) logits = mtf.einsum([hs_t, Why], [batch_dim, classes_dim]) + by if labels is None: loss = None else: loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(y, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss, hs_t
def call(self, context, x, losses=None): """Call the layer.""" wq, wk, wv, wo = mtf.layers.multihead_attention_params( context.mesh, self.heads_dim, context.model_dim, self.kv_dim, context.variable_dtype) memory_length = mtf.Dimension("memory_length", context.length_dim.size) q = mtf.einsum([x, wq], reduced_dims=[context.model_dim]) if context.mode == "incremental": m = x else: m = mtf.rename_dimension(x, context.length_dim.name, "memory_length") k = mtf.einsum([m, wk], reduced_dims=[context.model_dim]) v = mtf.einsum([m, wv], reduced_dims=[context.model_dim]) if context.mode == "incremental": old_k, old_v = context.get_states(2) one_hot = mtf.one_hot(context.position, memory_length, dtype=context.activation_dtype) inv_one_hot = 1.0 - one_hot k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if context.mode == "incremental" or context.mode == "first_part": context.record_new_states([k, v]) masks = [] if context.autoregressive: masks.append( mtf.cast( mtf.less( context.position, mtf.range(context.mesh, memory_length, dtype=tf.int32)), context.activation_dtype) * -1e9) if (context.sequence_id is not None and isinstance(context.sequence_id, mtf.Tensor) and context.length_dim in context.sequence_id.shape): masks.append( mtf.cast( mtf.not_equal( context.sequence_id, mtf.layers.rename_length_to_memory_length( context.sequence_id)), context.activation_dtype) * -1e9) mask = mtf.add_n(masks) if masks else None o = mtf.layers.dot_product_attention_v2( q, k, v, memory_length, self.kv_dim, self.kv_dim, mask, self.dropout_rate if context.train else 0.0, [context.length_dim]) return mtf.einsum([o, wo], x.shape, reduced_dims=[self.heads_dim, self.kv_dim])
def compute_loss(self, decoder, hidden, targets, context): """Computes the loss during training.""" logits = self._embedding.hidden_to_logits(hidden, context=context) soft_targets = mtf.one_hot(targets - self._start_token_id, self._vocab_dim, dtype=context.activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self._vocab_dim, z_loss=decoder.z_loss) padding_mask = mtf.layers.weights_nonzero( targets, dtype=context.activation_dtype) return (mtf.reduce_sum(loss * padding_mask) / decoder.loss_denominator(targets, context.num_microbatches))
def call(self, context, x, losses=None): """Call the layer.""" memory_length = self.memory_length(context) q = self.compute_q(context, x) if context.mode == "incremental": m = x else: m = mtf.replace_dimensions(x, context.length_dim, memory_length) if context.mode == "incremental": one_hot = mtf.one_hot( context.position, memory_length, dtype=context.activation_dtype) inv_one_hot = 1.0 - one_hot old_m, = context.get_states(1) m = old_m * inv_one_hot + one_hot * m memory_position = mtf.range(context.mesh, memory_length, tf.int32) else: memory_position = self.rename_length_to_memory_length( context.position, context) if context.mode == "incremental" or context.mode == "first_part": context.record_new_states([m]) bias = self.compute_bias(context, memory_position, x, self.heads_dims, q) return self.attention_internal(context, q, m, memory_length, bias)
def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0): if targets.dtype.is_integer: # hard targets if (set(targets.shape.dims) != set(logits.shape.dims).difference([vocab_dim])): raise ValueError( "softmax_cross_entropy_with_logits with hard targets " "dims in targets=%s should be dims in logits=%s other than " "vocab_dim=%s" % (targets, logits, vocab_dim)) targets = mtf.one_hot(targets, vocab_dim, dtype=logits.dtype) elif set(targets.shape.dims) != set(logits.shape.dims): raise ValueError( "softmax_cross_entropy_with_logits with soft targets " "dims in targets=%s should be dims in logits=%s" % (targets, logits)) if vocab_dim not in logits.shape.dims: raise ValueError("vocab_dim must be in logits.shape.dims") log_entmax = mtf.log(entmax(logits, dim=vocab_dim)) loss = mtf.negative( mtf.reduce_sum(log_entmax * targets, reduced_dim=vocab_dim)) return loss
def model_backbone(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 32*32] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", args_opt.batch_size) rows_dim = mtf.Dimension("rows_size", 224) cols_dim = mtf.Dimension("cols_size", 224) channel_dim = mtf.Dimension("image_channel", 3) classes_dim = mtf.Dimension(name='classesnum',size=args_opt.class_num) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [args_opt.batch_size, 224, 224, 3]), mtf.Shape( [batch_dim, rows_dim, cols_dim, channel_dim])) if args_opt.fp16: float16=mtf.VariableDType(tf.float16,tf.float16,tf.float16) else: float16=None logits = network[args_opt.model](x, classes_dim=classes_dim,float16=float16,batch_norm=False if 'vgg' in args_opt.model else True) logits = mtf.cast(logits,dtype=tf.float32) if labels is None: loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [args_opt.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def compute_loss(logits, positions): one_hot_positions = mtf.one_hot(positions, output_dim=seq_dim) log_probs = mtf.log_softmax(logits, seq_dim) loss = -mtf.reduce_mean( mtf.reduce_sum(one_hot_positions * log_probs, reduced_dim=seq_dim)) return loss
def _top_2_gating(inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="top_2_gating"): """Compute gating for mixture-of-experts in TensorFlow. Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_use_second_place_loss: a boolean hparams.moe_second_policy_train: a string hparams.moe_second_policy_eval: a string hparams.moe_second_threshold: a float The returned forward assignment is a tensor used to map (via einsum) from the inputs to the expert_inputs. Likewise, the returned combine_tensor is used to map (via einsum) from the expert outputs to the outputs. Both the forward and backward assignments are mostly zeros. The shapes of the tensors are as follows. inputs: [<batch_dims>, group_size_dim, input_dim] importance: [<batch_dims>, group_size_dim] dispatch_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] expert_inputs: [<batch_dims>, experts_dim, expert_capacity_dim, input_dim] expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim] combine_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] outputs: [<batch_dims>, group_size_dim, output_dim] "importance" is an optional tensor with one floating-point value for each input vector. If the importance of an input is 1.0, then we send it to up to 2 experts. If 0.0 < importance < 1.0, then we send it to at most one expert. If importance == 0.0, then we send it to no experts. We use "importance" at the second-level gating function of a hierarchical mixture of experts. Inputs to the first-choice expert-group get importance 1.0. Inputs to the second-choice expert group get importance 0.5. Inputs that represent padding get importance 0.0. Args: inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim] outer_expert_dims: an optional list of dimensions. This is for the case where we are at an inner level of a hierarchical MoE. experts_dim: a Dimension (the number of experts) expert_capacity_dim: a Dimension (number of examples per group per expert) hparams: model hyperparameters. train: a boolean variable_dtype: a mtf.VariableDType importance: an optional tensor with shape [<batch_dims>, group_size_dim] name: an optional string Returns: dispatch_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] combine_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] loss: a mtf scalar Raises: ValueError: on illegal hyperparameters """ group_size_dim, unused_input_dim = inputs.shape.dims[-2:] raw_gates = mtf.layers.dense(inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(raw_gates, experts_dim) # The internals of this function run in float32. # bfloat16 seems to reduce quality. raw_gates = mtf.to_float(raw_gates) expert_capacity_f = float(expert_capacity_dim.size) # FIND TOP 2 EXPERTS PER POSITON # Find the top expert for each position. shape=[batch, group] index_1, gate_1 = mtf.top_1(raw_gates, experts_dim) # [batch, group, experts] mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype) density_1_proxy = raw_gates if importance is not None: mask_1 *= mtf.to_float(mtf.equal(importance, 1.0)) gate_1 *= mtf.to_float(mtf.equal(importance, 1.0)) density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0)) gates_without_top_1 = raw_gates * (1.0 - mask_1) # [batch, group] index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim) # [batch, group, experts] mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype) if importance is not None: mask_2 *= mtf.to_float(mtf.greater(importance, 0.0)) denom = gate_1 + gate_2 + 1e-9 gate_1 /= denom gate_2 /= denom # BALANCING LOSSES # shape = [batch, experts] # We want to equalize the fraction of the batch assigned to each expert density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim) # Something continuous that is correlated with what we want to equalize. density_1_proxy = mtf.reduce_mean(density_1_proxy, reduced_dim=group_size_dim) loss = (mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if hparams.moe_use_second_place_loss: # Also add a loss to encourage all experts to be used equally also as the # second-place expert. Experimentally, this seems to be a wash. # We want to equalize the fraction of the batch assigned to each expert: density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim) # As a proxy for density_2, we renormalize the raw gates after the top one # has been removed. normalized = gates_without_top_1 / (mtf.reduce_sum( gates_without_top_1, reduced_dim=experts_dim) + 1e-9) density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_size_dim) loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) * float(experts_dim.size * experts_dim.size)) loss += loss_2 * 0.5 # Depending on the policy in the hparams, we may drop out some of the # second-place experts. if train: policy = hparams.moe_second_policy_train threshold = hparams.moe_second_threshold_train else: policy = hparams.moe_second_policy_eval threshold = hparams.moe_second_threshold_eval if policy == "all": # Use second-place experts for all examples. pass elif policy == "none": # Never use second-place experts for all examples. mask_2 = mtf.zeros_like(mask_2) elif policy == "threshold": # Use second-place experts if gate_2 > threshold. mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold)) elif policy == "random": # Use second-place experts with probablity min(1.0, gate_2 / threshold). mask_2 *= mtf.to_float( mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape), gate_2 / max(threshold, 1e-9))) else: raise ValueError("Unknown policy %s" % policy) # COMPUTE ASSIGNMENT TO EXPERTS # [batch, group, experts] # This is the position within the expert's mini-batch for this sequence position_in_expert_1 = mtf.cumsum(mask_1, group_size_dim, exclusive=True) * mask_1 # Remove the elements that don't fit. [batch, group, experts] mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f)) # [batch, experts] # How many examples in this sequence go to this expert mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim) # [batch, group] - mostly ones, but zeros where something didn't fit mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim) # [batch, group] position_in_expert_1 = mtf.reduce_sum(position_in_expert_1, reduced_dim=experts_dim) # Weight assigned to first expert. [batch, group] gate_1 *= mask_1_flat # [batch, group, experts] position_in_expert_2 = ( mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count) position_in_expert_2 *= mask_2 mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f)) # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) gate_2 *= mask_2_flat position_in_expert_2 = mtf.reduce_sum(position_in_expert_2, reduced_dim=experts_dim) # [batch, group, experts, expert_capacity] combine_tensor = ( gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) + gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim)) combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast(mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 4) col_blocks_dim = mtf.Dimension("col_blocks", 4) rows_dim = mtf.Dimension("rows_size", 7) cols_dim = mtf.Dimension("cols_size", 7) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) # add some convolutional layers to demonstrate that convolution works. fh_dim = mtf.Dimension("fh", 9) fw_dim = mtf.Dimension("fw", 9) filters1_dim = mtf.Dimension("filters1", 16) filters2_dim = mtf.Dimension("filters2", 16) kernel1 = mtf.get_variable(mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable(mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) f1 = mtf.relu( mtf.conv2d_with_blocks(x, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) f2 = mtf.relu( mtf.conv2d_with_blocks(f1, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) x = mtf.reduce_mean(f2, reduced_dim=filters2_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) #hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf.layers.dense(x, hidden_dim1, reduced_dims=x.shape.dims[-4:], activation=mtf.relu, name="hidden1") #h2 = mtf.layers.dense( # h1, hidden_dim2, # activation=mtf.relu, name="hidden2") logits = mtf.layers.dense(h1, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def Alexnet(img, labels, num_nodes, num_gpus, args): num_classes = 1000 keep_prob = 0.5 learning_rate = 0.01 graph, meshes, mesh_to_impl, mtf_img, mtf_labels = CreateMeshes( img, labels, num_nodes, num_gpus, args) RenameFC = lambda x: mt.rename_dimension(x, x.shape[-1].name, utils.RandName()) strategy = args.strategy if strategy == 0: fc6_units = mtf.Dimension(utils.RandName(), 4096) fc7_units = mtf.Dimension(utils.RandName(), 4096) fc8_units = mtf.Dimension(utils.RandName(), num_classes) elif strategy == 1: fc6_units = mtf.Dimension('axis1', 4096) fc7_units = mtf.Dimension('axis0', 4096) fc8_units = mtf.Dimension('axis1', num_classes) elif strategy == 2: num_classes = utils.RoundUp(num_classes, num_gpus) fc6_units = mtf.Dimension('axis0', 4096) fc7_units = mtf.Dimension('axis0', 4096) fc8_units = mtf.Dimension('axis0', num_classes) elif strategy == 3: num_classes = utils.RoundUp(num_classes, num_gpus // 2) fc6_units = mtf.Dimension('axis1', 4096) fc7_units = mtf.Dimension('axis1', 4096) fc8_units = mtf.Dimension('axis1', num_classes) with tf.variable_scope('alexnet'): # Conv1 + ReLU + maxpool1 conv1 = mt.Conv2d(mtf_img, GetFilterShape(mtf_img, (11, 11, 3, 96)), (4, 4), 'VALID', activation=mtf.relu, name='conv1') pool1 = mt.MaxPool(conv1, (3, 3), (2, 2), 'VALID', name='pool1') # Conv2 + ReLU + maxpool2 conv2 = mt.Conv2d(pool1, GetFilterShape(pool1, (5, 5, 96, 256)), (1, 1), 'SAME', activation=mtf.relu, name='conv2') pool2 = mt.MaxPool(conv2, (3, 3), (2, 2), name='pool2') # Conv3 + ReLU conv3 = mt.Conv2d(pool2, GetFilterShape(pool2, (3, 3, 256, 384)), padding='SAME', activation=mtf.relu, name='conv3') # Conv4 + ReLU conv4 = mt.Conv2d(conv3, GetFilterShape(conv3, (3, 3, 384, 384)), padding='SAME', activation=mtf.relu, name='conv4') # Conv5 + ReLU + maxpool5 conv5 = mt.Conv2d(conv4, GetFilterShape(conv4, (3, 3, 384, 256)), padding='SAME', activation=mtf.relu, name='conv5') pool5 = mt.MaxPool(conv5, (3, 3), (2, 2), name='pool5') # Rename dims if strategy == 1: k_dim = mtf.Dimension(utils.RandName(), utils.Prod(pool5.shape.to_integer_list[1:])) pool5 = mtf.reshape(pool5, mtf.Shape([pool5.shape[0], k_dim])) pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1], (utils.RandName(), 'axis0')) elif strategy == 2: pool5 = mt.rename_dimension(pool5, pool5.shape[0].name, utils.RandName()) elif strategy == 3: assert pool5.shape[0].name == 'axis0' #dim_names = pool5.shape.rename_dimension('axis0', utils.RandName()) #pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1], dim_names) pool5 = ReplaceMeshWithConcatSplit(pool5, meshes[1]) # FC + ReLU + dropout fc_activation = lambda x: mtf.dropout(mtf.relu(x), keep_prob) fc6 = mtf.layers.dense(pool5, fc6_units, activation=fc_activation, reduced_dims=pool5.shape[1:], name='fc6') if strategy == 2: fc6 = RenameFC(fc6) elif strategy == 3: fc6 = RenameFC(fc6) fc7 = mtf.layers.dense(fc6, fc7_units, activation=fc_activation, reduced_dims=fc6.shape.dims[-1:], name='fc7') if strategy == 2: fc7 = RenameFC(fc7) elif strategy == 3: fc7 = RenameFC(fc7) fc8 = mtf.layers.dense(fc7, fc8_units, reduced_dims=fc7.shape.dims[-1:], name='fc8') fc8 = mtf.dropout(fc8, keep_prob) if strategy == 1: assert fc8.shape[-1].name == 'axis1' fc8 = ReplaceMeshWithDuplicates(fc8, meshes[2]) with tf.variable_scope('loss'): if fc8.shape[0] != mtf_labels.shape[0]: fc8 = mt.rename_dimension(fc8, fc8.shape[0].name, mtf_labels.shape[0].name) one_hot_labels = mtf.one_hot(mtf_labels, fc8.shape[-1]) mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits( fc8, one_hot_labels, fc8.shape[-1]) mtf_loss = mtf.reduce_mean(mtf_cross_ent) return graph, mesh_to_impl, mtf_loss
def body_fn(position, ids, *states): """One step in the decode loop.""" nonlocal sampling_keep_top_k context = mtf_transformer.transformer.Context( model=None, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, position_is_default=True, states=states, new_states=[], initial_position=position, sequence_id=None, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=ids, encoder_inputs=encoder_inputs) if not slow_sampling else None with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE): logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context) # By default, do top_k sampling of 0.9 if sampling_keep_top_k == -2: sampling_keep_top_k = int(logits.shape[-1].size * 0.1) if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=other_features["vocab_dim"]) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, other_features["vocab_dim"], temperature) if slow_sampling: ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False) else: ids_this_step = mtf.reshape(ids_this_step, (batch_dims)) one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot new_ids = (1 - one_hot) * ids + one_new_id new_position = position + 1 ret = [new_position, new_ids] if context is not None: ret += context.new_states return ret
def _mtf_model_fn(self, features, mesh): features = copy.copy(features) hparams = self._hparams targets = tf.to_int32(features["targets"]) if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) for key in [ "targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position" ]: if key in features: features[key] = pad_to_max_length(features[key]) shifted_targets = common_layers.shift_right_2d(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) shifted_targets = self._import_to_batch_by_length( shifted_targets, "shifted_targets", mesh, hparams) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = ( mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) + mtf.layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype)) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) extra_losses = [] (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "decoder": encoder_output = None encoder_decoder_attention_mask = None else: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf.layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack( x, hparams.encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) if hparams.transformer_type == "encdec": if "inputs_segmentation" in features: encoder_decoder_attention_mask = ( mtf.layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: encoder_decoder_attention_mask = encoder_self_attention_mask encoder_output = mtf.rename_dimension(x, self.length_dim.name, self.memory_length_dim.name) if hparams.transformer_type != "encoder": # DECODER x = (mtf.gather(targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) logits = mtf.matmul(x, softmax_var) if hparams.mode == tf.estimator.ModeKeys.TRAIN: logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot(targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l logits = mtf.to_float(logits) # combine batch dims if len(self.batch_dims) > 1: combined_batch_dim = mtf.Dimension(self.batch_dims[0].name, mtf.Shape(self.batch_dims).size) logits = mtf.reshape(logits, [combined_batch_dim] + logits.shape.dims[-2:]) return logits, loss
def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None): # x :: [batch, seq, n_embd] x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh # n_state is the same as config["n_embd"], which is also the same as dim_embd. assert n_state.size % params["n_head"] == 0 dim_heads = mtf.Dimension("heads", params["n_head"]) num_mem_kv = params.get("num_mem_kv", 0) use_num_mem_kv = num_mem_kv > 0 with tf.variable_scope(scope): # Compute attention inputs dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"]) mtfparams = mtf.transformer.attention.attention_params_simple( x.mesh, io_dim=dim_embd, kv_dim=dim_kv, heads_dim=dim_heads, variable_dtype=variable_dtype ) q = mtfparams.compute_q(x) k = mtfparams.compute_k(x) v = mtfparams.compute_v(x) if is_incremental_inference(context): one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype) inv_one_hot = 1.0 - one_hot old_k, old_v = context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if exists(context): context.record_new_states([k, v]) with tf.variable_scope("attention"): if attention_type == "local": # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. radius = params.get("local_attention_radius", 256) if is_incremental_inference(context): q *= one_hot a = mtf_transformer.attention.local_attention_1d( q, k, v, length_dim=k.shape[1], key_dim=dim_kv, value_dim=dim_kv, radius=radius, length_dim_num_splits=1, fully_autoregressive=params["causal"], attention_kwargs={}, ) if is_incremental_inference(context): a = mtf.gather(a, context.position - 1, dim_seq) elif attention_type == "global": # TODO: pass in fake context # Broadcast mask bias across batch and heads if exists(bias): if not is_incremental_inference(context): broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]]) else: # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position bias = mtf.gather(bias, context.position - 1, dim_seq) broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]]) # memory key / values, from all-attention paper if use_num_mem_kv: k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh) k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim) v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim) attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0 a = mtf_transformer.attention.attention( q, k, v, memory_length_dim=memory_length_dim, key_dim=dim_kv, value_dim=dim_kv, bias=broadcasted_bias, dropout_rate=attn_dropout_rate ) elif attention_type == "linear": linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention a = linear_attn_fn(q, k, v) else: raise NotImplementedError("Unknown attention type {}!".format(attention_type)) with tf.variable_scope("compute_output"): a = mtfparams.compute_output(a, x_shape) with tf.variable_scope("compute_output_bias"): b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) a += b if params["mode"] == "train" and params["res_dropout"] > 0: a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout") return a
def _call_internal(self, context, inputs, targets=None): """Compute logits based on inputs (all positions in parallel). Also updates context if applicable. Args: context: a Context inputs: a Tensor targets: an optional Tensor Returns: logits: a Tensor with shape [<batch_dims>, length_dim, output_vocab_dim] """ mesh = inputs.mesh if "embedding" in context.shared_params: embedding_weights = context.shared_params["embedding"] else: embedding_weights = mtf.layers.embedding_weights( mesh, self.input_vocab_dim, self.model_dim, context.variable_dtype, name="embedding") x = mtf.gather(embedding_weights, inputs, self.input_vocab_dim) if "positional_embedding" in context.shared_params: pos_emb_var = context.shared_params["positional_embedding"] else: pos_emb_var = mtf.layers.embedding_weights( mesh, self.max_length_dim, self.model_dim, context.variable_dtype, "positional_embedding") if context.position_is_default: pos_emb = mtf.rename_dimension( mtf.slice(pos_emb_var, 0, context.length_dim.size, self.max_length_dim.name), self.max_length_dim.name, context.length_dim.name) else: pos_emb = mtf.gather( pos_emb_var, context.position, self.max_length_dim, output_shape=x.shape) x += pos_emb x = self.layer_stack.call(context, x) if self.output_vocab_dim is None: return x if self.shared_embedding_and_softmax_weights: logits = mtf.einsum( [x * (self.model_dim.size ** -0.5), embedding_weights], reduced_dims=[self.model_dim]) else: logits = mtf.layers.dense( x, self.output_vocab_dim, use_bias=False, variable_dtype=context.variable_dtype, name="logits") if targets is not None and context.losses is not None: off_value = self.label_smoothing / self.output_vocab_dim.size on_value = 1.0 - self.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.output_vocab_dim, dtype=context.activation_dtype, on_value=on_value, off_value=off_value) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.output_vocab_dim, z_loss=self.z_loss if context.train else 0.0) weights = mtf.layers.weights_nonzero( targets, dtype=context.activation_dtype) loss = mtf.reduce_mean(loss * weights) context.losses.append(loss) return logits
def Inception(img, labels, num_nodes, num_gpus, args): num_classes = 1000 graph, meshes, mesh_to_impl, mtf_img, mtf_labels = CreateMeshes( args, img, labels, num_nodes, num_gpus) strategy = args.strategy with tf.variable_scope('inception'): conv1a = BasicConv(mtf_img, (3, 3, 3, 32), stride=2, name='conv1a') conv2a = BasicConv(conv1a, (3, 3, 32, 32), name='conv2a') conv2b = BasicConv(conv2a, (3, 3, 32, 64), padding='SAME', name='conv2b') pool = MaxPool(conv2b, (3, 3), stride=2, name='pool1') conv3b = BasicConv(pool, (1, 1, 64, 80), name='conv3b') conv4a = BasicConv(conv3b, (3, 3, 80, 192), name='conv4a') pool = MaxPool(conv4a, (3, 3), stride=2, name='pool2') mixed5b = InceptionA(pool, 192, 32, name='mixed5b') mixed5c = InceptionA(mixed5b, 256, 64, name='mixed5c') mixed5d = InceptionA(mixed5c, 288, 64, name='mixed5d') mixed6a = InceptionB(mixed5d, 288, name='mixed6a') mixed6b = InceptionC(mixed6a, 768, 128, name='mixed6b') mixed6c = InceptionC(mixed6b, 768, 160, name='mixed6c') mixed6d = InceptionC(mixed6c, 768, 160, name='mixed6d') mixed6e = InceptionC(mixed6d, 768, 192, name='mixed6e') mixed7a = InceptionD(mixed6e, 768, name='mixed7a') mixed7b = InceptionE(mixed7a, 1280, strategy, meshes, name='mixed7b') mixed7c = InceptionE(mixed7b, 2048, strategy, name='mixed7c') mean = mtf.reduce_mean(mixed7c, output_shape=mtf.Shape( [mixed7c.shape[0], mixed7c.shape[-1]])) assert mean.shape[0].name == 'axis0' \ and not mean.shape[1].name.startswith('axis') if strategy == 1: shape = mean.shape shape = shape.rename_dimension(shape[0].name, mtf_labels.shape[0].name) shape = shape.rename_dimension(shape[1].name, 'axis0') with tf.variable_scope('reshape_mean'): mean = mtf.reshape(mean, shape) dim_name = 'axis1' elif strategy == 2: num_classes = utils.RoundUp(num_classes, num_gpus) mean = mt.rename_dimension(mean, 'axis0', mtf_labels.shape[0].name) dim_name = 'axis0' elif strategy == 3: num_classes = utils.RoundUp(num_classes, num_gpus) assert mean.shape[0].name == 'axis0' dim_names = mean.shape.rename_dimension('axis0', mtf_labels.shape[0].name) mean = ReplaceMeshWithConcatSplit(mean, meshes[1], dim_names) dim_name = 'axis1' else: dim_name = utils.RandName() fc = mtf.layers.dense(mean, mtf.Dimension(dim_name, num_classes), reduced_dims=mean.shape[-1:]) with tf.variable_scope('loss'): assert mtf_labels.mesh == fc.mesh assert mtf_labels.shape[0] == fc.shape[0] one_hot_labels = mtf.one_hot(mtf_labels, fc.shape[-1]) cross_ent = mtf.layers.softmax_cross_entropy_with_logits( fc, one_hot_labels, fc.shape[-1]) loss = mtf.reduce_mean(cross_ent) return graph, mesh_to_impl, loss
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.activation_type # We assume fixed vocab size for targets targets = tf.to_int32(features["targets"]) # Image preprocessing, reshape into a 1D sequence and shift right. length = hparams.img_len * hparams.img_len * hparams.num_channels targets = tf.reshape(targets, [hparams.batch_size, length]) shifted_targets = common_layers.shift_right_2d(targets) # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) def import_to_batch_by_length(x, name): return mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name) def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([batch_dim, self.model_dim])) targets = import_to_batch_by_length(targets, "targets") shifted_targets = import_to_batch_by_length(shifted_targets, "shifted_targets") extra_losses = [] # Create targets content and position embeddings. # Create embedding var for targets and positions and do a gather. targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([self.targets_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) x = mtf.gather(targets_embedding_var, shifted_targets, self.targets_vocab_dim) # Add positional embeddings x += mtf.reshape(self.create_positional_emb_2d(targets), [self.length_dim, self.model_dim]) # If conditional and input is given, add the input embedding to the target. # TODO(nikip): Verify conditional. if self.has_input and not hparams.unconditional: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = import_to_batch_by_length(inputs, "inputs") # Input embeddings inputs_embedding_var = mtf.layers.embedding( mesh, "input_embedding", mtf.Shape([self.inputs_vocab_dim, self.model_dim]), activation_dtype=activation_dtype) inputs_emb = mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) x += inputs_emb # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_decoder_layers): layer_name = "decoder_layer_%d" % layer with tf.variable_scope(layer_name): # Self attention layer x += layer_prepostprocess_dropout( mtf.layers.masked_local_attention_1d( mtf.layers.layer_norm(x, self.model_dim, name="layer_norm_att"), None, self.kv_dim, self.heads_dim, block_length=hparams.block_length, name="self_att")) # ffn layer x += layer_prepostprocess_dropout( mtf.layers.dense_relu_dense( mtf.layers.layer_norm(x, self.model_dim, name="layer_norm_ffn"), self.feedforward_dim, hparams.dropout, dropout_broadcast_dims=[self.length_dim])) x = mtf.layers.layer_norm(x, self.model_dim, name="final_layer_norm") # Calculate the logits and loss. logits = mtf.layers.dense(x, self.outputs_vocab_dim, name="logits") soft_targets = mtf.one_hot(targets, self.outputs_vocab_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.outputs_vocab_dim) loss = mtf.reduce_mean(loss) for l in extra_losses: loss += l # Reshape logits to original target shape. logits = mtf.reshape( logits, mtf.Shape([ batch_dim, self.rows_dim, self.orig_cols_dim, self.channels_dim, self.outputs_vocab_dim ])) return logits, loss
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.activation_type # We assume fixed vocab size for targets targets = tf.to_int32(features["targets"]) # Image preprocessing, reshape into a 1D sequence and shift right. length = hparams.img_len*hparams.img_len*hparams.num_channels targets = tf.reshape(targets, [hparams.batch_size, length]) shifted_targets = common_layers.shift_right_2d(targets) # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) def import_to_batch_by_length(x, name): return mtf.import_tf_tensor( mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name) targets = import_to_batch_by_length(targets, "targets") shifted_targets = import_to_batch_by_length( shifted_targets, "shifted_targets") extra_losses = [] # Create targets content and position embeddings. # Create embedding var for targets and positions and do a gather. targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([self.targets_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) x = mtf.gather(targets_embedding_var, shifted_targets, self.targets_vocab_dim) # Add positional embeddings x += mtf.reshape(self.create_positional_emb_2d(targets), [self.length_dim, self.model_dim]) # If conditional and input is given, add the input embedding to the target. # TODO(nikip): Verify conditional. if self.has_input and not hparams.unconditional: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = import_to_batch_by_length(inputs, "inputs") # Input embeddings inputs_embedding_var = mtf.layers.embedding( mesh, "input_embedding", mtf.Shape([self.inputs_vocab_dim, self.model_dim]), activation_dtype=activation_dtype) inputs_emb = mtf.gather( inputs_embedding_var, inputs, self.inputs_vocab_dim) x += inputs_emb # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n if hparams.attention_type == "local1d_spatial": decoder_output = local_attention1d_spatial_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) elif hparams.attention_type == "local2d_spatial": decoder_output = local_attention2d_spatial_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) elif hparams.attention_type == "local1d": decoder_output = local_attention1d_masked_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) else: raise ValueError("Invalid attention type.") # Calculate the logits and loss. logits = mtf.layers.dense( decoder_output, self.outputs_vocab_dim, name="logits") # Need a reshape for logits logits = mtf.reshape( logits, mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim])) soft_targets = mtf.one_hot( targets, self.outputs_vocab_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.outputs_vocab_dim) loss = mtf.reduce_mean(loss) for l in extra_losses: loss += l # Reshape logits to original target shape. logits = mtf.reshape( logits, mtf.Shape([batch_dim, self.rows_dim, self.orig_cols_dim, self.channels_dim, self.outputs_vocab_dim])) return logits, loss
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.activation_type # We assume fixed vocab size for targets targets = tf.to_int32(features["targets"]) # Image preprocessing, reshape into a 1D sequence and shift right. length = hparams.img_len * hparams.img_len * hparams.num_channels targets = tf.reshape(targets, [hparams.batch_size, length]) shifted_targets = common_layers.shift_right_2d(targets) # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) def import_to_batch_by_length(x, name): return mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name) targets = import_to_batch_by_length(targets, "targets") shifted_targets = import_to_batch_by_length(shifted_targets, "shifted_targets") extra_losses = [] # Create targets content and position embeddings. # Create embedding var for targets and positions and do a gather. targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([self.targets_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) x = mtf.gather(targets_embedding_var, shifted_targets, self.targets_vocab_dim) # Add positional embeddings x += mtf.reshape(self.create_positional_emb_2d(targets), [self.length_dim, self.model_dim]) # If conditional and input is given, add the input embedding to the target. # TODO(nikip): Verify conditional. if self.has_input and not hparams.unconditional: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = import_to_batch_by_length(inputs, "inputs") # Input embeddings inputs_embedding_var = mtf.layers.embedding( mesh, "input_embedding", mtf.Shape([self.inputs_vocab_dim, self.model_dim]), activation_dtype=activation_dtype) inputs_emb = mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) x += inputs_emb # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n if hparams.attention_type == "local1d_spatial": decoder_output = local_attention1d_spatial_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) elif hparams.attention_type == "local2d_spatial": decoder_output = local_attention2d_spatial_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) elif hparams.attention_type == "local1d": decoder_output = local_attention1d_masked_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) else: raise ValueError("Invalid attention type.") # Calculate the logits and loss. logits = mtf.layers.dense(decoder_output, self.outputs_vocab_dim, name="logits") # Need a reshape for logits logits = mtf.reshape( logits, mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim])) soft_targets = mtf.one_hot(targets, self.outputs_vocab_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.outputs_vocab_dim) loss = mtf.reduce_mean(loss) for l in extra_losses: loss += l # Reshape logits to original target shape. logits = mtf.reshape( logits, mtf.Shape([ batch_dim, self.rows_dim, self.orig_cols_dim, self.channels_dim, self.outputs_vocab_dim ])) return logits, loss
def _mtf_model_fn(self, features, mesh): self._original_features = features features = copy.copy(features) hparams = self._hparams extra_losses = [] targets = tf.to_int32(features["targets"]) if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) for key in ["targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position"]: if key in features: features[key] = pad_to_max_length(features[key]) if hparams.decoder_type == "autoregressive": shifted_targets = mtf.shift( targets, offset=1, dim=self.length_dim, wrap=False) elif hparams.decoder_type == "denoising": shifted_targets = self._noisy_targets(targets, extra_losses) else: raise ValueError( "unknown hparams.decoder_type = %s" % hparams.decoder_type) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = mtf.layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype) if hparams.decoder_type == "autoregressive": decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) if hparams.decoder_type == "autoregressive": decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) else: decoder_self_attention_mask = None def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "decoder": encoder_output = None encoder_decoder_attention_mask = None else: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf.layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) if hparams.transformer_type == "encdec": if "inputs_segmentation" in features: encoder_decoder_attention_mask = ( mtf.layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: encoder_decoder_attention_mask = encoder_self_attention_mask encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) if hparams.transformer_type != "encoder": # DECODER x = (mtf.gather( targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather( positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) if (hparams.reshape_logits_hack and hparams.mode == tf.estimator.ModeKeys.TRAIN): # For some reason, the logits computation is extremely slow on TPU # in some cases where the batch size per core is 1. Reshape the logits # and the targets to double the batch size and halve the length. # TODO(noam): file a bug. old_dims = self.batch_dims + [self.length_dim] new_dims = self.batch_dims[:-1] + [ mtf.Dimension(self.batch_dims[-1].name, self.batch_dims[-1].size * 2), mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)] x = mtf.reshape(x, new_dims + [self.model_dim]) targets = mtf.reshape(targets, new_dims) logits = mtf.matmul(x, softmax_var) if hparams.mode == tf.estimator.ModeKeys.TRAIN: logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l if (hparams.reshape_logits_hack and hparams.mode == tf.estimator.ModeKeys.TRAIN): logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim]) logits = mtf.to_float(logits) return logits, loss
def call(self, context, x, losses=None): """Call the layer.""" params = self.make_params(context) if self.share_qk_rep: q, k = params.mdha_shared_qk(x, context) else: q = params.mdha_q(x, context) memory_length = self.memory_length(context) if context.mode == "incremental": m = x else: if self.share_qk_rep: k = mtf.replace_dimensions(k, context.length_dim, memory_length) m = mtf.replace_dimensions(x, context.length_dim, memory_length) if self.shared_kv: kv = params.compute_kv(m) else: if not self.share_qk_rep: k = params.mdha_k(m, context) v = params.mdha_v(m, context) if context.mode == "incremental": one_hot = mtf.one_hot(context.position, memory_length, dtype=context.activation_dtype) inv_one_hot = 1.0 - one_hot if self.shared_kv: old_kv = context.get_states(1) kv = old_kv * inv_one_hot + kv * one_hot else: old_k, old_v = context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot memory_position = mtf.range(context.mesh, memory_length, tf.int32) else: memory_position = self.rename_length_to_memory_length( context.position, context) if context.mode == "incremental" or context.mode == "first_part": context.record_new_states([kv] if self.shared_kv else [k, v]) if self.shared_kv: k = kv v = kv o = self.attention_fn(q, k, v, context=context, memory_length_dim=memory_length, key_dim=self.kv_dim, value_dim=self.kv_dim, bias=self.compute_bias(context, memory_position, x, params.query_heads_dims, q), **self.attention_kwargs_from_context(context)) attention_output_shape = self.expected_attention_output_shape( x, params) attention_output = params.compute_output( o, output_shape=attention_output_shape) return self.layer_output_from_attention_output(context, attention_output, losses)
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) hidden_dim = mtf.Dimension("hidden", hparams.hidden_size) filter_h_dim = mtf.Dimension("filter_height", 7) filter_w_dim = mtf.Dimension("filter_width", 7) filters = mtf.Dimension("filters", hparams.filter_sizes[0]) rows_dim = mtf.Dimension("rows_size", hparams.rows_size) cols_dim = mtf.Dimension("cols_size", hparams.cols_size) row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks) col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks) classes_dim = mtf.Dimension("classes", 10) channels_dim = mtf.Dimension("channels", 3) one_channel_dim = mtf.Dimension("one_channel", 1) inputs = features["inputs"] x = mtf.import_tf_tensor( mesh, tf.reshape(inputs, [ hparams.batch_size, hparams.row_blocks, hparams.rows_size // hparams.row_blocks, hparams.col_blocks, hparams.num_channels*hparams.cols_size // hparams.col_blocks, hparams.num_channels]), mtf.Shape( [batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, channels_dim])) x = mtf.transpose(x, [batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, channels_dim]) x = mtf.to_float(x) initial_filters = mtf.get_variable( mesh, "init_filters", mtf.Shape([filter_h_dim, filter_w_dim, channels_dim, filters])) x = mtf.conv2d_with_blocks( x, initial_filters, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) x = batch_norm_relu(x, is_training) # Conv blocks # [block - strided block layer - strided block layer] x n for layer in range(hparams.num_layers): layer_name = "block_layer_%d" % layer with tf.variable_scope(layer_name): # Residual block layer x = block_layer( inputs=x, filters=hparams.filter_sizes[0], blocks=hparams.layer_sizes[0], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer1", row_blocks_dim=None, col_blocks_dim=None) x = block_layer( inputs=x, filters=hparams.filter_sizes[1], blocks=hparams.layer_sizes[1], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer2", row_blocks_dim=None, col_blocks_dim=None) x = block_layer( inputs=x, filters=hparams.filter_sizes[2], blocks=hparams.layer_sizes[2], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer3", row_blocks_dim=None, col_blocks_dim=None) # Calculate the logits and loss. out = x outputs = mtf.layers.dense( out, hidden_dim, reduced_dims=out.shape.dims[-5:], activation=mtf.relu, name="dense") # We assume fixed vocab size for targets labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3]) labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim])) logits = mtf.layers.dense(outputs, classes_dim, name="logits") soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, classes_dim) # Reshape logits so it doesn't break inside t2t. logits = mtf.reshape( logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim])) loss = mtf.reduce_mean(loss) return logits, loss
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows_size", 28) cols_dim = mtf.Dimension("cols_size", 28) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 28, 28, 1]), mtf.Shape([batch_dim, rows_dim, cols_dim, one_channel_dim])) fh_dim = mtf.Dimension("fh", 3) fw_dim = mtf.Dimension("fw", 3) filters1_dim = mtf.Dimension("filters1", FLAGS.num_filters) filters2_dim = mtf.Dimension("filters2", FLAGS.num_filters) filters3_dim = mtf.Dimension("filters3", FLAGS.num_filters) filters4_dim = mtf.Dimension("filters4", FLAGS.num_filters) filters5_dim = mtf.Dimension("filters5", FLAGS.num_filters) filters6_dim = mtf.Dimension("filters6", FLAGS.num_filters) kernel1 = mtf.get_variable(mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable(mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) kernel3 = mtf.get_variable(mesh, "kernel3", [fh_dim, fw_dim, filters2_dim, filters3_dim]) kernel4 = mtf.get_variable(mesh, "kernel4", [fh_dim, fw_dim, filters3_dim, filters4_dim]) kernel5 = mtf.get_variable(mesh, "kernel5", [fh_dim, fw_dim, filters4_dim, filters5_dim]) kernel6 = mtf.get_variable(mesh, "kernel6", [fh_dim, fw_dim, filters5_dim, filters6_dim]) x = mtf.relu(mtf.conv2d(x, kernel1, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel2, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel3, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel4, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel5, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel6, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.reduce_mean(x, reduced_dim=filters6_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) logits = mtf.Dimension("logits", 10) h1 = mtf.layers.dense(x, hidden_dim1, reduced_dims=x.shape.dims[-2:], activation=mtf.relu, name="hidden1") h2 = mtf.layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2") logits = mtf.layers.dense(h2, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def _mtf_model_fn(self, features, mesh): self._original_features = features features = copy.copy(features) hparams = self._hparams extra_losses = [] targets = tf.to_int32(features["targets"]) mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN) is_training = mode == tf.estimator.ModeKeys.TRAIN if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) for key in ["targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position"]: if key in features: features[key] = pad_to_max_length(features[key]) if hparams.decoder_type == "autoregressive": shifted_targets = mtf.shift( targets, offset=1, dim=self.length_dim, wrap=False) elif hparams.decoder_type == "denoising": shifted_targets = self._noisy_targets(targets, extra_losses) else: raise ValueError( "unknown hparams.decoder_type = %s" % hparams.decoder_type) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = mtf.layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype) if hparams.decoder_type == "autoregressive": decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) if hparams.decoder_type == "autoregressive": decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) else: decoder_self_attention_mask = None def layer_prepostprocess_dropout(x): return mtf.dropout( x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "decoder": encoder_output = None encoder_decoder_attention_mask = None else: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf.layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) if hparams.transformer_type == "encdec": if "inputs_segmentation" in features: encoder_decoder_attention_mask = ( mtf.layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: encoder_decoder_attention_mask = encoder_self_attention_mask encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) if hparams.transformer_type != "encoder": # DECODER x = (mtf.gather( targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather( positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) if (hparams.reshape_logits_hack and hparams.mode == tf.estimator.ModeKeys.TRAIN): # For some reason, the logits computation is extremely slow on TPU # in some cases where the batch size per core is 1. Reshape the logits # and the targets to double the batch size and halve the length. # TODO(noam): file a bug. old_dims = self.batch_dims + [self.length_dim] new_dims = self.batch_dims[:-1] + [ mtf.Dimension(self.batch_dims[-1].name, self.batch_dims[-1].size * 2), mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)] x = mtf.reshape(x, new_dims + [self.model_dim]) targets = mtf.reshape(targets, new_dims) logits = mtf.matmul(x, softmax_var) if hparams.mode == tf.estimator.ModeKeys.TRAIN: logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l if (hparams.reshape_logits_hack and hparams.mode == tf.estimator.ModeKeys.TRAIN): logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim]) logits = mtf.to_float(logits) return logits, loss
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) hidden_dim = mtf.Dimension("hidden", hparams.hidden_size) filter_h_dim = mtf.Dimension("filter_height", 7) filter_w_dim = mtf.Dimension("filter_width", 7) filters = mtf.Dimension("filters", hparams.filter_sizes[0]) rows_dim = mtf.Dimension("rows_size", 32) cols_dim = mtf.Dimension("cols_size", 96) row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks) col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) inputs = features["inputs"] x = mtf.import_tf_tensor( mesh, tf.reshape(inputs, [ hparams.batch_size, hparams.row_blocks, hparams.rows_size // hparams.row_blocks, hparams.col_blocks, hparams.num_channels * hparams.cols_size // hparams.col_blocks, 1 ]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) x = mtf.to_float(x) initial_filters = mtf.get_variable( mesh, "init_filters", mtf.Shape([filter_h_dim, filter_w_dim, one_channel_dim, filters])) x = mtf.conv2d_with_blocks(x, initial_filters, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) x = batch_norm_relu(x, is_training) # Conv blocks # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_layers): layer_name = "block_layer_%d" % layer with tf.variable_scope(layer_name): # Residual block layer x = block_layer(inputs=x, filters=hparams.filter_sizes[0], blocks=hparams.layer_sizes[0], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer1", row_blocks_dim=None, col_blocks_dim=None) x = block_layer(inputs=x, filters=hparams.filter_sizes[1], blocks=hparams.layer_sizes[1], strides=[1, 2, 2, 1], is_training=is_training, name="block_layer2", row_blocks_dim=None, col_blocks_dim=None) x = block_layer(inputs=x, filters=hparams.filter_sizes[2], blocks=hparams.layer_sizes[2], strides=[1, 2, 2, 1], is_training=is_training, name="block_layer3", row_blocks_dim=None, col_blocks_dim=None) # Calculate the logits and loss. out = x outputs = mtf.layers.dense(out, hidden_dim, reduced_dims=out.shape.dims[-5:], activation=mtf.relu, name="dense") # We assume fixed vocab size for targets labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3]) labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim])) logits = mtf.layers.dense(outputs, classes_dim, name="logits") soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, classes_dim) # Reshape logits so it doesn't break inside t2t. logits = mtf.reshape( logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim])) loss = mtf.reduce_mean(loss) return logits, loss
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 4) col_blocks_dim = mtf.Dimension("col_blocks", 4) rows_dim = mtf.Dimension("rows_size", 7) cols_dim = mtf.Dimension("cols_size", 7) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]), mtf.Shape( [batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim]) tf.logging.info("[intra variable] (name, shape): ({},{})".format(x.name,x.shape)) # add some convolutional layers to demonstrate that convolution works. filters1_dim = mtf.Dimension("filters1", 16) filters2_dim = mtf.Dimension("filters2", 16) f1 = mtf.relu(mtf.layers.conv2d_with_blocks( x, filters1_dim, filter_size=[9, 9], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv0")) tf.logging.info("[intra variable] (name, shape): ({},{})".format(f1.name,f1.shape)) f2 = mtf.relu(mtf.layers.conv2d_with_blocks( f1, filters2_dim, filter_size=[9, 9], strides=[1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv1")) tf.logging.info("[intra variable] (name, shape): ({},{})".format(f2.name,f2.shape)) x = mtf.reduce_mean(f2, reduced_dim=filters2_dim) tf.logging.info("[intra variable] (name, shape): ({},{})".format(x.name,x.shape)) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf.layers.dense( x, hidden_dim1, reduced_dims=x.shape.dims[-4:], activation=mtf.relu, name="hidden1") tf.logging.info("[intra variable] (name, shape): ({},{})".format(h1.name,h1.shape)) h2 = mtf.layers.dense( h1, hidden_dim2, activation=mtf.relu, name="hidden2") tf.logging.info("[intra variable] (name, shape): ({},{})".format(h2.name,h2.shape)) logits = mtf.layers.dense(h2, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss