Exemple #1
0
 def attention_internal(self, context, x, m, q, k, v, memory_length, bias):
     p = mtf.einsum([q, k], reduced_dims=[self.key_dim])
     logits = self.talking_heads(
         context,
         p,
         "logits",
         self.key_heads_dims,
         self.softmax_heads_dims,
         dynamic_projections_from=(
             ([x] if "x2l" in self.dynamic_projections else []) +
             ([m] if "m2l" in self.dynamic_projections else [])))
     if bias is not None:
         logits += bias
     h = mtf.softmax(logits, memory_length)
     weights = self.talking_heads(
         context,
         h,
         "weights",
         self.softmax_heads_dims,
         self.value_heads_dims,
         dynamic_projections_from=(
             ([x] if "x2w" in self.dynamic_projections else []) +
             ([m] if "m2w" in self.dynamic_projections else [])))
     # TODO(noam): make dropout_broadcast_dims configurable
     dropout_broadcast_dims = [context.length_dim]
     weights = mtf.dropout(weights,
                           rate=self.dropout_rate if context.train else 0.0,
                           noise_shape=weights.shape -
                           dropout_broadcast_dims)
     u = mtf.einsum([weights, v], reduced_dims=[memory_length])
     return self.compute_y(context, u)
Exemple #2
0
    def testOptimizeLayoutRepetition(self):
        x1 = mtf.zeros(self.mesh, "a:10,b:5")
        x2 = mtf.zeros(self.mesh, "b:5,c:20")
        for _ in six.moves.xrange(100):
            mtf.einsum([x1, x2], "a:10,c:20")
        optimizer = self.get_layout_optimizer()

        self.assertGreaterEqual(
            len(list(optimizer._graph.get_all_operation_names())), 50)
        self.assertLessEqual(len(optimizer._model.Proto().variables), 50)

        # Same checks.
        layout = optimizer.solve()
        self.assertEqual(layout, "a:m2;c:m1")
        layout_value = optimizer.evaluate_layout(layout)
        self.assertLessEqual(layout_value,
                             optimizer.evaluate_layout("a:m1;b:m2"))
        self.assertLessEqual(layout_value,
                             optimizer.evaluate_layout("a:m1;c:m2"))
        self.assertLessEqual(layout_value,
                             optimizer.evaluate_layout("b:m1;a:m2"))
        self.assertLessEqual(layout_value,
                             optimizer.evaluate_layout("b:m1;c:m2"))
        self.assertLessEqual(layout_value,
                             optimizer.evaluate_layout("c:m1;b:m2"))
        self.assertEqual(layout_value, optimizer.evaluate_layout("c:m1;a:m2"))
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """

    # tf_images is a tf.Tensor with shape [batch, 28, 28] and dtype tf.float32
    # tf_labels is a tf.Tensor with shape [batch] and dtype tf.int32
    batch_dim = mtf.Dimension("batch", 100)
    rows_dim = mtf.Dimension("rows", 28)
    cols_dim = mtf.Dimension("cols", 28)
    hidden_dim = mtf.Dimension("hidden", 1024)
    classes_dim = mtf.Dimension("classes", 10)
    images = mtf.import_tf_tensor(mesh,
                                  image,
                                  shape=[batch_dim, rows_dim, cols_dim])
    labels = mtf.import_tf_tensor(mesh, labels, [batch_dim])
    w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, hidden_dim])
    w2 = mtf.get_variable(mesh, "w2", [hidden_dim, classes_dim])
    # einsum is a generalization of matrix multiplication (see numpy.einsum)
    hidden = mtf.relu(
        mtf.einsum(images, w1, output_shape=[batch_dim, hidden_dim]))
    logits = mtf.einsum(hidden, w2, output_shape=[batch_dim, classes_dim])
    loss = mtf.reduce_mean(
        mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim))

    return logits, loss
Exemple #4
0
 def testOptimizeLayoutTiebreak(self):
     x1 = mtf.zeros(self.mesh, "a:10,b:5")
     x2 = mtf.zeros(self.mesh, "b:5,c:20")
     mtf.einsum([x1, x2], "a:10,c:20")
     # Rewrite mesh_shape to have a dummy dimension.
     self.mesh_shape = mtf.convert_to_shape("m1:4,m2:2,m3:1")
     optimizer = self.get_layout_optimizer()
     layout = optimizer.solve()
     self.assertEqual(layout, "a:m2;b:m3;c:m1")
 def hidden_to_logits(self, hidden):
     hidden *= self._output_dim.size**-0.5
     if self._is_factorized:
         tmp = mtf.einsum([hidden, self._factor2],
                          reduced_dims=[self._output_dim])
         return mtf.einsum([tmp, self._factor1],
                           reduced_dims=[self._inner_dim])
     else:
         return mtf.einsum([hidden, self._embedding_weights],
                           reduced_dims=[self._output_dim])
Exemple #6
0
 def attention_internal(self, context, q, m, memory_length, bias):
   logits = mtf.einsum([q, m], reduced_dims=[context.model.model_dim])
   if bias is not None:
     logits += bias
   weights = mtf.softmax(logits, memory_length)
   # TODO(noam): make dropout_broadcast_dims configurable
   dropout_broadcast_dims = [context.length_dim]
   weights = mtf.dropout(
       weights, rate=self.dropout_rate if context.train else 0.0,
       noise_shape=weights.shape - dropout_broadcast_dims)
   u = mtf.einsum([weights, m], reduced_dims=[memory_length])
   return self.compute_y(context, u)
Exemple #7
0
def mnist_model(image, labels, mesh, hs_t):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh
    hs_t: a mtf.Tensor with shape [batch, hidden_1]
  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
    hs_t: an updated mtf.Tensor
  """
    input_num = 28
    timesteps_num = 28
    classes_num = 10

    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    input_dim = mtf.Dimension("input", input_num)
    timesteps_dim = mtf.Dimension("timesteps", timesteps_num)
    classes_dim = mtf.Dimension("classes", classes_num)
    hidden_dim_1 = mtf.Dimension("hidden_1", FLAGS.hidden_size)
    hidden_dim_2 = mtf.Dimension("hidden_2", FLAGS.hidden_size)

    x = mtf.import_tf_tensor(mesh, tf.reshape(image,
                                              [FLAGS.batch_size, 28, 28]),
                             [batch_dim, timesteps_dim, input_dim])
    y = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]),
                             [batch_dim])
    hs_t = mtf.import_tf_tensor(mesh, hs_t, [batch_dim, hidden_dim_1])

    Wxh = mtf.get_variable(mesh, "Wxh", [input_dim, hidden_dim_2])
    Whh = mtf.get_variable(mesh, "Whh", [hidden_dim_1, hidden_dim_2])
    Why = mtf.get_variable(mesh, "Why", [hidden_dim_2, classes_dim])
    bh = mtf.get_variable(mesh, "bh", [hidden_dim_2])
    by = mtf.get_variable(mesh, "by", [classes_dim])

    x_list = mtf.unstack(x, timesteps_dim)

    for xs_t in x_list:
        hs_t = mtf.tanh(
            mtf.einsum([xs_t, Wxh], [batch_dim, hidden_dim_2]) +
            mtf.einsum([hs_t, Whh], [batch_dim, hidden_dim_2]) + bh)
        logits = mtf.einsum([hs_t, Why], [batch_dim, classes_dim]) + by

    if labels is None:
        loss = None
    else:
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(y, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss, hs_t
Exemple #8
0
def linear_attention(q, k, v):
    batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3])
    q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in")
    k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in")

    dim_in = k.shape[-1]

    q = mtf.softmax(q, dim_in)
    k = mtf.softmax(k, seq_dim)

    context = mtf.einsum([k, v], output_shape=[batch_dim, head_dim, dim_in, dim_out])
    attn = mtf.einsum([q, context], output_shape=[batch_dim, seq_dim, head_dim, dim_out])
    return attn
