def _make_output_dense(self, free_dims, common_kwargs, name=None): """Builds the output projection matrix. Args: free_dims: Number of free dimensions for einsum equation building. common_kwargs: Common keyword arguments for einsum layer. name: Name for the projection layer. Returns: Projection layer. """ if self._output_shape: if not isinstance(self._output_shape, collections.abc.Sized): output_shape = [self._output_shape] else: output_shape = self._output_shape else: output_shape = [self._query_shape[-1]] einsum_equation, bias_axes, output_rank = _build_proj_equation( free_dims, bound_dims=2, output_dims=len(output_shape) ) return core.EinsumDense( einsum_equation, output_shape=_get_output_shape(output_rank - 1, output_shape), bias_axes=bias_axes if self._use_bias else None, name=name, **common_kwargs )
def _build_from_signature(self, query, value, key=None): """Builds layers and variables. Once the method is called, self._built_from_signature will be set to True. Args: query: Query tensor or TensorShape. value: Value tensor or TensorShape. key: Key tensor or TensorShape. """ self._built_from_signature = True if hasattr(query, "shape"): self._query_shape = tf.TensorShape(query.shape) else: self._query_shape = tf.TensorShape(query) if hasattr(value, "shape"): self._value_shape = tf.TensorShape(value.shape) else: self._value_shape = tf.TensorShape(value) if key is None: self._key_shape = self._value_shape elif hasattr(key, "shape"): self._key_shape = tf.TensorShape(key.shape) else: self._key_shape = tf.TensorShape(key) # Any setup work performed only once should happen in an `init_scope` # to avoid creating symbolic Tensors that will later pollute any eager # operations. with tf_utils.maybe_init_scope(self): free_dims = self._query_shape.rank - 1 einsum_equation, bias_axes, output_rank = _build_proj_equation( free_dims, bound_dims=1, output_dims=2 ) self._query_dense = core.EinsumDense( einsum_equation, output_shape=_get_output_shape( output_rank - 1, [self._num_heads, self._key_dim] ), bias_axes=bias_axes if self._use_bias else None, name="query", **self._get_common_kwargs_for_sublayer() ) einsum_equation, bias_axes, output_rank = _build_proj_equation( self._key_shape.rank - 1, bound_dims=1, output_dims=2 ) self._key_dense = core.EinsumDense( einsum_equation, output_shape=_get_output_shape( output_rank - 1, [self._num_heads, self._key_dim] ), bias_axes=bias_axes if self._use_bias else None, name="key", **self._get_common_kwargs_for_sublayer() ) einsum_equation, bias_axes, output_rank = _build_proj_equation( self._value_shape.rank - 1, bound_dims=1, output_dims=2 ) self._value_dense = core.EinsumDense( einsum_equation, output_shape=_get_output_shape( output_rank - 1, [self._num_heads, self._value_dim] ), bias_axes=bias_axes if self._use_bias else None, name="value", **self._get_common_kwargs_for_sublayer() ) # Builds the attention computations for multi-head dot product attention. # These computations could be wrapped into the keras attention layer once # it supports mult-head einsum computations. self._build_attention(output_rank) self._output_dense = self._make_output_dense( free_dims, self._get_common_kwargs_for_sublayer(), "attention_output", )