Esempio n. 1
0
    def get_candidate_actions(self, vocab):
        candidate_actions = []
        if self.is_end():
            ac = Action(CODE.NO_ACTION, -1)
            candidate_actions.append(ac)

        if self.allow_shift():
            ac = Action(CODE.SHIFT, -1)
            candidate_actions.append(ac)

        if self.allow_arc_left():
            for idx in range(0, vocab.rel_size):
                ac = Action(CODE.ARC_LEFT, idx)
                if ac in vocab._id2ac:
                    candidate_actions.append(ac)

        if self.allow_arc_right():
            for idx in range(0, vocab.rel_size):
                ac = Action(CODE.ARC_RIGHT, idx)
                if ac in vocab._id2ac:
                    candidate_actions.append(ac)

        if self.allow_pop_root():
            ac = Action(CODE.POP_ROOT, vocab.ROOT)
            candidate_actions.append(ac)
        return candidate_actions
Esempio n. 2
0
 def create_action_table(self, all_actions):
     self._id2ac.append(Action(CODE.NO_ACTION, -1))
     ac_counter = Counter()
     for actions in all_actions:
         for ac in actions:
             ac_counter[ac] += 1
     for ac, count in ac_counter.most_common():
         self._id2ac.append(ac)
     reverse = lambda x: dict(zip(x, range(len(x))))
     self._ac2id = reverse(self._id2ac)
     if len(self._ac2id) != len(self._id2ac):
         print("serious bug: actions dumplicated, please check!")
     print("action num: ", len(self._ac2id))
     print("action: ", end=' ')
     self.mask_shift = np.array([False] * self.ac_size)
     self.mask_arc_left = np.array([False] * self.ac_size)
     self.mask_arc_right = np.array([False] * self.ac_size)
     self.mask_arc_label = np.array([False] * self.ac_size)
     self.mask_pop_root = np.array([False] * self.ac_size)
     self.mask_no_action = np.array([False] * self.ac_size)
     for (idx, ac) in enumerate(self._id2ac):
         if ac.is_shift():
             self.mask_shift[idx] = True
         if ac.is_arc_left():
             self.mask_arc_left[idx] = True
         if ac.is_arc_right():
             self.mask_arc_right[idx] = True
         if ac.is_arc_label():
             self.mask_arc_label[idx] = True
         if ac.is_finish():
             self.mask_pop_root[idx] = True
         if ac.is_none():
             self.mask_no_action[idx] = True
         print(ac.str(self), end=', ')
     print()
Esempio n. 3
0
 def get_gold_action(self, vocab):
     gold_action = Action(CODE.NO_ACTION, -1)
     if self._stack_size == 0:
         gold_action.set(CODE.SHIFT, -1)
     elif self._stack_size == 1:
         if self._next_index == self._word_size:
             gold_action.set(CODE.POP_ROOT, vocab.ROOT)
         else:
             gold_action.set(CODE.SHIFT, -1)
     elif self._stack_size > 1:  # arc
         top0 = self._stack[self._stack_size - 1]
         top1 = self._stack[self._stack_size - 2]
         assert top0 < self._word_size and top1 < self._word_size
         if top0 == self._inst.heads[top1]:  # top1 <- top0
             gold_action.set(CODE.ARC_LEFT,
                             vocab._rel2id[self._inst.rels[top1]])
         elif top1 == self._inst.heads[top0]:  # top1 -> top0,
             # if top0 have right child, shift.
             have_right_child = False
             for idx in range(self._next_index, self._word_size):
                 if self._inst.heads[idx] == top0:
                     have_right_child = True
                     break
             if have_right_child:
                 gold_action.set(CODE.SHIFT, -1)
             else:
                 gold_action.set(CODE.ARC_RIGHT,
                                 vocab._rel2id[self._inst.rels[top0]])
         else:  # can not arc
             gold_action.set(CODE.SHIFT, -1)
     return gold_action
Esempio n. 4
0
    def clear(self):
        self._next_index = 0
        self._stack_size = 0
        self._word_size = 0
        self._is_gold = True
        self._is_start = True
        self._pre_action = Action(CODE.NO_ACTION, -1)

        self.done_mark()
Esempio n. 5
0
 def __init__(self):
     self._stack = [-3] * max_length
     self._stack_size = 0
     self._rel = [-3] * max_length
     self._head = [-3] * max_length
     self._have_parent = [-1] * max_length
     self._next_index = 0
     self._word_size = 0
     self._is_start = True
     self._is_gold = True
     self._inst = None
     self._atom_feat = AtomFeat()
     self._pre_action = Action(CODE.NO_ACTION, -1)
Esempio n. 6
0
 def __init__(self):
     self._stack = []
     for idx in range(max_length):
         self._stack.append(Node())
     self._stack_size = 0
     self._edu_size = 0
     self._next_index = 0
     self._word_size = 0
     self._is_start = True
     self._is_gold = True
     self._inst = None
     self._pre_state = None
     self._atom_feat = AtomFeat()
     self._pre_action = Action(CODE.NO_ACTION)
