예제 #1
0
    def get_reduce_candidate(self, error_cstate, gold_tree, candidate_actions):
        assert (error_cstate.stack_size >= 2)
        label_size = self.gold_action_alpha.size()
        tmp_acts = []  # 1 element is tuple (CAction, int)
        for nuclear in ['NN', 'NS', 'SN']:
            for label in self.action_label_alpha.alphas:
                ac = CAction(CAction.REDUCE, nuclear, label)
                action_str = ac.get_str()
                pad_id = self.gold_action_alpha.alpha2id['PAD']
                if self.gold_action_alpha.word2id(action_str) != pad_id:
                    loss = self.nuclear_label_loss(ac, error_cstate, gold_tree)
                    tmp_acts.append((ac, loss))
                    if loss == 0:
                        candidate_actions.append(ac)
                        return candidate_actions
        assert (len(tmp_acts) > 0)
        action_size = len(tmp_acts)
        min_loss = tmp_acts[0][1]
        min_index = 0
        for i in range(1, action_size):
            cur_iter = tmp_acts[i]
            cur_loss = cur_iter[1]
            if cur_loss < min_loss:
                min_index = i
                min_loss = cur_loss

        for i in range(action_size):
            cur_iter = tmp_acts[i]
            if cur_iter[1] == min_loss:
                candidate_actions.append(cur_iter[0])
        return candidate_actions
예제 #2
0
 def clear(self):
     self.stack_size = 0  #int
     self.edu_size = 0  #int
     self.next_index = 0  #int
     self.pre_state = None  #CState
     self.pre_action = CAction('', '', '')  #CAction
     self.is_start = True
     self.atom_feat = AtomFeat()  #AtomFeat
예제 #3
0
 def __init__(self):
     self.stack = [CNode() for i in range(MAX_LENGTH)]  #list of CNode
     self.stack_size = 0  #int
     self.edu_size = 0  #int
     self.next_index = 0  #int
     self.pre_state = None  #CState
     self.pre_action = CAction('', '', '')  #CAction
     self.is_start = True
     self.atom_feat = AtomFeat()  #AtomFeat
예제 #4
0
 def get_action(self, id_selected_action):
     mapper = {
         'SHIFT': 'SH',
         'REDUCE': 'RD',
         'POPROOT': 'PR',
         'NOACTION': ''
     }
     str_selected_action = self.gold_action_alpha.id2word(
         id_selected_action).split('_')
     selected_action = CAction(mapper[str_selected_action[0]],
                               str_selected_action[1],
                               str_selected_action[2])
     return selected_action
예제 #5
0
 def get_oracle(self, error_cstate, gold_tree):
     candidate_actions = []
     ac = CAction('', '', '')
     if error_cstate.stack_size < 2:
         if error_cstate.next_index == error_cstate.edu_size:
             ac.set(CAction.POP_ROOT, '', '')
         else:
             ac.set(CAction.SHIFT, '', '')
         candidate_actions.append(ac)
     elif error_cstate.next_index == error_cstate.edu_size:
         ac.set(CAction.REDUCE, '', '')
     else:
         shift_loss = self.shift_loss(error_cstate, gold_tree)
         reduce_loss = self.reduce_loss(error_cstate, gold_tree)
         if shift_loss < reduce_loss:
             ac.set(CAction.SHIFT, '', '')
             candidate_actions.append(ac)
         elif shift_loss >= reduce_loss:
             ac.set(CAction.REDUCE, '', '')
             if shift_loss == reduce_loss:
                 shift_action = CAction(CAction.SHIFT, '', '')
                 candidate_actions.append(shift_action)
     if ac.is_reduce():
         candidate_actions = self.get_reduce_candidate(
             error_cstate, gold_tree, candidate_actions)
     minimum = 0
     maximum = len(candidate_actions)
     rand_index = int(random.random() * (maximum - minimum))
     # import ipdb; ipdb.set_trace()
     return candidate_actions[rand_index]
