Esempio n. 1
0
    def forward(self, question_embed, question_mask, column_embed,
                column_name_mask, column_mask):
        B, C_L, N_L, embed_D = list(column_embed.size())

        # Column Encoder
        encoded_column = utils.encode_column(column_embed, column_name_mask,
                                             self.column_rnn)
        encoded_question, _ = self.question_rnn(question_embed)

        if self.column_attention:
            attn_matrix = torch.bmm(
                encoded_column,
                self.linear_attn(encoded_question).transpose(1, 2))
            attn_matrix = f.add_masked_value(attn_matrix,
                                             question_mask.unsqueeze(1),
                                             value=-1e7)
            attn_matrix = F.softmax(attn_matrix, dim=-1)
            attn_question = (encoded_question.unsqueeze(1) *
                             attn_matrix.unsqueeze(3)).sum(2)
        else:
            attn_matrix = self.seq_attn(encoded_question, question_mask)
            attn_question = f.weighted_sum(attn_matrix, encoded_question)
            attn_question = attn_question.unsqueeze(1)

        logits = self.mlp(
            self.linear_question(attn_question) +
            self.linear_column(encoded_column)).squeeze()
        logits = f.add_masked_value(logits, column_mask, value=-1e7)
        return logits
Esempio n. 2
0
    def forward(self,
                context_embed,
                question_embed,
                context_mask=None,
                question_mask=None):
        C, Q = context_embed, question_embed
        B, C_L, Q_L, D = C.size(0), C.size(1), Q.size(1), Q.size(2)

        similarity_matrix_shape = torch.zeros(B, C_L, Q_L,
                                              D)  # (B, C_L, Q_L, D)

        C_ = C.unsqueeze(2).expand_as(similarity_matrix_shape)
        Q_ = Q.unsqueeze(1).expand_as(similarity_matrix_shape)
        C_Q = torch.mul(C_, Q_)

        S = self.W_0(torch.cat([C_, Q_, C_Q], 3)).squeeze(3)  # (B, C_L, Q_L)

        S_question = S
        if question_mask is not None:
            S_question = f.add_masked_value(S_question,
                                            question_mask.unsqueeze(1),
                                            value=-1e7)
        S_q = F.softmax(S_question, 2)  # (B, C_L, Q_L)

        S_context = S.transpose(1, 2)
        if context_mask is not None:
            S_context = f.add_masked_value(S_context,
                                           context_mask.unsqueeze(1),
                                           value=-1e7)
        S_c = F.softmax(S_context, 2)  # (B, Q_L, C_L)

        A = torch.bmm(S_q, Q)  # context2query (B, C_L, D)
        B = torch.bmm(S_q, S_c).bmm(C)  # query2context (B, Q_L, D)
        out = torch.cat([C, A, C * A, C * B], dim=-1)
        return out
Esempio n. 3
0
    def decode_then_output(
        self,
        encoded_used_column,
        encoded_question,
        question_mask,
        decoder_input,
        decoder_hidden=None,
    ):
        B = encoded_used_column.size(0)

        decoder_output, decoder_hidden = self.decoder(
            decoder_input.view(B * self.column_maxlen, -1, self.token_maxlen),
            decoder_hidden)
        decoder_output = decoder_output.contiguous().view(
            B, self.column_maxlen, -1, self.model_dim)
        decoder_output = decoder_output.unsqueeze(3)

        logits = self.mlp(
            self.linear_column(encoded_used_column) +
            self.linear_conds(decoder_output) +
            self.linear_question(encoded_question)).squeeze()
        logits = f.add_masked_value(logits,
                                    question_mask.unsqueeze(1).unsqueeze(1),
                                    value=-1e7)
        return logits, decoder_hidden
Esempio n. 4
0
    def forward(self, question_embed, question_mask, column_embed,
                column_name_mask, col_idx):
        B, C_L, N_L, embed_D = list(column_embed.size())

        # Column Encoder
        encoded_column = utils.encode_column(column_embed, column_name_mask,
                                             self.column_rnn)
        encoded_used_column = utils.filter_used_column(
            encoded_column, col_idx, padding_count=self.column_maxlen)

        encoded_question, _ = self.question_rnn(question_embed)
        if self.column_attention:
            attn_matrix = torch.matmul(
                self.linear_attn(encoded_question).unsqueeze(1),
                encoded_used_column.unsqueeze(3)).squeeze()
            attn_matrix = f.add_masked_value(attn_matrix,
                                             question_mask.unsqueeze(1),
                                             value=-1e7)
            attn_matrix = F.softmax(attn_matrix, dim=-1)
            attn_question = (encoded_question.unsqueeze(1) *
                             attn_matrix.unsqueeze(3)).sum(2)
        else:
            attn_matrix = self.seq_attn(encoded_question, question_mask)
            attn_question = f.weighted_sum(attn_matrix, encoded_question)
            attn_question = attn_question.unsqueeze(1)

        return self.mlp(
            self.linear_question(attn_question) +
            self.linear_column(encoded_used_column)).squeeze()
Esempio n. 5
0
    def _scaled_dot_product(self, query, key, value, mask=None):
        K_D = query.size(-1)

        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(K_D)

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)  # [B, #H, C_L, D]
            scores = f.add_masked_value(scores, mask, value=-1e7)

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        return torch.matmul(attn, value)
Esempio n. 6
0
    def forward(self, context, context_mask, query, query_mask):
        c, c_mask, q, q_mask = context, context_mask, query, query_mask

        S = self._make_similiarity_matrix(c, q)  # (B, C_L, Q_L)
        masked_S = f.add_masked_value(S, query_mask.unsqueeze(1), value=-1e7)

        c2q = self._context2query(S, q, q_mask)
        q2c = self._query2context(masked_S.max(dim=-1)[0], c, c_mask)

        # [h; u˜; h◦u˜; h◦h˜] ~ (B, C_L, 8d)
        G = torch.cat((c, c2q, c * c2q, c * q2c), dim=-1)
        return G
Esempio n. 7
0
def test_add_masked_value():
    a = torch.rand(3, 5)
    a_mask = torch.FloatTensor([
        [1, 1, 1, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1],
    ])

    tensor = f.add_masked_value(a, a_mask, value=100)

    assert tensor[0][3] == 100
    assert tensor[0][4] == 100
    assert tensor[1][2] == 100
    assert tensor[1][3] == 100
    assert tensor[1][4] == 100
Esempio n. 8
0
    def forward(self, x, x_mask, key, key_mask):
        S = self._trilinear(x, key)

        if self.self_attn:
            seq_length = x.size(1)
            diag_mask = self.diag_mask.narrow(0, 0, seq_length).narrow(
                1, 0, seq_length)
            joint_mask = 1 - self._compute_attention_mask(x_mask, key_mask)
            mask = torch.clamp(diag_mask + joint_mask, 0, 1)
            masked_S = S + mask * (-1e7)
            x2key = self._x2key(masked_S, key, key_mask)
            return torch.cat((x, x2key, x * x2key), dim=-1)
        else:
            joint_mask = 1 - self._compute_attention_mask(x_mask, key_mask)
            masked_S = S + joint_mask * (-1e7)
            x2key = self._x2key(masked_S, key, key_mask)

            masked_S = f.add_masked_value(S, key_mask.unsqueeze(1), value=-1e7)
            key2x = self._key2x(masked_S.max(dim=-1)[0], x, x_mask)
            return torch.cat((x, x2key, x * x2key, x * key2x), dim=-1)
