Ejemplo n.º 1
0
def test_last_dim_masked_softmax_with_2_dim():
    tensor = torch.FloatTensor([
            [2, 3, 1, 0, 0],
            [4, 1, 0, 0, 0],
            [1, 5, 2, 4, 1],
        ])
    mask = f.get_mask_from_tokens({"word": tensor}).float()

    result = f.last_dim_masked_softmax(tensor, mask)
    assert result.argmax(dim=-1).equal(torch.LongTensor([1, 0, 1]))
Ejemplo n.º 2
0
def test_get_mask_from_tokens_with_2_dim():
    tokens = {
        "word" : torch.LongTensor([
            [1, 1, 1, 0, 0],
            [1, 1, 0, 0, 0],
            [1, 1, 1, 1, 1],
        ]),
    }

    mask = f.get_mask_from_tokens(tokens)
    print(mask)
    assert mask.equal(tokens["word"])
Ejemplo n.º 3
0
def test_get_mask_from_tokens_with_3_dim():
    tokens = {
        "char" : torch.LongTensor([
            [[4, 2], [3, 6], [0, 0]],
            [[5, 1], [0, 0], [0, 0]],
            [[1, 3], [2, 4], [3, 6]],
        ]),
    }

    mask = f.get_mask_from_tokens(tokens)
    expect_tensor = torch.LongTensor([
        [1, 1, 0],
        [1, 0, 0],
        [1, 1, 1],
    ])
    assert mask.equal(expect_tensor)