예제 #6
0
    def parse_tree(self, string_tree, sent_types):
        # return value
        edus = []
        gold_actions = []

        subtree_stack = []
        op_stack = []
        relation_stack = []
        action_stack = []
        result = CResult()

        step = 0
        start = ''
        end = ''
        edu_start = 0
        edu_end = 0

        buffers = string_tree.split(' ')
        while (True):
            assert (step <= len(buffers))
            if step == len(buffers):
                break
            if buffers[step] == '(':
                op_stack.append(buffers[step])
                relation_stack.append(buffers[step + 1])
                action_stack.append(buffers[step + 2])
                if buffers[step + 2] == 't':
                    start = buffers[step + 3]
                    end = buffers[step + 4]
                    step += 2
                step += 3
            elif buffers[step] == ")":
                action = action_stack[-1]  #stack.top
                if action == 't':
                    edu = EDU(int(start), int(end))
                    for j in range(len(sent_types)):
                        if edu.start_index >= sent_types[j][
                                0] and edu.end_index <= sent_types[j][1]:
                            edu.etype = sent_types[j][2]
                            break
                    edu_start = len(edus)
                    edu_end = len(edus)
                    subtree_stack.append((edu_start, edu_end))
                    edus.append(edu)
                    ac = CAction(CA_SHIFT, '', relation_stack[-1])
                    assert (relation_stack[-1] == 'leaf')
                    gold_actions.append(ac)
                elif action == 'l' or action == 'r' or action == 'c':
                    nuclear = ''
                    if action == 'l':
                        nuclear = CA_NS
                    elif action == 'r':
                        nuclear = CA_SN
                    elif action == 'c':
                        nuclear = CA_NN
                    ac = CAction(CA_REDUCE, nuclear, relation_stack[-1])
                    gold_actions.append(ac)

                    subtree_size = len(subtree_stack)
                    assert (subtree_size >= 2)

                    right_tree_index = subtree_stack[subtree_size - 1]
                    left_tree_index = subtree_stack[subtree_size - 2]

                    right_tree = SubTree()
                    right_tree.edu_start = right_tree_index[0]
                    right_tree.edu_end = right_tree_index[1]

                    left_tree = SubTree()
                    left_tree.edu_start = left_tree_index[0]
                    left_tree.edu_end = left_tree_index[1]

                    if action == 'l':
                        left_tree.nuclear = NUCLEAR
                        right_tree.nuclear = SATELLITE
                        left_tree.relation = SPAN
                        right_tree.relation = ac.label
                    elif action == 'r':
                        left_tree.nuclear = SATELLITE
                        right_tree.nuclear = NUCLEAR
                        left_tree.relation = ac.label
                        right_tree.relation = SPAN
                    elif action == 'c':
                        left_tree.nuclear = NUCLEAR
                        right_tree.nuclear = NUCLEAR
                        left_tree.relation = ac.label
                        right_tree.relation = ac.label

                    result.subtrees.append(left_tree)
                    result.subtrees.append(right_tree)

                    edu_start = right_tree_index[0]
                    edu_end = right_tree_index[1]
                    assert (left_tree_index[1] + 1 == edu_start)
                    subtree_stack[subtree_size - 2] = modTupByIndex(
                        left_tree_index, 1, edu_end)
                    subtree_stack.pop()

                action_stack.pop()
                relation_stack.pop()
                op_stack.pop()
                step += 1

        ac = CAction(CA_POP_ROOT, '', '')
        gold_actions.append(ac)
        #Check stack
        assert (len(op_stack) == 0 and len(relation_stack) == 0
                and len(action_stack) == 0)
        return edus, gold_actions, result
