def _run_mst_decoding(
            batch_energy: torch.Tensor,
            lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu(), lengths):
            scores, tag_ids = energy.max(dim=0)
            # Although we need to include the root node so that the MST includes it,
            # we do not want any word to be the parent of the root node.
            # Here, we enforce this by setting the scores for all word -> ROOT edges
            # edges to be 0.
            scores[0, :] = 0
            # Decode the heads. Because we modify the scores to prevent
            # adding in word -> ROOT edges, we need to find the labels ourselves.
            instance_heads, _ = decode_mst(scores.numpy(),
                                           length,
                                           has_labels=False)

            # Find the labels which correspond to the edges in the max spanning tree.
            instance_head_tags = []
            for child, parent in enumerate(instance_heads):
                instance_head_tags.append(tag_ids[parent, child].item())
            # We don't care what the head or tag is for the root token, but by default it's
            # not necesarily the same in the batched vs unbatched case, which is annoying.
            # Here we'll just set them to zero.
            instance_heads[0] = 0
            instance_head_tags[0] = 0
            heads.append(instance_heads)
            head_tags.append(instance_head_tags)
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(
            numpy.stack(head_tags))
 def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
     heads = []
     head_tags = []
     for energy, length in zip(batch_energy.detach().cpu().numpy(), lengths):
         head, head_tag = decode_mst(energy, length)
         heads.append(head)
         head_tags.append(head_tag)
     return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags))
