コード例 #1
0
ファイル: test_functional.py プロジェクト: zzozzolev/claf
def test_get_sorted_seq_config():
    tensor = torch.LongTensor([
            [2, 3, 1, 0, 0],
            [4, 1, 0, 0, 0],
            [1, 5, 2, 4, 1],
        ])

    seq_config = f.get_sorted_seq_config({"word": tensor})
    assert seq_config["seq_lengths"].tolist() == [5, 3, 2]
    assert seq_config["perm_idx"].tolist() == [2, 0, 1]
    assert seq_config["unperm_idx"].tolist() == [1, 2, 0]
コード例 #2
0
ファイル: test_functional.py プロジェクト: zzozzolev/claf
def test_forward_rnn_with_pack():
    tensor = torch.LongTensor([
            [2, 3, 1, 0, 0],
            [4, 1, 0, 0, 0],
            [1, 5, 2, 4, 1],
        ])
    matrix = torch.rand(10, 10)
    embedded_tensor = torch.nn.functional.embedding(tensor, matrix)

    seq_config = f.get_sorted_seq_config({"word": tensor})

    gru = torch.nn.GRU(input_size=10, hidden_size=1, bidirectional=False, batch_first=True)
    encoded_tensor = f.forward_rnn_with_pack(gru, embedded_tensor, seq_config)
    assert encoded_tensor[0][3] == 0
    assert encoded_tensor[0][4] == 0
    assert encoded_tensor[1][2] == 0
    assert encoded_tensor[1][3] == 0
    assert encoded_tensor[1][4] == 0