Esempio n. 9
0
    def forward(self, features, labels=None):
        """
        * Args:
            features: feature dictionary like below.
                {"feature_name1": {
                     "token_name1": tensor,
                     "toekn_name2": tensor},
                 "feature_name2": ...}

        * Kwargs:
            label: label dictionary like below.
                {"label_name1": tensor,
                 "label_name2": tensor}
                 Do not calculate loss when there is no label. (inference/predict mode)

        * Returns: output_dict (dict) consisting of
            - start_logits: representing unnormalized log probabilities of the span start position.
            - end_logits: representing unnormalized log probabilities of the span end position.
            - best_span: the string from the original passage that the model thinks is the best answer to the question.
            - answer_idx: the question id, mapping with answer
            - loss: A scalar loss to be optimised.
        """

        context = features["context"]
        question = features["question"]

        # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward
        context_seq_config = f.get_sorted_seq_config(context)
        query_seq_config = f.get_sorted_seq_config(question)

        # Embedding Layer (Char + Word -> Contextual)
        query_params = {"frequent_word": {"frequent_tuning": True}}
        context_embed, query_embed = self.token_embedder(
            context,
            question,
            query_params=query_params,
            query_align=self.aligned_query_embedding)

        context_mask = f.get_mask_from_tokens(context).float()
        query_mask = f.get_mask_from_tokens(question).float()

        B, C_L = context_embed.size(0), context_embed.size(1)

        context_embed = self.context_highway(context_embed)
        query_embed = self.query_highway(query_embed)

        context_encoded = f.forward_rnn_with_pack(self.context_contextual_rnn,
                                                  context_embed,
                                                  context_seq_config)
        context_encoded = self.dropout(context_encoded)

        query_encoded = f.forward_rnn_with_pack(self.query_contextual_rnn,
                                                query_embed, query_seq_config)
        query_encoded = self.dropout(query_encoded)

        # Attention Flow Layer
        attention_context_query = self.attention(context_encoded, context_mask,
                                                 query_encoded, query_mask)

        # Modeling Layer
        modeled_context = f.forward_rnn_with_pack(self.modeling_rnn,
                                                  attention_context_query,
                                                  context_seq_config)
        modeled_context = self.dropout(modeled_context)

        M_D = modeled_context.size(-1)

        # Output Layer
        span_start_input = self.dropout(
            torch.cat([attention_context_query, modeled_context],
                      dim=-1))  # (B, C_L, 10d)
        span_start_logits = self.span_start_linear(span_start_input).squeeze(
            -1)  # (B, C_L)
        span_start_probs = f.masked_softmax(span_start_logits, context_mask)

        span_start_representation = f.weighted_sum(attention=span_start_probs,
                                                   matrix=modeled_context)
        tiled_span_start_representation = span_start_representation.unsqueeze(
            1).expand(B, C_L, M_D)

        span_end_representation = torch.cat(
            [
                attention_context_query,
                modeled_context,
                tiled_span_start_representation,
                modeled_context * tiled_span_start_representation,
            ],
            dim=-1,
        )
        encoded_span_end = f.forward_rnn_with_pack(self.output_end_rnn,
                                                   span_end_representation,
                                                   context_seq_config)
        encoded_span_end = self.dropout(encoded_span_end)

        span_end_input = self.dropout(
            torch.cat([attention_context_query, encoded_span_end], dim=-1))
        span_end_logits = self.span_end_linear(span_end_input).squeeze(-1)

        # Masked Value
        span_start_logits = f.add_masked_value(span_start_logits,
                                               context_mask,
                                               value=-1e7)
        span_end_logits = f.add_masked_value(span_end_logits,
                                             context_mask,
                                             value=-1e7)

        # No_Answer Bias
        bias = self.bias.expand(B, 1)
        span_start_logits = torch.cat([span_start_logits, bias], dim=-1)
        span_end_logits = torch.cat([span_end_logits, bias], dim=-1)

        output_dict = {
            "start_logits":
            span_start_logits,
            "end_logits":
            span_end_logits,
            "best_span":
            self.get_best_span(
                span_start_logits[:, :-1],
                span_end_logits[:, :-1],
                answer_maxlen=self.answer_maxlen,  # except no_answer bias
            ),
        }

        if labels:
            answer_idx = labels["answer_idx"]
            answer_start_idx = labels["answer_start_idx"]
            answer_end_idx = labels["answer_end_idx"]
            answerable = labels["answerable"]

            # No_Asnwer Case
            C_L = context_mask.size(1)
            answer_start_idx = answer_start_idx.masked_fill(
                answerable.eq(0), C_L)
            answer_end_idx = answer_end_idx.masked_fill(answerable.eq(0), C_L)

            output_dict["answer_idx"] = answer_idx

            # Loss
            loss = self.criterion(span_start_logits, answer_start_idx)
            loss += self.criterion(span_end_logits, answer_end_idx)
            output_dict["loss"] = loss.unsqueeze(
                0)  # NOTE: DataParallel concat Error

        return output_dict
