Exemple #1
0
    def get_candidate_actions(self, vocab):
        candidate_actions = []
        if self.is_end():
            ac = Action(CODE.NO_ACTION, -1)
            candidate_actions.append(ac)

        if self.allow_shift():
            ac = Action(CODE.SHIFT, -1)
            candidate_actions.append(ac)

        if self.allow_arc_left():
            for idx in range(0, vocab.rel_size):
                ac = Action(CODE.ARC_LEFT, idx)
                if ac in vocab._id2ac:
                    candidate_actions.append(ac)

        if self.allow_arc_right():
            for idx in range(0, vocab.rel_size):
                ac = Action(CODE.ARC_RIGHT, idx)
                if ac in vocab._id2ac:
                    candidate_actions.append(ac)

        if self.allow_pop_root():
            ac = Action(CODE.POP_ROOT, vocab.ROOT)
            candidate_actions.append(ac)
        return candidate_actions
Exemple #2
0
    def clear(self):
        self._next_index = 0
        self._stack_size = 0
        self._word_size = 0
        self._is_gold = True
        self._is_start = True
        self._pre_action = Action(CODE.NO_ACTION, -1)

        self.done_mark()
Exemple #3
0
 def __init__(self):
     self._stack = [-3] * max_length
     self._stack_size = 0
     self._rel = [-3] * max_length
     self._head = [-3] * max_length
     self._have_parent = [-1] * max_length
     self._next_index = 0
     self._word_size = 0
     self._is_start = True
     self._is_gold = True
     self._inst = None
     self._atom_feat = AtomFeat()
     self._pre_action = Action(CODE.NO_ACTION, -1)
Exemple #4
0
 def __init__(self):
     self._stack = []
     for idx in range(max_length):
         self._stack.append(Node())
     self._stack_size = 0
     self._edu_size = 0
     self._next_index = 0
     self._word_size = 0
     self._is_start = True
     self._is_gold = True
     self._inst = None
     self._pre_state = None
     self._atom_feat = AtomFeat()
     self._pre_action = Action(CODE.NO_ACTION)
Exemple #5
0
 def create_action_table(self, all_actions):
     self._id2ac.append(Action(CODE.NO_ACTION, -1))
     ac_counter = Counter()
     for actions in all_actions:
         for ac in actions:
             ac_counter[ac] += 1
     for ac, count in ac_counter.most_common():
         self._id2ac.append(ac)
     reverse = lambda x: dict(zip(x, range(len(x))))
     self._ac2id = reverse(self._id2ac)
     if len(self._ac2id) != len(self._id2ac):
         print("serious bug: actions dumplicated, please check!")
     print("action num: ", len(self._ac2id))
     print("action: ", end=' ')
     self.mask_shift = np.array([False] * self.ac_size)
     self.mask_arc_left = np.array([False] * self.ac_size)
     self.mask_arc_right = np.array([False] * self.ac_size)
     self.mask_arc_label = np.array([False] * self.ac_size)
     self.mask_pop_root = np.array([False] * self.ac_size)
     self.mask_no_action = np.array([False] * self.ac_size)
     for (idx, ac) in enumerate(self._id2ac):
         if ac.is_shift():
             self.mask_shift[idx] = True
         if ac.is_arc_left():
             self.mask_arc_left[idx] = True
         if ac.is_arc_right():
             self.mask_arc_right[idx] = True
         if ac.is_arc_label():
             self.mask_arc_label[idx] = True
         if ac.is_finish():
             self.mask_pop_root[idx] = True
         if ac.is_none():
             self.mask_no_action[idx] = True
         print(ac.str(self), end=', ')
     print()
Exemple #6
0
 def create_action_table(self, all_actions):
     self._id2ac.append(Action(CODE.NO_ACTION, -1))
     ac_counter = Counter()
     for actions in all_actions:
         for ac in actions:
             ac_counter[ac] += 1
     for ac, count in ac_counter.most_common():
         self._id2ac.append(ac)
     reverse = lambda x: dict(zip(x, range(len(x))))
     self._ac2id = reverse(self._id2ac)
     if len(self._ac2id) != len(self._id2ac):
         print("serious bug: actions dumplicated, please check!")
     print("action num: ", len(self._ac2id))
     print("action: ", end=' ')
     for ac in self._id2ac:
         print(ac.str(self), end=', ')
     print()
