예제 #1
0
파일: sqlnet.py 프로젝트: zzozzolev/claf
    def forward(self, question_embed, question_mask, column_embed,
                column_name_mask, col_idx):
        B, C_L, N_L, embed_D = list(column_embed.size())

        # Column Encoder
        encoded_column = utils.encode_column(column_embed, column_name_mask,
                                             self.column_rnn)
        encoded_used_column = utils.filter_used_column(
            encoded_column, col_idx, padding_count=self.column_maxlen)

        encoded_question, _ = self.question_rnn(question_embed)
        if self.column_attention:
            attn_matrix = torch.matmul(
                self.linear_attn(encoded_question).unsqueeze(1),
                encoded_used_column.unsqueeze(3)).squeeze()
            attn_matrix = f.add_masked_value(attn_matrix,
                                             question_mask.unsqueeze(1),
                                             value=-1e7)
            attn_matrix = F.softmax(attn_matrix, dim=-1)
            attn_question = (encoded_question.unsqueeze(1) *
                             attn_matrix.unsqueeze(3)).sum(2)
        else:
            attn_matrix = self.seq_attn(encoded_question, question_mask)
            attn_question = f.weighted_sum(attn_matrix, encoded_question)
            attn_question = attn_question.unsqueeze(1)

        return self.mlp(
            self.linear_question(attn_question) +
            self.linear_column(encoded_used_column)).squeeze()
예제 #2
0
파일: sqlnet.py 프로젝트: zzozzolev/claf
    def forward(self, question_embed, question_mask, column_embed,
                column_name_mask, column_mask):
        B, C_L, N_L, embed_D = list(column_embed.size())

        # Column Encoder
        encoded_column = utils.encode_column(column_embed, column_name_mask,
                                             self.column_rnn)
        encoded_question, _ = self.question_rnn(question_embed)

        if self.column_attention:
            attn_matrix = torch.bmm(
                encoded_column,
                self.linear_attn(encoded_question).transpose(1, 2))
            attn_matrix = f.add_masked_value(attn_matrix,
                                             question_mask.unsqueeze(1),
                                             value=-1e7)
            attn_matrix = F.softmax(attn_matrix, dim=-1)
            attn_question = (encoded_question.unsqueeze(1) *
                             attn_matrix.unsqueeze(3)).sum(2)
        else:
            attn_matrix = self.seq_attn(encoded_question, question_mask)
            attn_question = f.weighted_sum(attn_matrix, encoded_question)
            attn_question = attn_question.unsqueeze(1)

        logits = self.mlp(
            self.linear_question(attn_question) +
            self.linear_column(encoded_column)).squeeze()
        logits = f.add_masked_value(logits, column_mask, value=-1e7)
        return logits
예제 #3
0
    def _x2key(self, S, key, key_mask):
        if self.self_attn:
            bias = torch.exp(self.bias)
            S = torch.exp(S)
            attention = S / (S.sum(dim=-1, keepdim=True).expand(S.size()) +
                             bias.expand(S.size()))
        else:
            attention = F.softmax(S, dim=-1)  # (B, C_L, Q_L)

        x2key = f.weighted_sum(attention=attention, matrix=key)  # (B, C_L, 2d)
        return x2key