Esempio n. 10
0
    def forward(self, features, labels=None):
        """
        * Args:
            features: feature dictionary like below.
                {"feature_name1": {
                     "token_name1": tensor,
                     "toekn_name2": tensor},
                 "feature_name2": ...}

        * Kwargs:
            label: label dictionary like below.
                {"label_name1": tensor,
                 "label_name2": tensor}
                 Do not calculate loss when there is no label. (inference/predict mode)

        * Returns: output_dict (dict) consisting of
            - start_logits: representing unnormalized log probabilities of the span start position.
            - end_logits: representing unnormalized log probabilities of the span end position.
            - best_span: the string from the original passage that the model thinks is the best answer to the question.
            - data_idx: the question id, mapping with answer
            - loss: A scalar loss to be optimised.
        """

        context = features["context"]  # aka paragraph
        question = features["question"]

        # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward
        context_seq_config = f.get_sorted_seq_config(context)
        query_seq_config = f.get_sorted_seq_config(question)

        # Embedding
        query_params = {"frequent_word": {"frequent_tuning": True}}
        context_embed, query_embed = self.token_embedder(
            context,
            question,
            query_params=query_params,
            query_align=self.aligned_query_embedding)

        context_mask = f.get_mask_from_tokens(context).float()
        query_mask = f.get_mask_from_tokens(question).float()

        context_embed = self.dropout(context_embed)
        query_embed = self.dropout(query_embed)

        # RNN (LSTM)
        context_encoded = f.forward_rnn_with_pack(self.paragraph_rnn,
                                                  context_embed,
                                                  context_seq_config)
        context_encoded = self.dropout(context_encoded)

        query_encoded = f.forward_rnn_with_pack(
            self.query_rnn, query_embed, query_seq_config)  # (B, Q_L, H*2)
        query_encoded = self.dropout(query_encoded)

        query_attention = self.query_att(query_encoded, query_mask)  # (B, Q_L)
        query_att_sum = f.weighted_sum(query_attention,
                                       query_encoded)  # (B, H*2)

        span_start_logits = self.start_attn(context_encoded, query_att_sum,
                                            context_mask)
        span_end_logits = self.end_attn(context_encoded, query_att_sum,
                                        context_mask)

        # Masked Value
        span_start_logits = f.add_masked_value(span_start_logits,
                                               context_mask,
                                               value=-1e7)
        span_end_logits = f.add_masked_value(span_end_logits,
                                             context_mask,
                                             value=-1e7)

        output_dict = {
            "start_logits":
            span_start_logits,
            "end_logits":
            span_end_logits,
            "best_span":
            self.get_best_span(span_start_logits,
                               span_end_logits,
                               answer_maxlen=self.answer_maxlen),
        }

        if labels:
            data_idx = labels["data_idx"]
            answer_start_idx = labels["answer_start_idx"]
            answer_end_idx = labels["answer_end_idx"]

            output_dict["data_idx"] = data_idx

            loss = self.criterion(span_start_logits, answer_start_idx)
            loss += self.criterion(span_end_logits, answer_end_idx)
            output_dict["loss"] = loss.unsqueeze(0)

        return output_dict
Esempio n. 11
0
    def forward(self, features, labels=None):
        """
        * Args:
            features: feature dictionary like below.
            {"sequence": [0, 3, 4, 1]}

        * Kwargs:
            label: label dictionary like below.
            {"class_idx": 2, "data_idx": 0}
             Do not calculate loss when there is no label. (inference/predict mode)

        * Returns: output_dict (dict) consisting of
            - sequence_embed: embedding vector of the sequence
            - class_logits: representing unnormalized log probabilities of the class.

            - class_idx: target class idx
            - data_idx: data idx
            - loss: a scalar loss to be optimized
        """

        sequence = features["sequence"]

        # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward
        sequence_config = f.get_sorted_seq_config(sequence)

        token_embed = self.token_embedder(sequence)

        token_encodings = f.forward_rnn_with_pack(
            self.encoder, token_embed, sequence_config
        )  # [B, L, encoding_rnn_hidden_dim]

        attention = self.A(token_encodings).transpose(1, 2)  # [B, num_attention_heads, L]

        sequence_mask = f.get_mask_from_tokens(sequence).float()  # [B, L]
        sequence_mask = sequence_mask.unsqueeze(1).expand_as(attention)
        attention = F.softmax(f.add_masked_value(attention, sequence_mask) + 1e-13, dim=2)

        attended_encodings = torch.bmm(
            attention, token_encodings
        )  # [B, num_attention_heads, sequence_embed_dim]
        sequence_embed = self.fully_connected(
            attended_encodings.view(attended_encodings.size(0), -1)
        )  # [B, sequence_embed_dim]

        class_logits = self.classifier(sequence_embed)  # [B, num_classes]

        output_dict = {"sequence_embed": sequence_embed, "class_logits": class_logits}

        if labels:
            class_idx = labels["class_idx"]
            data_idx = labels["data_idx"]

            output_dict["class_idx"] = class_idx
            output_dict["data_idx"] = data_idx

            # Loss
            loss = self.criterion(class_logits, class_idx)
            loss += self.penalty(attention)
            output_dict["loss"] = loss.unsqueeze(0)  # NOTE: DataParallel concat Error

        return output_dict
