def __init__(self, num_tags, labels: Vocabulary, *args) -> None:
     super().__init__(list(labels), *args)
     self.crf = CRF(
         num_tags=num_tags,
         ignore_index=labels.get_pad_index(Padding.DEFAULT_LABEL_PAD_IDX),
         default_label_pad_index=Padding.DEFAULT_LABEL_PAD_IDX,
     )
Beispiel #2
0
    def test_crf_forward(self, num_tags, seq_lens):

        crf_model = CRF(
            num_tags,
            ignore_index=Padding.WORD_LABEL_PAD_IDX,
            default_label_pad_index=Padding.DEFAULT_LABEL_PAD_IDX,
        )

        total_manual_loss = 0

        padded_inputs = []
        padded_targets = []

        max_num_words = max(seq_lens)

        for seq_len in seq_lens:

            target_tokens = np.random.randint(1, num_tags, size=(1, seq_len))
            padded_targets.append(
                np.concatenate(
                    [target_tokens,
                     np.zeros((1, max_num_words - seq_len))],
                    axis=1))

            input_emission = np.random.rand(seq_len, num_tags)
            padded_inputs.append(
                np.concatenate(
                    [
                        input_emission,
                        np.zeros((max_num_words - seq_len, num_tags))
                    ],
                    axis=0,
                ))

            manual_loss = self._compute_loss_manual(
                input_emission,
                num_tags,
                target_tokens.reshape(-1),
                crf_model.get_transitions().tolist(),
            )
            crf_loss = crf_model(
                torch.tensor(input_emission, dtype=torch.float).unsqueeze(0),
                torch.tensor(target_tokens),
            )

            # Loss returned by CRF model for each input should be equal to
            # manually calculated loss
            self.assertAlmostEqual(manual_loss, -1 * crf_loss.item(), places=4)
            total_manual_loss += manual_loss

        # Loss returned by CRF model for batched input should be equal to
        # average of manually calculated loss
        batched_crf_loss = crf_model(
            torch.tensor(padded_inputs, dtype=torch.float),
            torch.tensor(padded_targets, dtype=torch.long).squeeze(1),
        )
        self.assertAlmostEqual(total_manual_loss / len(seq_lens),
                               -1 * batched_crf_loss.item(),
                               places=4)
Beispiel #3
0
    def test_crf_decode_torchscript(self, num_tags, seq_lens):
        crf_model = CRF(
            num_tags,
            ignore_index=Padding.WORD_LABEL_PAD_IDX,
            default_label_pad_index=Padding.DEFAULT_LABEL_PAD_IDX,
        )
        crf_model.eval()
        scripted_crf_model = torch.jit.script(crf_model)

        max_num_words = max(seq_lens)
        padded_inputs = []
        for seq_len in seq_lens:
            input_emission = np.random.rand(seq_len, num_tags)
            padded_inputs.append(
                np.concatenate(
                    [
                        input_emission,
                        np.zeros((max_num_words - seq_len, num_tags))
                    ],
                    axis=0,
                ))
            crf_decode = crf_model.decode(
                torch.tensor(input_emission, dtype=torch.float).unsqueeze(0),
                torch.tensor([seq_len]),
            )

            scripted_crf_decode = scripted_crf_model.decode(
                torch.tensor(input_emission, dtype=torch.float).unsqueeze(0),
                torch.tensor([seq_len]),
            )

            self.assertTrue(torch.allclose(crf_decode, scripted_crf_decode))

        batched_emissions = torch.tensor(padded_inputs, dtype=torch.float)
        batched_seq_lens = torch.tensor(seq_lens)
        crf_batch_decode = crf_model.decode(batched_emissions,
                                            batched_seq_lens)
        scriped_crf_batch_decode = scripted_crf_model.decode(
            batched_emissions, batched_seq_lens)
        self.assertTrue(
            torch.allclose(crf_batch_decode, scriped_crf_batch_decode))
Beispiel #4
0
 def __init__(self, num_tags, *args) -> None:
     super().__init__(*args)
     self.crf = CRF(num_tags)
Beispiel #5
0
class CRFOutputLayer(OutputLayerBase):
    """
    Output layer for word tagging models that use Conditional Random Field.

    Args:
        num_tags (int): Total number of possible word tags.

    Attributes:
        num_tags: Total number of possible word tags.

    """
    @classmethod
    def from_config(cls, config: OutputLayerBase.Config, metadata: FieldMeta):
        return cls(metadata.vocab_size, metadata.vocab.itos)

    def __init__(self, num_tags, *args) -> None:
        super().__init__(*args)
        self.crf = CRF(num_tags)

    def get_loss(
        self,
        logit: torch.Tensor,
        target: torch.Tensor,
        context: Dict[str, Any],
        reduce=True,
    ):
        """Compute word tagging loss by using CRF.

        Args:
            logit (torch.Tensor): Logit returned by
                :class:`~pytext.models.WordTaggingModel`.
            targets (torch.Tensor): True document label/target.
            context (Dict[str, Any]): Context is a dictionary of items
                that's passed as additional metadata by the
                :class:`~pytext.data.JointModelDataHandler`. Defaults to None.
            reduce (bool): Whether to reduce loss over the batch. Defaults to True.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Model prediction and scores.

        """
        loss = -1 * self.crf(logit, target, reduce=False)
        return loss.mean() if reduce else loss

    def get_pred(
        self,
        logit: torch.Tensor,
        target: Optional[torch.Tensor] = None,
        context: Optional[Dict[str, Any]] = None,
    ):
        """Compute and return prediction and scores from the model.

        Prediction is computed using CRF decoding.

        Scores are softmax scores over the model logits where the logits are
        computed by rearranging the word logits such that decoded word tag has
        the highest valued logits. This is done because with CRF, the highest valued
        word tag for a given may not be part of the overall set of word tags. In
        order for argmax to work, we rearrange the logit values.

        Args:
            logit (torch.Tensor): Logits returned
                :class:`~pytext.models.WordTaggingModel`.
            target (torch.Tensor): Not applicable. Defaults to None.
            context (Optional[Dict[str, Any]]): Context is a dictionary of items
                that's passed as additional metadata by the
                :class:`~pytext.data.JointModelDataHandler`. Defaults to None.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Model prediction and scores.

        """
        if not context:
            raise MissingValueError("Expected non-None context but got None.")
        pred = self.crf.decode(logit, context[SEQ_LENS])
        logit = _rearrange_output(logit, pred)
        return pred, F.log_softmax(logit, 2)

    def export_to_caffe2(
        self,
        workspace: core.workspace,
        init_net: core.Net,
        predict_net: core.Net,
        model_out: torch.Tensor,
        output_name: str,
    ) -> List[core.BlobReference]:
        """
        Exports the CRF output layer to Caffe2.
        See `OutputLayerBase.export_to_caffe2()` for details.
        """
        output_score = self.crf.export_to_caffe2(workspace, init_net,
                                                 predict_net, output_name)
        probability_out = predict_net.Softmax(output_score,
                                              axis=model_out.dim() - 1)
        return OutputLayerUtils.gen_additional_blobs(predict_net,
                                                     probability_out,
                                                     model_out, output_name,
                                                     self.target_names)