Ejemplo n.º 4
0
    def forward(self, features, labels=None):
        """
        * Args:
            features: feature dictionary like below.
                {"feature_name1": {
                     "token_name1": tensor,
                     "toekn_name2": tensor},
                 "feature_name2": ...}

        * Kwargs:
            label: label dictionary like below.
                {"label_name1": tensor,
                 "label_name2": tensor}
                 Do not calculate loss when there is no label. (inference/predict mode)

        * Returns: output_dict (dict) consisting of
            - start_logits: representing unnormalized log probabilities of the span start position.
            - end_logits: representing unnormalized log probabilities of the span end position.
            - best_span: the string from the original passage that the model thinks is the best answer to the question.
            - answer_idx: the question id, mapping with answer
            - loss: A scalar loss to be optimised.
        """

        context = features["context"]
        question = features["question"]

        # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward
        context_seq_config = f.get_sorted_seq_config(context)
        query_seq_config = f.get_sorted_seq_config(question)

        # Embedding Layer (Char + Word -> Contextual)
        query_params = {"frequent_word": {"frequent_tuning": True}}
        context_embed, query_embed = self.token_embedder(
            context,
            question,
            query_params=query_params,
            query_align=self.aligned_query_embedding)

        context_mask = f.get_mask_from_tokens(context).float()
        query_mask = f.get_mask_from_tokens(question).float()

        B, C_L = context_embed.size(0), context_embed.size(1)

        context_embed = self.context_highway(context_embed)
        query_embed = self.query_highway(query_embed)

        context_encoded = f.forward_rnn_with_pack(self.context_contextual_rnn,
                                                  context_embed,
                                                  context_seq_config)
        context_encoded = self.dropout(context_encoded)

        query_encoded = f.forward_rnn_with_pack(self.query_contextual_rnn,
                                                query_embed, query_seq_config)
        query_encoded = self.dropout(query_encoded)

        # Attention Flow Layer
        attention_context_query = self.attention(context_encoded, context_mask,
                                                 query_encoded, query_mask)

        # Modeling Layer
        modeled_context = f.forward_rnn_with_pack(self.modeling_rnn,
                                                  attention_context_query,
                                                  context_seq_config)
        modeled_context = self.dropout(modeled_context)

        M_D = modeled_context.size(-1)

        # Output Layer
        span_start_input = self.dropout(
            torch.cat([attention_context_query, modeled_context],
                      dim=-1))  # (B, C_L, 10d)
        span_start_logits = self.span_start_linear(span_start_input).squeeze(
            -1)  # (B, C_L)
        span_start_probs = f.masked_softmax(span_start_logits, context_mask)

        span_start_representation = f.weighted_sum(attention=span_start_probs,
                                                   matrix=modeled_context)
        tiled_span_start_representation = span_start_representation.unsqueeze(
            1).expand(B, C_L, M_D)

        span_end_representation = torch.cat(
            [
                attention_context_query,
                modeled_context,
                tiled_span_start_representation,
                modeled_context * tiled_span_start_representation,
            ],
            dim=-1,
        )
        encoded_span_end = f.forward_rnn_with_pack(self.output_end_rnn,
                                                   span_end_representation,
                                                   context_seq_config)
        encoded_span_end = self.dropout(encoded_span_end)

        span_end_input = self.dropout(
            torch.cat([attention_context_query, encoded_span_end], dim=-1))
        span_end_logits = self.span_end_linear(span_end_input).squeeze(-1)

        # Masked Value
        span_start_logits = f.add_masked_value(span_start_logits,
                                               context_mask,
                                               value=-1e7)
        span_end_logits = f.add_masked_value(span_end_logits,
                                             context_mask,
                                             value=-1e7)

        # No_Answer Bias
        bias = self.bias.expand(B, 1)
        span_start_logits = torch.cat([span_start_logits, bias], dim=-1)
        span_end_logits = torch.cat([span_end_logits, bias], dim=-1)

        output_dict = {
            "start_logits":
            span_start_logits,
            "end_logits":
            span_end_logits,
            "best_span":
            self.get_best_span(
                span_start_logits[:, :-1],
                span_end_logits[:, :-1],
                answer_maxlen=self.answer_maxlen,  # except no_answer bias
            ),
        }

        if labels:
            answer_idx = labels["answer_idx"]
            answer_start_idx = labels["answer_start_idx"]
            answer_end_idx = labels["answer_end_idx"]
            answerable = labels["answerable"]

            # No_Asnwer Case
            C_L = context_mask.size(1)
            answer_start_idx = answer_start_idx.masked_fill(
                answerable.eq(0), C_L)
            answer_end_idx = answer_end_idx.masked_fill(answerable.eq(0), C_L)

            output_dict["answer_idx"] = answer_idx

            # Loss
            loss = self.criterion(span_start_logits, answer_start_idx)
            loss += self.criterion(span_end_logits, answer_end_idx)
            output_dict["loss"] = loss.unsqueeze(
                0)  # NOTE: DataParallel concat Error

        return output_dict