Exemple #3
0
    def test_mst(self):
        # First, test some random cases as sanity checks.
        # No label case
        energy = numpy.random.rand(5, 5)
        heads, types = decode_mst(energy, 5, has_labels=False)
        assert not _find_cycle(heads, 5, [True] * 5)[0]

        # Labeled case
        energy = numpy.random.rand(3, 5, 5)
        heads, types = decode_mst(energy, 5)

        assert not _find_cycle(heads, 5, [True] * 5)[0]
        label_id_matrix = energy.argmax(axis=0)

        # Check that the labels correspond to the
        # argmax of the labels for the arcs.
        for child, parent in enumerate(heads):
            # The first index corresponds to the symbolic
            # head token, which won't necessarily have an
            # argmax type.
            if child == 0:
                continue
            assert types[child] == label_id_matrix[parent, child]

        # Check wrong dimensions throw errors
        with pytest.raises(ConfigurationError):
            energy = numpy.random.rand(5, 5)
            decode_mst(energy, 5, has_labels=True)

        with pytest.raises(ConfigurationError):
            energy = numpy.random.rand(3, 5, 5)
            decode_mst(energy, 5, has_labels=False)
    def test_mst(self):
        # First, test some random cases as sanity checks.
        # No label case
        energy = numpy.random.rand(5, 5)
        heads, types = decode_mst(energy, 5, has_labels=False)
        assert not _find_cycle(heads, 5, [True] * 5)[0]

        # Labeled case
        energy = numpy.random.rand(3, 5, 5)
        heads, types = decode_mst(energy, 5)

        assert not _find_cycle(heads, 5, [True] * 5)[0]
        label_id_matrix = energy.argmax(axis=0)

        # Check that the labels correspond to the
        # argmax of the labels for the arcs.
        for child, parent in enumerate(heads):
            # The first index corresponds to the symbolic
            # head token, which won't necessarily have an
            # argmax type.
            if child == 0:
                continue
            assert types[child] == label_id_matrix[parent, child]

        # Check wrong dimensions throw errors
        with pytest.raises(ConfigurationError):
            energy = numpy.random.rand(5, 5)
            decode_mst(energy, 5, has_labels=True)

        with pytest.raises(ConfigurationError):
            energy = numpy.random.rand(3, 5, 5)
            decode_mst(energy, 5, has_labels=False)
    def test_mst_respects_no_outgoing_root_edges_constraint(self):
        # This energy tensor expresses the following relation:
        # energy[i,j] = "Score that i is the head of j". In this
        # case, we have heads pointing to their children.

        # We want to construct a case that has 2 children for the ROOT node,
        # because in a typical dependency parse there should only be one
        # word which has the ROOT as it's head.
        energy = torch.Tensor([[0, 9, 5], [2, 0, 4], [3, 1, 0]])

        length = torch.LongTensor([3])
        heads, _ = decode_mst(energy.numpy(), length.item(), has_labels=False)

        # This is the correct MST, but not desirable for dependency parsing.
        assert list(heads) == [-1, 0, 0]
        # If we run the decoding with the model, it should enforce
        # the constraint.
        heads_model, _ = self.model._run_mst_decoding(energy.view(1, 1, 3, 3),
                                                      length)  # pylint: disable=protected-access
        assert heads_model.tolist()[0] == [0, 0, 1]
    def _mst_decode(self, head_tag_representation: torch.Tensor,
                    child_tag_representation: torch.Tensor,
                    attended_arcs: torch.Tensor,
                    mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions using the Edmonds' Algorithm
        for finding minimum spanning trees on directed graphs. Nodes in the
        graph are the words in the sentence, and between each pair of nodes,
        there is an edge in each direction, where the weight of the edge corresponds
        to the most likely dependency label probability for that arc. The MST is
        then generated from this directed graph.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachements of a given word to all other words.

        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the optimally decoded heads of each word.
        """
        batch_size, sequence_length, tag_representation_dim = head_tag_representation.size(
        )

        lengths = mask.data.sum(dim=1).long().cpu().numpy()

        expanded_shape = [
            batch_size, sequence_length, sequence_length,
            tag_representation_dim
        ]
        head_tag_representation = head_tag_representation.unsqueeze(2)
        head_tag_representation = head_tag_representation.expand(
            *expanded_shape).contiguous()
        child_tag_representation = child_tag_representation.unsqueeze(1)
        child_tag_representation = child_tag_representation.expand(
            *expanded_shape).contiguous()
        # Shape (batch_size, sequence_length, sequence_length, num_head_tags)
        pairwise_head_logits = self.tag_bilinear(head_tag_representation,
                                                 child_tag_representation)

        # Note that this log_softmax is over the tag dimension, and we don't consider pairs
        # of tags which are invalid (e.g are a pair which includes a padded element) anyway below.
        # Shape (batch, num_labels,sequence_length, sequence_length)
        normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits,
                                                        dim=3).permute(
                                                            0, 3, 1, 2)

        # Mask padded tokens, because we only want to consider actual words as heads.
        minus_inf = -1e8
        minus_mask = (1 - mask.float()) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        # Shape (batch_size, sequence_length, sequence_length)
        normalized_arc_logits = F.log_softmax(attended_arcs,
                                              dim=2).transpose(1, 2)
        # Shape (batch_size, num_head_tags, sequence_length, sequence_length)
        batch_energy = torch.exp(
            normalized_arc_logits.unsqueeze(1) +
            normalized_pairwise_head_logits)

        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu().numpy(),
                                  lengths):
            head, head_tag = decode_mst(energy, length)
            heads.append(head)
            head_tags.append(head_tag)
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(
            numpy.stack(head_tags))
    def _mst_decode(self,
                    head_tag_representation: torch.Tensor,
                    child_tag_representation: torch.Tensor,
                    attended_arcs: torch.Tensor,
                    mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decodes the head and head tag predictions using the Edmonds' Algorithm
        for finding minimum spanning trees on directed graphs. Nodes in the
        graph are the words in the sentence, and between each pair of nodes,
        there is an edge in each direction, where the weight of the edge corresponds
        to the most likely dependency label probability for that arc. The MST is
        then generated from this directed graph.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        attended_arcs : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
            a distribution over attachements of a given word to all other words.

        Returns
        -------
        heads : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            greedily decoded heads of each word.
        head_tags : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length) representing the
            dependency tags of the optimally decoded heads of each word.
        """
        batch_size, sequence_length, tag_representation_dim = head_tag_representation.size()

        lengths = mask.data.sum(dim=1).long().cpu().numpy()

        expanded_shape = [batch_size, sequence_length, sequence_length, tag_representation_dim]
        head_tag_representation = head_tag_representation.unsqueeze(2)
        head_tag_representation = head_tag_representation.expand(*expanded_shape).contiguous()
        child_tag_representation = child_tag_representation.unsqueeze(1)
        child_tag_representation = child_tag_representation.expand(*expanded_shape).contiguous()
        # Shape (batch_size, sequence_length, sequence_length, num_head_tags)
        pairwise_head_logits = self.tag_bilinear(head_tag_representation, child_tag_representation)

        # Note that this log_softmax is over the tag dimension, and we don't consider pairs
        # of tags which are invalid (e.g are a pair which includes a padded element) anyway below.
        # Shape (batch, num_labels,sequence_length, sequence_length)
        normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits, dim=3).permute(0, 3, 1, 2)

        # Mask padded tokens, because we only want to consider actual words as heads.
        minus_inf = -1e8
        minus_mask = (1 - mask.float()) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        # Shape (batch_size, sequence_length, sequence_length)
        normalized_arc_logits = F.log_softmax(attended_arcs, dim=2).transpose(1, 2)
        # Shape (batch_size, num_head_tags, sequence_length, sequence_length)
        batch_energy = torch.exp(normalized_arc_logits.unsqueeze(1) + normalized_pairwise_head_logits)

        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu().numpy(), lengths):
            head, head_tag = decode_mst(energy, length)
            heads.append(head)
            head_tags.append(head_tag)
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags))
Exemple #8
0
 def test_mst_finds_maximum_spanning_tree(self):
     energy = torch.range(1, 9).view(1, 3, 3)
     heads, _ = decode_mst(energy.numpy(), 3) # pylint: disable=protected-access
     assert list(heads) == [-1, 2, 0]