コード例 #3
0
ファイル: bidaf_no_answer.py プロジェクト: seongl/claf
    def forward(self, features, labels=None):
        """
        * Args:
            features: feature dictionary like below.
                {"feature_name1": {
                     "token_name1": tensor,
                     "toekn_name2": tensor},
                 "feature_name2": ...}

        * Kwargs:
            label: label dictionary like below.
                {"label_name1": tensor,
                 "label_name2": tensor}
                 Do not calculate loss when there is no label. (inference/predict mode)

        * Returns: output_dict (dict) consisting of
            - start_logits: representing unnormalized log probabilities of the span start position.
            - end_logits: representing unnormalized log probabilities of the span end position.
            - best_span: the string from the original passage that the model thinks is the best answer to the question.
            - answer_idx: the question id, mapping with answer
            - loss: A scalar loss to be optimised.
        """

        context = features["context"]
        question = features["question"]

        # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward
        context_seq_config = f.get_sorted_seq_config(context)
        query_seq_config = f.get_sorted_seq_config(question)

        # Embedding Layer (Char + Word -> Contextual)
        query_params = {"frequent_word": {"frequent_tuning": True}}
        context_embed, query_embed = self.token_embedder(
            context,
            question,
            query_params=query_params,
            query_align=self.aligned_query_embedding)

        context_mask = f.get_mask_from_tokens(context).float()
        query_mask = f.get_mask_from_tokens(question).float()

        B, C_L = context_embed.size(0), context_embed.size(1)

        context_embed = self.context_highway(context_embed)
        query_embed = self.query_highway(query_embed)

        context_encoded = f.forward_rnn_with_pack(self.context_contextual_rnn,
                                                  context_embed,
                                                  context_seq_config)
        context_encoded = self.dropout(context_encoded)

        query_encoded = f.forward_rnn_with_pack(self.query_contextual_rnn,
                                                query_embed, query_seq_config)
        query_encoded = self.dropout(query_encoded)

        # Attention Flow Layer
        attention_context_query = self.attention(context_encoded, context_mask,
                                                 query_encoded, query_mask)

        # Modeling Layer
        modeled_context = f.forward_rnn_with_pack(self.modeling_rnn,
                                                  attention_context_query,
                                                  context_seq_config)
        modeled_context = self.dropout(modeled_context)

        M_D = modeled_context.size(-1)

        # Output Layer
        span_start_input = self.dropout(
            torch.cat([attention_context_query, modeled_context],
                      dim=-1))  # (B, C_L, 10d)
        span_start_logits = self.span_start_linear(span_start_input).squeeze(
            -1)  # (B, C_L)
        span_start_probs = f.masked_softmax(span_start_logits, context_mask)

        span_start_representation = f.weighted_sum(attention=span_start_probs,
                                                   matrix=modeled_context)
        tiled_span_start_representation = span_start_representation.unsqueeze(
            1).expand(B, C_L, M_D)

        span_end_representation = torch.cat(
            [
                attention_context_query,
                modeled_context,
                tiled_span_start_representation,
                modeled_context * tiled_span_start_representation,
            ],
            dim=-1,
        )
        encoded_span_end = f.forward_rnn_with_pack(self.output_end_rnn,
                                                   span_end_representation,
                                                   context_seq_config)
        encoded_span_end = self.dropout(encoded_span_end)

        span_end_input = self.dropout(
            torch.cat([attention_context_query, encoded_span_end], dim=-1))
        span_end_logits = self.span_end_linear(span_end_input).squeeze(-1)

        # Masked Value
        span_start_logits = f.add_masked_value(span_start_logits,
                                               context_mask,
                                               value=-1e7)
        span_end_logits = f.add_masked_value(span_end_logits,
                                             context_mask,
                                             value=-1e7)

        # No_Answer Bias
        bias = self.bias.expand(B, 1)
        span_start_logits = torch.cat([span_start_logits, bias], dim=-1)
        span_end_logits = torch.cat([span_end_logits, bias], dim=-1)

        output_dict = {
            "start_logits":
            span_start_logits,
            "end_logits":
            span_end_logits,
            "best_span":
            self.get_best_span(
                span_start_logits[:, :-1],
                span_end_logits[:, :-1],
                answer_maxlen=self.answer_maxlen,  # except no_answer bias
            ),
        }

        if labels:
            answer_idx = labels["answer_idx"]
            answer_start_idx = labels["answer_start_idx"]
            answer_end_idx = labels["answer_end_idx"]
            answerable = labels["answerable"]

            # No_Asnwer Case
            C_L = context_mask.size(1)
            answer_start_idx = answer_start_idx.masked_fill(
                answerable.eq(0), C_L)
            answer_end_idx = answer_end_idx.masked_fill(answerable.eq(0), C_L)

            output_dict["answer_idx"] = answer_idx

            # Loss
            loss = self.criterion(span_start_logits, answer_start_idx)
            loss += self.criterion(span_end_logits, answer_end_idx)
            output_dict["loss"] = loss.unsqueeze(
                0)  # NOTE: DataParallel concat Error

        return output_dict