Ejemplo n.º 5
0
    def forward(self, features, labels=None):
        """
        * Args:
            features: feature dictionary like below.
                {"feature_name1": {
                     "token_name1": tensor,
                     "toekn_name2": tensor},
                 "feature_name2": ...}

        * Kwargs:
            label: label dictionary like below.
                {"label_name1": tensor,
                 "label_name2": tensor}
                 Do not calculate loss when there is no label. (inference/predict mode)

        * Returns: output_dict (dict) consisting of
            - start_logits: representing unnormalized log probabilities of the span start position.
            - end_logits: representing unnormalized log probabilities of the span end position.
            - best_span: the string from the original passage that the model thinks is the best answer to the question.
            - data_idx: the question id, mapping with answer
            - loss: A scalar loss to be optimised.
        """

        context = features["context"]  # aka paragraph
        question = features["question"]

        # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward
        context_seq_config = f.get_sorted_seq_config(context)
        query_seq_config = f.get_sorted_seq_config(question)

        # Embedding
        query_params = {"frequent_word": {"frequent_tuning": True}}
        context_embed, query_embed = self.token_embedder(
            context,
            question,
            query_params=query_params,
            query_align=self.aligned_query_embedding)

        context_mask = f.get_mask_from_tokens(context).float()
        query_mask = f.get_mask_from_tokens(question).float()

        context_embed = self.dropout(context_embed)
        query_embed = self.dropout(query_embed)

        # RNN (LSTM)
        context_encoded = f.forward_rnn_with_pack(self.paragraph_rnn,
                                                  context_embed,
                                                  context_seq_config)
        context_encoded = self.dropout(context_encoded)

        query_encoded = f.forward_rnn_with_pack(
            self.query_rnn, query_embed, query_seq_config)  # (B, Q_L, H*2)
        query_encoded = self.dropout(query_encoded)

        query_attention = self.query_att(query_encoded, query_mask)  # (B, Q_L)
        query_att_sum = f.weighted_sum(query_attention,
                                       query_encoded)  # (B, H*2)

        span_start_logits = self.start_attn(context_encoded, query_att_sum,
                                            context_mask)
        span_end_logits = self.end_attn(context_encoded, query_att_sum,
                                        context_mask)

        # Masked Value
        span_start_logits = f.add_masked_value(span_start_logits,
                                               context_mask,
                                               value=-1e7)
        span_end_logits = f.add_masked_value(span_end_logits,
                                             context_mask,
                                             value=-1e7)

        output_dict = {
            "start_logits":
            span_start_logits,
            "end_logits":
            span_end_logits,
            "best_span":
            self.get_best_span(span_start_logits,
                               span_end_logits,
                               answer_maxlen=self.answer_maxlen),
        }

        if labels:
            data_idx = labels["data_idx"]
            answer_start_idx = labels["answer_start_idx"]
            answer_end_idx = labels["answer_end_idx"]

            output_dict["data_idx"] = data_idx

            loss = self.criterion(span_start_logits, answer_start_idx)
            loss += self.criterion(span_end_logits, answer_end_idx)
            output_dict["loss"] = loss.unsqueeze(0)

        return output_dict
Ejemplo n.º 6
0
    def forward(self, features, labels=None):
        """
        * Args:
            features: feature dictionary like below.
            {"sequence": [0, 3, 4, 1]}

        * Kwargs:
            label: label dictionary like below.
            {"class_idx": 2, "data_idx": 0}
             Do not calculate loss when there is no label. (inference/predict mode)

        * Returns: output_dict (dict) consisting of
            - sequence_embed: embedding vector of the sequence
            - class_logits: representing unnormalized log probabilities of the class.

            - class_idx: target class idx
            - data_idx: data idx
            - loss: a scalar loss to be optimized
        """

        sequence = features["sequence"]

        # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward
        sequence_config = f.get_sorted_seq_config(sequence)

        token_embed = self.token_embedder(sequence)

        token_encodings = f.forward_rnn_with_pack(
            self.encoder, token_embed, sequence_config
        )  # [B, L, encoding_rnn_hidden_dim]

        attention = self.A(token_encodings).transpose(1, 2)  # [B, num_attention_heads, L]

        sequence_mask = f.get_mask_from_tokens(sequence).float()  # [B, L]
        sequence_mask = sequence_mask.unsqueeze(1).expand_as(attention)
        attention = F.softmax(f.add_masked_value(attention, sequence_mask) + 1e-13, dim=2)

        attended_encodings = torch.bmm(
            attention, token_encodings
        )  # [B, num_attention_heads, sequence_embed_dim]
        sequence_embed = self.fully_connected(
            attended_encodings.view(attended_encodings.size(0), -1)
        )  # [B, sequence_embed_dim]

        class_logits = self.classifier(sequence_embed)  # [B, num_classes]

        output_dict = {"sequence_embed": sequence_embed, "class_logits": class_logits}

        if labels:
            class_idx = labels["class_idx"]
            data_idx = labels["data_idx"]

            output_dict["class_idx"] = class_idx
            output_dict["data_idx"] = data_idx

            # Loss
            loss = self.criterion(class_logits, class_idx)
            loss += self.penalty(attention)
            output_dict["loss"] = loss.unsqueeze(0)  # NOTE: DataParallel concat Error

        return output_dict
