Esempio n. 1
0
 def test_unspecified_output_dim_fails(self):
   input_tensor = keras.Input(shape=(32,))
   layer = einsum_dense.EinsumDense(equation="ab,bc->cd", output_shape=64)
   with self.assertRaisesRegexp(
       ValueError, ".*Dimension 'd' was specified in the output 'cd' but has "
       "no corresponding dim.*"):
     _ = layer(input_tensor)
Esempio n. 2
0
 def test_unspecified_bias_dim_fails(self):
   input_tensor = keras.Input(shape=(32,))
   layer = einsum_dense.EinsumDense(
       equation="ab,bc->ac", output_shape=64, bias_axes="y")
   with self.assertRaisesRegexp(
       ValueError, ".*is not a part of the output specification.*"):
     _ = layer(input_tensor)
Esempio n. 3
0
    def _make_output_dense(self, free_dims, common_kwargs, name=None):
        """Builds the output projection matrix.

    Args:
      free_dims: Number of free dimensions for einsum equation building.
      common_kwargs: Common keyword arguments for einsum layer.
      name: the name for the projection layer.

    Returns:
      Projection layer.
    """
        if self._output_shape:
            if not isinstance(self._output_shape, collections.abc.Sized):
                output_shape = [self._output_shape]
            else:
                output_shape = self._output_shape
        else:
            output_shape = [self._query_shape[-1]]
        einsum_equation, bias_axes, output_rank = _build_proj_equation(
            free_dims, bound_dims=2, output_dims=len(output_shape))
        return einsum_dense.EinsumDense(
            einsum_equation,
            output_shape=_get_output_shape(output_rank - 1, output_shape),
            bias_axes=bias_axes if self._use_bias else None,
            name=name,
            **common_kwargs)
Esempio n. 4
0
 def test_incompatible_input_output_shape_fails(self):
   input_tensor = keras.Input(shape=(32, 64))
   layer = einsum_dense.EinsumDense(
       equation="abc,cd->abd", output_shape=(10, 96))
   with self.assertRaisesRegexp(
       ValueError, ".*Input shape and output shape do not match at shared "
       "dimension 'b'.*"):
     _ = layer(input_tensor)
Esempio n. 5
0
def einsum_multihead_attention(i,
                               q,
                               k,
                               v,
                               h,
                               n_a,
                               reg,
                               dropout,
                               seqlen,
                               mask=None):
    dim = n_a // h
    Wq = einsum_dense.EinsumDense("abc,cde->abde",
                                  kernel_regularizer=reg,
                                  output_shape=[None, h, dim],
                                  bias_axes="de",
                                  name="dense_q_%d" % i)
    Wk = einsum_dense.EinsumDense("abc,cde->abde",
                                  kernel_regularizer=reg,
                                  output_shape=[None, h, dim],
                                  bias_axes="de",
                                  name="dense_k_%d" % i)
    Wv = einsum_dense.EinsumDense("abc,cde->abde",
                                  kernel_regularizer=reg,
                                  output_shape=[None, h, dim],
                                  bias_axes="de",
                                  name="dense_v_%d" % i)
    Wo = einsum_dense.EinsumDense("abcd,cde->abe",
                                  kernel_regularizer=reg,
                                  output_shape=[None, n_a],
                                  bias_axes="e",
                                  name="dense_o_%d" % i)

    Q = Wq(q)
    K = Wk(k)
    V = Wv(v)

    C, attn_factor = einsum_attn(i, Q, K, V, dropout, dim, mask)

    return Wo(C)
Esempio n. 6
0
  def test_layer_creation(self, equation, bias_axes, input_shape, output_shape,
                          expected_weight_shape, expected_bias_shape,
                          expected_output_shape):
    # Keras elides the 0-dimension of the input shape when constructing inputs.
    non_batch_input_shape = list(input_shape)[1:]

    input_tensor = keras.Input(shape=non_batch_input_shape)
    layer = einsum_dense.EinsumDense(
        equation=equation, output_shape=output_shape, bias_axes=bias_axes)
    output_tensor = layer(input_tensor)

    self.assertAllEqual(expected_weight_shape, layer.kernel.shape.as_list())
    if expected_bias_shape is None:
      self.assertIsNone(layer.bias)
    else:
      self.assertAllEqual(expected_bias_shape, layer.bias.shape.as_list())
    self.assertAllEqual(expected_output_shape, output_tensor.shape.as_list())