Esempio n. 12
0
    def forward(self, features, labels=None):
        """
        * Args:
            features: feature dictionary like below.
                {"feature_name1": {
                     "token_name1": tensor,
                     "toekn_name2": tensor},
                 "feature_name2": ...}

        * Kwargs:
            label: label dictionary like below.
                {"label_name1": tensor,
                 "label_name2": tensor}
                 Do not calculate loss when there is no label. (inference/predict mode)

        * Returns: output_dict (dict) consisting of
            - start_logits: representing unnormalized log probabilities of the span start position.
            - end_logits: representing unnormalized log probabilities of the span end position.
            - best_span: the string from the original passage that the model thinks is the best answer to the question.
            - answer_idx: the question id, mapping with answer
            - loss: A scalar loss to be optimised.
        """

        context = features["context"]
        question = features["question"]

        # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward
        context_seq_config = f.get_sorted_seq_config(context)
        query_seq_config = f.get_sorted_seq_config(question)

        # Embedding
        query_params = {"frequent_word": {"frequent_tuning": True}}
        context_embed, query_embed = self.token_embedder(
            context,
            question,
            query_params=query_params,
            query_align=self.aligned_query_embedding)

        context_mask = f.get_mask_from_tokens(context).float()  # B X 1 X C_L
        query_mask = f.get_mask_from_tokens(question).float()  # B X 1 X Q_L

        # Pre-process
        context_embed = self.dropout(context_embed)
        context_encoded = f.forward_rnn_with_pack(self.context_preprocess_rnn,
                                                  context_embed,
                                                  context_seq_config)
        context_encoded = self.dropout(context_encoded)

        query_embed = self.dropout(query_embed)
        query_encoded = f.forward_rnn_with_pack(self.query_preprocess_rnn,
                                                query_embed, query_seq_config)
        query_encoded = self.dropout(query_encoded)

        # Attention -> Projection
        context_attnded = self.bi_attention(context_encoded, context_mask,
                                            query_encoded, query_mask)
        context_attnded = self.activation_fn(
            self.attn_linear(context_attnded))  # B X C_L X dim*2

        # Residual Self-Attention
        context_attnded = self.dropout(context_attnded)
        context_encoded = f.forward_rnn_with_pack(self.modeling_rnn,
                                                  context_attnded,
                                                  context_seq_config)
        context_encoded = self.dropout(context_encoded)

        context_self_attnded = self.self_attention(
            context_encoded, context_mask)  # B X C_L X dim*2
        context_final = self.dropout(context_attnded +
                                     context_self_attnded)  # B X C_L X dim*2

        # Prediction
        span_start_input = f.forward_rnn_with_pack(
            self.span_start_rnn, context_final,
            context_seq_config)  # B X C_L X dim*2
        span_start_input = self.dropout(span_start_input)
        span_start_logits = self.span_start_linear(span_start_input).squeeze(
            -1)  # B X C_L

        span_end_input = torch.cat([span_start_input, context_final],
                                   dim=-1)  # B X C_L X dim*4
        span_end_input = f.forward_rnn_with_pack(
            self.span_end_rnn, span_end_input,
            context_seq_config)  # B X C_L X dim*2
        span_end_input = self.dropout(span_end_input)
        span_end_logits = self.span_end_linear(span_end_input).squeeze(
            -1)  # B X C_L

        # Masked Value
        span_start_logits = f.add_masked_value(span_start_logits,
                                               context_mask,
                                               value=-1e7)
        span_end_logits = f.add_masked_value(span_end_logits,
                                             context_mask,
                                             value=-1e7)

        output_dict = {
            "start_logits":
            span_start_logits,
            "end_logits":
            span_end_logits,
            "best_span":
            self.get_best_span(span_start_logits,
                               span_end_logits,
                               answer_maxlen=self.answer_maxlen),
        }

        if labels:
            answer_idx = labels["answer_idx"]
            answer_start_idx = labels["answer_start_idx"]
            answer_end_idx = labels["answer_end_idx"]

            output_dict["answer_idx"] = answer_idx

            # Loss
            loss = self.criterion(span_start_logits, answer_start_idx)
            loss += self.criterion(span_end_logits, answer_end_idx)
            output_dict["loss"] = loss.unsqueeze(
                0)  # NOTE: DataParallel concat Error

        return output_dict