Ejemplo n.º 7
0
    def forward(self, features, labels=None):
        """
        * Args:
            features: feature dictionary like below.
                {"feature_name1": {
                     "token_name1": tensor,
                     "toekn_name2": tensor},
                 "feature_name2": ...}

        * Kwargs:
            label: label dictionary like below.
                {"label_name1": tensor,
                 "label_name2": tensor}
                 Do not calculate loss when there is no label. (inference/predict mode)

        * Returns: output_dict (dict) consisting of
            - start_logits: representing unnormalized log probabilities of the span start position.
            - end_logits: representing unnormalized log probabilities of the span end position.
            - best_span: the string from the original passage that the model thinks is the best answer to the question.
            - answer_idx: the question id, mapping with answer
            - loss: A scalar loss to be optimised.
        """

        context = features["context"]
        question = features["question"]

        # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward
        context_seq_config = f.get_sorted_seq_config(context)
        query_seq_config = f.get_sorted_seq_config(question)

        # Embedding
        query_params = {"frequent_word": {"frequent_tuning": True}}
        context_embed, query_embed = self.token_embedder(
            context,
            question,
            query_params=query_params,
            query_align=self.aligned_query_embedding)

        context_mask = f.get_mask_from_tokens(context).float()  # B X 1 X C_L
        query_mask = f.get_mask_from_tokens(question).float()  # B X 1 X Q_L

        # Pre-process
        context_embed = self.dropout(context_embed)
        context_encoded = f.forward_rnn_with_pack(self.context_preprocess_rnn,
                                                  context_embed,
                                                  context_seq_config)
        context_encoded = self.dropout(context_encoded)

        query_embed = self.dropout(query_embed)
        query_encoded = f.forward_rnn_with_pack(self.query_preprocess_rnn,
                                                query_embed, query_seq_config)
        query_encoded = self.dropout(query_encoded)

        # Attention -> Projection
        context_attnded = self.bi_attention(context_encoded, context_mask,
                                            query_encoded, query_mask)
        context_attnded = self.activation_fn(
            self.attn_linear(context_attnded))  # B X C_L X dim*2

        # Residual Self-Attention
        context_attnded = self.dropout(context_attnded)
        context_encoded = f.forward_rnn_with_pack(self.modeling_rnn,
                                                  context_attnded,
                                                  context_seq_config)
        context_encoded = self.dropout(context_encoded)

        context_self_attnded = self.self_attention(
            context_encoded, context_mask)  # B X C_L X dim*2
        context_final = self.dropout(context_attnded +
                                     context_self_attnded)  # B X C_L X dim*2

        # Prediction
        span_start_input = f.forward_rnn_with_pack(
            self.span_start_rnn, context_final,
            context_seq_config)  # B X C_L X dim*2
        span_start_input = self.dropout(span_start_input)
        span_start_logits = self.span_start_linear(span_start_input).squeeze(
            -1)  # B X C_L

        span_end_input = torch.cat([span_start_input, context_final],
                                   dim=-1)  # B X C_L X dim*4
        span_end_input = f.forward_rnn_with_pack(
            self.span_end_rnn, span_end_input,
            context_seq_config)  # B X C_L X dim*2
        span_end_input = self.dropout(span_end_input)
        span_end_logits = self.span_end_linear(span_end_input).squeeze(
            -1)  # B X C_L

        # Masked Value
        span_start_logits = f.add_masked_value(span_start_logits,
                                               context_mask,
                                               value=-1e7)
        span_end_logits = f.add_masked_value(span_end_logits,
                                             context_mask,
                                             value=-1e7)

        output_dict = {
            "start_logits":
            span_start_logits,
            "end_logits":
            span_end_logits,
            "best_span":
            self.get_best_span(span_start_logits,
                               span_end_logits,
                               answer_maxlen=self.answer_maxlen),
        }

        if labels:
            answer_idx = labels["answer_idx"]
            answer_start_idx = labels["answer_start_idx"]
            answer_end_idx = labels["answer_end_idx"]

            output_dict["answer_idx"] = answer_idx

            # Loss
            loss = self.criterion(span_start_logits, answer_start_idx)
            loss += self.criterion(span_end_logits, answer_end_idx)
            output_dict["loss"] = loss.unsqueeze(
                0)  # NOTE: DataParallel concat Error

        return output_dict