Exemple #7
0
 def get_gold_action(self, vocab):
     gold_action = Action(CODE.NO_ACTION, -1)
     if self._stack_size == 0:
         gold_action.set(CODE.SHIFT, -1)
     elif self._stack_size == 1:
         if self._next_index == self._word_size:
             gold_action.set(CODE.POP_ROOT, vocab.ROOT)
         else:
             gold_action.set(CODE.SHIFT, -1)
     elif self._stack_size > 1:  # arc
         top0 = self._stack[self._stack_size - 1]
         top1 = self._stack[self._stack_size - 2]
         assert top0 < self._word_size and top1 < self._word_size
         if top0 == self._inst.heads[top1]:  # top1 <- top0
             gold_action.set(CODE.ARC_LEFT,
                             vocab._rel2id[self._inst.rels[top1]])
         elif top1 == self._inst.heads[top0]:  # top1 -> top0,
             # if top0 have right child, shift.
             have_right_child = False
             for idx in range(self._next_index, self._word_size):
                 if self._inst.heads[idx] == top0:
                     have_right_child = True
                     break
             if have_right_child:
                 gold_action.set(CODE.SHIFT, -1)
             else:
                 gold_action.set(CODE.ARC_RIGHT,
                                 vocab._rel2id[self._inst.rels[top0]])
         else:  # can not arc
             gold_action.set(CODE.SHIFT, -1)
     return gold_action
