def test_unspecified_output_dim_fails(self): input_tensor = keras.Input(shape=(32,)) layer = einsum_dense.EinsumDense(equation="ab,bc->cd", output_shape=64) with self.assertRaisesRegexp( ValueError, ".*Dimension 'd' was specified in the output 'cd' but has " "no corresponding dim.*"): _ = layer(input_tensor)
def test_unspecified_bias_dim_fails(self): input_tensor = keras.Input(shape=(32,)) layer = einsum_dense.EinsumDense( equation="ab,bc->ac", output_shape=64, bias_axes="y") with self.assertRaisesRegexp( ValueError, ".*is not a part of the output specification.*"): _ = layer(input_tensor)
def _make_output_dense(self, free_dims, common_kwargs, name=None): """Builds the output projection matrix. Args: free_dims: Number of free dimensions for einsum equation building. common_kwargs: Common keyword arguments for einsum layer. name: the name for the projection layer. Returns: Projection layer. """ if self._output_shape: if not isinstance(self._output_shape, collections.abc.Sized): output_shape = [self._output_shape] else: output_shape = self._output_shape else: output_shape = [self._query_shape[-1]] einsum_equation, bias_axes, output_rank = _build_proj_equation( free_dims, bound_dims=2, output_dims=len(output_shape)) return einsum_dense.EinsumDense( einsum_equation, output_shape=_get_output_shape(output_rank - 1, output_shape), bias_axes=bias_axes if self._use_bias else None, name=name, **common_kwargs)
def test_incompatible_input_output_shape_fails(self): input_tensor = keras.Input(shape=(32, 64)) layer = einsum_dense.EinsumDense( equation="abc,cd->abd", output_shape=(10, 96)) with self.assertRaisesRegexp( ValueError, ".*Input shape and output shape do not match at shared " "dimension 'b'.*"): _ = layer(input_tensor)
def einsum_multihead_attention(i, q, k, v, h, n_a, reg, dropout, seqlen, mask=None): dim = n_a // h Wq = einsum_dense.EinsumDense("abc,cde->abde", kernel_regularizer=reg, output_shape=[None, h, dim], bias_axes="de", name="dense_q_%d" % i) Wk = einsum_dense.EinsumDense("abc,cde->abde", kernel_regularizer=reg, output_shape=[None, h, dim], bias_axes="de", name="dense_k_%d" % i) Wv = einsum_dense.EinsumDense("abc,cde->abde", kernel_regularizer=reg, output_shape=[None, h, dim], bias_axes="de", name="dense_v_%d" % i) Wo = einsum_dense.EinsumDense("abcd,cde->abe", kernel_regularizer=reg, output_shape=[None, n_a], bias_axes="e", name="dense_o_%d" % i) Q = Wq(q) K = Wk(k) V = Wv(v) C, attn_factor = einsum_attn(i, Q, K, V, dropout, dim, mask) return Wo(C)
def test_layer_creation(self, equation, bias_axes, input_shape, output_shape, expected_weight_shape, expected_bias_shape, expected_output_shape): # Keras elides the 0-dimension of the input shape when constructing inputs. non_batch_input_shape = list(input_shape)[1:] input_tensor = keras.Input(shape=non_batch_input_shape) layer = einsum_dense.EinsumDense( equation=equation, output_shape=output_shape, bias_axes=bias_axes) output_tensor = layer(input_tensor) self.assertAllEqual(expected_weight_shape, layer.kernel.shape.as_list()) if expected_bias_shape is None: self.assertIsNone(layer.bias) else: self.assertAllEqual(expected_bias_shape, layer.bias.shape.as_list()) self.assertAllEqual(expected_output_shape, output_tensor.shape.as_list())
def _build_from_signature(self, query, value, key=None): """Builds layers and variables. Once the method is called, self._built_from_signature will be set to True. Args: query: query tensor or TensorShape. value: value tensor or TensorShape. key: key tensor or TensorShape. """ self._built_from_signature = True if hasattr(query, "shape"): query_shape = tensor_shape.TensorShape(query.shape) else: query_shape = query if hasattr(value, "shape"): value_shape = tensor_shape.TensorShape(value.shape) else: value_shape = value if key is None: key_shape = value_shape elif hasattr(key, "shape"): key_shape = tensor_shape.TensorShape(key.shape) else: key_shape = key common_kwargs = dict(kernel_initializer=self._kernel_initializer, bias_initializer=self._bias_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer, activity_regularizer=self._activity_regularizer, kernel_constraint=self._kernel_constraint, bias_constraint=self._bias_constraint) # Any setup work performed only once should happen in an `init_scope` # to avoid creating symbolic Tensors that will later pollute any eager # operations. with tf_utils.maybe_init_scope(self): free_dims = query_shape.rank - 1 einsum_equation, bias_axes, output_rank = _build_proj_equation( free_dims, bound_dims=1, output_dims=2) self._query_dense = einsum_dense.EinsumDense( einsum_equation, output_shape=_get_output_shape( output_rank - 1, [self._num_heads, self._key_dim]), bias_axes=bias_axes if self._use_bias else None, name="query", **common_kwargs) einsum_equation, bias_axes, output_rank = _build_proj_equation( key_shape.rank - 1, bound_dims=1, output_dims=2) self._key_dense = einsum_dense.EinsumDense( einsum_equation, output_shape=_get_output_shape( output_rank - 1, [self._num_heads, self._key_dim]), bias_axes=bias_axes if self._use_bias else None, name="key", **common_kwargs) einsum_equation, bias_axes, output_rank = _build_proj_equation( value_shape.rank - 1, bound_dims=1, output_dims=2) self._value_dense = einsum_dense.EinsumDense( einsum_equation, output_shape=_get_output_shape( output_rank - 1, [self._num_heads, self._value_dim]), bias_axes=bias_axes if self._use_bias else None, name="value", **common_kwargs) # Builds the attention computations for multi-head dot product attention. # These computations could be wrapped into the keras attention layer once # it support mult-head einsum computations. self._build_attention(output_rank) if self._output_shape: if not isinstance(self._output_shape, collections.abc.Sized): output_shape = [self._output_shape] else: output_shape = self._output_shape else: output_shape = [query_shape[-1]] einsum_equation, bias_axes, output_rank = _build_proj_equation( free_dims, bound_dims=2, output_dims=len(output_shape)) self._output_dense = einsum_dense.EinsumDense( einsum_equation, output_shape=_get_output_shape(output_rank - 1, output_shape), bias_axes=bias_axes if self._use_bias else None, name="attention_output", **common_kwargs)
def test_unspecified_weight_dim_fails(self): input_tensor = keras.Input(shape=(32,)) layer = einsum_dense.EinsumDense(equation="ab,zd->ad", output_shape=64) with self.assertRaisesRegexp( ValueError, ".*Weight dimension 'z' did not have a match "): _ = layer(input_tensor)