Ejemplo n.º 8
0
    def forward(self, features, labels=None):
        """
            * Args:
                features: feature dictionary like below.
                    {"feature_name1": {
                         "token_name1": tensor,
                         "toekn_name2": tensor},
                     "feature_name2": ...}

            * Kwargs:
                label: label dictionary like below.
                    {"label_name1": tensor,
                     "label_name2": tensor}
                     Do not calculate loss when there is no label. (inference/predict mode)

            * Returns: output_dict (dict) consisting of
                - start_logits: representing unnormalized log probabilities of the span start position.
                - end_logits: representing unnormalized log probabilities of the span end position.
                - best_span: the string from the original passage that the model thinks is the best answer to the question.
                - data_idx: the question id, mapping with answer
                - loss: A scalar loss to be optimised.
        """

        context = features["context"]
        question = features["question"]

        # 1. Input Embedding Layer
        query_params = {"frequent_word": {"frequent_tuning": True}}
        context_embed, query_embed = self.token_embedder(
            context,
            question,
            query_params=query_params,
            query_align=self.aligned_query_embedding)

        context_mask = f.get_mask_from_tokens(context).float()
        query_mask = f.get_mask_from_tokens(question).float()

        context_embed = self.context_highway(context_embed)
        context_embed = self.dropout(context_embed)
        context_embed = self.context_embed_pointwise_conv(context_embed)

        query_embed = self.query_highway(query_embed)
        query_embed = self.dropout(query_embed)
        query_embed = self.query_embed_pointwise_conv(query_embed)

        # 2. Embedding Encoder Layer
        for encoder_block in self.embed_encoder_blocks:
            context = encoder_block(context_embed)
            context_embed = context

            query = encoder_block(query_embed)
            query_embed = query

        # 3. Context-Query Attention Layer
        context_query_attention = self.co_attention(context, query,
                                                    context_mask, query_mask)

        # Projection (memory issue)
        context_query_attention = self.pointwise_conv(context_query_attention)
        context_query_attention = self.dropout(context_query_attention)

        # 4. Model Encoder Layer
        model_encoder_block_inputs = context_query_attention

        # Stacked Model Encoder Block
        stacked_model_encoder_blocks = []
        for i in range(3):
            for _, model_encoder_block in enumerate(self.model_encoder_blocks):
                output = model_encoder_block(model_encoder_block_inputs,
                                             context_mask)
                model_encoder_block_inputs = output

            stacked_model_encoder_blocks.append(output)

        # 5. Output Layer
        span_start_inputs = torch.cat(
            [stacked_model_encoder_blocks[0], stacked_model_encoder_blocks[1]],
            dim=-1)
        span_start_inputs = self.dropout(span_start_inputs)
        span_start_logits = self.span_start_linear(span_start_inputs).squeeze(
            -1)

        span_end_inputs = torch.cat(
            [stacked_model_encoder_blocks[0], stacked_model_encoder_blocks[2]],
            dim=-1)
        span_end_inputs = self.dropout(span_end_inputs)
        span_end_logits = self.span_end_linear(span_end_inputs).squeeze(-1)

        # Masked Value
        span_start_logits = f.add_masked_value(span_start_logits,
                                               context_mask,
                                               value=-1e7)
        span_end_logits = f.add_masked_value(span_end_logits,
                                             context_mask,
                                             value=-1e7)

        output_dict = {
            "start_logits": span_start_logits,
            "end_logits": span_end_logits,
            "best_span": self.get_best_span(span_start_logits,
                                            span_end_logits),
        }

        if labels:
            data_idx = labels["data_idx"]
            answer_start_idx = labels["answer_start_idx"]
            answer_end_idx = labels["answer_end_idx"]

            output_dict["data_idx"] = data_idx

            # Loss
            loss = self.criterion(span_start_logits, answer_start_idx)
            loss += self.criterion(span_end_logits, answer_end_idx)
            output_dict["loss"] = loss.unsqueeze(
                0)  # NOTE: DataParallel concat Error

        return output_dict