Exemple #9
0
    def call(self, context, x, losses=None):
        """Call the layer."""
        wq, wk, wv, wo = mtf.layers.multihead_attention_params(
            context.mesh, self.heads_dim, context.model_dim, self.kv_dim,
            context.variable_dtype)
        memory_length = mtf.Dimension("memory_length", context.length_dim.size)
        q = mtf.einsum([x, wq], reduced_dims=[context.model_dim])
        if context.mode == "incremental":
            m = x
        else:
            m = mtf.rename_dimension(x, context.length_dim.name,
                                     "memory_length")
        k = mtf.einsum([m, wk], reduced_dims=[context.model_dim])
        v = mtf.einsum([m, wv], reduced_dims=[context.model_dim])
        if context.mode == "incremental":
            old_k, old_v = context.get_states(2)
            one_hot = mtf.one_hot(context.position,
                                  memory_length,
                                  dtype=context.activation_dtype)
            inv_one_hot = 1.0 - one_hot
            k = old_k * inv_one_hot + k * one_hot
            v = old_v * inv_one_hot + v * one_hot
        if context.mode == "incremental" or context.mode == "first_part":
            context.record_new_states([k, v])
        masks = []
        if context.autoregressive:
            masks.append(
                mtf.cast(
                    mtf.less(
                        context.position,
                        mtf.range(context.mesh, memory_length,
                                  dtype=tf.int32)), context.activation_dtype) *
                -1e9)
        if (context.sequence_id is not None
                and isinstance(context.sequence_id, mtf.Tensor)
                and context.length_dim in context.sequence_id.shape):
            masks.append(
                mtf.cast(
                    mtf.not_equal(
                        context.sequence_id,
                        mtf.layers.rename_length_to_memory_length(
                            context.sequence_id)), context.activation_dtype) *
                -1e9)
        mask = mtf.add_n(masks) if masks else None

        o = mtf.layers.dot_product_attention_v2(
            q, k, v, memory_length, self.kv_dim, self.kv_dim, mask,
            self.dropout_rate if context.train else 0.0, [context.length_dim])
        return mtf.einsum([o, wo],
                          x.shape,
                          reduced_dims=[self.heads_dim, self.kv_dim])
Exemple #10
0
def attention(q,
              k,
              v,
              memory_length_dim,
              key_dim,
              value_dim,
              bias=None,
              dropout_rate=0.0,
              dropout_broadcast_dims=None,
              extra_logit=None):
    """Dot-product attention - doesn't use positional dimensions.

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """
    logits = mtf.einsum([q, k], reduced_dims=[key_dim])
    if bias is not None:
        logits += bias
    weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit)
    if dropout_rate != 0.0:
        weights = mtf.dropout(weights,
                              1.0 - dropout_rate,
                              noise_shape=weights.shape -
                              dropout_broadcast_dims)
    outputs_shape = q.shape - key_dim + value_dim
    outputs = mtf.einsum([weights, v], outputs_shape)
    return outputs