Exemple #8
0
class State:
    def __init__(self):
        self._stack = [-3] * max_length
        self._stack_size = 0
        self._rel = [-3] * max_length
        self._head = [-3] * max_length
        self._have_parent = [-1] * max_length
        self._next_index = 0
        self._word_size = 0
        self._is_start = True
        self._is_gold = True
        self._inst = None
        self._atom_feat = AtomFeat()
        self._pre_state = None
        self._pre_action = Action(CODE.NO_ACTION, -1)

    def ready(self, sentence, vocab):
        self._inst = Instance(sentence, vocab)
        self._word_size = len(self._inst.words)

    def clear(self):
        self._next_index = 0
        self._stack_size = 0
        self._word_size = 0
        self._is_gold = True
        self._is_start = True
        self._pre_action = Action(CODE.NO_ACTION, -1)
        self._pre_state = None

        self.done_mark()

    def done_mark(self):
        self._stack[self._stack_size] = -2
        self._head[self._next_index] = -2
        self._rel[self._next_index] = -2
        self._have_parent[self._next_index] = -2

    def allow_shift(self):
        if self._next_index < self._word_size:
            return True
        else:
            return False

    def allow_arc_left(self):
        if self._stack_size > 1:
            return True
        else:
            return False

    def allow_arc_right(self):
        if self._stack_size > 1:
            return True
        else:
            return False

    def allow_pop_root(self):
        if self._stack_size == 1 and self._next_index == self._word_size:
            return True
        else:
            return False

    def shift(self, next_state):
        assert self._next_index < self._word_size
        next_state._next_index = self._next_index + 1
        next_state._stack_size = self._stack_size + 1
        self.copy_state(next_state)
        next_state._stack[next_state._stack_size - 1] = self._next_index
        next_state._have_parent[self._next_index] = 0
        next_state.done_mark()
        next_state._pre_action.set(CODE.SHIFT, -1)

    def arc_left(self, next_state, dep):
        assert self._stack_size > 1
        next_state._next_index = self._next_index
        next_state._stack_size = self._stack_size - 1
        self.copy_state(next_state)
        top0 = self._stack[self._stack_size - 1]
        top1 = self._stack[self._stack_size - 2]
        next_state._stack[next_state._stack_size - 1] = top0
        next_state._head[top1] = top0
        next_state._have_parent[top1] = 1
        next_state._rel[top1] = dep
        next_state.done_mark()
        next_state._pre_action.set(CODE.ARC_LEFT, dep)

    def arc_right(self, next_state, dep):
        assert self._stack_size > 1
        next_state._next_index = self._next_index
        next_state._stack_size = self._stack_size - 1
        self.copy_state(next_state)
        top0 = self._stack[self._stack_size - 1]
        top1 = self._stack[self._stack_size - 2]
        #next_state._stack[next_state._stack_size - 1] = top0
        next_state._head[top0] = top1
        next_state._have_parent[top0] = 1
        next_state.done_mark()
        next_state._rel[top0] = dep
        next_state._pre_action.set(CODE.ARC_RIGHT, dep)

    def pop_root(self, next_state, dep):
        assert self._stack_size == 1 and self._next_index == self._word_size
        next_state._next_index = self._word_size
        next_state._stack_size = 0
        self.copy_state(next_state)
        top0 = self._stack[self._stack_size - 1]
        next_state._head[top0] = -1
        next_state._have_parent[top0] = 1
        next_state._rel[top0] = dep
        next_state.done_mark()
        next_state._pre_action.set(CODE.POP_ROOT, dep)

    def move(self, next_state, action):
        next_state._is_start = False
        next_state._is_gold = False
        if action.is_shift():
            self.shift(next_state)
        elif action.is_arc_left():
            self.arc_left(next_state, action.label)
        elif action.is_arc_right():
            self.arc_right(next_state, action.label)
        elif action.is_finish():
            self.pop_root(next_state, action.label)
        else:
            print(" error state ")

    def get_candidate_actions(self, vocab):
        candidate_actions = []
        if self.is_end():
            ac = Action(CODE.NO_ACTION, -1)
            candidate_actions.append(ac)

        if self.allow_shift():
            ac = Action(CODE.SHIFT, -1)
            candidate_actions.append(ac)

        if self.allow_arc_left():
            for idx in range(0, vocab.rel_size):
                ac = Action(CODE.ARC_LEFT, idx)
                if ac in vocab._id2ac:
                    candidate_actions.append(ac)

        if self.allow_arc_right():
            for idx in range(0, vocab.rel_size):
                ac = Action(CODE.ARC_RIGHT, idx)
                if ac in vocab._id2ac:
                    candidate_actions.append(ac)

        if self.allow_pop_root():
            ac = Action(CODE.POP_ROOT, vocab.ROOT)
            candidate_actions.append(ac)
        return candidate_actions

    def copy_state(self, next_state):
        next_state._pre_state = self
        next_state._inst = self._inst
        next_state._word_size = self._word_size
        next_state._stack[0:self._stack_size] = deepcopy(
            self._stack[0:self._stack_size])
        next_state._rel[0:self._next_index] = deepcopy(
            self._rel[0:self._next_index])
        next_state._head[0:self._next_index] = deepcopy(
            self._head[0:self._next_index])
        next_state._have_parent[0:self._next_index] = deepcopy(
            self._have_parent[0:self._next_index])

    def is_end(self):
        if self._pre_action.is_finish():
            return True
        else:
            return False

    def get_gold_action(self, vocab):
        gold_action = Action(CODE.NO_ACTION, -1)
        if self._stack_size == 0:
            gold_action.set(CODE.SHIFT, -1)
        elif self._stack_size == 1:
            if self._next_index == self._word_size:
                gold_action.set(CODE.POP_ROOT, vocab.ROOT)
            else:
                gold_action.set(CODE.SHIFT, -1)
        elif self._stack_size > 1:  # arc
            top0 = self._stack[self._stack_size - 1]
            top1 = self._stack[self._stack_size - 2]
            assert top0 < self._word_size and top1 < self._word_size
            if top0 == self._inst.heads[top1]:  # top1 <- top0
                gold_action.set(CODE.ARC_LEFT,
                                vocab._rel2id[self._inst.rels[top1]])
            elif top1 == self._inst.heads[top0]:  # top1 -> top0,
                # if top0 have right child, shift.
                have_right_child = False
                for idx in range(self._next_index, self._word_size):
                    if self._inst.heads[idx] == top0:
                        have_right_child = True
                        break
                if have_right_child:
                    gold_action.set(CODE.SHIFT, -1)
                else:
                    gold_action.set(CODE.ARC_RIGHT,
                                    vocab._rel2id[self._inst.rels[top0]])
            else:  # can not arc
                gold_action.set(CODE.SHIFT, -1)
        return gold_action

    def get_result(self, vocab):
        result = []
        result.append(
            Dependency(0, vocab._root_form, vocab._root, 0, vocab._root))
        for idx in range(0, self._word_size):
            assert self._have_parent[idx] == 1
            relation = vocab.id2rel(self._rel[idx])
            head = self._head[idx]
            word = self._inst.words[idx]
            tag = self._inst.tags[idx]
            result.append(Dependency(idx + 1, word, tag, head + 1, relation))
        return result

    def prepare_atom_feat(self, encoder_output, offset, bucket, vocab):
        if self._stack_size > 0:
            self._atom_feat.s0 = self._stack[self._stack_size - 1]
        else:
            self._atom_feat.s0 = -1

        if self._stack_size > 1:
            self._atom_feat.s1 = self._stack[self._stack_size - 2]
        else:
            self._atom_feat.s1 = -1

        if self._stack_size > 2:
            self._atom_feat.s2 = self._stack[self._stack_size - 3]
        else:
            self._atom_feat.s2 = -1

        if self._next_index >= 0 and self._next_index < self._word_size:
            self._atom_feat.q0 = self._next_index

        if self._atom_feat.s0 == -1:
            hidden_s0 = bucket
        else:
            hidden_s0 = encoder_output[offset][self._atom_feat.s0].unsqueeze(0)

        if self._atom_feat.s1 == -1:
            hidden_s1 = bucket
        else:
            hidden_s1 = encoder_output[offset][self._atom_feat.s1].unsqueeze(0)

        if self._atom_feat.s2 == -1:
            hidden_s2 = bucket
        else:
            hidden_s2 = encoder_output[offset][self._atom_feat.s2].unsqueeze(0)

        if self._atom_feat.q0 == -1:
            hidden_q0 = bucket
        else:
            hidden_q0 = encoder_output[offset][self._atom_feat.q0].unsqueeze(0)

        self.hidden_state = torch.cat(
            (hidden_s0, hidden_s1, hidden_s2, hidden_q0), 1)
    def parseTree(self, tree_str):
        buffer = tree_str.strip().split(" ")
        buffer_size = len(buffer)
        step = 0
        subtree_stack = []  # edu index
        op_stack = []
        relation_stack = []
        action_stack = []
        while True:
            assert step <= buffer_size
            if step == buffer_size:
                break
            if buffer[step] == "(":
                op_stack.append(buffer[step])
                relation_stack.append(buffer[step + 1])
                action_stack.append(buffer[step + 2])
                if buffer[step + 1] == 'leaf' and buffer[step + 2] == 't':
                    start = int(buffer[step + 3])
                    end = int(buffer[step + 4])
                    step += 2
                step += 3
            elif buffer[step] == ")":
                action = action_stack[-1]
                if action == 't':
                    for sent_type in self.sent_types:
                        assert len(sent_type) == 3
                        if start >= sent_type[0] and end <= sent_type[1]:
                            e = EDU(start, end, sent_type[2])
                            edu_start = len(self.EDUs)
                            edu_end = len(self.EDUs)
                            subtree_stack.append([edu_start, edu_end])
                            self.EDUs.append(e)
                            assert relation_stack[-1] == "leaf"
                            ac = Action(CODE.SHIFT, -1, -1, relation_stack[-1])
                            self.gold_actions.append(ac)
                            break
                elif action == 'l' or action == 'r' or action == 'c':
                    if action == 'l':
                        nuclear = NUCLEAR.NS
                    if action == 'r':
                        nuclear = NUCLEAR.SN
                    if action == 'c':
                        nuclear = NUCLEAR.NN
                    code = CODE.REDUCE
                    ac = Action(code, nuclear, -1, relation_stack[-1])
                    self.gold_actions.append(ac)

                    assert len(subtree_stack) >= 2
                    l_index = subtree_stack[-2]
                    r_index = subtree_stack[-1]
                    assert l_index[1] + 1 == r_index[0]
                    left_subtree = SubTree(nullkey, nullkey, l_index[0],
                                           l_index[1])
                    right_subtree = SubTree(nullkey, nullkey, r_index[0],
                                            r_index[1])

                    if action == "l":  #NS
                        left_subtree.nuclear = nuclear_str
                        left_subtree.relation = span_str
                        right_subtree.nuclear = satellite_str
                        right_subtree.relation = ac.label_str
                    if action == "r":  #SN
                        left_subtree.nuclear = satellite_str
                        left_subtree.relation = ac.label_str
                        right_subtree.nuclear = nuclear_str
                        right_subtree.relation = span_str
                    if action == "c":  #NN
                        left_subtree.nuclear = nuclear_str
                        left_subtree.relation = ac.label_str
                        right_subtree.nuclear = nuclear_str
                        right_subtree.relation = ac.label_str
                    self.result.subtrees.append(left_subtree)
                    self.result.subtrees.append(right_subtree)
                    l_index[1] = r_index[1]
                    subtree_stack.pop()

                relation_stack.pop()
                op_stack.pop()
                action_stack.pop()

                step += 1
        ac = Action(CODE.POP_ROOT)
        self.gold_actions.append(ac)
        assert len(subtree_stack) == 1
        root = subtree_stack[0]
        assert root[0] == 0 and root[1] == len(self.EDUs) - 1
        subtree_stack.pop()  # pop root

        #### check stack, all stack empty
        assert op_stack == [] and relation_stack == [] and action_stack == [] and subtree_stack == []
        #### check edu index
        for idx in range(len(self.EDUs)):
            edu = self.EDUs[idx]
            assert edu.start >= 0 and edu.end < len(self.total_words)
            assert edu.start <= edu.end
            if idx < len(self.EDUs) - 1:
                assert edu.end + 1 == self.EDUs[idx + 1].start
        #### initialize edu word and tag
        sum = 0
        for edu in self.EDUs:
            for idx in range(edu.start, edu.end + 1):
                if self.total_tags[idx] != nullkey:
                    edu.words.append(self.total_words[idx])
                    edu.tags.append(self.total_tags[idx])
            sum += len(edu.words)
        assert sum == len(self.words)
        #### check subtree
        for subtree in self.result.subtrees:
            assert subtree.relation != nullkey and subtree.nuclear != nullkey