Ejemplo n.º 9
0
    def forward(self, features, labels=None):
        column = features["column"]
        question = features["question"]

        column_embed = self.token_embedder(column)
        question_embed = self.token_embedder(question)

        B, C_L = column_embed.size(0), column_embed.size(1)

        column_indexed = column[next(iter(column))]
        column_name_mask = column_indexed.gt(0).float()  # NOTE: hard-code
        column_lengths = utils.get_column_lengths(column_embed,
                                                  column_name_mask)
        column_mask = column_lengths.view(B,
                                          C_L).gt(0).float()  # NOTE: hard-code
        question_mask = f.get_mask_from_tokens(question).float()

        agg_logits = self.agg_predictor(question_embed, question_mask)
        sel_logits = self.sel_predictor(question_embed, question_mask,
                                        column_embed, column_name_mask,
                                        column_mask)

        conds_col_idx, conds_val_pos = None, None
        if labels:
            data_idx = labels["data_idx"]
            ground_truths = self._dataset.get_ground_truths(data_idx)

            conds_col_idx = [
                ground_truth["conds_col"] for ground_truth in ground_truths
            ]
            conds_val_pos = [
                ground_truth["conds_val_pos"] for ground_truth in ground_truths
            ]

        conds_logits = self.conds_predictor(
            question_embed,
            question_mask,
            column_embed,
            column_name_mask,
            column_mask,
            conds_col_idx,
            conds_val_pos,
        )

        # Convert GPU to CPU
        agg_logits = agg_logits.cpu()
        sel_logits = sel_logits.cpu()
        conds_logits = [logits.cpu() for logits in conds_logits]

        output_dict = {
            "agg_logits": agg_logits,
            "sel_logits": sel_logits,
            "conds_logits": conds_logits,
        }

        if labels:
            data_idx = labels["data_idx"]
            output_dict["data_id"] = data_idx

            ground_truths = self._dataset.get_ground_truths(data_idx)

            # Aggregator, Select Column
            target_agg_idx = torch.LongTensor(
                [ground_truth["agg_idx"] for ground_truth in ground_truths])
            target_sel_idx = torch.LongTensor(
                [ground_truth["sel_idx"] for ground_truth in ground_truths])

            loss = 0
            loss += self.cross_entropy(agg_logits, target_agg_idx)
            loss += self.cross_entropy(sel_logits, target_sel_idx)

            conds_num_logits, conds_column_logits, conds_op_logits, conds_value_logits = (
                conds_logits)

            # Conditions
            # 1. The number of conditions
            target_conds_num = torch.LongTensor(
                [ground_truth["conds_num"] for ground_truth in ground_truths])
            target_conds_column = [
                ground_truth["conds_col"] for ground_truth in ground_truths
            ]

            loss += self.cross_entropy(conds_num_logits, target_conds_num)

            # 2. Columns of conditions
            B = conds_column_logits.size(0)

            target_conds_columns = np.zeros(list(conds_column_logits.size()),
                                            dtype=np.float32)
            for i in range(B):
                target_conds_column_idx = target_conds_column[i]
                if len(target_conds_column_idx) == 0:
                    continue
                target_conds_columns[i][target_conds_column_idx] = 1
            target_conds_columns = torch.from_numpy(target_conds_columns)
            conds_column_probs = torch.sigmoid(conds_column_logits)

            bce_loss = -torch.mean(self.conds_column_loss_alpha *
                                   (target_conds_columns *
                                    torch.log(conds_column_probs + 1e-10)) +
                                   (1 - target_conds_columns) *
                                   torch.log(1 - conds_column_probs + 1e-10))
            loss += bce_loss

            # 3. Operator of conditions
            conds_op_loss = 0
            for i in range(B):
                target_conds_op = ground_truths[i]["conds_op"]
                if len(target_conds_op) == 0:
                    continue

                target_conds_op = torch.from_numpy(np.array(target_conds_op))
                logits_conds_op = conds_op_logits[i, :len(target_conds_op)]

                target_op_count = len(target_conds_op)
                conds_op_loss += (
                    self.cross_entropy(logits_conds_op, target_conds_op) /
                    target_op_count)
            loss += conds_op_loss

            # 4. Value of conditions
            conds_val_pos = [
                ground_truth["conds_val_pos"] for ground_truth in ground_truths
            ]

            conds_value_loss = 0
            for i in range(B):
                for j in range(len(conds_val_pos[i])):
                    cond_val_pos = conds_val_pos[i][j]
                    if len(cond_val_pos) == 1:
                        continue

                    target_cond_val_pos = torch.from_numpy(
                        np.array(cond_val_pos[1:]))  # index 0: START_TOKEN
                    logits_cond_val_pos = conds_value_logits[
                        i, j, :len(cond_val_pos) - 1]

                    conds_value_loss += self.cross_entropy(
                        logits_cond_val_pos, target_cond_val_pos) / len(
                            conds_val_pos[i])

            loss += conds_value_loss / B

            output_dict["loss"] = loss.unsqueeze(0)

        return output_dict