Exemple #11
0
    def testLayout(self):
        # Construct a Mesh TensorFlow graph and mesh.
        mtf_graph = mtf.Graph()
        mesh = mtf.Mesh(mtf_graph, "my_mesh")
        x = mtf.zeros(mesh, "a:10,b:5")
        y = mtf.zeros(mesh, "b:5,c:20")
        z = mtf.einsum([x, y], "a:10,c:20")

        # Decide on a mesh shape.
        mesh_shape = mtf.convert_to_shape("m1:4,m2:2")

        # Compute a layout based on the graph and mesh.
        # Note that knowing the identity of the outputs is important to the
        # optimization since they cannot be freed.
        layout = mtf.auto_mtf.layout(mtf_graph, mesh_shape, [z])

        a_dim = mtf.convert_to_dimension(("a", 10))
        b_dim = mtf.convert_to_dimension(("b", 5))
        c_dim = mtf.convert_to_dimension(("c", 20))

        self.assertEqual(
            layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1)
        self.assertIsNone(
            layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape))
        self.assertEqual(
            layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0)
 def ids_to_embedding(self, ids):
     if self._is_factorized:
         tmp = mtf.gather(self._factor1, ids, self._vocab_dim)
         return mtf.einsum([tmp, self._factor2],
                           reduced_dims=[self._inner_dim])
     else:
         return mtf.gather(self._embedding_weights, ids, self._vocab_dim)
  def testLayoutAndMeshShape(self):
    # Same as previous test, but don't specify a 4x2 mesh.
    mtf_graph = mtf.Graph()
    mesh = mtf.Mesh(mtf_graph, "my_mesh")
    x = mtf.zeros(mesh, "a:10,b:5")
    y = mtf.zeros(mesh, "b:5,c:20")
    z = mtf.einsum([x, y], "a:10,c:20")

    layout, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(mtf_graph, 8, [z])

    a_dim = mtf.convert_to_dimension(("a", 10))
    b_dim = mtf.convert_to_dimension(("b", 5))
    c_dim = mtf.convert_to_dimension(("c", 20))

    self.assertEqual(layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1)
    self.assertIsNone(layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape))
    self.assertEqual(layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0)

    self.assertCountEqual(mesh_shape.dims,
                          [mtf.Dimension("mesh_0", 4),
                           mtf.Dimension("mesh_1", 2)])

    layout, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(
        mtf_graph, 8, [z], 1)

    self.assertIsNone(layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape))
    self.assertIsNone(layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape))
    self.assertIsNone(layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape))

    self.assertCountEqual(mesh_shape.dims, [mtf.Dimension("mesh_0", 8)])
Exemple #14
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a tf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    rows_dim = mtf.Dimension("rows", 28)
    cols_dim = mtf.Dimension("cols", 28)
    classes_dim = mtf.Dimension("classes", 10)

    x = mtf.import_tf_tensor(mesh, tf.reshape(image,
                                              [FLAGS.batch_size, 28, 28]),
                             [batch_dim, rows_dim, cols_dim])
    y = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]),
                             [batch_dim])

    w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, classes_dim])
    b1 = mtf.get_variable(mesh, "b1", [classes_dim])

    logits = mtf.relu(mtf.einsum([x, w1], [batch_dim, classes_dim]) + b1)

    if labels is None:
        loss = None
    else:
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(y, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
Exemple #15
0
    def get_indices(self, keys: mtf.Tensor,
                    query: mtf.Tensor) -> Tuple[mtf.Tensor, mtf.Tensor]:
        """Generate score and indices for the query."""
        score_shape = mtf.Shape(query.shape.dims[:-1] + keys.shape.dims[2:3])
        scores = mtf.einsum([query, keys],
                            output_shape=score_shape)  # [b, l, h, 2, n_keys]
        knn_dim = mtf.Dimension("knn", self.knn)
        scores, indices = mtf.top_k(scores, score_shape.dims[-1],
                                    knn_dim)  # [b, l, h, 2, knn]

        # Computes the top cartesian products and their indices
        knn_square_dim = mtf.Dimension("knn_square_dim", self.knn**2)
        scores1, scores2 = mtf.unstack(scores, scores.shape.dims[-2])
        scores2 = mtf.rename_dimension(scores2, "knn", "knn2")
        out_shape = mtf.Shape(scores1.shape.dims + scores2.shape.dims[-1:])
        all_scores = mtf.add(scores1, scores2, output_shape=out_shape)
        all_scores = mtf.replace_dimensions(all_scores, out_shape[-2:],
                                            knn_square_dim)

        indices1, indices2 = mtf.unstack(indices, indices.shape.dims[-2])
        indices1 = mtf.multiply(indices1, self.n_keys)
        indices2 = mtf.rename_dimension(indices2, "knn", "knn2")
        all_indices = mtf.add(indices1, indices2, output_shape=out_shape)
        all_indices = mtf.replace_dimensions(all_indices, out_shape[-2:],
                                             knn_square_dim)

        scores, best_indices = mtf.top_k(all_scores, all_scores.shape.dims[-1],
                                         knn_dim)
        return scores, mtf.gather(all_indices, best_indices, knn_square_dim)
Exemple #16
0
def causal_linear_attention(q, k, v, epsilon=1e-6):
    batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3])
    q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in")
    k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in")

    dim_in = k.shape[-1]

    q = mtf.softmax(q, dim_in)
    k = mtf.exp(k)

    cumulative_k = mtf.cumsum(k, seq_dim)
    context = mtf.einsum([k, v], output_shape=[batch_dim, seq_dim, head_dim, dim_in, dim_out])
    cumulative_context = mtf.cumsum(context, seq_dim)

    cumulative_context /= (cumulative_k + epsilon)
    attn = mtf.einsum([q, cumulative_context], output_shape=[batch_dim, seq_dim, head_dim, dim_out])
    return attn
