def dense(x, output_dim, reduced_dims=None, expert_dims=None, use_bias=True, activation=None, name=None): """Dense layer doing (kernel*x + bias) computation. Args: x: a mtf.Tensor of shape [..., reduced_dims]. output_dim: a mtf.Dimension reduced_dims: an optional list of mtf.Dimensions of x to be reduced. If omitted, we reduce the last dimension. expert_dims: an optional list of mtf.Dimension which represent different experts. Different experts get different weights. use_bias: a boolean, whether to add bias. activation: an optional function from mtf.Tensor to mtf.Tensor name: a string. variable scope. Returns: a mtf.Tensor of shape [..., output_dim]. """ if expert_dims is None: expert_dims = [] if reduced_dims is None: reduced_dims = x.shape.dims[-1:] w_shape = mtf.TensorShape(expert_dims + reduced_dims + [output_dim]) output_shape = mtf.TensorShape( [d for d in x.shape.dims if d not in reduced_dims] + [output_dim]) with tf.variable_scope(name, default_name="dense"): stddev = mtf.list_product(d.size for d in reduced_dims)**-0.5 w = mtf.get_variable( x.mesh, "kernel", w_shape, initializer=tf.random_normal_initializer(stddev=stddev), activation_dtype=x.dtype) y = mtf.matmul(x, w, output_shape=output_shape) if use_bias: b = mtf.get_variable(x.mesh, "bias", mtf.TensorShape(expert_dims + [output_dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) y += b if activation is not None: y = activation(y) return y
def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" self_attention_k = states[:hparams.num_decoder_layers] self_attention_v = states[hparams.num_decoder_layers:] ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_self_attention_k, new_self_attention_v = ( self._decoder_layer_stack_incremental( x, step_num, encdec_tensors, self_attention_k, self_attention_v, encdec_attention_mask=encoder_attention_mask)) logits = mtf.matmul(x, softmax_var) return logits, new_self_attention_k + new_self_attention_v
def _mtf_model_fn(self, features, mesh): features = copy.copy(features) hparams = self._hparams targets = tf.to_int32(features["targets"]) if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) for key in ["targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position"]: if key in features: features[key] = pad_to_max_length(features[key]) shifted_targets = common_layers.shift_right_2d(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) shifted_targets = self._import_to_batch_by_length( shifted_targets, "shifted_targets", mesh, hparams) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = ( mtf_layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) + mtf_layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype)) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) decoder_self_attention_mask = mtf_layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([self.batch_dim, self.model_dim])) extra_losses = [] (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if self.has_input: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf_layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) encoder_decoder_attention_mask = ( mtf_layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf_layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) encoder_decoder_attention_mask = encoder_self_attention_mask x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.num_encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) else: encoder_output = None encoder_decoder_attention_mask = None # DECODER x = (mtf.gather( targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather( positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) # Decoder with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.num_decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) logits = mtf.matmul(x, softmax_var) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf_layers.weights_nonzero( targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l return logits, loss