Exemplo n.º 1
0
    def _maybe_build(self, inputs):
        # Check input assumptions set before layer building, e.g. input rank.
        if not self.built:
            input_spec.assert_input_compatibility(self.input_spec, inputs,
                                                  self.name)
            input_list = nest.flatten(inputs)

            input_shapes = None
            if all(hasattr(x, 'shape') for x in input_list):
                input_shapes = nest.map_structure(lambda x: x.shape, inputs)
            # Only call `build` if the user has manually overridden the build method.
            if not hasattr(self.build, '_is_default'):
                # Any setup work performed only once should happen in an `init_scope`
                # to avoid creating symbolic Tensors that will later pollute any eager
                # operations.
                with tf_utils.maybe_init_scope(self):
                    self.build(input_shapes)
            # We must set self.built since user defined build functions are not
            # constrained to set self.built.
            self.built = True
Exemplo n.º 2
0
    def _build_from_signature(self, query, value, key=None):
        """Builds layers and variables.

    Once the method is called, self._built_from_signature will be set to True.

    Args:
      query: query tensor or TensorShape.
      value: value tensor or TensorShape.
      key: key tensor or TensorShape.
    """
        self._built_from_signature = True
        if hasattr(query, "shape"):
            query_shape = tensor_shape.TensorShape(query.shape)
        else:
            query_shape = query
        if hasattr(value, "shape"):
            value_shape = tensor_shape.TensorShape(value.shape)
        else:
            value_shape = value
        if key is None:
            key_shape = value_shape
        elif hasattr(key, "shape"):
            key_shape = tensor_shape.TensorShape(key.shape)
        else:
            key_shape = key

        common_kwargs = dict(kernel_initializer=self._kernel_initializer,
                             bias_initializer=self._bias_initializer,
                             kernel_regularizer=self._kernel_regularizer,
                             bias_regularizer=self._bias_regularizer,
                             activity_regularizer=self._activity_regularizer,
                             kernel_constraint=self._kernel_constraint,
                             bias_constraint=self._bias_constraint)
        # Any setup work performed only once should happen in an `init_scope`
        # to avoid creating symbolic Tensors that will later pollute any eager
        # operations.
        with tf_utils.maybe_init_scope(self):
            free_dims = query_shape.rank - 1
            einsum_equation, bias_axes, output_rank = _build_proj_equation(
                free_dims, bound_dims=1, output_dims=2)
            self._query_dense = einsum_dense.EinsumDense(
                einsum_equation,
                output_shape=_get_output_shape(
                    output_rank - 1, [self._num_heads, self._key_dim]),
                bias_axes=bias_axes if self._use_bias else None,
                name="query",
                **common_kwargs)
            einsum_equation, bias_axes, output_rank = _build_proj_equation(
                key_shape.rank - 1, bound_dims=1, output_dims=2)
            self._key_dense = einsum_dense.EinsumDense(
                einsum_equation,
                output_shape=_get_output_shape(
                    output_rank - 1, [self._num_heads, self._key_dim]),
                bias_axes=bias_axes if self._use_bias else None,
                name="key",
                **common_kwargs)
            einsum_equation, bias_axes, output_rank = _build_proj_equation(
                value_shape.rank - 1, bound_dims=1, output_dims=2)
            self._value_dense = einsum_dense.EinsumDense(
                einsum_equation,
                output_shape=_get_output_shape(
                    output_rank - 1, [self._num_heads, self._value_dim]),
                bias_axes=bias_axes if self._use_bias else None,
                name="value",
                **common_kwargs)

            # Builds the attention computations for multi-head dot product attention.
            # These computations could be wrapped into the keras attention layer once
            # it support mult-head einsum computations.
            self._build_attention(output_rank)
            if self._output_shape:
                if not isinstance(self._output_shape, collections.abc.Sized):
                    output_shape = [self._output_shape]
                else:
                    output_shape = self._output_shape
            else:
                output_shape = [query_shape[-1]]
            einsum_equation, bias_axes, output_rank = _build_proj_equation(
                free_dims, bound_dims=2, output_dims=len(output_shape))
            self._output_dense = einsum_dense.EinsumDense(
                einsum_equation,
                output_shape=_get_output_shape(output_rank - 1, output_shape),
                bias_axes=bias_axes if self._use_bias else None,
                name="attention_output",
                **common_kwargs)
    def __init__(self,
                 num_heads,
                 key_dim,
                 local_scope,
                 num_timesteps,
                 num_features,
                 value_dim=None,
                 dropout=0.0):
        super(MixedMultiHeadAttention, self).__init__()

        assert num_heads % 4 == 0

        self._num_mixed_heads = int(num_heads / 4)
        self._key_dim = key_dim
        self._local_scope = local_scope
        self._num_timesteps = num_timesteps
        self._num_features = num_features
        self._value_dim = value_dim if value_dim else key_dim
        self._dropout = dropout

        self._query_dense = EinsumDense(equation="abe,cdef->acdbf",
                                        output_shape=(4, self._num_mixed_heads,
                                                      self._num_timesteps,
                                                      self._key_dim),
                                        bias_axes="f")

        self._key_dense = EinsumDense(equation="abe,cdef->acdfb",
                                      output_shape=(4, self._num_mixed_heads,
                                                    self._key_dim,
                                                    self._num_timesteps),
                                      bias_axes="f")

        self._value_dense = EinsumDense(equation="abe,cdef->acdbf",
                                        output_shape=(4, self._num_mixed_heads,
                                                      self._num_timesteps,
                                                      self._value_dim),
                                        bias_axes="f")

        self._softmax = Softmax()
        self._dropout_layer = Dropout(self._dropout)

        self._output_dense = EinsumDense(equation="abc,cd->abd",
                                         output_shape=(self._num_timesteps,
                                                       self._num_features),
                                         bias_axes="d")

        with tf_utils.maybe_init_scope(self):
            g = np.zeros((self._num_timesteps, self._num_timesteps))
            l = np.zeros((self._num_timesteps, self._num_timesteps))
            f = np.zeros((self._num_timesteps, self._num_timesteps))
            b = np.zeros((self._num_timesteps, self._num_timesteps))

            for i in range(self._num_timesteps):
                for j in range(self._num_timesteps):
                    if i - self._local_scope > j or j > i + self._local_scope:
                        l[i, j] = np.NINF
                    if i > j:
                        f[i, j] = np.NINF
                    if i < j:
                        b[i, j] = np.NINF

            m = np.stack([g, l, f, b])
            m = tf.convert_to_tensor(m, tf.float16)
            m = tf.expand_dims(m, 1)
            m = tf.expand_dims(m, 0)

            self._masks = m