Exemple #17
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     params = mtf.layers.multihead_attention_params(context.mesh,
                                                    self.heads_dim,
                                                    context.model_dim,
                                                    self.kv_dim,
                                                    context.variable_dtype)
     if context.mode == "incremental":
         prev_k, prev_v = context.get_states(2)
         y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
             x, prev_k, prev_v, context.position, params=params)
         context.record_new_states([new_k, new_v])
         return y
     else:
         kv = []
         y = mtf.layers.masked_local_attention_1d(x,
                                                  self.kv_dim,
                                                  self.heads_dim,
                                                  self.window_size,
                                                  params=params,
                                                  return_kv=kv)
         if context.mode == "first_part":
             k = kv[0]
             v = kv[1]
             window_dim = mtf.Dimension("window", self.window_size)
             mesh = k.mesh
             window_pos = mtf.range(mesh, window_dim, tf.int32)
             pos = mtf.range(mesh, context.length_dim, tf.int32)
             select_recent = mtf.cast(
                 mtf.equal(window_pos, mtf.mod(pos, self.window_size)),
                 k.dtype)
             select_recent *= mtf.cast(
                 mtf.less(pos, context.initial_position), k.dtype)
             select_recent *= mtf.cast(
                 mtf.greater_equal(
                     pos, context.initial_position - self.window_size),
                 k.dtype)
             state_shape = k.shape.dims[:-2] + [window_dim, self.kv_dim]
             k_state = mtf.einsum([k, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             v_state = mtf.einsum([v, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             context.new_states.extend([k_state, v_state])
         return y
Exemple #18
0
  def testOptimizeLayout(self):
    x1 = mtf.zeros(self.mesh, "a:10,b:5")
    x2 = mtf.zeros(self.mesh, "b:5,c:20")
    mtf.einsum([x1, x2], "a:10,c:20")
    optimizer = self.get_layout_optimizer()

    # Cut dimensions to make them equally sized.
    layout = optimizer.solve()
    self.assertEqual(layout, "a:m2;c:m1")

    # This optimal layout should have the lowest value.
    layout_value = optimizer.evaluate_layout(layout)
    self.assertLessEqual(layout_value, optimizer.evaluate_layout("a:m1;b:m2"))
    self.assertLessEqual(layout_value, optimizer.evaluate_layout("a:m1;c:m2"))
    self.assertLessEqual(layout_value, optimizer.evaluate_layout("b:m1;a:m2"))
    self.assertLessEqual(layout_value, optimizer.evaluate_layout("b:m1;c:m2"))
    self.assertLessEqual(layout_value, optimizer.evaluate_layout("c:m1;b:m2"))
    self.assertEqual(layout_value, optimizer.evaluate_layout("c:m1;a:m2"))
Exemple #19
0
 def talking_heads(
     self, context, inp, name, input_heads_dims, output_heads_dims,
     dynamic_projections_from=None):
   shared_dims = [d for d in input_heads_dims if d in output_heads_dims]
   reduced_dims = [d for d in input_heads_dims if d not in output_heads_dims]
   new_dims = [d for d in output_heads_dims if d not in input_heads_dims]
   if not (reduced_dims or new_dims):
     # Output dimensions are same as input dimensions.  Return the input
     return inp
   elif dynamic_projections_from:
     # There are one or more dynamic talking-heads-projections
     with tf.variable_scope(name):
       # static projection - this is the same as the static projection in the
       # "else" case below.  We create the weight matrix with get_variable
       # instead of calling mtf.layers.dense() so that we can fold the
       # static projection into one of the dynamic projections.
       static_p_initializer = mtf.layers.VarianceScalingInitializer()(
           reduced_dims, new_dims)
       static_p_shape = (
           context.model.ensemble_dims + shared_dims + reduced_dims + new_dims)
       static_p = mtf.get_variable(inp.mesh,
                                   "kernel",
                                   static_p_shape,
                                   initializer=static_p_initializer,
                                   dtype=context.variable_dtype)
       ps = []
       for i, dp_from in enumerate(dynamic_projections_from):
         kernel_initializer = mtf.layers.VarianceScalingInitializer(
             self.dynamic_projections_init_scale
             / mtf.Shape(reduced_dims).size)
         ps.append(
             mtf.layers.dense(
                 dp_from, reduced_dims=[context.model.model_dim],
                 new_dims=shared_dims + reduced_dims + new_dims,
                 use_bias=False, activation=None,
                 variable_dtype=context.variable_dtype,
                 name="%s_dynamic_%d" % (name, i),
                 expert_dims=context.model.ensemble_dims,
                 kernel_initializer=kernel_initializer))
       # Fold the static projection into one of the static projections.
       # Mathematically, we could add all the dynamic projections together
       #   here, but it would create a very large tensor which contained
       #   both the query-length and memory-length dimensions, and would
       #   probably be slower in practice.
       ps[0] += static_p
       return mtf.add_n(
           [mtf.einsum([inp, p], reduced_dims=reduced_dims) for p in ps])
   else:
     # No dynamic projections.  Static talking-heads projection only
     return mtf.layers.dense(
         inp, reduced_dims=reduced_dims,
         new_dims=new_dims,
         use_bias=False, activation=None,
         variable_dtype=context.variable_dtype,
         name=name, expert_dims=context.model.ensemble_dims + shared_dims)
    def compute_loss(self, decoder: transformer.Unitransformer,
                     hidden: mtf.Tensor, targets: mtf.Tensor,
                     context: transformer.Context) -> mtf.Tensor:
        """Returns the loss without computing a softmax over the entire vocab."""
        loss = 0
        tail_cluster_masks = []
        for cluster in self._tail_clusters:
            cluster_mask = cluster.get_cluster_mask(targets)
            tail_cluster_masks.append(cluster_mask)

            if cluster.length_projection_factor == 1:
                targets_in_cluster = mtf.where(cluster_mask, targets, 0)
                hidden_in_cluster = mtf.where(cluster_mask, hidden, 0)
            else:
                # TODO(mmatena): Unfold the batch dim to get a super long sequence dim
                # to reduce the risk of overflowing the projection.
                proj_to_cluster_len = cluster.get_project_to_cluster_length(
                    cluster_mask, dtype=targets.dtype)
                targets_in_cluster = mtf.einsum(
                    [proj_to_cluster_len, targets],
                    reduced_dims=[targets.shape.get_dim_by_name("length")])
                hidden_in_cluster = mtf.einsum(
                    [mtf.cast(proj_to_cluster_len, hidden.dtype), hidden],
                    reduced_dims=[hidden.shape.get_dim_by_name("length")])

            loss += cluster.compute_loss(decoder, hidden_in_cluster,
                                         targets_in_cluster, context)

        tail_clusters_dim = mtf.Dimension("tail_clusters",
                                          len(tail_cluster_masks))
        tail_node_targets = mtf.reduce_sum(
            mtf.stack([(self._head_cluster.end_token_id + i) *
                       mtf.cast(mask, targets.dtype)
                       for i, mask in enumerate(tail_cluster_masks)],
                      tail_clusters_dim.name),
            reduced_dim=tail_clusters_dim)
        head_targets = mtf.where(mtf.cast(tail_node_targets, tf.bool),
                                 tail_node_targets, targets)
        loss += self._head_cluster.compute_loss(decoder, hidden, head_targets,
                                                context)

        return loss
Exemple #21
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     memory_input_dim = context.encoder_output.shape[-1]
     if memory_input_dim != context.model_dim:
         raise NotImplementedError(
             "TODO(noam): support different model_dim in encoder and decoder."
         )
     wq, wk, wv, wo = mtf.layers.multihead_attention_params(
         context.mesh, self.heads_dim, context.model_dim, self.kv_dim,
         context.variable_dtype)
     q = mtf.einsum([x, wq], reduced_dims=[context.model_dim])
     if context.mode == "incremental":
         k, v, memory_length = context.get_constant_state()
     else:
         m = context.encoder_output
         memory_length, = [
             d for d in m.shape.dims if d.name == "memory_length"
         ]
         k = mtf.einsum([m, wk], reduced_dims=[context.model_dim])
         v = mtf.einsum([m, wv], reduced_dims=[context.model_dim])
         if context.mode == "first_part":
             context.record_constant_state((k, v, memory_length))
     if context.encoder_sequence_id and context.sequence_id:
         mask = mtf.cast(
             mtf.not_equal(context.sequence_id,
                           context.encoder_sequence_id),
             context.activation_dtype) * -1e9
     else:
         mask = None
     o = mtf.layers.dot_product_attention_v2(
         q,
         k,
         v,
         memory_length,
         self.kv_dim,
         self.kv_dim,
         mask,
         dropout=self.dropout_rate if context.train else 0.0,
         dropout_broadcast_dims=[context.length_dim])
     return mtf.einsum([o, wo],
                       x.shape,
                       reduced_dims=[self.heads_dim, self.kv_dim])
    def hidden_to_logits(self, hidden: mtf.Tensor,
                         context: transformer.Context) -> mtf.Tensor:
        """Function called by mtf transformer to get the logits.

    Note that we are taking the log of a mixture of softmaxes. The logits will
    then go through a softmax. This could potentially run into numerical
    stability issues. If that happens, try setting the activation_dtype to
    float32.

    Args:
      hidden: hidden model states of the final decoder layer.
      context: the context used for the call to the
        transformer.

    Returns:
      The logits.
    """
        del context
        hidden *= self._output_dim.size**-0.5

        component_prior_logits = mtf.einsum([hidden, self._mixture_weights],
                                            reduced_dims=[self._output_dim])

        component_contexts = mtf.einsum([
            mtf.rename_dimension(hidden, self._output_dim.name,
                                 self._copy_output_dim.name),
            self._context_weights,
        ],
                                        reduced_dims=[self._copy_output_dim])
        component_contexts = mtf.tanh(component_contexts)
        component_logits = mtf.einsum(
            [component_contexts, self._embedding_weights],
            reduced_dims=[self._output_dim])

        component_prior_logits = mtf.log_softmax(
            component_prior_logits, reduced_dim=self._components_dim)
        component_logits = mtf.log_softmax(component_logits,
                                           reduced_dim=self._vocab_dim)

        logits = component_prior_logits + component_logits
        logits = mtf.reduce_logsumexp(logits, reduced_dim=self._components_dim)
        return logits
Exemple #23
0
  def get_log_softmax_prefix(self, log_softmax, end_index):
    """Returns first end_index entries in log_softmax along the vocab dim."""
    prefix_dim = mtf.Dimension(self._vocab_dim.name, end_index)

    indices = mtf.mtf_range(
        log_softmax.mesh, dim=self._vocab_dim, dtype=tf.int32)
    prefix_indices = mtf.where(mtf.less(indices, end_index), indices, -1)
    projection = mtf.one_hot(
        prefix_indices, prefix_dim, dtype=log_softmax.dtype)

    return mtf.einsum([log_softmax, projection], reduced_dims=[self._vocab_dim])
Exemple #24
0
def wide(x, mask, float16=None):
    x = mtf.einsum([x,mask],output_shape=[x.shape.dims[0],x.shape.dims[-1]], name='wide_mul')
    logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape))
    if float16:
        wide_b = np.array(0,dtype=np.float16)
    else:
        wide_b = np.array(0,dtype=np.float32)

    x = mtf.add(x,wide_b,name="wide_sum")
    logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape))
    return x
