Beispiel #1
0
    def __init__(
        self,
        hidden_size1: int,
        hidden_size2: int,
        combined_hidden_size: int,
        num_attention_heads: int,
        dropout1: float = 0.0,
        dropout2: float = 0.0,
        scoring_func1: str = "scaled_dot_product",
        scoring_func2: str = "scaled_dot_product",
    ):
        super().__init__()
        if combined_hidden_size % num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (combined_hidden_size, num_attention_heads))

        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(combined_hidden_size /
                                       num_attention_heads)

        # This is basically the `combined_hidden_size`, since we already ensure
        # that `combined_hidden_size` is divisible by `num_attention_heads`.
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # First modality:

        self.query1 = torch.nn.Linear(hidden_size1, self.all_head_size)
        self.key1 = torch.nn.Linear(hidden_size1, self.all_head_size)
        self.value1 = torch.nn.Linear(hidden_size1, self.all_head_size)

        self.scoring_func1 = scoring_func1
        self.attn1 = MatrixAttention.by_name(self.scoring_func1)()
        self.dropout1 = torch.nn.Dropout(dropout1)

        # Second modality:

        self.query2 = torch.nn.Linear(hidden_size2, self.all_head_size)
        self.key2 = torch.nn.Linear(hidden_size2, self.all_head_size)
        self.value2 = torch.nn.Linear(hidden_size2, self.all_head_size)

        self.scoring_func2 = scoring_func2
        self.attn2 = MatrixAttention.by_name(self.scoring_func2)()
        self.dropout2 = torch.nn.Dropout(dropout2)
Beispiel #2
0
    def __init__(
        self,
        hidden_size: int = 512,
        attention_head_size: int = 64,
        num_attention_heads: int = 8,
        scoring_func: str = "scaled_dot_product",
        output_linear: bool = False,
        dropout: float = 0.0,
        bias: bool = True,
        normalize_weights: bool = False,
        is_decoder: bool = False,
        is_cross_attention: bool = False,
        relative_attention_num_buckets: Optional[int] = None,
    ):

        super().__init__()

        if hidden_size % num_attention_heads != 0:
            raise ConfigurationError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, num_attention_heads))

        if is_cross_attention and not is_decoder:
            raise ConfigurationError(
                "The attention layer can be a cross-attention layer only "
                "if it is within a decoder.")

        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = attention_head_size
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = torch.nn.Linear(hidden_size,
                                     self.all_head_size,
                                     bias=bias)
        self.key = torch.nn.Linear(hidden_size, self.all_head_size, bias=bias)
        self.value = torch.nn.Linear(hidden_size,
                                     self.all_head_size,
                                     bias=bias)

        # out linear layer for distilbert, T5 etc.
        if output_linear:
            self.output = torch.nn.Linear(self.all_head_size,
                                          hidden_size,
                                          bias=bias)

        self.scoring_func = scoring_func
        self.attn = MatrixAttention.by_name(self.scoring_func)()

        self.relative_attention_num_buckets = relative_attention_num_buckets

        if self.relative_attention_num_buckets is not None:
            self.relative_attention_bias = torch.nn.Embedding(
                self.relative_attention_num_buckets, self.num_attention_heads)

        self.dropout = dropout

        self.is_decoder = is_decoder
        self.is_cross_attention = is_cross_attention

        if normalize_weights:
            self._normalize()