Esempio n. 7
0
    def _build_from_signature(self, query, value, key=None):
        """Builds layers and variables.

    Once the method is called, self._built_from_signature will be set to True.

    Args:
      query: query tensor or TensorShape.
      value: value tensor or TensorShape.
      key: key tensor or TensorShape.
    """
        self._built_from_signature = True
        if hasattr(query, "shape"):
            query_shape = tensor_shape.TensorShape(query.shape)
        else:
            query_shape = query
        if hasattr(value, "shape"):
            value_shape = tensor_shape.TensorShape(value.shape)
        else:
            value_shape = value
        if key is None:
            key_shape = value_shape
        elif hasattr(key, "shape"):
            key_shape = tensor_shape.TensorShape(key.shape)
        else:
            key_shape = key

        common_kwargs = dict(kernel_initializer=self._kernel_initializer,
                             bias_initializer=self._bias_initializer,
                             kernel_regularizer=self._kernel_regularizer,
                             bias_regularizer=self._bias_regularizer,
                             activity_regularizer=self._activity_regularizer,
                             kernel_constraint=self._kernel_constraint,
                             bias_constraint=self._bias_constraint)
        # Any setup work performed only once should happen in an `init_scope`
        # to avoid creating symbolic Tensors that will later pollute any eager
        # operations.
        with tf_utils.maybe_init_scope(self):
            free_dims = query_shape.rank - 1
            einsum_equation, bias_axes, output_rank = _build_proj_equation(
                free_dims, bound_dims=1, output_dims=2)
            self._query_dense = einsum_dense.EinsumDense(
                einsum_equation,
                output_shape=_get_output_shape(
                    output_rank - 1, [self._num_heads, self._key_dim]),
                bias_axes=bias_axes if self._use_bias else None,
                name="query",
                **common_kwargs)
            einsum_equation, bias_axes, output_rank = _build_proj_equation(
                key_shape.rank - 1, bound_dims=1, output_dims=2)
            self._key_dense = einsum_dense.EinsumDense(
                einsum_equation,
                output_shape=_get_output_shape(
                    output_rank - 1, [self._num_heads, self._key_dim]),
                bias_axes=bias_axes if self._use_bias else None,
                name="key",
                **common_kwargs)
            einsum_equation, bias_axes, output_rank = _build_proj_equation(
                value_shape.rank - 1, bound_dims=1, output_dims=2)
            self._value_dense = einsum_dense.EinsumDense(
                einsum_equation,
                output_shape=_get_output_shape(
                    output_rank - 1, [self._num_heads, self._value_dim]),
                bias_axes=bias_axes if self._use_bias else None,
                name="value",
                **common_kwargs)

            # Builds the attention computations for multi-head dot product attention.
            # These computations could be wrapped into the keras attention layer once
            # it support mult-head einsum computations.
            self._build_attention(output_rank)
            if self._output_shape:
                if not isinstance(self._output_shape, collections.abc.Sized):
                    output_shape = [self._output_shape]
                else:
                    output_shape = self._output_shape
            else:
                output_shape = [query_shape[-1]]
            einsum_equation, bias_axes, output_rank = _build_proj_equation(
                free_dims, bound_dims=2, output_dims=len(output_shape))
            self._output_dense = einsum_dense.EinsumDense(
                einsum_equation,
                output_shape=_get_output_shape(output_rank - 1, output_shape),
                bias_axes=bias_axes if self._use_bias else None,
                name="attention_output",
                **common_kwargs)
Esempio n. 8
0
 def test_unspecified_weight_dim_fails(self):
   input_tensor = keras.Input(shape=(32,))
   layer = einsum_dense.EinsumDense(equation="ab,zd->ad", output_shape=64)
   with self.assertRaisesRegexp(
       ValueError, ".*Weight dimension 'z' did not have a match "):
     _ = layer(input_tensor)