Exemple #25
0
def deep(x, mask, float16=None):
    x = mtf.einsum([x, mask], output_shape=x.shape.dims, name='deep_mul')
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))

    # 使用仿照mindspore中使用fp16来计算下面的dense
    x = mtf.cast(x, dtype=tf.float16)

    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim0', size=1024),
                         name="deep-dense-0",
                         reduced_dims=x.shape.dims[-2:],
                         activation=mtf.relu,
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim1', size=512),
                         name="deep-dense-1",
                         activation=mtf.relu,
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim2', size=256),
                         name="deep-dense-2",
                         activation=mtf.relu,
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim3', size=128),
                         name="deep-dense-3",
                         activation=mtf.relu,
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    x = mtf.layers.dense(x,
                         mtf.Dimension(name='dense_dim4', size=1),
                         name="deep-dense-4",
                         variable_dtype=mtf.VariableDType(
                             tf.float16, tf.float16, tf.float16))
    logger.debug("[output tensor] (name,shape):({},{})".format(
        x.name, x.shape))
    if float16:
        pass
    else:
        x = mtf.cast(x, dtype=tf.float32)
    return x
Exemple #26
0
def lpt_init_single(lr_field, a0, kvec_lr, halo_size, lr_shape, hr_shape, part_shape, antialias=True, order=1, post_filtering=True, cosmology=Planck15):
  a = a0
  batch_dim = lr_field.shape[0]
  lnc = lr_shape[-1].size

  # Create particles on the high resolution grid
  mstate = mesh_ops.mtf_indices(lr_field.mesh, shape=part_shape, dtype=tf.float32)
  X = mtf.einsum([mtf.ones(lr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:])


  k_dims_lr = [d.shape[0] for d in kvec_lr]
  k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]]

  lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr)

  grad_kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr)

  # Reorder the low res FFTs which where transposed# y,z,x
  grad_kfield_lr = [grad_kfield_lr[2], grad_kfield_lr[0], grad_kfield_lr[1]]


  displacement = []
  for f in grad_kfield_lr:
    f = mesh_utils.c2r3d(f, lr_shape[-3:])
    f = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
                      [f],
                      output_dtype=tf.float32,
                      output_shape=mtf.Shape(hr_shape[0:4]+[
                        mtf.Dimension('sx_block', lnc//hr_shape[1].size),
                        mtf.Dimension('sy_block', lnc//hr_shape[2].size),
                        mtf.Dimension('sz_block', lnc//hr_shape[3].size)]),
                      name='my_reshape',
                      splittable_dims=lr_shape[:-1]+hr_shape[1:4]+part_shape[1:3])

    for block_size_dim in hr_shape[-3:]:
      f = mtf.pad(f, [halo_size, halo_size], block_size_dim.name)
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]):
      f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size)
    d =  mesh_utils.cic_readout(f, X, halo_size)
    displacement.append(d)
  # Readout to particle positions
  displacement = mtf.stack([ d for d in displacement],"ndim",axis=4)

  pt = PerturbationGrowth(cosmology, a=[a], a_normalize=1.0)
  DX = pt.D1(a) * displacement
  P = (a ** 2 * pt.f1(a) * pt.E(a)) * DX
  F = (a ** 2 * pt.E(a) * pt.gf(a) / pt.D1(a)) * DX
  # TODO: Implement 2nd order LPT

  # Moves the particles according to displacement
  X = X + DX

  return X, P, F
Exemple #27
0
    def _get_decoder_inputs(self, context):
        """Computes the inputs to the decoder when using transparent attention.

    We must cache on the context in order to ensure that we are not replicating
    variables when the layer's call function is called in different tf variable
    scopes.

    Args:
      context: a Context

    Returns:
      a list containing `self.num_decoder_modules` of tensors with shape
        [<batch_dims>, length_dim, output_vocab_dim]
    """
        if hasattr(context, "decoder_layers_per_module"):
            return context.decoder_layers_per_module

        encoder_layer_outputs = [
            mtf.layers.rename_length_to_memory_length(output)
            for output in context.encoder_layer_outputs
        ]

        layers_per_module = self.layers_per_encoder_module
        encoder_module_outputs_dim = mtf.Dimension(
            "encoder_module_outputs", size=self.encoder_num_modules + 1)
        decoder_module_inputs_dim = mtf.Dimension(
            "decoder_module_inputs", size=self.decoder_num_modules)
        encoder_module_outputs = mtf.stack(
            [encoder_layer_outputs[0]] +
            encoder_layer_outputs[layers_per_module::layers_per_module],
            dim_name="encoder_module_outputs")
        w = mtf.get_variable(
            context.mesh,
            "w",
            mtf.Shape([encoder_module_outputs_dim, decoder_module_inputs_dim]),
            initializer=tf.random_normal_initializer(
                stddev=(encoder_module_outputs_dim.size *
                        decoder_module_inputs_dim.size)**-0.5),
            dtype=context.variable_dtype)
        if context.train and self.dropout_rate != 0.0:
            w = mtf.dropout(w, 1.0 - self.dropout_rate)
        s = mtf.softmax(w, reduced_dim=encoder_module_outputs_dim)
        z = mtf.einsum([s, encoder_module_outputs],
                       reduced_dims=[encoder_module_outputs_dim])
        input_per_decoder = mtf.split(
            z,
            split_dim=decoder_module_inputs_dim,
            num_or_size_splits=decoder_module_inputs_dim.size)
        context.decoder_layers_per_module = [
            mtf.reshape(inpt, z.shape.dims[1:]) for inpt in input_per_decoder
        ]
        return context.decoder_layers_per_module
Exemple #28
0
    def compute_q(self, query_antecedent):
        """Compute query Tensor q.

    Args:
      query_antecedent: a Tensor with dimensions
         {query_input_dim} + other_dims
    Returns:
      a Tensor with dimensions
         query_heads_dims + {key_dim} + other_dims
    """
        ret = mtf.einsum([query_antecedent, self.wq],
                         reduced_dims=[self.query_input_dim])
        if self.combine_dims:
            ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.q_dims)
        return ret
