Esempio n. 1
0
    def __init__(self, rnn_dim, linear_dim, dropout=0.2, weight_init=True):
        super(SelfAttention, self).__init__()

        self.self_attention = attention.DocQAAttention(
            rnn_dim, linear_dim, self_attn=True, weight_init=weight_init
        )
        self.self_attn_Linear = nn.Linear(rnn_dim * 6, linear_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.activation_fn = F.relu

        if weight_init:
            initializer.weight(self.self_attn_Linear)
Esempio n. 2
0
    def __init__(self, rnn_dim, linear_dim, self_attn=False, weight_init=True):
        super(DocQAAttention, self).__init__()
        self.self_attn = self_attn

        self.input_w = nn.Linear(2 * rnn_dim, 1, bias=False)
        self.key_w = nn.Linear(2 * rnn_dim, 1, bias=False)

        self.dot_w = nn.Parameter(torch.randn(1, 1, rnn_dim * 2))
        torch.nn.init.xavier_uniform_(self.dot_w)

        self.bias = nn.Parameter(torch.FloatTensor([[1]]))
        self.diag_mask = nn.Parameter(
            torch.eye(2000))  # NOTE: (hard-code) max_sequence_length

        if weight_init:
            initializer.weight(self.input_w)
            initializer.weight(self.key_w)
Esempio n. 3
0
    def __init__(
        self,
        token_embedder,
        aligned_query_embedding=False,
        answer_maxlen=17,
        rnn_dim=100,
        linear_dim=200,
        preprocess_rnn_num_layer=1,
        modeling_rnn_num_layer=2,
        predict_rnn_num_layer=1,
        dropout=0.2,
        weight_init=True,
    ):
        super(DocQA, self).__init__(token_embedder)

        self.aligned_query_embedding = aligned_query_embedding
        self.answer_maxlen = answer_maxlen
        self.token_embedder = token_embedder
        self.dropout = nn.Dropout(p=dropout)

        context_embed_dim, query_embed_dim = token_embedder.get_embed_dim()
        if self.aligned_query_embedding:
            context_embed_dim += query_embed_dim

        if context_embed_dim != query_embed_dim:
            self.context_preprocess_rnn = nn.GRU(
                input_size=context_embed_dim,
                hidden_size=rnn_dim,
                bidirectional=True,
                num_layers=preprocess_rnn_num_layer,
                batch_first=True,
            )
            self.query_preprocess_rnn = nn.GRU(
                input_size=query_embed_dim,
                hidden_size=rnn_dim,
                bidirectional=True,
                num_layers=preprocess_rnn_num_layer,
                batch_first=True,
            )
        else:
            preprocess_rnn = nn.GRU(
                input_size=context_embed_dim,
                hidden_size=rnn_dim,
                bidirectional=True,
                num_layers=preprocess_rnn_num_layer,
                batch_first=True,
            )

            self.context_preprocess_rnn = preprocess_rnn
            self.query_preprocess_rnn = preprocess_rnn

        self.bi_attention = attention.DocQAAttention(rnn_dim, linear_dim)
        self.attn_linear = nn.Linear(rnn_dim * 8, linear_dim)

        self.modeling_rnn = nn.GRU(
            input_size=linear_dim,
            hidden_size=rnn_dim,
            num_layers=modeling_rnn_num_layer,
            bidirectional=True,
            batch_first=True,
        )
        self.self_attention = SelfAttention(rnn_dim,
                                            linear_dim,
                                            weight_init=weight_init)

        self.span_start_rnn = nn.GRU(
            input_size=linear_dim,
            hidden_size=rnn_dim,
            bidirectional=True,
            num_layers=predict_rnn_num_layer,
            batch_first=True,
        )
        self.span_start_linear = nn.Linear(rnn_dim * 2, 1)

        self.span_end_rnn = nn.GRU(
            input_size=linear_dim + rnn_dim * 2,
            hidden_size=rnn_dim,
            bidirectional=True,
            num_layers=predict_rnn_num_layer,
            batch_first=True,
        )
        self.span_end_linear = nn.Linear(rnn_dim * 2, 1)

        self.activation_fn = F.relu
        self.criterion = nn.CrossEntropyLoss()

        if weight_init:
            modules = [
                self.context_preprocess_rnn,
                self.query_preprocess_rnn,
                self.modeling_rnn,
                self.attn_linear,
                self.span_start_rnn,
                self.span_start_linear,
                self.span_end_rnn,
                self.span_end_linear,
            ]
            initializer.weight(modules)