コード例 #4
0
    def forward(self, features, labels=None):
        """
        * Args:
            features: feature dictionary like below.
                {"feature_name1": {
                     "token_name1": tensor,
                     "toekn_name2": tensor},
                 "feature_name2": ...}

        * Kwargs:
            label: label dictionary like below.
                {"label_name1": tensor,
                 "label_name2": tensor}
                 Do not calculate loss when there is no label. (inference/predict mode)

        * Returns: output_dict (dict) consisting of
            - start_logits: representing unnormalized log probabilities of the span start position.
            - end_logits: representing unnormalized log probabilities of the span end position.
            - best_span: the string from the original passage that the model thinks is the best answer to the question.
            - data_idx: the question id, mapping with answer
            - loss: A scalar loss to be optimised.
        """

        context = features["context"]  # aka paragraph
        question = features["question"]

        # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward
        context_seq_config = f.get_sorted_seq_config(context)
        query_seq_config = f.get_sorted_seq_config(question)

        # Embedding
        query_params = {"frequent_word": {"frequent_tuning": True}}
        context_embed, query_embed = self.token_embedder(
            context,
            question,
            query_params=query_params,
            query_align=self.aligned_query_embedding)

        context_mask = f.get_mask_from_tokens(context).float()
        query_mask = f.get_mask_from_tokens(question).float()

        context_embed = self.dropout(context_embed)
        query_embed = self.dropout(query_embed)

        # RNN (LSTM)
        context_encoded = f.forward_rnn_with_pack(self.paragraph_rnn,
                                                  context_embed,
                                                  context_seq_config)
        context_encoded = self.dropout(context_encoded)

        query_encoded = f.forward_rnn_with_pack(
            self.query_rnn, query_embed, query_seq_config)  # (B, Q_L, H*2)
        query_encoded = self.dropout(query_encoded)

        query_attention = self.query_att(query_encoded, query_mask)  # (B, Q_L)
        query_att_sum = f.weighted_sum(query_attention,
                                       query_encoded)  # (B, H*2)

        span_start_logits = self.start_attn(context_encoded, query_att_sum,
                                            context_mask)
        span_end_logits = self.end_attn(context_encoded, query_att_sum,
                                        context_mask)

        # Masked Value
        span_start_logits = f.add_masked_value(span_start_logits,
                                               context_mask,
                                               value=-1e7)
        span_end_logits = f.add_masked_value(span_end_logits,
                                             context_mask,
                                             value=-1e7)

        output_dict = {
            "start_logits":
            span_start_logits,
            "end_logits":
            span_end_logits,
            "best_span":
            self.get_best_span(span_start_logits,
                               span_end_logits,
                               answer_maxlen=self.answer_maxlen),
        }

        if labels:
            data_idx = labels["data_idx"]
            answer_start_idx = labels["answer_start_idx"]
            answer_end_idx = labels["answer_end_idx"]

            output_dict["data_idx"] = data_idx

            loss = self.criterion(span_start_logits, answer_start_idx)
            loss += self.criterion(span_end_logits, answer_end_idx)
            output_dict["loss"] = loss.unsqueeze(0)

        return output_dict
コード例 #5
0
    def forward(self, features, labels=None):
        """
        * Args:
            features: feature dictionary like below.
            {"sequence": [0, 3, 4, 1]}

        * Kwargs:
            label: label dictionary like below.
            {"class_idx": 2, "data_idx": 0}
             Do not calculate loss when there is no label. (inference/predict mode)

        * Returns: output_dict (dict) consisting of
            - sequence_embed: embedding vector of the sequence
            - class_logits: representing unnormalized log probabilities of the class.

            - class_idx: target class idx
            - data_idx: data idx
            - loss: a scalar loss to be optimized
        """

        sequence = features["sequence"]

        # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward
        sequence_config = f.get_sorted_seq_config(sequence)

        token_embed = self.token_embedder(sequence)

        token_encodings = f.forward_rnn_with_pack(
            self.encoder, token_embed, sequence_config
        )  # [B, L, encoding_rnn_hidden_dim]

        attention = self.A(token_encodings).transpose(1, 2)  # [B, num_attention_heads, L]

        sequence_mask = f.get_mask_from_tokens(sequence).float()  # [B, L]
        sequence_mask = sequence_mask.unsqueeze(1).expand_as(attention)
        attention = F.softmax(f.add_masked_value(attention, sequence_mask) + 1e-13, dim=2)

        attended_encodings = torch.bmm(
            attention, token_encodings
        )  # [B, num_attention_heads, sequence_embed_dim]
        sequence_embed = self.fully_connected(
            attended_encodings.view(attended_encodings.size(0), -1)
        )  # [B, sequence_embed_dim]

        class_logits = self.classifier(sequence_embed)  # [B, num_classes]

        output_dict = {"sequence_embed": sequence_embed, "class_logits": class_logits}

        if labels:
            class_idx = labels["class_idx"]
            data_idx = labels["data_idx"]

            output_dict["class_idx"] = class_idx
            output_dict["data_idx"] = data_idx

            # Loss
            loss = self.criterion(class_logits, class_idx)
            loss += self.penalty(attention)
            output_dict["loss"] = loss.unsqueeze(0)  # NOTE: DataParallel concat Error

        return output_dict
