def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu(), lengths):
            scores, tag_ids = energy.max(dim=0)
            # Although we need to include the root node so that the MST includes it,
            # we do not want any word to be the parent of the root node.
            # Here, we enforce this by setting the scores for all word -> ROOT edges
            # edges to be 0.
            scores[0, :] = 0
            # Decode the heads. Because we modify the scores to prevent
            # adding in word -> ROOT edges, we need to find the labels ourselves.
            instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False)

            # Find the labels which correspond to the edges in the max spanning tree.
            instance_head_tags = []
            for child, parent in enumerate(instance_heads):
                instance_head_tags.append(tag_ids[parent, child].item())
            # We don't care what the head or tag is for the root token, but by default it's
            # not necesarily the same in the batched vs unbatched case, which is annoying.
            # Here we'll just set them to zero.
            instance_heads[0] = 0
            instance_head_tags[0] = 0
            heads.append(instance_heads)
            head_tags.append(instance_head_tags)
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags))
Esempio n. 2
0
def run_mst_decoding(
        batch_energy: torch.Tensor,
        lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    heads = []
    head_tags = []
    for energy, length in zip(batch_energy.detach().cpu(), lengths):

        scores, tag_ids = energy.max(dim=0)
        scores[0, :] = 0
        instance_heads, _ = decode_mst(scores.numpy(),
                                       length,
                                       has_labels=False)

        # Find the labels which correspond to the edges in the max spanning tree.
        instance_head_tags = []
        for child, parent in enumerate(instance_heads):

            instance_head_tags.append(tag_ids[child, parent].item())
            # OLD: instance_head_tags.append(tag_ids[parent, child].item())
        # We don't care what the head or tag is for the root token, but by default it's
        # not necesarily the same in the batched vs unbatched case, which is annoying.
        # Here we'll just set them to zero.
        instance_heads[0] = 0
        instance_head_tags[0] = 0
        heads.append(instance_heads)
        head_tags.append(instance_head_tags)

    return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(
        numpy.stack(head_tags))
Esempio n. 3
0
    def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu(), lengths):
            scores, tag_ids = energy.max(dim=0)
            # Although we need to include the root node so that the MST includes it,
            # we do not want any word to be the parent of the root node.
            # Here, we enforce this by setting the scores for all word -> ROOT edges
            # edges to be 0.
            scores[0, :] = 0
            # Decode the heads. Because we modify the scores to prevent
            # adding in word -> ROOT edges, we need to find the labels ourselves.
            instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False)

            # Find the labels which correspond to the edges in the max spanning tree.
            instance_head_tags = []
            for child, parent in enumerate(instance_heads):
                instance_head_tags.append(tag_ids[parent, child].item())
            # We don't care what the head or tag is for the root token, but by default it's
            # not necesarily the same in the batched vs unbatched case, which is annoying.
            # Here we'll just set them to zero.
            instance_heads[0] = 0
            instance_head_tags[0] = 0
            heads.append(instance_heads)
            head_tags.append(instance_head_tags)
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags))
Esempio n. 4
0
def run_cle(matrix):

    matrix = matrix.cpu().numpy()
    n = matrix.shape[0]
    scores = cle.decode_mst(matrix, length=n, has_labels=False)

    return scores[0]
Esempio n. 5
0
    def _run_mst_decoding(batch_energy, lengths):
        edge_heads = []
        edge_labels = []

        for i, (energy,
                length) in enumerate(zip(batch_energy.detach().cpu(),
                                         lengths)):
            # decode heads and labels
            # need to decode labels separately so that we can enforce single root
            scores, label_ids = energy.max(dim=0)
            energy = scores

            instance_heads, instance_head_labels = decode_mst(energy.numpy(),
                                                              length,
                                                              has_labels=False)
            #instance_heads, instance_head_labels = decode_mst(scores.numpy(), length, has_labels=True)

            ## Find the labels which correspond to the edges in the max spanning tree.
            instance_head_labels = []
            for child, parent in enumerate(instance_heads):
                instance_head_labels.append(label_ids[parent, child].item())

            # check for multiroot
            multi_root = sum(
                [1 if h == 0 else 0 for h in instance_heads[0:length]]) > 1

            if multi_root:
                energy = energy.unsqueeze(0)
                energy = DeepTreeParser._enforce_root(energy)
                energy = energy.squeeze(0)
                instance_heads, instance_head_labels = decode_mst(
                    energy.numpy(), length, has_labels=False)
                #instance_heads, instance_head_labels = decode_mst(scores.numpy(), length, has_labels=True)

                ## Find the labels which correspond to the edges in the max spanning tree.
                instance_head_labels = []
                for child, parent in enumerate(instance_heads):
                    instance_head_labels.append(label_ids[parent,
                                                          child].item())

            edge_heads.append(instance_heads)
            edge_labels.append(instance_head_labels)

        return torch.from_numpy(np.stack(edge_heads)), torch.from_numpy(
            np.stack(edge_labels))
Esempio n. 6
0
    def test_mst(self):
        # First, test some random cases as sanity checks.
        # No label case
        energy = numpy.random.rand(5, 5)
        heads, types = decode_mst(energy, 5, has_labels=False)
        assert not _find_cycle(heads, 5, [True] * 5)[0]

        # Labeled case
        energy = numpy.random.rand(3, 5, 5)
        heads, types = decode_mst(energy, 5)

        assert not _find_cycle(heads, 5, [True] * 5)[0]
        label_id_matrix = energy.argmax(axis=0)

        # Check that the labels correspond to the
        # argmax of the labels for the arcs.
        for child, parent in enumerate(heads):
            # The first index corresponds to the symbolic
            # head token, which won't necessarily have an
            # argmax type.
            if child == 0:
                continue
            assert types[child] == label_id_matrix[parent, child]

        # Check wrong dimensions throw errors
        with pytest.raises(ConfigurationError):
            energy = numpy.random.rand(5, 5)
            decode_mst(energy, 5, has_labels=True)

        with pytest.raises(ConfigurationError):
            energy = numpy.random.rand(3, 5, 5)
            decode_mst(energy, 5, has_labels=False)
Esempio n. 7
0
    def test_mst(self):
        # First, test some random cases as sanity checks.
        # No label case
        energy = numpy.random.rand(5, 5)
        heads, types = decode_mst(energy, 5, has_labels=False)
        assert not _find_cycle(heads, 5, [True] * 5)[0]

        # Labeled case
        energy = numpy.random.rand(3, 5, 5)
        heads, types = decode_mst(energy, 5)

        assert not _find_cycle(heads, 5, [True] * 5)[0]
        label_id_matrix = energy.argmax(axis=0)

        # Check that the labels correspond to the
        # argmax of the labels for the arcs.
        for child, parent in enumerate(heads):
            # The first index corresponds to the symbolic
            # head token, which won't necessarily have an
            # argmax type.
            if child == 0:
                continue
            assert types[child] == label_id_matrix[parent, child]

        # Check wrong dimensions throw errors
        with pytest.raises(ConfigurationError):
            energy = numpy.random.rand(5, 5)
            decode_mst(energy, 5, has_labels=True)

        with pytest.raises(ConfigurationError):
            energy = numpy.random.rand(3, 5, 5)
            decode_mst(energy, 5, has_labels=False)
Esempio n. 8
0
    def forward(
        self,
        x: Union[torch.Tensor, List[torch.Tensor]],
        mask: Optional[torch.BoolTensor] = None,
        labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
        sample_weights: Optional[Union[torch.Tensor,
                                       List[torch.Tensor]]] = None
    ) -> Dict[str, torch.Tensor]:
        if mask is None:
            mask = x.new_ones(x.size()[-1])

        head_arc_emb = self.head_projection_layer(x)
        dep_arc_emb = self.dependency_projection_layer(x)
        x = dep_arc_emb.bmm(head_arc_emb.transpose(2, 1))

        if self.training:
            pred = x.argmax(-1)
        else:
            pred = []
            # Adding non existing in mask ROOT to lengths
            lengths = mask.data.sum(dim=1).long().cpu().numpy() + 1
            for idx, length in enumerate(lengths):
                probs = x[idx, :].softmax(dim=-1).cpu().numpy()

                # We do not want any word to be parent of the root node (ROOT, 0).
                # Also setting it to -1 instead of 0 fixes edge case where softmax made all
                # but ROOT prediction to EXACTLY 0.0 and it might cause in many ROOT -> word edges)
                probs[:, 0] = -1
                heads, _ = chu_liu_edmonds.decode_mst(probs.T,
                                                      length=length,
                                                      has_labels=False)
                heads[0] = 0
                pred.append(heads)
            pred = torch.from_numpy(np.stack(pred)).to(x.device)

        output = {"prediction": pred[:, 1:], "probability": x}

        if labels is not None:
            if sample_weights is None:
                sample_weights = labels.new_ones([mask.size(0)])
            output["loss"], output["cycle_loss"] = self._loss(
                x, labels, mask, sample_weights)

        return output
    def test_mst_respects_no_outgoing_root_edges_constraint(self):
        # This energy tensor expresses the following relation:
        # energy[i,j] = "Score that i is the head of j". In this
        # case, we have heads pointing to their children.

        # We want to construct a case that has 2 children for the ROOT node,
        # because in a typical dependency parse there should only be one
        # word which has the ROOT as it's head.
        energy = torch.Tensor([[0, 9, 5], [2, 0, 4], [3, 1, 0]])

        length = torch.LongTensor([3])
        heads, _ = decode_mst(energy.numpy(), length.item(), has_labels=False)

        # This is the correct MST, but not desirable for dependency parsing.
        assert list(heads) == [-1, 0, 0]
        # If we run the decoding with the model, it should enforce
        # the constraint.
        heads_model, _ = self.model._run_mst_decoding(energy.view(1, 1, 3, 3), length)
        assert heads_model.tolist()[0] == [0, 0, 1]
    def test_mst_respects_no_outgoing_root_edges_constraint(self):
        # This energy tensor expresses the following relation:
        # energy[i,j] = "Score that i is the head of j". In this
        # case, we have heads pointing to their children.

        # We want to construct a case that has 2 children for the ROOT node,
        # because in a typical dependency parse there should only be one
        # word which has the ROOT as it's head.
        energy = torch.Tensor([[0, 9, 5],
                               [2, 0, 4],
                               [3, 1, 0]])

        length = torch.LongTensor([3])
        heads, _ = decode_mst(energy.numpy(), length.item(), has_labels=False)

        # This is the correct MST, but not desirable for dependency parsing.
        assert list(heads) == [-1, 0, 0]
        # If we run the decoding with the model, it should enforce
        # the constraint.
        heads_model, _ = self.model._run_mst_decoding(energy.view(1, 1, 3, 3), length) # pylint: disable=protected-access
        assert heads_model.tolist()[0] == [0, 0, 1]
Esempio n. 11
0
 def test_mst_finds_maximum_spanning_tree(self):
     energy = torch.arange(1, 10).view(1, 3, 3)
     heads, _ = decode_mst(energy.numpy(), 3)
     assert list(heads) == [-1, 2, 0]
Esempio n. 12
0
 def test_mst_finds_maximum_spanning_tree(self):
     energy = torch.arange(1, 10).view(1, 3, 3)
     heads, _ = decode_mst(energy.numpy(), 3) # pylint: disable=protected-access
     assert list(heads) == [-1, 2, 0]
    def mst_decode(self,
                   arc_alpha: float = 1.0) -> Tuple[List[int], List[int]]:
        """
        do mst decode:
        S[i][j] = max_{T1, T2}{Score_span(T1) + Score_span(T2) + arc_alpha * Score_link(T1, T2), T1.r==i, T2.r==j}

        """
        if self.span_lst is None:
            self.get_spans_infos()

        root2span_idxs = defaultdict(list)
        for idx, (root, start, end) in enumerate(self.span_lst):
            root2span_idxs[root].append(idx)

        seq_len = len(self.words)

        # scores[i,j] = "Score that i is the head of j"
        # tag_ids[i, j] = "best label of arc i->j"
        scores = torch.zeros([seq_len, seq_len])
        tag_ids = torch.zeros([seq_len, seq_len], dtype=torch.long)
        for i in range(seq_len):
            for j in range(1, seq_len):  # root cannot be child
                # compute energy[i][j] and tag_ids[i][j]
                max_score = -math.inf
                max_tag_id = None
                for parent_idx in root2span_idxs[i]:
                    parent_root, parent_start, parent_end = self.span_lst[
                        parent_idx]
                    for child_idx in root2span_idxs[j]:
                        child_root, child_start, child_end = self.span_lst[
                            child_idx]
                        # for computational stability, we use log prob for comparision
                        s = (self.span_root_score_lst[parent_idx] +
                             self.span_root_score_lst[child_idx] + arc_alpha *
                             (math.log(self.parent_arc_score_lst[child_idx]
                                       [parent_root] + EPS) +
                              math.log(self.parent_start_score_lst[child_idx]
                                       [parent_start] + EPS) +
                              math.log(self.parent_end_score_lst[child_idx]
                                       [parent_end] + EPS) +
                              math.log(self.child_score_lst[parent_idx]
                                       [child_root] + EPS) +
                              math.log(self.child_start_score_lst[parent_idx]
                                       [child_start] + EPS) +
                              math.log(self.child_end_score_lst[parent_idx]
                                       [child_end] + EPS)))
                        t = self.parent_tags_idxs_lst[child_idx][parent_root]
                        if s > max_score:
                            max_score = s
                            max_tag_id = t

                if max_tag_id is None:
                    warnings.warn(
                        f"no valid arc between {i} and {j} for {self.words} "
                        f"with spans: {root2span_idxs[i] + root2span_idxs[j]}")
                    max_score = 0.0
                    max_tag_id = 0

                scores[i][j] = math.exp(max_score)
                tag_ids[i][j] = max_tag_id

        # Decode the heads. Because we modify the scores to prevent
        # adding in word -> ROOT edges, we need to find the labels ourselves.
        # print(scores)
        instance_heads, _ = decode_mst(scores.numpy(),
                                       seq_len,
                                       has_labels=False)

        # Find the labels which correspond to the edges in the max spanning tree.
        instance_head_tags = []
        for child, parent in enumerate(instance_heads):
            instance_head_tags.append(tag_ids[parent, child].item())
        # We don't care what the head or tag is for the root token, but by default it's
        # not necessarily the same in the batched vs unbatched case, which is annoying.
        # Here we'll just set them to zero.
        instance_heads[0] = 0
        instance_head_tags[0] = 0
        return instance_heads.tolist()[1:], instance_head_tags[1:]
def get_parse_from_attention_matrix(sentence, attention_matrix):
    governors, _ = decode_mst(attention_matrix, len(sentence), False)
    return Sentence([
        Word(w, i + 1, g + 1)
        for i, (w, g) in enumerate(zip(sentence, governors))
    ])
Esempio n. 15
0
               0.10544933293378421, 0.06696645233053426
           ],
           [
               0.33278193213153157, 0.24182637172052882,
               0.24348700794789607, 0.0, 0.3074931499229424,
               0.34207021338827026, 0.38077567455395994
           ],
           [
               0.12199173439186563, 0.0856421256888373, 0.0856836501258372,
               0.21524582195485076, 0.0, 0.2090616298352853,
               0.1473791651283505
           ],
           [
               0.14103502167319285, 0.1383250368740608, 0.13255458167128512,
               0.16840110553265578, 0.25125240895112677, 0.0,
               0.16677963148603495
           ],
           [
               0.08937150749085589, 0.0601478789935935, 0.07209526206160943,
               0.12177543294853871, 0.12950957963820142,
               0.14592798150362726, 0.0
           ]]
 scores = np.random.random((7, 7))
 print(_mst(np.array(scores)))
 decoder = Eisner()
 print(decoder.parse_proj(np.array(scores)))
 print(decode_mst(np.array(scores), length=len(scores), has_labels=False))
 # scores = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
 # decoder = Eisner()
 # scores = np.array(scores)
 # best_arcs, root_pred = decoder.parse_proj_no_root(scores)
Esempio n. 16
0
 def test_mst_finds_maximum_spanning_tree(self):
     energy = torch.arange(1, 10).view(1, 3, 3)
     heads, _ = decode_mst(energy.numpy(), 3)  # pylint: disable=protected-access
     assert list(heads) == [-1, 2, 0]