Exemple #29
0
    def compute_v(self, memory_antecedent):
        """Compute value Tensor v.

    Args:
      memory_antecedent: a Tensor with dimensions
        {memory_input_dim} + other_dims
    Returns:
      a Tensor with dimensions
        memory_heads_dims + {value_dim} + other_dims
    """
        if self.shared_kv:
            raise ValueError("compute_v cannot be called with shared_kv")
        ret = mtf.einsum([memory_antecedent, self.wv],
                         reduced_dims=[self.memory_input_dim])
        if self.combine_dims:
            ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.v_dims)
        return ret
Exemple #30
0
    def call(self, context, x: mtf.Tensor) -> mtf.Tensor:
        """Call the layer."""
        # Initialize Memory Keys and Values
        n_key_dim = mtf.Dimension("n_keys", self.n_keys)
        n_value_dim = mtf.Dimension("n_values", self.n_values)
        key_dim = mtf.Dimension("key", self.key_size // 2)
        value_dim = x.shape.dims[-1]
        head_dim = mtf.Dimension("n_heads", self.n_heads)
        product_dim = mtf.Dimension("product_key", 2)
        keys = mtf.get_variable(
            context.mesh,
            name="keys",
            shape=mtf.Shape([head_dim, product_dim, n_key_dim, key_dim]),
            dtype=context.variable_dtype)
        values = mtf.layers.embedding_weights(
            context.mesh,
            vocab_dim=n_value_dim,
            output_dim=value_dim,
            variable_dtype=context.variable_dtype,
            name="values")

        # Compute query
        new_dims = [head_dim, product_dim, key_dim]
        reduce_dims = x.shape.dims[-1:]
        query = mtf.layers.dense(x,
                                 new_dims,
                                 reduced_dims=reduce_dims,
                                 activation=None,
                                 use_bias=True,
                                 variable_dtype=context.variable_dtype,
                                 name="query")  # [b, l, h, 2, k]

        # Note: We use layer norm instead of batch norm to normalize queries.
        # The main advantage is that layer norm works well with the codebase
        # whereas the implementation of batch norm requires handling of tf ops.
        query = mtf.layers.layer_norm(query, query.shape.dims[-1])

        # Retrieve indices and scores
        scores, indices = self.get_indices(keys, query)  # [b, l, h, k]
        scores = mtf.softmax(scores, reduced_dim=scores.shape.dims[-1])
        top_values = mtf.gather(values, indices,
                                n_value_dim)  # [b, l, h, k, v]
        out_values = mtf.einsum(
            [top_values, scores],
            reduced_dims=scores.shape.dims[-2:])  # [b, l, v]
        return out_values
  def _sample(self, features, mesh):
    hparams = self._hparams
    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if hparams.transformer_type == "encdec":
      inputs = features["inputs"]
      while len(inputs.shape.as_list()) > 2:
        inputs = tf.squeeze(inputs, axis=2)
      actual_batch_size = tf.shape(inputs)[0]
      actual_length = tf.shape(inputs)[1]
      inputs = tf.pad(
          inputs, [[0, hparams.batch_size - actual_batch_size],
                   [0, hparams.max_length - actual_length]])
      inputs = self._import_to_batch_by_length(
          inputs, "inputs", mesh, hparams)
      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.reshape(positional_embedding_var,
                       mtf.Shape([self.length_dim, self.model_dim])))
      encoder_attention_mask = (
          mtf.layers.attention_mask_ignore_padding(
              inputs, dtype=self.activation_dtype))
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.encoder_layers,
                              self_attention_mask=encoder_attention_mask)
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)
      encdec_tensors = []
      for layer_num, layer_type in enumerate(hparams.decoder_layers):
        if layer_type == "enc_att":
          with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num):
            q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars(
                mesh, self.heads_dim, self.model_dim,
                self.kv_dim, self.master_dtype, self.slice_dtype,
                self.activation_dtype)
            k = mtf.einsum(
                [encoder_output, k_var],
                mtf.Shape(
                    self.batch_dims + [self.heads_dim,
                                       self.memory_length_dim, self.kv_dim]))
            v = mtf.einsum(
                [encoder_output, v_var],
                mtf.Shape(
                    self.batch_dims + [self.heads_dim,
                                       self.memory_length_dim, self.kv_dim]))
          encdec_tensors.append((q_var, o_var, k, v))
        else:
          encdec_tensors.append(None)
      partial_targets = None
    elif hparams.transformer_type == "decoder":
      encdec_tensors = None
      encoder_output = None
      encoder_attention_mask = None
      # Prepare partial targets.
      # In either features["inputs"] or features["targets"].
      # We force the outputs to begin with these sequences.
      partial_targets = features.get("inputs", None)
      if partial_targets is None:
        partial_targets = features.get("targets", None)
      if partial_targets is not None:
        partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
        partial_targets = tf.to_int32(partial_targets)
        partial_targets_batch = tf.shape(partial_targets)[0]
        partial_targets_length = tf.shape(partial_targets)[1]
        partial_targets = tf.pad(
            partial_targets, [[0, hparams.batch_size - partial_targets_batch],
                              [0, hparams.max_length - partial_targets_length]])
        partial_targets = self._import_to_batch_by_length(
            partial_targets, "partial_targets", mesh, hparams)
    else:
      raise ValueError(
          "hparams.model_type = %s not yet supported"
          % hparams.transformer_type)

    local_attention_window = mtf.Dimension(
        "local_attention_window", hparams.local_attention_window_size)
    if hparams.beam_size == 1:
      ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
      kv_shape = mtf.Shape(self.batch_dims +
                           [self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
      local_kv_shape = mtf.Shape(self.batch_dims +
                                 [self.heads_dim,
                                  local_attention_window, self.kv_dim])
    else:
      beam_dim = mtf.Dimension("beam", hparams.beam_size)
      ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim])
      kv_shape = mtf.Shape(self.batch_dims +
                           [beam_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
      local_kv_shape = mtf.Shape(self.batch_dims +
                                 [beam_dim, self.heads_dim,
                                  local_attention_window, self.kv_dim])

    initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
    initial_states = []
    for layer in hparams.decoder_layers:
      if layer == "att":
        initial_states.extend(
            [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2)
      elif layer == "local_att":
        initial_states.extend(
            [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2)

    def logits_fn(step_num, ids, states):
      """Produce logits for this step, and new states."""
      ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
      x = (mtf.gather(targets_embedding_var, ids_this_step,
                      self.targets_vocab_dim) +
           mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
      with tf.variable_scope("decoder"):
        x, new_states = self._layer_stack(
            x,
            hparams.decoder_layers,
            encdec_attention_mask=encoder_attention_mask,
            step_num=step_num,
            encdec_tensors=encdec_tensors,
            states=states)
      logits = mtf.matmul(x, softmax_var)
      return logits, new_states

    if hparams.beam_size == 1:
      temperature = (0.0 if hparams.sampling_method == "argmax"
                     else hparams.sampling_temp)
      return mtf.beam_search.greedy_decode(
          logits_fn,
          initial_ids,
          temperature=temperature,
          initial_states=initial_states,
          forced_ids=partial_targets,
          use_tpu=hparams.use_tpu)
    else:
      if hparams.transformer_type == "encdec":
        input_length = mtf.reduce_sum(
            mtf.to_float(mtf.cast(inputs, tf.bool)),
            reduced_dim=self.length_dim)
        max_input_length = mtf.reduce_max(input_length)
        decode_length = mtf.cast(
            max_input_length * hparams.decode_length_multiplier
            + hparams.decode_length_constant, tf.int32)
      else:
        decode_length = None
      beams, unused_scores = mtf.beam_search.beam_search(
          logits_fn,
          initial_ids,
          hparams.alpha,
          states=initial_states,
          decode_length=decode_length,
          use_tpu=hparams.use_tpu,
          dtype=self.activation_dtype)
      return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)