예제 #7
0
class CState(object):
    def __init__(self):
        self.stack = [CNode() for i in range(MAX_LENGTH)]  #list of CNode
        self.stack_size = 0  #int
        self.edu_size = 0  #int
        self.next_index = 0  #int
        self.pre_state = None  #CState
        self.pre_action = CAction('', '', '')  #CAction
        self.is_start = True
        self.atom_feat = AtomFeat()  #AtomFeat

    def clear(self):
        self.stack_size = 0  #int
        self.edu_size = 0  #int
        self.next_index = 0  #int
        self.pre_state = None  #CState
        self.pre_action = CAction('', '', '')  #CAction
        self.is_start = True
        self.atom_feat = AtomFeat()  #AtomFeat

    def ready(self, edu_size):
        self.edu_size = edu_size

    def is_end(self):
        if (self.pre_action.is_finish()):
            return True
        else:
            return False

    def copy_state(self, cstate):
        cstate.stack = copy.deepcopy(self.stack)
        cstate.edu_size = self.edu_size
        cstate.pre_state = self

    def done_mark(self):
        self.stack[self.stack_size].clear()

    def shift(self, cstate):
        cstate.stack_size = self.stack_size + 1
        cstate.next_index = self.next_index + 1
        self.copy_state(cstate)
        top = cstate.stack[cstate.stack_size - 1]
        top.clear()
        top.is_validate = True
        top.edu_start = self.next_index
        top.edu_end = self.next_index

        cstate.pre_action.set('SH', '', '')
        cstate.done_mark()

    def reduce(self, cstate, nuclear, label):
        cstate.stack_size = self.stack_size - 1
        cstate.next_index = self.next_index
        self.copy_state(cstate)
        top0 = cstate.stack[self.stack_size - 1]
        top1 = cstate.stack[self.stack_size - 2]
        try:
            assert (top0.edu_start == top1.edu_end + 1)
            assert (top0.is_validate and top1.is_validate)
        except:
            import ipdb
            ipdb.set_trace()
        top1.edu_end = top0.edu_end
        top1.nuclear = nuclear
        top1.label = label
        top0.clear()

        cstate.stack[self.stack_size - 1] = top0
        cstate.stack[self.stack_size - 2] = top1

        cstate.pre_action.set('RD', nuclear, label)
        cstate.done_mark()

    def pop_root(self, cstate):
        assert self.stack_size == 1 and self.next_index == self.edu_size
        cstate.stack_size = 0
        cstate.next_index = self.edu_size
        self.copy_state(cstate)
        top0 = cstate.stack[self.stack_size - 1]
        # assert(top0.edu_start == 0 and top0.edu_end + 1 == self.edu_size)
        assert (top0.edu_start == 0)
        assert (top0.is_validate)
        top0.clear()

        cstate.stack[self.stack_size - 1] = top0
        cstate.pre_action.set('PR', '', '')
        cstate.done_mark()

    #cstate = CState
    #ac = CAction
    def move(self, cstate, ac):
        cstate.is_start = False
        if ac.is_shift():
            self.shift(cstate)
        elif ac.is_reduce():
            self.reduce(cstate, ac.nuclear, ac.label)
        elif ac.is_finish():
            self.pop_root(cstate)
        else:
            raise Exception('Error Action!')
        return cstate

    def get_result(self):
        result = CResult()
        state = self
        while (not state.pre_state.is_start):
            ac = state.pre_action
            st = state.pre_state
            if (ac.is_reduce()):
                assert (st.stack_size >= 2)
                right_node = st.stack[st.stack_size - 1]
                left_node = st.stack[st.stack_size - 2]
                left_subtree = SubTree()
                right_subtree = SubTree()

                left_subtree.edu_start = left_node.edu_start
                left_subtree.edu_end = left_node.edu_end

                right_subtree.edu_start = right_node.edu_start
                right_subtree.edu_end = right_node.edu_end

                if ac.nuclear == 'NN':
                    left_subtree.nuclear = NUCLEAR
                    right_subtree.nuclear = NUCLEAR
                    left_subtree.relation = ac.label
                    right_subtree.relation = ac.label
                elif ac.nuclear == 'SN':
                    left_subtree.nuclear = SATELLITE
                    right_subtree.nuclear = NUCLEAR
                    left_subtree.relation = ac.label
                    right_subtree.relation = SPAN
                elif ac.nuclear == 'NS':
                    left_subtree.nuclear = NUCLEAR
                    right_subtree.nuclear = SATELLITE
                    left_subtree.relation = SPAN
                    right_subtree.relation = ac.label

                result.subtrees.insert(0, right_subtree)
                result.subtrees.insert(0, left_subtree)
            state = state.pre_state
        return result

    def allow_shift(self):
        if self.next_index == self.edu_size:
            return False
        return True

    def allow_reduce(self):
        if self.stack_size >= 2:
            return True
        return False

    def allow_pop_root(self):
        if self.next_index == self.edu_size and self.stack_size == 1:
            return True
        return False

    def get_candidate_actions(self, vocab):
        mask = np.array([False] * vocab.gold_action_alpha.size())
        if self.allow_reduce():
            mask = mask | vocab.mask_reduce
        if self.is_end():
            mask = mask | vocab.mask_no_action
        if self.allow_shift():
            mask = mask | vocab.mask_shift
        if self.allow_pop_root():
            mask = mask | vocab.mask_pop_root
        return ~mask

    def prepare_index(self):
        if self.stack_size > 0:
            self.atom_feat.s0 = self.stack[self.stack_size - 1]
        else:
            self.atom_feat.s0 = None
        if self.stack_size > 1:
            self.atom_feat.s1 = self.stack[self.stack_size - 2]
        else:
            self.atom_feat.s1 = None
        if self.stack_size > 2:
            self.atom_feat.s2 = self.stack[self.stack_size - 3]
        else:
            self.atom_feat.s2 = None
        if self.next_index >= 0 and self.next_index < self.edu_size:
            self.atom_feat.q0 = self.next_index
        else:
            self.atom_feat.q0 = None
        return self.atom_feat