def dense_relu_dense(x, hidden_channels, dropout=0.0, dropout_broadcast_dims=None, name=None): """Hidden layer with ReLU activation followed by linear projection. The output has the same number of channels as the input. Args: x: a mtf.Tensor hidden_channels: a mtf.Dimension - channels in the hidden layer dropout: an optional float dropout_broadcast_dims: an optional list of mtf.Dimension name: an optional string Returns: a mtf.Tensor with the same shape as x. """ with tf.variable_scope(name, default_name="dense_relu_dense"): io_channels = x.shape.dims[-1] stddev = (hidden_channels.size * io_channels.size) ** -0.25 io = mtf.Dimension("io", 2) w = mtf.get_variable( x.mesh, "kernel", mtf.Shape([io, io_channels, hidden_channels]), initializer=tf.random_normal_initializer(stddev=stddev), activation_dtype=x.dtype) wi, wo = mtf.unstack(w, io) h = mtf.relu(mtf.einsum([x, wi])) if dropout != 0.0: h = mtf.dropout(h, 1.0 - dropout, noise_shape=h.shape - dropout_broadcast_dims) return mtf.einsum([h, wo])
def dot_product_attention(q, k, v, mask, dropout=0.0, dropout_broadcast_dims=None): """Dot-product attention. Args: q: Tensor with shape [...., length_q, depth_k]. Typically leading dimensions are [batch, heads]. k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must match with q. v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must match with q. mask: mask Tensor (see attention_mask()) dropout: a float. dropout_broadcast_dims: an optional list of mtf.Dimension Returns: Tensor with shape [..., length_q, depth_v]. """ length_kv = k.shape.dims[-2] logits_shape = mtf.TensorShape(q.shape.dims[:-1] + [length_kv]) logits = mtf.einsum([q, k], logits_shape) if mask is not None: logits += mask weights = mtf.softmax(logits, length_kv) if dropout != 0.0: weights = mtf.dropout(weights, 1.0 - dropout, noise_shape=weights.shape - dropout_broadcast_dims) depth_v = v.shape.dims[-1] outputs_shape = mtf.TensorShape(q.shape.dims[:-1] + [depth_v]) outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([self.batch_dim, self.model_dim]))