Exemple #10
0
class State:
    def __init__(self):
        self._stack = [-3] * max_length
        self._stack_size = 0
        self._rel = [-3] * max_length
        self._head = [-3] * max_length
        self._have_parent = [-1] * max_length
        self._next_index = 0
        self._word_size = 0
        self._is_start = True
        self._is_gold = True
        self._inst = None
        self._atom_feat = AtomFeat()
        self._pre_action = Action(CODE.NO_ACTION, -1)

    def ready(self, sentence, vocab):
        self._inst = Instance(sentence, vocab)
        self._word_size = len(self._inst.words)

    def clear(self):
        self._next_index = 0
        self._stack_size = 0
        self._word_size = 0
        self._is_gold = True
        self._is_start = True
        self._pre_action = Action(CODE.NO_ACTION, -1)

        self.done_mark()

    def done_mark(self):
        self._stack[self._stack_size] = -2
        self._head[self._next_index] = -2
        self._rel[self._next_index] = -2
        self._have_parent[self._next_index] = -2

    def allow_shift(self):
        if self._next_index < self._word_size:
            return True
        else:
            return False

    def allow_arc_left(self):
        if self._stack_size > 1:
            return True
        else:
            return False

    def allow_arc_right(self):
        if self._stack_size > 1:
            return True
        else:
            return False

    def allow_pop_root(self):
        if self._stack_size == 1 and self._next_index == self._word_size:
            return True
        else:
            return False

    def allow_arc_label(self):
        if self._pre_action.is_arc_left() or self._pre_action.is_arc_right():
            return True
        else:
            return False

    def shift(self, next_state):
        assert self._next_index < self._word_size
        next_state._next_index = self._next_index + 1
        next_state._stack_size = self._stack_size + 1
        self.copy_state(next_state)
        next_state._stack[next_state._stack_size - 1] = self._next_index
        next_state._have_parent[self._next_index] = 0
        next_state.done_mark()
        next_state._pre_action.set(CODE.SHIFT, -1)

    def arc_left(self, next_state):
        assert self._stack_size > 1
        next_state._next_index = self._next_index
        next_state._stack_size = self._stack_size
        self.copy_state(next_state)
        next_state.done_mark()
        next_state._pre_action.set(CODE.ARC_LEFT, -1)

    def arc_right(self, next_state):
        assert self._stack_size > 1
        next_state._next_index = self._next_index
        next_state._stack_size = self._stack_size
        self.copy_state(next_state)
        next_state.done_mark()
        next_state._pre_action.set(CODE.ARC_RIGHT, -1)

    def arc_label(self, next_state, dep):
        assert self._stack_size > 1
        next_state._next_index = self._next_index
        next_state._stack_size = self._stack_size - 1
        self.copy_state(next_state)
        top0 = self._stack[self._stack_size - 1]
        top1 = self._stack[self._stack_size - 2]
        if (self._pre_action.is_arc_left()):
            next_state._stack[next_state._stack_size - 1] = top0
            next_state._head[top1] = top0
            next_state._have_parent[top1] = 1
            next_state._rel[top1] = dep
        else:
            next_state._head[top0] = top1
            next_state._have_parent[top0] = 1
            next_state._rel[top0] = dep
        next_state.done_mark()
        next_state._pre_action.set(CODE.ARC_LABEL, dep)

    def pop_root(self, next_state, dep):
        assert self._stack_size == 1 and self._next_index == self._word_size
        next_state._next_index = self._word_size
        next_state._stack_size = 0
        self.copy_state(next_state)
        top0 = self._stack[self._stack_size - 1]
        next_state._head[top0] = -1
        next_state._have_parent[top0] = 1
        next_state._rel[top0] = dep
        next_state.done_mark()
        next_state._pre_action.set(CODE.POP_ROOT, dep)

    def move(self, next_state, action):
        next_state._is_start = False
        next_state._is_gold = False
        if action.is_shift():
            self.shift(next_state)
        elif action.is_arc_left():
            self.arc_left(next_state)
        elif action.is_arc_right():
            self.arc_right(next_state)
        elif action.is_arc_label():
            self.arc_label(next_state, action.label)
        elif action.is_finish():
            self.pop_root(next_state, action.label)
        else:
            print(" error state ")

    def get_candidate_actions(self, vocab):
        mask = np.array([False] * vocab.ac_size)

        if self.allow_arc_label():
            mask = mask | vocab.mask_arc_label
            return ~mask
        if self.allow_arc_left():
            mask = mask | vocab.mask_arc_left
        if self.allow_arc_right():
            mask = mask | vocab.mask_arc_right

        if self.is_end():
            mask = mask | vocab.mask_no_action

        if self.allow_shift():
            mask = mask | vocab.mask_shift

        if self.allow_pop_root():
            mask = mask | vocab.mask_pop_root
        return ~mask

    def copy_state(self, next_state):
        next_state._inst = self._inst
        next_state._word_size = self._word_size
        next_state._stack[0:self._stack_size] = (
            self._stack[0:self._stack_size])
        next_state._rel[0:self._next_index] = (self._rel[0:self._next_index])
        next_state._head[0:self._next_index] = (self._head[0:self._next_index])
        next_state._have_parent[0:self._next_index] = (
            self._have_parent[0:self._next_index])

    def is_end(self):
        if self._pre_action.is_finish():
            return True
        else:
            return False

    def get_gold_action(self, vocab):
        gold_action = Action(CODE.NO_ACTION, -1)
        if self._stack_size == 0:
            gold_action.set(CODE.SHIFT, -1)
        elif self._stack_size == 1:
            if self._next_index == self._word_size:
                gold_action.set(CODE.POP_ROOT, vocab.ROOT)
            else:
                gold_action.set(CODE.SHIFT, -1)
        elif self._pre_action.is_arc_left() or self._pre_action.is_arc_right(
        ):  # arc label
            assert self._stack_size > 1
            top0 = self._stack[self._stack_size - 1]
            top1 = self._stack[self._stack_size - 2]
            if self._pre_action.is_arc_left():
                gold_action.set(CODE.ARC_LABEL,
                                vocab._rel2id[self._inst.rels[top1]])
            elif self._pre_action.is_arc_right():
                gold_action.set(CODE.ARC_LABEL,
                                vocab._rel2id[self._inst.rels[top0]])
        elif self._stack_size > 1:  # arc
            top0 = self._stack[self._stack_size - 1]
            top1 = self._stack[self._stack_size - 2]
            assert top0 < self._word_size and top1 < self._word_size
            if top0 == self._inst.heads[top1]:  # top1 <- top0
                gold_action.set(CODE.ARC_LEFT, -1)
            elif top1 == self._inst.heads[top0]:  # top1 -> top0,
                # if top0 have right child, shift.
                have_right_child = False
                for idx in range(self._next_index, self._word_size):
                    if self._inst.heads[idx] == top0:
                        have_right_child = True
                        break
                if have_right_child:
                    gold_action.set(CODE.SHIFT, -1)
                else:
                    gold_action.set(CODE.ARC_RIGHT, -1)
            else:  # can not arc
                gold_action.set(CODE.SHIFT, -1)
        return gold_action

    def get_result(self, vocab):
        result = []
        result.append(
            Dependency(0, vocab._root_form, vocab._root, 0, vocab._root))
        for idx in range(0, self._word_size):
            assert self._have_parent[idx] == 1
            relation = vocab.id2rel(self._rel[idx])
            head = self._head[idx]
            word = self._inst.words[idx]
            tag = self._inst.tags[idx]
            result.append(Dependency(idx + 1, word, tag, head + 1, relation))
        return result

    def prepare_index(self):
        if self._stack_size > 0:
            self._atom_feat.s0 = self._stack[self._stack_size - 1]
        else:
            self._atom_feat.s0 = self._word_size
        if self._stack_size > 1:
            self._atom_feat.s1 = self._stack[self._stack_size - 2]
        else:
            self._atom_feat.s1 = self._word_size
        if self._stack_size > 2:
            self._atom_feat.s2 = self._stack[self._stack_size - 3]
        else:
            self._atom_feat.s2 = self._word_size
        if self._next_index >= 0 and self._next_index < self._word_size:
            self._atom_feat.q0 = self._next_index
        else:
            self._atom_feat.q0 = self._word_size

        if self._pre_action.is_arc_left() or self._pre_action.is_arc_right():
            self._atom_feat.arc = True
        else:
            self._atom_feat.arc = False

        return self._atom_feat.index()