Ejemplo n.º 10
0
    def forward(self,
                context,
                query,
                context_params={},
                query_params={},
                query_align=False):
        """
        * Args:
            context: context inputs (eg. {"token_name1": tensor, "token_name2": tensor, ...})
            query: query inputs (eg. {"token_name1": tensor, "token_name2": tensor, ...})

        * Kwargs:
            context_params: custom context parameters
            query_params: query context parameters
            query_align: f_align(p_i) = sum(a_ij, E(qj), where the attention score a_ij
                captures the similarity between pi and each question words q_j.
                these features add soft alignments between similar but non-identical words (e.g., car and vehicle)
                it only apply to 'context_embed'.
        """

        if set(self.token_names) != set(context.keys()):
            raise ValueError(
                f"Mismatch token_names  inputs: {context.keys()}, embeddings: {self.token_names}"
            )

        context_tokens, query_tokens = {}, {}
        for token_name, context_tensors in context.items():
            embedding = getattr(self, token_name)

            context_tokens[token_name] = embedding(
                context_tensors, **context_params.get(token_name, {}))
            if token_name in query:
                query_tokens[token_name] = embedding(
                    query[token_name], **query_params.get(token_name, {}))

        # query_align_embedding
        if query_align:
            common_context = self._filter(context_tokens, exclusive=False)
            embedded_common_context = torch.cat(list(common_context.values()),
                                                dim=-1)
            exclusive_context = self._filter(context_tokens, exclusive=True)

            embedded_exclusive_context = None
            if exclusive_context != {}:
                embedded_exclusive_context = torch.cat(list(
                    exclusive_context.values()),
                                                       dim=-1)

            query_mask = f.get_mask_from_tokens(query_tokens)
            embedded_query = torch.cat(list(query_tokens.values()), dim=-1)

            embedded_aligned_query = self.align_attention(
                embedded_common_context, embedded_query, query_mask)

            # Merge context embedded
            embedded_context = [
                embedded_common_context, embedded_aligned_query
            ]
            if embedded_exclusive_context is not None:
                embedded_context.append(embedded_exclusive_context)

            context_output = torch.cat(embedded_context, dim=-1)
            query_output = embedded_query
        else:
            context_output = torch.cat(list(context_tokens.values()), dim=-1)
            query_output = torch.cat(list(query_tokens.values()), dim=-1)

        return context_output, query_output