Esempio n. 13
0
    def forward(self, features, labels=None):
        """
            * Args:
                features: feature dictionary like below.
                    {"feature_name1": {
                         "token_name1": tensor,
                         "toekn_name2": tensor},
                     "feature_name2": ...}

            * Kwargs:
                label: label dictionary like below.
                    {"label_name1": tensor,
                     "label_name2": tensor}
                     Do not calculate loss when there is no label. (inference/predict mode)

            * Returns: output_dict (dict) consisting of
                - start_logits: representing unnormalized log probabilities of the span start position.
                - end_logits: representing unnormalized log probabilities of the span end position.
                - best_span: the string from the original passage that the model thinks is the best answer to the question.
                - data_idx: the question id, mapping with answer
                - loss: A scalar loss to be optimised.
        """

        context = features["context"]
        question = features["question"]

        # 1. Input Embedding Layer
        query_params = {"frequent_word": {"frequent_tuning": True}}
        context_embed, query_embed = self.token_embedder(
            context,
            question,
            query_params=query_params,
            query_align=self.aligned_query_embedding)

        context_mask = f.get_mask_from_tokens(context).float()
        query_mask = f.get_mask_from_tokens(question).float()

        context_embed = self.context_highway(context_embed)
        context_embed = self.dropout(context_embed)
        context_embed = self.context_embed_pointwise_conv(context_embed)

        query_embed = self.query_highway(query_embed)
        query_embed = self.dropout(query_embed)
        query_embed = self.query_embed_pointwise_conv(query_embed)

        # 2. Embedding Encoder Layer
        for encoder_block in self.embed_encoder_blocks:
            context = encoder_block(context_embed)
            context_embed = context

            query = encoder_block(query_embed)
            query_embed = query

        # 3. Context-Query Attention Layer
        context_query_attention = self.co_attention(context, query,
                                                    context_mask, query_mask)

        # Projection (memory issue)
        context_query_attention = self.pointwise_conv(context_query_attention)
        context_query_attention = self.dropout(context_query_attention)

        # 4. Model Encoder Layer
        model_encoder_block_inputs = context_query_attention

        # Stacked Model Encoder Block
        stacked_model_encoder_blocks = []
        for i in range(3):
            for _, model_encoder_block in enumerate(self.model_encoder_blocks):
                output = model_encoder_block(model_encoder_block_inputs,
                                             context_mask)
                model_encoder_block_inputs = output

            stacked_model_encoder_blocks.append(output)

        # 5. Output Layer
        span_start_inputs = torch.cat(
            [stacked_model_encoder_blocks[0], stacked_model_encoder_blocks[1]],
            dim=-1)
        span_start_inputs = self.dropout(span_start_inputs)
        span_start_logits = self.span_start_linear(span_start_inputs).squeeze(
            -1)

        span_end_inputs = torch.cat(
            [stacked_model_encoder_blocks[0], stacked_model_encoder_blocks[2]],
            dim=-1)
        span_end_inputs = self.dropout(span_end_inputs)
        span_end_logits = self.span_end_linear(span_end_inputs).squeeze(-1)

        # Masked Value
        span_start_logits = f.add_masked_value(span_start_logits,
                                               context_mask,
                                               value=-1e7)
        span_end_logits = f.add_masked_value(span_end_logits,
                                             context_mask,
                                             value=-1e7)

        output_dict = {
            "start_logits": span_start_logits,
            "end_logits": span_end_logits,
            "best_span": self.get_best_span(span_start_logits,
                                            span_end_logits),
        }

        if labels:
            data_idx = labels["data_idx"]
            answer_start_idx = labels["answer_start_idx"]
            answer_end_idx = labels["answer_end_idx"]

            output_dict["data_idx"] = data_idx

            # Loss
            loss = self.criterion(span_start_logits, answer_start_idx)
            loss += self.criterion(span_end_logits, answer_end_idx)
            output_dict["loss"] = loss.unsqueeze(
                0)  # NOTE: DataParallel concat Error

        return output_dict