コード例 #6
0
    def forward(self, features, labels=None):
        """
        * Args:
            features: feature dictionary like below.
                {"feature_name1": {
                     "token_name1": tensor,
                     "toekn_name2": tensor},
                 "feature_name2": ...}

        * Kwargs:
            label: label dictionary like below.
                {"label_name1": tensor,
                 "label_name2": tensor}
                 Do not calculate loss when there is no label. (inference/predict mode)

        * Returns: output_dict (dict) consisting of
            - start_logits: representing unnormalized log probabilities of the span start position.
            - end_logits: representing unnormalized log probabilities of the span end position.
            - best_span: the string from the original passage that the model thinks is the best answer to the question.
            - answer_idx: the question id, mapping with answer
            - loss: A scalar loss to be optimised.
        """

        context = features["context"]
        question = features["question"]

        # Sorted Sequence config (seq_lengths, perm_idx, unperm_idx) for RNN pack_forward
        context_seq_config = f.get_sorted_seq_config(context)
        query_seq_config = f.get_sorted_seq_config(question)

        # Embedding
        query_params = {"frequent_word": {"frequent_tuning": True}}
        context_embed, query_embed = self.token_embedder(
            context,
            question,
            query_params=query_params,
            query_align=self.aligned_query_embedding)

        context_mask = f.get_mask_from_tokens(context).float()  # B X 1 X C_L
        query_mask = f.get_mask_from_tokens(question).float()  # B X 1 X Q_L

        # Pre-process
        context_embed = self.dropout(context_embed)
        context_encoded = f.forward_rnn_with_pack(self.context_preprocess_rnn,
                                                  context_embed,
                                                  context_seq_config)
        context_encoded = self.dropout(context_encoded)

        query_embed = self.dropout(query_embed)
        query_encoded = f.forward_rnn_with_pack(self.query_preprocess_rnn,
                                                query_embed, query_seq_config)
        query_encoded = self.dropout(query_encoded)

        # Attention -> Projection
        context_attnded = self.bi_attention(context_encoded, context_mask,
                                            query_encoded, query_mask)
        context_attnded = self.activation_fn(
            self.attn_linear(context_attnded))  # B X C_L X dim*2

        # Residual Self-Attention
        context_attnded = self.dropout(context_attnded)
        context_encoded = f.forward_rnn_with_pack(self.modeling_rnn,
                                                  context_attnded,
                                                  context_seq_config)
        context_encoded = self.dropout(context_encoded)

        context_self_attnded = self.self_attention(
            context_encoded, context_mask)  # B X C_L X dim*2
        context_final = self.dropout(context_attnded +
                                     context_self_attnded)  # B X C_L X dim*2

        # Prediction
        span_start_input = f.forward_rnn_with_pack(
            self.span_start_rnn, context_final,
            context_seq_config)  # B X C_L X dim*2
        span_start_input = self.dropout(span_start_input)
        span_start_logits = self.span_start_linear(span_start_input).squeeze(
            -1)  # B X C_L

        span_end_input = torch.cat([span_start_input, context_final],
                                   dim=-1)  # B X C_L X dim*4
        span_end_input = f.forward_rnn_with_pack(
            self.span_end_rnn, span_end_input,
            context_seq_config)  # B X C_L X dim*2
        span_end_input = self.dropout(span_end_input)
        span_end_logits = self.span_end_linear(span_end_input).squeeze(
            -1)  # B X C_L

        # Masked Value
        span_start_logits = f.add_masked_value(span_start_logits,
                                               context_mask,
                                               value=-1e7)
        span_end_logits = f.add_masked_value(span_end_logits,
                                             context_mask,
                                             value=-1e7)

        output_dict = {
            "start_logits":
            span_start_logits,
            "end_logits":
            span_end_logits,
            "best_span":
            self.get_best_span(span_start_logits,
                               span_end_logits,
                               answer_maxlen=self.answer_maxlen),
        }

        if labels:
            answer_idx = labels["answer_idx"]
            answer_start_idx = labels["answer_start_idx"]
            answer_end_idx = labels["answer_end_idx"]

            output_dict["answer_idx"] = answer_idx

            # Loss
            loss = self.criterion(span_start_logits, answer_start_idx)
            loss += self.criterion(span_end_logits, answer_end_idx)
            output_dict["loss"] = loss.unsqueeze(
                0)  # NOTE: DataParallel concat Error

        return output_dict