Exemple #11
0
class State:
    def __init__(self):
        self._stack = []
        for idx in range(max_length):
            self._stack.append(Node())
        self._stack_size = 0
        self._edu_size = 0
        self._next_index = 0
        self._word_size = 0
        self._is_start = True
        self._is_gold = True
        self._inst = None
        self._pre_state = None
        self._atom_feat = AtomFeat()
        self._pre_action = Action(CODE.NO_ACTION)

    def ready(self, doc):
        self._inst = doc
        self._edu_size = len(self._inst.EDUs)

    def clear(self):
        self._next_index = 0
        self._stack_size = 0
        self._inst = None
        self._is_gold = True
        self._is_start = True
        self._pre_state = None
        self._pre_action = Action(CODE.NO_ACTION)
        self.done_mark()

    def done_mark(self):
        self._stack[self._stack_size].clear()

    def allow_shift(self):
        if self._next_index >= self._edu_size:
            return False
        else:
            return True

    def allow_pop_root(self):
        if self._stack_size == 1 and self._next_index == self._edu_size:
            return True
        else:
            return False

    def allow_reduce(self):
        if self._stack_size >= 2:
            return True
        else:
            return False

    def shift(self, next_state):
        assert self._next_index < self._edu_size
        next_state._stack_size = self._stack_size + 1
        next_state._next_index = self._next_index + 1
        self.copy_state(next_state)
        top = next_state._stack[next_state._stack_size - 1]
        top.clear()
        top.is_validate = True
        top.edu_start = self._next_index
        top.edu_end = self._next_index
        next_state.done_mark()
        next_state._pre_action.set(CODE.SHIFT)

    def reduce(self, next_state, nuclear, label):
        next_state._stack_size = self._stack_size - 1
        next_state._next_index = self._next_index
        self.copy_state(next_state)
        top0 = next_state._stack[self._stack_size - 1]
        top1 = next_state._stack[self._stack_size - 2]
        assert top0.is_validate == True and top1.is_validate == True
        assert top0.edu_start == top1.edu_end + 1
        top1.edu_end = top0.edu_end
        top1.nuclear = nuclear
        top1.label = label
        top0.clear()
        next_state.done_mark()
        next_state._pre_action.set(CODE.REDUCE, nuclear=nuclear, label=label)

    def pop_root(self, next_state):
        assert  self._stack_size == 1 and self._next_index == self._edu_size
        next_state._next_index = self._edu_size
        next_state._stack_size = 0
        self.copy_state(next_state)
        top0 = next_state._stack[self._stack_size - 1]
        assert top0.is_validate == True
        assert top0.edu_start == 0 and top0.edu_end + 1 == len(self._inst.EDUs)
        top0.clear()
        next_state.done_mark()
        next_state._pre_action.set(CODE.POP_ROOT)

    def move(self, next_state, action):
        next_state._is_start = False
        next_state._is_gold = False
        if action.is_shift():
            self.shift(next_state)
        elif action.is_reduce():
            self.reduce(next_state, action.nuclear, action.label)
        elif action.is_finish():
            self.pop_root(next_state)
        else:
            print(" error state ")

    def get_candidate_actions(self, vocab):
        mask = np.array([False]*vocab.ac_size)
        if self.allow_reduce():
            mask = mask | vocab.mask_reduce
        if self.is_end():
            mask = mask | vocab.mask_no_action
        if self.allow_shift():
            mask = mask | vocab.mask_shift
        if self.allow_pop_root():
            mask = mask | vocab.mask_pop_root
        return ~mask

    def copy_state(self, next_state):
        next_state._stack[0:self._stack_size] = deepcopy(self._stack[0:self._stack_size])
        next_state._edu_size = self._edu_size
        next_state._inst = self._inst
        next_state._pre_state = self

    def is_end(self):
        if self._pre_action.is_finish():
            return True
        else:
            return False

    def get_result(self, vocab):
        result = Result()
        state_iter = self
        while not state_iter._pre_state._is_start:
            action = state_iter._pre_action
            pre_state = state_iter._pre_state
            if action.is_reduce():
                assert pre_state._stack_size >= 2
                right_node = pre_state._stack[pre_state._stack_size - 1]
                left_node = pre_state._stack[pre_state._stack_size - 2]
                left_subtree = SubTree()
                right_subtree = SubTree()

                left_subtree.edu_start = left_node.edu_start
                left_subtree.edu_end = left_node.edu_end

                right_subtree.edu_start = right_node.edu_start
                right_subtree.edu_end = right_node.edu_end

                if action.nuclear == NUCLEAR.NN:
                    left_subtree.nuclear = nuclear_str
                    right_subtree.nuclear = nuclear_str
                    left_subtree.relation = vocab._id2rel[action.label]
                    right_subtree.relation = vocab._id2rel[action.label]
                elif action.nuclear == NUCLEAR.SN:
                    left_subtree.nuclear = satellite_str
                    right_subtree.nuclear = nuclear_str
                    left_subtree.relation = vocab._id2rel[action.label]
                    right_subtree.relation = span_str
                elif action.nuclear == NUCLEAR.NS:
                    left_subtree.nuclear = nuclear_str
                    right_subtree.nuclear = satellite_str
                    left_subtree.relation = span_str
                    right_subtree.relation = vocab._id2rel[action.label]

                result.subtrees.insert(0, right_subtree)
                result.subtrees.insert(0, left_subtree)
            state_iter = state_iter._pre_state
        return result

    def prepare_index(self):
        if self._stack_size > 0:
            self._atom_feat.s0 = self._stack[self._stack_size - 1]
        else:
            self._atom_feat.s0 = None
        if self._stack_size > 1:
            self._atom_feat.s1 = self._stack[self._stack_size - 2]
        else:
            self._atom_feat.s1 = None
        if self._stack_size > 2:
            self._atom_feat.s2 = self._stack[self._stack_size - 3]
        else:
            self._atom_feat.s2 = None
        if self._next_index >= 0 and self._next_index < self._edu_size:
            self._atom_feat.q0 = self._next_index
        else:
            self._atom_feat.q0 = None

        return self._atom_feat