コード例 #1
0
ファイル: viterbi.py プロジェクト: tadejmagajna/flair
    def _format_targets(self, targets: torch.Tensor, lengths: torch.IntTensor):
        """
        Formats targets into matrix indices.
        CRF scores contain per sentence, per token a (tagset_size x tagset_size) matrix, containing emission score for
            token j + transition prob from previous token i. Means, if we think of our rows as "to tag" and our columns
            as "from tag", the matrix in cell [10,5] would contain the emission score for tag 10 + transition score
            from previous tag 5 and could directly be addressed through the 1-dim indices (10 + tagset_size * 5) = 70,
            if our tagset consists of 12 tags.

        :param targets: targets as in tag dictionary
        :param lengths: lengths of sentences in batch
        """
        targets_per_sentence = []

        targets_list = targets.tolist()
        for cut in lengths:
            targets_per_sentence.append(targets_list[:cut])
            targets_list = targets_list[cut:]

        for t in targets_per_sentence:
            t += [self.tag_dictionary.get_idx_for_item(STOP_TAG)] * (int(lengths.max().item()) - len(t))

        matrix_indices = list(
            map(
                lambda s: [self.tag_dictionary.get_idx_for_item(START_TAG) + (s[0] * self.tagset_size)]
                + [s[i] + (s[i + 1] * self.tagset_size) for i in range(0, len(s) - 1)],
                targets_per_sentence,
            )
        )

        return targets_per_sentence, matrix_indices
コード例 #2
0
def audio_to_class_idxs(audio: torch.IntTensor, n_classes):
    "Convert audio [-128, 127] to class indices [0, 255]."
    assert audio.min() >= -n_classes // 2, audio.min()
    assert audio.max() <= n_classes // 2 - 1, audio.max()
    return (audio + n_classes // 2).long()
コード例 #3
0
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()
        premise_sequence_lengths = get_lengths_from_binary_sequence_mask(
            premise_mask)
        hypothesis_sequence_lengths = get_lengths_from_binary_sequence_mask(
            hypothesis_mask)

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise,
                                                     premise_sequence_lengths)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(
                embedded_hypothesis, hypothesis_sequence_lengths)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise,
                                                   projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat(
            [embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat(
            [embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(
            hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(
            -1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis],
                                    dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            if label.dim() == 2:
                _, label = label.max(-1)
            loss = self._loss(label_logits, label.view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict