def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu(), lengths):
            scores, tag_ids = energy.max(dim=0)
            # Although we need to include the root node so that the MST includes it,
            # we do not want any word to be the parent of the root node.
            # Here, we enforce this by setting the scores for all word -> ROOT edges
            # edges to be 0.
            scores[0, :] = 0
            # Decode the heads. Because we modify the scores to prevent
            # adding in word -> ROOT edges, we need to find the labels ourselves.
            instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False)

            # Find the labels which correspond to the edges in the max spanning tree.
            instance_head_tags = []
            for child, parent in enumerate(instance_heads):
                instance_head_tags.append(tag_ids[parent, child].item())
            # We don't care what the head or tag is for the root token, but by default it's
            # not necesarily the same in the batched vs unbatched case, which is annoying.
            # Here we'll just set them to zero.
            instance_heads[0] = 0
            instance_head_tags[0] = 0
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags))
Esempio n. 2
def run_mst_decoding(
        batch_energy: torch.Tensor,
        lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    heads = []
    head_tags = []
    for energy, length in zip(batch_energy.detach().cpu(), lengths):

        scores, tag_ids = energy.max(dim=0)
        scores[0, :] = 0
        instance_heads, _ = decode_mst(scores.numpy(),

        # Find the labels which correspond to the edges in the max spanning tree.
        instance_head_tags = []
        for child, parent in enumerate(instance_heads):

            instance_head_tags.append(tag_ids[child, parent].item())
            # OLD: instance_head_tags.append(tag_ids[parent, child].item())
        # We don't care what the head or tag is for the root token, but by default it's
        # not necesarily the same in the batched vs unbatched case, which is annoying.
        # Here we'll just set them to zero.
        instance_heads[0] = 0
        instance_head_tags[0] = 0

    return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(
Esempio n. 3
    def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu(), lengths):
            scores, tag_ids = energy.max(dim=0)
            # Although we need to include the root node so that the MST includes it,
            # we do not want any word to be the parent of the root node.
            # Here, we enforce this by setting the scores for all word -> ROOT edges
            # edges to be 0.
            scores[0, :] = 0
            # Decode the heads. Because we modify the scores to prevent
            # adding in word -> ROOT edges, we need to find the labels ourselves.
            instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False)

            # Find the labels which correspond to the edges in the max spanning tree.
            instance_head_tags = []
            for child, parent in enumerate(instance_heads):
                instance_head_tags.append(tag_ids[parent, child].item())
            # We don't care what the head or tag is for the root token, but by default it's
            # not necesarily the same in the batched vs unbatched case, which is annoying.
            # Here we'll just set them to zero.
            instance_heads[0] = 0
            instance_head_tags[0] = 0
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags))
Esempio n. 4
def run_cle(matrix):

    matrix = matrix.cpu().numpy()
    n = matrix.shape[0]
    scores = cle.decode_mst(matrix, length=n, has_labels=False)

    return scores[0]
