Ejemplo n.º 1
0
def dense_relu_dense(x,
                     hidden_channels,
                     dropout=0.0,
                     dropout_broadcast_dims=None,
                     name=None):
  """Hidden layer with ReLU activation followed by linear projection.

  The output has the same number of channels as the input.

  Args:
    x: a mtf.Tensor
    hidden_channels: a mtf.Dimension - channels in the hidden layer
    dropout: an optional float
    dropout_broadcast_dims: an optional list of mtf.Dimension
    name: an optional string

  Returns:
    a mtf.Tensor with the same shape as x.
  """
  with tf.variable_scope(name, default_name="dense_relu_dense"):
    io_channels = x.shape.dims[-1]
    stddev = (hidden_channels.size * io_channels.size) ** -0.25
    io = mtf.Dimension("io", 2)
    w = mtf.get_variable(
        x.mesh,
        "kernel",
        mtf.Shape([io, io_channels, hidden_channels]),
        initializer=tf.random_normal_initializer(stddev=stddev),
        activation_dtype=x.dtype)
    wi, wo = mtf.unstack(w, io)
    h = mtf.relu(mtf.einsum([x, wi]))
    if dropout != 0.0:
      h = mtf.dropout(h, 1.0 - dropout,
                      noise_shape=h.shape - dropout_broadcast_dims)
    return mtf.einsum([h, wo])
Ejemplo n.º 2
0
def dot_product_attention(q,
                          k,
                          v,
                          mask,
                          dropout=0.0,
                          dropout_broadcast_dims=None):
    """Dot-product attention.

  Args:
    q: Tensor with shape [...., length_q, depth_k]. Typically leading dimensions
      are [batch, heads].
    k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
      match with q.
    v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
      match with q.
    mask: mask Tensor (see attention_mask())
    dropout: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension

  Returns:
    Tensor with shape [..., length_q, depth_v].
  """
    length_kv = k.shape.dims[-2]
    logits_shape = mtf.TensorShape(q.shape.dims[:-1] + [length_kv])
    logits = mtf.einsum([q, k], logits_shape)
    if mask is not None:
        logits += mask
    weights = mtf.softmax(logits, length_kv)
    if dropout != 0.0:
        weights = mtf.dropout(weights,
                              1.0 - dropout,
                              noise_shape=weights.shape -
                              dropout_broadcast_dims)
    depth_v = v.shape.dims[-1]
    outputs_shape = mtf.TensorShape(q.shape.dims[:-1] + [depth_v])
    outputs = mtf.einsum([weights, v], outputs_shape)
    return outputs
Ejemplo n.º 3
0
 def layer_prepostprocess_dropout(x):
   return mtf.dropout(
       x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
       noise_shape=mtf.Shape([self.batch_dim, self.model_dim]))