Esempio n. 7
0
 def create_action_table(self, all_actions):
     self._id2ac.append(Action(CODE.NO_ACTION, -1))
     ac_counter = Counter()
     for actions in all_actions:
         for ac in actions:
             ac_counter[ac] += 1
     for ac, count in ac_counter.most_common():
         self._id2ac.append(ac)
     reverse = lambda x: dict(zip(x, range(len(x))))
     self._ac2id = reverse(self._id2ac)
     if len(self._ac2id) != len(self._id2ac):
         print("serious bug: actions dumplicated, please check!")
     print("action num: ", len(self._ac2id))
     print("action: ", end=' ')
     for ac in self._id2ac:
         print(ac.str(self), end=', ')
     print()
Esempio n. 8
0
    def parseTree(self, tree_str):
        buffer = tree_str.strip().split(" ")
        buffer_size = len(buffer)
        step = 0
        subtree_stack = []  # edu index
        op_stack = []
        relation_stack = []
        action_stack = []
        while True:
            assert step <= buffer_size
            if step == buffer_size:
                break
            if buffer[step] == "(":
                op_stack.append(buffer[step])
                relation_stack.append(buffer[step + 1])
                action_stack.append(buffer[step + 2])
                if buffer[step + 1] == 'leaf' and buffer[step + 2] == 't':
                    start = int(buffer[step + 3])
                    end = int(buffer[step + 4])
                    step += 2
                step += 3
            elif buffer[step] == ")":
                action = action_stack[-1]
                if action == 't':
                    for sent_type in self.sent_types:
                        assert len(sent_type) == 3
                        if start >= sent_type[0] and end <= sent_type[1]:
                            e = EDU(start, end, sent_type[2])
                            edu_start = len(self.EDUs)
                            edu_end = len(self.EDUs)
                            subtree_stack.append([edu_start, edu_end])
                            self.EDUs.append(e)
                            assert relation_stack[-1] == "leaf"
                            ac = Action(CODE.SHIFT, -1, -1, relation_stack[-1])
                            self.gold_actions.append(ac)
                            break
                elif action == 'l' or action == 'r' or action == 'c':
                    if action == 'l':
                        nuclear = NUCLEAR.NS
                    if action == 'r':
                        nuclear = NUCLEAR.SN
                    if action == 'c':
                        nuclear = NUCLEAR.NN
                    code = CODE.REDUCE
                    ac = Action(code, nuclear, -1, relation_stack[-1])
                    self.gold_actions.append(ac)

                    assert len(subtree_stack) >= 2
                    l_index = subtree_stack[-2]
                    r_index = subtree_stack[-1]
                    assert l_index[1] + 1 == r_index[0]
                    left_subtree = SubTree(nullkey, nullkey, l_index[0],
                                           l_index[1])
                    right_subtree = SubTree(nullkey, nullkey, r_index[0],
                                            r_index[1])

                    if action == "l":  #NS
                        left_subtree.nuclear = nuclear_str
                        left_subtree.relation = span_str
                        right_subtree.nuclear = satellite_str
                        right_subtree.relation = ac.label_str
                    if action == "r":  #SN
                        left_subtree.nuclear = satellite_str
                        left_subtree.relation = ac.label_str
                        right_subtree.nuclear = nuclear_str
                        right_subtree.relation = span_str
                    if action == "c":  #NN
                        left_subtree.nuclear = nuclear_str
                        left_subtree.relation = ac.label_str
                        right_subtree.nuclear = nuclear_str
                        right_subtree.relation = ac.label_str
                    self.result.subtrees.append(left_subtree)
                    self.result.subtrees.append(right_subtree)
                    l_index[1] = r_index[1]
                    subtree_stack.pop()

                relation_stack.pop()
                op_stack.pop()
                action_stack.pop()

                step += 1
        ac = Action(CODE.POP_ROOT)
        self.gold_actions.append(ac)
        assert len(subtree_stack) == 1
        root = subtree_stack[0]
        assert root[0] == 0 and root[1] == len(self.EDUs) - 1
        subtree_stack.pop()  # pop root

        #### check stack, all stack empty
        assert op_stack == [] and relation_stack == [] and action_stack == [] and subtree_stack == []
        #### check edu index
        for idx in range(len(self.EDUs)):
            edu = self.EDUs[idx]
            assert edu.start >= 0 and edu.end < len(self.total_words)
            assert edu.start <= edu.end
            if idx < len(self.EDUs) - 1:
                assert edu.end + 1 == self.EDUs[idx + 1].start
        #### initialize edu word and tag
        sum = 0
        for edu in self.EDUs:
            for idx in range(edu.start, edu.end + 1):
                if self.total_tags[idx] != nullkey:
                    edu.words.append(self.total_words[idx])
                    edu.tags.append(self.total_tags[idx])
            sum += len(edu.words)
        assert sum == len(self.words)
        #### check subtree
        for subtree in self.result.subtrees:
            assert subtree.relation != nullkey and subtree.nuclear != nullkey