Ejemplo n.º 1
0
def embed_encode_and_aggregate_list_text_field_with_feats_only(texts_list: Dict[str, torch.LongTensor],
                                                text_field_embedder,
                                                embeddings_dropout,
                                                encoder: Seq2SeqEncoder,
                                                aggregation_type,
                                                token_features=None,
                                                init_hidden_states=None):
    embedded_texts = text_field_embedder(texts_list)
    embedded_texts = embeddings_dropout(embedded_texts)

    if token_features is not None:
        embedded_texts = torch.cat([token_features], dim=-1)

    bs, ch_cnt, ch_tkn_cnt, d = tuple(embedded_texts.shape)

    embedded_texts_flattened = embedded_texts.view([bs * ch_cnt, ch_tkn_cnt, -1])
    # masks

    texts_mask_dim_3 = get_text_field_mask(texts_list, num_wrapping_dims=1).float()
    texts_mask_flatened = texts_mask_dim_3.view([-1, ch_tkn_cnt])

    # context encoding
    multiple_texts_init_states = None
    if init_hidden_states is not None:
        if init_hidden_states.shape[0] == bs and init_hidden_states.shape[1] != ch_cnt:
            if init_hidden_states.shape[1] != encoder.get_output_dim():
                raise ValueError("The shape of init_hidden_states is {0} but is expected to be {1} or {2}".format(str(init_hidden_states.shape),
                                                                            str([bs, encoder.get_output_dim()]),
                                                                            str([bs, ch_cnt, encoder.get_output_dim()])))
            # in this case we passed only 2D tensor which is the default output from question encoder
            multiple_texts_init_states = init_hidden_states.unsqueeze(1).expand([bs, ch_cnt, encoder.get_output_dim()]).contiguous()

            # reshape this to match the flattedned tokens
            multiple_texts_init_states = multiple_texts_init_states.view([bs * ch_cnt, encoder.get_output_dim()])
        else:
            multiple_texts_init_states = init_hidden_states.view([bs * ch_cnt, encoder.get_output_dim()])

    encoded_texts_flattened = encoder(embedded_texts_flattened, texts_mask_flatened, hidden_state=multiple_texts_init_states)

    aggregated_choice_flattened = seq2vec_seq_aggregate(encoded_texts_flattened, texts_mask_flatened,
                                                        aggregation_type,
                                                        encoder.is_bidirectional(),
                                                        1)  # bs*ch X d

    aggregated_choice_flattened_reshaped = aggregated_choice_flattened.view([bs, ch_cnt, -1])
    return aggregated_choice_flattened_reshaped
Ejemplo n.º 2
0
def embedd_encode_and_aggregate_text_field(question: Dict[str, torch.LongTensor],
                                           text_field_embedder,
                                           embeddings_dropout,
                                           encoder,
                                           aggregation_type,
                                           get_last_states=False):
    embedded_question = text_field_embedder(question)
    question_mask = get_text_field_mask(question).float()
    embedded_question = embeddings_dropout(embedded_question)

    encoded_question = encoder(embedded_question, question_mask)

    # aggregate sequences to a single item
    encoded_question_aggregated = seq2vec_seq_aggregate(encoded_question, question_mask, aggregation_type,
                                                        encoder.is_bidirectional(), 1)  # bs X d

    last_hidden_states = None
    if get_last_states:
        last_hidden_states = get_final_encoder_states(encoded_question, question_mask, encoder.is_bidirectional())

    return encoded_question_aggregated, last_hidden_states
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_premise = self._embeddings_dropout(embedded_premise)

        embedded_hypothesis = self._text_field_embedder(hypothesis)
        embedded_hypothesis = self._embeddings_dropout(embedded_hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise,
                                                     premise_mask)

        embedded_premise = seq2vec_seq_aggregate(
            embedded_premise, premise_mask, self._premise_aggregate,
            self._premise_encoder.is_bidirectional(), 1)

        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(
                embedded_hypothesis, hypothesis_mask)

        embedded_hypothesis = seq2vec_seq_aggregate(
            embedded_hypothesis, hypothesis_mask, self._hypothesis_aggregate,
            self._premise_encoder.is_bidirectional(), 1)

        aggregate_input = torch.cat([
            embedded_premise, embedded_hypothesis,
            torch.abs(embedded_hypothesis - embedded_premise),
            embedded_hypothesis * embedded_hypothesis
        ],
                                    dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            labels = label.long().view(-1)
            loss = self._loss(label_logits, labels)
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict