def test_unspecified_output_dim_fails(self): input_tensor = keras.Input(shape=(32,)) layer = einsum_dense.EinsumDense(equation="ab,bc->cd", output_shape=64) with self.assertRaisesRegex( ValueError, ".*Dimension 'd' was specified in the output 'cd' but has " "no corresponding dim.*"): _ = layer(input_tensor)
def _make_output_dense(self, free_dims, common_kwargs, name=None): """Builds the output projection matrix. Args: free_dims: Number of free dimensions for einsum equation building. common_kwargs: Common keyword arguments for einsum layer. name: Name for the projection layer. Returns: Projection layer. """ if self._output_shape: if not isinstance(self._output_shape, collections.abc.Sized): output_shape = [self._output_shape] else: output_shape = self._output_shape else: output_shape = [self._query_shape[-1]] einsum_equation, bias_axes, output_rank = _build_proj_equation( free_dims, bound_dims=2, output_dims=len(output_shape)) return einsum_dense.EinsumDense( einsum_equation, output_shape=_get_output_shape(output_rank - 1, output_shape), bias_axes=bias_axes if self._use_bias else None, name=name, **common_kwargs)
def test_unspecified_bias_dim_fails(self): input_tensor = keras.Input(shape=(32,)) layer = einsum_dense.EinsumDense( equation="ab,bc->ac", output_shape=64, bias_axes="y") with self.assertRaisesRegex( ValueError, ".*is not a part of the output specification.*"): _ = layer(input_tensor)
def test_incompatible_input_output_shape_fails(self): input_tensor = keras.Input(shape=(32, 64)) layer = einsum_dense.EinsumDense( equation="abc,cd->abd", output_shape=(10, 96)) with self.assertRaisesRegex( ValueError, ".*Input shape and output shape do not match at shared " "dimension 'b'.*"): _ = layer(input_tensor)
def test_layer_creation(self, equation, bias_axes, input_shape, output_shape, expected_weight_shape, expected_bias_shape, expected_output_shape): # Keras elides the 0-dimension of the input shape when constructing inputs. non_batch_input_shape = list(input_shape)[1:] input_tensor = keras.Input(shape=non_batch_input_shape) layer = einsum_dense.EinsumDense( equation=equation, output_shape=output_shape, bias_axes=bias_axes) output_tensor = layer(input_tensor) self.assertAllEqual(expected_weight_shape, layer.kernel.shape.as_list()) if expected_bias_shape is None: self.assertIsNone(layer.bias) else: self.assertAllEqual(expected_bias_shape, layer.bias.shape.as_list()) self.assertAllEqual(expected_output_shape, output_tensor.shape.as_list())
def _build_from_signature(self, query, value, key=None): """Builds layers and variables. Once the method is called, self._built_from_signature will be set to True. Args: query: Query tensor or TensorShape. value: Value tensor or TensorShape. key: Key tensor or TensorShape. """ self._built_from_signature = True if hasattr(query, "shape"): self._query_shape = tf.TensorShape(query.shape) else: self._query_shape = tf.TensorShape(query) if hasattr(value, "shape"): self._value_shape = tf.TensorShape(value.shape) else: self._value_shape = tf.TensorShape(value) if key is None: self._key_shape = self._value_shape elif hasattr(key, "shape"): self._key_shape = tf.TensorShape(key.shape) else: self._key_shape = tf.TensorShape(key) common_kwargs = dict(kernel_initializer=self._kernel_initializer, bias_initializer=self._bias_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer, activity_regularizer=self._activity_regularizer, kernel_constraint=self._kernel_constraint, bias_constraint=self._bias_constraint) # Any setup work performed only once should happen in an `init_scope` # to avoid creating symbolic Tensors that will later pollute any eager # operations. with tf_utils.maybe_init_scope(self): free_dims = self._query_shape.rank - 1 einsum_equation, bias_axes, output_rank = _build_proj_equation( free_dims, bound_dims=1, output_dims=2) self._query_dense = einsum_dense.EinsumDense( einsum_equation, output_shape=_get_output_shape( output_rank - 1, [self._num_heads, self._key_dim]), bias_axes=bias_axes if self._use_bias else None, name="query", **common_kwargs) einsum_equation, bias_axes, output_rank = _build_proj_equation( self._key_shape.rank - 1, bound_dims=1, output_dims=2) self._key_dense = einsum_dense.EinsumDense( einsum_equation, output_shape=_get_output_shape( output_rank - 1, [self._num_heads, self._key_dim]), bias_axes=bias_axes if self._use_bias else None, name="key", **common_kwargs) einsum_equation, bias_axes, output_rank = _build_proj_equation( self._value_shape.rank - 1, bound_dims=1, output_dims=2) self._value_dense = einsum_dense.EinsumDense( einsum_equation, output_shape=_get_output_shape( output_rank - 1, [self._num_heads, self._value_dim]), bias_axes=bias_axes if self._use_bias else None, name="value", **common_kwargs) # Builds the attention computations for multi-head dot product attention. # These computations could be wrapped into the keras attention layer once # it support mult-head einsum computations. self._build_attention(output_rank) self._output_dense = self._make_output_dense( free_dims, common_kwargs, "attention_output")
def test_unspecified_weight_dim_fails(self): input_tensor = keras.Input(shape=(32,)) layer = einsum_dense.EinsumDense(equation="ab,zd->ad", output_shape=64) with self.assertRaisesRegex(ValueError, ".*Weight dimension 'z' did not have a match "): _ = layer(input_tensor)