예제 #4
0
파일: sqlnet.py 프로젝트: zzozzolev/claf
    def forward(self, question_embed, question_mask, column_embed,
                column_name_mask, column_mask):
        B, C_L, N_L, embed_D = list(column_embed.size())

        encoded_column = utils.encode_column(column_embed, column_name_mask,
                                             self.column_rnn)
        attn_column = self.column_seq_attn(encoded_column, column_mask)
        out_column = f.weighted_sum(attn_column, encoded_column)

        question_rnn_hidden_state = (
            self.column_to_hidden_state(out_column).view(
                B, self.column_maxlen,
                self.model_dim // 2).transpose(0, 1).contiguous())
        question_rnn_cell_state = (self.column_to_cell_state(out_column).view(
            B, self.column_maxlen,
            self.model_dim // 2).transpose(0, 1).contiguous())

        encoded_question, _ = self.question_rnn(
            question_embed,
            (question_rnn_hidden_state, question_rnn_cell_state))
        attn_question = self.question_seq_attn(encoded_question, question_mask)
        out_question = f.weighted_sum(attn_question, encoded_question)
        return self.mlp(out_question)
예제 #5
0
 def _key2x(self, S, x, x_mask):
     attention = f.masked_softmax(S, x_mask)  # (B, C_L)
     key2x = f.weighted_sum(attention=attention, matrix=x)
     return key2x.unsqueeze(1).expand(x.size())  # (B, C_L, 2d)
예제 #6
0
    def forward(self, features, labels=None):
        """
        * Args:
            features: feature dictionary like below.
                {"feature_name1": {
                     "token_name1": tensor,
                     "toekn_name2": tensor},
                 "feature_name2": ...}

        * Kwargs:
            label: label dictionary like below.
                {"label_name1": tensor,
                 "label_name2": tensor}
                 Do not calculate loss when there is no label. (inference/predict mode)

        * Returns: output_dict (dict) consisting of
            - start_logits: representing unnormalized log probabilities of the span start position.
            - end_logits: representing unnormalized log probabilities of the span end position.
            - best_span: the string from the original passage that the model thinks is the best answer to the question.
            - answer_idx: the question id, mapping with answer
            - loss: A scalar loss to be optimised.
        """

        context = features["context"]
        question = features["question"]

        # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward
        context_seq_config = f.get_sorted_seq_config(context)
        query_seq_config = f.get_sorted_seq_config(question)

        # Embedding Layer (Char + Word -> Contextual)
        query_params = {"frequent_word": {"frequent_tuning": True}}
        context_embed, query_embed = self.token_embedder(
            context,
            question,
            query_params=query_params,
            query_align=self.aligned_query_embedding)

        context_mask = f.get_mask_from_tokens(context).float()
        query_mask = f.get_mask_from_tokens(question).float()

        B, C_L = context_embed.size(0), context_embed.size(1)

        context_embed = self.context_highway(context_embed)
        query_embed = self.query_highway(query_embed)

        context_encoded = f.forward_rnn_with_pack(self.context_contextual_rnn,
                                                  context_embed,
                                                  context_seq_config)
        context_encoded = self.dropout(context_encoded)

        query_encoded = f.forward_rnn_with_pack(self.query_contextual_rnn,
                                                query_embed, query_seq_config)
        query_encoded = self.dropout(query_encoded)

        # Attention Flow Layer
        attention_context_query = self.attention(context_encoded, context_mask,
                                                 query_encoded, query_mask)

        # Modeling Layer
        modeled_context = f.forward_rnn_with_pack(self.modeling_rnn,
                                                  attention_context_query,
                                                  context_seq_config)
        modeled_context = self.dropout(modeled_context)

        M_D = modeled_context.size(-1)

        # Output Layer
        span_start_input = self.dropout(
            torch.cat([attention_context_query, modeled_context],
                      dim=-1))  # (B, C_L, 10d)
        span_start_logits = self.span_start_linear(span_start_input).squeeze(
            -1)  # (B, C_L)
        span_start_probs = f.masked_softmax(span_start_logits, context_mask)

        span_start_representation = f.weighted_sum(attention=span_start_probs,
                                                   matrix=modeled_context)
        tiled_span_start_representation = span_start_representation.unsqueeze(
            1).expand(B, C_L, M_D)

        span_end_representation = torch.cat(
            [
                attention_context_query,
                modeled_context,
                tiled_span_start_representation,
                modeled_context * tiled_span_start_representation,
            ],
            dim=-1,
        )
        encoded_span_end = f.forward_rnn_with_pack(self.output_end_rnn,
                                                   span_end_representation,
                                                   context_seq_config)
        encoded_span_end = self.dropout(encoded_span_end)

        span_end_input = self.dropout(
            torch.cat([attention_context_query, encoded_span_end], dim=-1))
        span_end_logits = self.span_end_linear(span_end_input).squeeze(-1)

        # Masked Value
        span_start_logits = f.add_masked_value(span_start_logits,
                                               context_mask,
                                               value=-1e7)
        span_end_logits = f.add_masked_value(span_end_logits,
                                             context_mask,
                                             value=-1e7)

        # No_Answer Bias
        bias = self.bias.expand(B, 1)
        span_start_logits = torch.cat([span_start_logits, bias], dim=-1)
        span_end_logits = torch.cat([span_end_logits, bias], dim=-1)

        output_dict = {
            "start_logits":
            span_start_logits,
            "end_logits":
            span_end_logits,
            "best_span":
            self.get_best_span(
                span_start_logits[:, :-1],
                span_end_logits[:, :-1],
                answer_maxlen=self.answer_maxlen,  # except no_answer bias
            ),
        }

        if labels:
            answer_idx = labels["answer_idx"]
            answer_start_idx = labels["answer_start_idx"]
            answer_end_idx = labels["answer_end_idx"]
            answerable = labels["answerable"]

            # No_Asnwer Case
            C_L = context_mask.size(1)
            answer_start_idx = answer_start_idx.masked_fill(
                answerable.eq(0), C_L)
            answer_end_idx = answer_end_idx.masked_fill(answerable.eq(0), C_L)

            output_dict["answer_idx"] = answer_idx

            # Loss
            loss = self.criterion(span_start_logits, answer_start_idx)
            loss += self.criterion(span_end_logits, answer_end_idx)
            output_dict["loss"] = loss.unsqueeze(
                0)  # NOTE: DataParallel concat Error

        return output_dict
예제 #7
0
    def forward(self, features, labels=None):
        """
        * Args:
            features: feature dictionary like below.
                {"feature_name1": {
                     "token_name1": tensor,
                     "toekn_name2": tensor},
                 "feature_name2": ...}

        * Kwargs:
            label: label dictionary like below.
                {"label_name1": tensor,
                 "label_name2": tensor}
                 Do not calculate loss when there is no label. (inference/predict mode)

        * Returns: output_dict (dict) consisting of
            - start_logits: representing unnormalized log probabilities of the span start position.
            - end_logits: representing unnormalized log probabilities of the span end position.
            - best_span: the string from the original passage that the model thinks is the best answer to the question.
            - data_idx: the question id, mapping with answer
            - loss: A scalar loss to be optimised.
        """

        context = features["context"]  # aka paragraph
        question = features["question"]

        # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward
        context_seq_config = f.get_sorted_seq_config(context)
        query_seq_config = f.get_sorted_seq_config(question)

        # Embedding
        query_params = {"frequent_word": {"frequent_tuning": True}}
        context_embed, query_embed = self.token_embedder(
            context,
            question,
            query_params=query_params,
            query_align=self.aligned_query_embedding)

        context_mask = f.get_mask_from_tokens(context).float()
        query_mask = f.get_mask_from_tokens(question).float()

        context_embed = self.dropout(context_embed)
        query_embed = self.dropout(query_embed)

        # RNN (LSTM)
        context_encoded = f.forward_rnn_with_pack(self.paragraph_rnn,
                                                  context_embed,
                                                  context_seq_config)
        context_encoded = self.dropout(context_encoded)

        query_encoded = f.forward_rnn_with_pack(
            self.query_rnn, query_embed, query_seq_config)  # (B, Q_L, H*2)
        query_encoded = self.dropout(query_encoded)

        query_attention = self.query_att(query_encoded, query_mask)  # (B, Q_L)
        query_att_sum = f.weighted_sum(query_attention,
                                       query_encoded)  # (B, H*2)

        span_start_logits = self.start_attn(context_encoded, query_att_sum,
                                            context_mask)
        span_end_logits = self.end_attn(context_encoded, query_att_sum,
                                        context_mask)

        # Masked Value
        span_start_logits = f.add_masked_value(span_start_logits,
                                               context_mask,
                                               value=-1e7)
        span_end_logits = f.add_masked_value(span_end_logits,
                                             context_mask,
                                             value=-1e7)

        output_dict = {
            "start_logits":
            span_start_logits,
            "end_logits":
            span_end_logits,
            "best_span":
            self.get_best_span(span_start_logits,
                               span_end_logits,
                               answer_maxlen=self.answer_maxlen),
        }

        if labels:
            data_idx = labels["data_idx"]
            answer_start_idx = labels["answer_start_idx"]
            answer_end_idx = labels["answer_end_idx"]

            output_dict["data_idx"] = data_idx

            loss = self.criterion(span_start_logits, answer_start_idx)
            loss += self.criterion(span_end_logits, answer_end_idx)
            output_dict["loss"] = loss.unsqueeze(0)

        return output_dict
예제 #8
0
파일: sqlnet.py 프로젝트: zzozzolev/claf
 def forward(self, question_embed, question_mask):
     encoded_question, _ = self.question_rnn(question_embed)
     attn_matrix = self.seq_attn(encoded_question, question_mask)
     attn_question = f.weighted_sum(attn_matrix, encoded_question)
     logits = self.mlp(attn_question)
     return logits
예제 #9
0
    def _query2context(self, S, c, c_mask):
        attention = f.masked_softmax(S, c_mask)  # (B, C_L)
        q2c = f.weighted_sum(attention=attention, matrix=c)

        return q2c.unsqueeze(1).expand(c.size())  # (B, C_L, 2d)
예제 #10
0
    def _context2query(self, S, q, q_mask):
        attention = f.last_dim_masked_softmax(S, q_mask)  # (B, C_L, Q_L)
        c2q = f.weighted_sum(attention=attention, matrix=q)  # (B, C_L, 2d)

        return c2q