Esempio n. 5
    def _run_mst_decoding(batch_energy, lengths):
        edge_heads = []
        edge_labels = []

        for i, (energy,
                length) in enumerate(zip(batch_energy.detach().cpu(),
            # decode heads and labels
            # need to decode labels separately so that we can enforce single root
            scores, label_ids = energy.max(dim=0)
            energy = scores

            instance_heads, instance_head_labels = decode_mst(energy.numpy(),
            #instance_heads, instance_head_labels = decode_mst(scores.numpy(), length, has_labels=True)

            ## Find the labels which correspond to the edges in the max spanning tree.
            instance_head_labels = []
            for child, parent in enumerate(instance_heads):
                instance_head_labels.append(label_ids[parent, child].item())

            # check for multiroot
            multi_root = sum(
                [1 if h == 0 else 0 for h in instance_heads[0:length]]) > 1

            if multi_root:
                energy = energy.unsqueeze(0)
                energy = DeepTreeParser._enforce_root(energy)
                energy = energy.squeeze(0)
                instance_heads, instance_head_labels = decode_mst(
                    energy.numpy(), length, has_labels=False)
                #instance_heads, instance_head_labels = decode_mst(scores.numpy(), length, has_labels=True)

                ## Find the labels which correspond to the edges in the max spanning tree.
                instance_head_labels = []
                for child, parent in enumerate(instance_heads):


        return torch.from_numpy(np.stack(edge_heads)), torch.from_numpy(
Esempio n. 6
    def test_mst(self):
        # First, test some random cases as sanity checks.
        # No label case
        energy = numpy.random.rand(5, 5)
        heads, types = decode_mst(energy, 5, has_labels=False)
        assert not _find_cycle(heads, 5, [True] * 5)[0]

        # Labeled case
        energy = numpy.random.rand(3, 5, 5)
        heads, types = decode_mst(energy, 5)

        assert not _find_cycle(heads, 5, [True] * 5)[0]
        label_id_matrix = energy.argmax(axis=0)

        # Check that the labels correspond to the
        # argmax of the labels for the arcs.
        for child, parent in enumerate(heads):
            # The first index corresponds to the symbolic
            # head token, which won't necessarily have an
            # argmax type.
            if child == 0:
            assert types[child] == label_id_matrix[parent, child]

        # Check wrong dimensions throw errors
        with pytest.raises(ConfigurationError):
            energy = numpy.random.rand(5, 5)
            decode_mst(energy, 5, has_labels=True)

        with pytest.raises(ConfigurationError):
            energy = numpy.random.rand(3, 5, 5)
            decode_mst(energy, 5, has_labels=False)
Esempio n. 7
    def test_mst(self):
        # First, test some random cases as sanity checks.
        # No label case
        energy = numpy.random.rand(5, 5)
        heads, types = decode_mst(energy, 5, has_labels=False)
        assert not _find_cycle(heads, 5, [True] * 5)[0]

        # Labeled case
        energy = numpy.random.rand(3, 5, 5)
        heads, types = decode_mst(energy, 5)

        assert not _find_cycle(heads, 5, [True] * 5)[0]
        label_id_matrix = energy.argmax(axis=0)

        # Check that the labels correspond to the
        # argmax of the labels for the arcs.
        for child, parent in enumerate(heads):
            # The first index corresponds to the symbolic
            # head token, which won't necessarily have an
            # argmax type.
            if child == 0:
            assert types[child] == label_id_matrix[parent, child]

        # Check wrong dimensions throw errors
        with pytest.raises(ConfigurationError):
            energy = numpy.random.rand(5, 5)
            decode_mst(energy, 5, has_labels=True)

        with pytest.raises(ConfigurationError):
            energy = numpy.random.rand(3, 5, 5)
            decode_mst(energy, 5, has_labels=False)
Esempio n. 8
    def forward(
        x: Union[torch.Tensor, List[torch.Tensor]],
        mask: Optional[torch.BoolTensor] = None,
        labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
        sample_weights: Optional[Union[torch.Tensor,
                                       List[torch.Tensor]]] = None
    ) -> Dict[str, torch.Tensor]:
        if mask is None:
            mask = x.new_ones(x.size()[-1])

        head_arc_emb = self.head_projection_layer(x)
        dep_arc_emb = self.dependency_projection_layer(x)
        x = dep_arc_emb.bmm(head_arc_emb.transpose(2, 1))

            pred = x.argmax(-1)
            pred = []
            # Adding non existing in mask ROOT to lengths
            lengths = + 1
            for idx, length in enumerate(lengths):
                probs = x[idx, :].softmax(dim=-1).cpu().numpy()

                # We do not want any word to be parent of the root node (ROOT, 0).
                # Also setting it to -1 instead of 0 fixes edge case where softmax made all
                # but ROOT prediction to EXACTLY 0.0 and it might cause in many ROOT -> word edges)
                probs[:, 0] = -1
                heads, _ = chu_liu_edmonds.decode_mst(probs.T,
                heads[0] = 0
            pred = torch.from_numpy(np.stack(pred)).to(x.device)

        output = {"prediction": pred[:, 1:], "probability": x}

        if labels is not None:
            if sample_weights is None:
                sample_weights = labels.new_ones([mask.size(0)])
            output["loss"], output["cycle_loss"] = self._loss(
                x, labels, mask, sample_weights)

        return output
    def test_mst_respects_no_outgoing_root_edges_constraint(self):
        # This energy tensor expresses the following relation:
        # energy[i,j] = "Score that i is the head of j". In this
        # case, we have heads pointing to their children.

        # We want to construct a case that has 2 children for the ROOT node,
        # because in a typical dependency parse there should only be one
        # word which has the ROOT as it's head.
        energy = torch.Tensor([[0, 9, 5], [2, 0, 4], [3, 1, 0]])

        length = torch.LongTensor([3])
        heads, _ = decode_mst(energy.numpy(), length.item(), has_labels=False)

        # This is the correct MST, but not desirable for dependency parsing.
        assert list(heads) == [-1, 0, 0]
        # If we run the decoding with the model, it should enforce
        # the constraint.
        heads_model, _ = self.model._run_mst_decoding(energy.view(1, 1, 3, 3), length)
        assert heads_model.tolist()[0] == [0, 0, 1]
    def test_mst_respects_no_outgoing_root_edges_constraint(self):
        # This energy tensor expresses the following relation:
        # energy[i,j] = "Score that i is the head of j". In this
        # case, we have heads pointing to their children.

        # We want to construct a case that has 2 children for the ROOT node,
        # because in a typical dependency parse there should only be one
        # word which has the ROOT as it's head.
        energy = torch.Tensor([[0, 9, 5],
                               [2, 0, 4],
                               [3, 1, 0]])

        length = torch.LongTensor([3])
        heads, _ = decode_mst(energy.numpy(), length.item(), has_labels=False)

        # This is the correct MST, but not desirable for dependency parsing.
        assert list(heads) == [-1, 0, 0]
        # If we run the decoding with the model, it should enforce
        # the constraint.
        heads_model, _ = self.model._run_mst_decoding(energy.view(1, 1, 3, 3), length) # pylint: disable=protected-access
        assert heads_model.tolist()[0] == [0, 0, 1]
Esempio n. 11
 def test_mst_finds_maximum_spanning_tree(self):
     energy = torch.arange(1, 10).view(1, 3, 3)
     heads, _ = decode_mst(energy.numpy(), 3)
     assert list(heads) == [-1, 2, 0]
Esempio n. 12
 def test_mst_finds_maximum_spanning_tree(self):
     energy = torch.arange(1, 10).view(1, 3, 3)
     heads, _ = decode_mst(energy.numpy(), 3) # pylint: disable=protected-access
     assert list(heads) == [-1, 2, 0]
    def mst_decode(self,
                   arc_alpha: float = 1.0) -> Tuple[List[int], List[int]]:
        do mst decode:
        S[i][j] = max_{T1, T2}{Score_span(T1) + Score_span(T2) + arc_alpha * Score_link(T1, T2), T1.r==i, T2.r==j}

        if self.span_lst is None:

        root2span_idxs = defaultdict(list)
        for idx, (root, start, end) in enumerate(self.span_lst):

        seq_len = len(self.words)

        # scores[i,j] = "Score that i is the head of j"
        # tag_ids[i, j] = "best label of arc i->j"
        scores = torch.zeros([seq_len, seq_len])
        tag_ids = torch.zeros([seq_len, seq_len], dtype=torch.long)
        for i in range(seq_len):
            for j in range(1, seq_len):  # root cannot be child
                # compute energy[i][j] and tag_ids[i][j]
                max_score = -math.inf
                max_tag_id = None
                for parent_idx in root2span_idxs[i]:
                    parent_root, parent_start, parent_end = self.span_lst[
                    for child_idx in root2span_idxs[j]:
                        child_root, child_start, child_end = self.span_lst[
                        # for computational stability, we use log prob for comparision
                        s = (self.span_root_score_lst[parent_idx] +
                             self.span_root_score_lst[child_idx] + arc_alpha *
                                       [parent_root] + EPS) +
                                       [parent_start] + EPS) +
                                       [parent_end] + EPS) +
                                       [child_root] + EPS) +
                                       [child_start] + EPS) +
                                       [child_end] + EPS)))
                        t = self.parent_tags_idxs_lst[child_idx][parent_root]
                        if s > max_score:
                            max_score = s
                            max_tag_id = t

                if max_tag_id is None:
                        f"no valid arc between {i} and {j} for {self.words} "
                        f"with spans: {root2span_idxs[i] + root2span_idxs[j]}")
                    max_score = 0.0
                    max_tag_id = 0

                scores[i][j] = math.exp(max_score)
                tag_ids[i][j] = max_tag_id

        # Decode the heads. Because we modify the scores to prevent
        # adding in word -> ROOT edges, we need to find the labels ourselves.
        # print(scores)
        instance_heads, _ = decode_mst(scores.numpy(),

        # Find the labels which correspond to the edges in the max spanning tree.
        instance_head_tags = []
        for child, parent in enumerate(instance_heads):
            instance_head_tags.append(tag_ids[parent, child].item())
        # We don't care what the head or tag is for the root token, but by default it's
        # not necessarily the same in the batched vs unbatched case, which is annoying.
        # Here we'll just set them to zero.
        instance_heads[0] = 0
        instance_head_tags[0] = 0
        return instance_heads.tolist()[1:], instance_head_tags[1:]
def get_parse_from_attention_matrix(sentence, attention_matrix):
    governors, _ = decode_mst(attention_matrix, len(sentence), False)
    return Sentence([
        Word(w, i + 1, g + 1)
        for i, (w, g) in enumerate(zip(sentence, governors))
Esempio n. 15
               0.10544933293378421, 0.06696645233053426
               0.33278193213153157, 0.24182637172052882,
               0.24348700794789607, 0.0, 0.3074931499229424,
               0.34207021338827026, 0.38077567455395994
               0.12199173439186563, 0.0856421256888373, 0.0856836501258372,
               0.21524582195485076, 0.0, 0.2090616298352853,
               0.14103502167319285, 0.1383250368740608, 0.13255458167128512,
               0.16840110553265578, 0.25125240895112677, 0.0,
               0.08937150749085589, 0.0601478789935935, 0.07209526206160943,
               0.12177543294853871, 0.12950957963820142,
               0.14592798150362726, 0.0
 scores = np.random.random((7, 7))
 decoder = Eisner()
 print(decode_mst(np.array(scores), length=len(scores), has_labels=False))
 # scores = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
 # decoder = Eisner()
 # scores = np.array(scores)
 # best_arcs, root_pred = decoder.parse_proj_no_root(scores)
Esempio n. 16
 def test_mst_finds_maximum_spanning_tree(self):
     energy = torch.arange(1, 10).view(1, 3, 3)
     heads, _ = decode_mst(energy.numpy(), 3)  # pylint: disable=protected-access
     assert list(heads) == [-1, 2, 0]