Ejemplo n.º 1
0
 def test_unspecified_output_dim_fails(self):
   input_tensor = keras.Input(shape=(32,))
   layer = einsum_dense.EinsumDense(equation="ab,bc->cd", output_shape=64)
   with self.assertRaisesRegex(
       ValueError, ".*Dimension 'd' was specified in the output 'cd' but has "
       "no corresponding dim.*"):
     _ = layer(input_tensor)
Ejemplo n.º 2
0
    def _make_output_dense(self, free_dims, common_kwargs, name=None):
        """Builds the output projection matrix.

    Args:
      free_dims: Number of free dimensions for einsum equation building.
      common_kwargs: Common keyword arguments for einsum layer.
      name: Name for the projection layer.

    Returns:
      Projection layer.
    """
        if self._output_shape:
            if not isinstance(self._output_shape, collections.abc.Sized):
                output_shape = [self._output_shape]
            else:
                output_shape = self._output_shape
        else:
            output_shape = [self._query_shape[-1]]
        einsum_equation, bias_axes, output_rank = _build_proj_equation(
            free_dims, bound_dims=2, output_dims=len(output_shape))
        return einsum_dense.EinsumDense(
            einsum_equation,
            output_shape=_get_output_shape(output_rank - 1, output_shape),
            bias_axes=bias_axes if self._use_bias else None,
            name=name,
            **common_kwargs)
Ejemplo n.º 3
0
 def test_unspecified_bias_dim_fails(self):
   input_tensor = keras.Input(shape=(32,))
   layer = einsum_dense.EinsumDense(
       equation="ab,bc->ac", output_shape=64, bias_axes="y")
   with self.assertRaisesRegex(
       ValueError, ".*is not a part of the output specification.*"):
     _ = layer(input_tensor)
Ejemplo n.º 4
0
 def test_incompatible_input_output_shape_fails(self):
   input_tensor = keras.Input(shape=(32, 64))
   layer = einsum_dense.EinsumDense(
       equation="abc,cd->abd", output_shape=(10, 96))
   with self.assertRaisesRegex(
       ValueError, ".*Input shape and output shape do not match at shared "
       "dimension 'b'.*"):
     _ = layer(input_tensor)
Ejemplo n.º 5
0
  def test_layer_creation(self, equation, bias_axes, input_shape, output_shape,
                          expected_weight_shape, expected_bias_shape,
                          expected_output_shape):
    # Keras elides the 0-dimension of the input shape when constructing inputs.
    non_batch_input_shape = list(input_shape)[1:]

    input_tensor = keras.Input(shape=non_batch_input_shape)
    layer = einsum_dense.EinsumDense(
        equation=equation, output_shape=output_shape, bias_axes=bias_axes)
    output_tensor = layer(input_tensor)

    self.assertAllEqual(expected_weight_shape, layer.kernel.shape.as_list())
    if expected_bias_shape is None:
      self.assertIsNone(layer.bias)
    else:
      self.assertAllEqual(expected_bias_shape, layer.bias.shape.as_list())
    self.assertAllEqual(expected_output_shape, output_tensor.shape.as_list())
Ejemplo n.º 6
0
    def _build_from_signature(self, query, value, key=None):
        """Builds layers and variables.

    Once the method is called, self._built_from_signature will be set to True.

    Args:
      query: Query tensor or TensorShape.
      value: Value tensor or TensorShape.
      key: Key tensor or TensorShape.
    """
        self._built_from_signature = True
        if hasattr(query, "shape"):
            self._query_shape = tf.TensorShape(query.shape)
        else:
            self._query_shape = tf.TensorShape(query)
        if hasattr(value, "shape"):
            self._value_shape = tf.TensorShape(value.shape)
        else:
            self._value_shape = tf.TensorShape(value)
        if key is None:
            self._key_shape = self._value_shape
        elif hasattr(key, "shape"):
            self._key_shape = tf.TensorShape(key.shape)
        else:
            self._key_shape = tf.TensorShape(key)

        common_kwargs = dict(kernel_initializer=self._kernel_initializer,
                             bias_initializer=self._bias_initializer,
                             kernel_regularizer=self._kernel_regularizer,
                             bias_regularizer=self._bias_regularizer,
                             activity_regularizer=self._activity_regularizer,
                             kernel_constraint=self._kernel_constraint,
                             bias_constraint=self._bias_constraint)
        # Any setup work performed only once should happen in an `init_scope`
        # to avoid creating symbolic Tensors that will later pollute any eager
        # operations.
        with tf_utils.maybe_init_scope(self):
            free_dims = self._query_shape.rank - 1
            einsum_equation, bias_axes, output_rank = _build_proj_equation(
                free_dims, bound_dims=1, output_dims=2)
            self._query_dense = einsum_dense.EinsumDense(
                einsum_equation,
                output_shape=_get_output_shape(
                    output_rank - 1, [self._num_heads, self._key_dim]),
                bias_axes=bias_axes if self._use_bias else None,
                name="query",
                **common_kwargs)
            einsum_equation, bias_axes, output_rank = _build_proj_equation(
                self._key_shape.rank - 1, bound_dims=1, output_dims=2)
            self._key_dense = einsum_dense.EinsumDense(
                einsum_equation,
                output_shape=_get_output_shape(
                    output_rank - 1, [self._num_heads, self._key_dim]),
                bias_axes=bias_axes if self._use_bias else None,
                name="key",
                **common_kwargs)
            einsum_equation, bias_axes, output_rank = _build_proj_equation(
                self._value_shape.rank - 1, bound_dims=1, output_dims=2)
            self._value_dense = einsum_dense.EinsumDense(
                einsum_equation,
                output_shape=_get_output_shape(
                    output_rank - 1, [self._num_heads, self._value_dim]),
                bias_axes=bias_axes if self._use_bias else None,
                name="value",
                **common_kwargs)

            # Builds the attention computations for multi-head dot product attention.
            # These computations could be wrapped into the keras attention layer once
            # it support mult-head einsum computations.
            self._build_attention(output_rank)
            self._output_dense = self._make_output_dense(
                free_dims, common_kwargs, "attention_output")
Ejemplo n.º 7
0
 def test_unspecified_weight_dim_fails(self):
   input_tensor = keras.Input(shape=(32,))
   layer = einsum_dense.EinsumDense(equation="ab,zd->ad", output_shape=64)
   with self.assertRaisesRegex(ValueError,
                               ".*Weight dimension 'z' did not have a match "):
     _ = layer(input_tensor)