def get_reduce_candidate(self, error_cstate, gold_tree, candidate_actions): assert (error_cstate.stack_size >= 2) label_size = self.gold_action_alpha.size() tmp_acts = [] # 1 element is tuple (CAction, int) for nuclear in ['NN', 'NS', 'SN']: for label in self.action_label_alpha.alphas: ac = CAction(CAction.REDUCE, nuclear, label) action_str = ac.get_str() pad_id = self.gold_action_alpha.alpha2id['PAD'] if self.gold_action_alpha.word2id(action_str) != pad_id: loss = self.nuclear_label_loss(ac, error_cstate, gold_tree) tmp_acts.append((ac, loss)) if loss == 0: candidate_actions.append(ac) return candidate_actions assert (len(tmp_acts) > 0) action_size = len(tmp_acts) min_loss = tmp_acts[0][1] min_index = 0 for i in range(1, action_size): cur_iter = tmp_acts[i] cur_loss = cur_iter[1] if cur_loss < min_loss: min_index = i min_loss = cur_loss for i in range(action_size): cur_iter = tmp_acts[i] if cur_iter[1] == min_loss: candidate_actions.append(cur_iter[0]) return candidate_actions
def clear(self): self.stack_size = 0 #int self.edu_size = 0 #int self.next_index = 0 #int self.pre_state = None #CState self.pre_action = CAction('', '', '') #CAction self.is_start = True self.atom_feat = AtomFeat() #AtomFeat
def __init__(self): self.stack = [CNode() for i in range(MAX_LENGTH)] #list of CNode self.stack_size = 0 #int self.edu_size = 0 #int self.next_index = 0 #int self.pre_state = None #CState self.pre_action = CAction('', '', '') #CAction self.is_start = True self.atom_feat = AtomFeat() #AtomFeat
def get_action(self, id_selected_action): mapper = { 'SHIFT': 'SH', 'REDUCE': 'RD', 'POPROOT': 'PR', 'NOACTION': '' } str_selected_action = self.gold_action_alpha.id2word( id_selected_action).split('_') selected_action = CAction(mapper[str_selected_action[0]], str_selected_action[1], str_selected_action[2]) return selected_action
def get_oracle(self, error_cstate, gold_tree): candidate_actions = [] ac = CAction('', '', '') if error_cstate.stack_size < 2: if error_cstate.next_index == error_cstate.edu_size: ac.set(CAction.POP_ROOT, '', '') else: ac.set(CAction.SHIFT, '', '') candidate_actions.append(ac) elif error_cstate.next_index == error_cstate.edu_size: ac.set(CAction.REDUCE, '', '') else: shift_loss = self.shift_loss(error_cstate, gold_tree) reduce_loss = self.reduce_loss(error_cstate, gold_tree) if shift_loss < reduce_loss: ac.set(CAction.SHIFT, '', '') candidate_actions.append(ac) elif shift_loss >= reduce_loss: ac.set(CAction.REDUCE, '', '') if shift_loss == reduce_loss: shift_action = CAction(CAction.SHIFT, '', '') candidate_actions.append(shift_action) if ac.is_reduce(): candidate_actions = self.get_reduce_candidate( error_cstate, gold_tree, candidate_actions) minimum = 0 maximum = len(candidate_actions) rand_index = int(random.random() * (maximum - minimum)) # import ipdb; ipdb.set_trace() return candidate_actions[rand_index]
def parse_tree(self, string_tree, sent_types): # return value edus = [] gold_actions = [] subtree_stack = [] op_stack = [] relation_stack = [] action_stack = [] result = CResult() step = 0 start = '' end = '' edu_start = 0 edu_end = 0 buffers = string_tree.split(' ') while (True): assert (step <= len(buffers)) if step == len(buffers): break if buffers[step] == '(': op_stack.append(buffers[step]) relation_stack.append(buffers[step + 1]) action_stack.append(buffers[step + 2]) if buffers[step + 2] == 't': start = buffers[step + 3] end = buffers[step + 4] step += 2 step += 3 elif buffers[step] == ")": action = action_stack[-1] #stack.top if action == 't': edu = EDU(int(start), int(end)) for j in range(len(sent_types)): if edu.start_index >= sent_types[j][ 0] and edu.end_index <= sent_types[j][1]: edu.etype = sent_types[j][2] break edu_start = len(edus) edu_end = len(edus) subtree_stack.append((edu_start, edu_end)) edus.append(edu) ac = CAction(CA_SHIFT, '', relation_stack[-1]) assert (relation_stack[-1] == 'leaf') gold_actions.append(ac) elif action == 'l' or action == 'r' or action == 'c': nuclear = '' if action == 'l': nuclear = CA_NS elif action == 'r': nuclear = CA_SN elif action == 'c': nuclear = CA_NN ac = CAction(CA_REDUCE, nuclear, relation_stack[-1]) gold_actions.append(ac) subtree_size = len(subtree_stack) assert (subtree_size >= 2) right_tree_index = subtree_stack[subtree_size - 1] left_tree_index = subtree_stack[subtree_size - 2] right_tree = SubTree() right_tree.edu_start = right_tree_index[0] right_tree.edu_end = right_tree_index[1] left_tree = SubTree() left_tree.edu_start = left_tree_index[0] left_tree.edu_end = left_tree_index[1] if action == 'l': left_tree.nuclear = NUCLEAR right_tree.nuclear = SATELLITE left_tree.relation = SPAN right_tree.relation = ac.label elif action == 'r': left_tree.nuclear = SATELLITE right_tree.nuclear = NUCLEAR left_tree.relation = ac.label right_tree.relation = SPAN elif action == 'c': left_tree.nuclear = NUCLEAR right_tree.nuclear = NUCLEAR left_tree.relation = ac.label right_tree.relation = ac.label result.subtrees.append(left_tree) result.subtrees.append(right_tree) edu_start = right_tree_index[0] edu_end = right_tree_index[1] assert (left_tree_index[1] + 1 == edu_start) subtree_stack[subtree_size - 2] = modTupByIndex( left_tree_index, 1, edu_end) subtree_stack.pop() action_stack.pop() relation_stack.pop() op_stack.pop() step += 1 ac = CAction(CA_POP_ROOT, '', '') gold_actions.append(ac) #Check stack assert (len(op_stack) == 0 and len(relation_stack) == 0 and len(action_stack) == 0) return edus, gold_actions, result
class CState(object): def __init__(self): self.stack = [CNode() for i in range(MAX_LENGTH)] #list of CNode self.stack_size = 0 #int self.edu_size = 0 #int self.next_index = 0 #int self.pre_state = None #CState self.pre_action = CAction('', '', '') #CAction self.is_start = True self.atom_feat = AtomFeat() #AtomFeat def clear(self): self.stack_size = 0 #int self.edu_size = 0 #int self.next_index = 0 #int self.pre_state = None #CState self.pre_action = CAction('', '', '') #CAction self.is_start = True self.atom_feat = AtomFeat() #AtomFeat def ready(self, edu_size): self.edu_size = edu_size def is_end(self): if (self.pre_action.is_finish()): return True else: return False def copy_state(self, cstate): cstate.stack = copy.deepcopy(self.stack) cstate.edu_size = self.edu_size cstate.pre_state = self def done_mark(self): self.stack[self.stack_size].clear() def shift(self, cstate): cstate.stack_size = self.stack_size + 1 cstate.next_index = self.next_index + 1 self.copy_state(cstate) top = cstate.stack[cstate.stack_size - 1] top.clear() top.is_validate = True top.edu_start = self.next_index top.edu_end = self.next_index cstate.pre_action.set('SH', '', '') cstate.done_mark() def reduce(self, cstate, nuclear, label): cstate.stack_size = self.stack_size - 1 cstate.next_index = self.next_index self.copy_state(cstate) top0 = cstate.stack[self.stack_size - 1] top1 = cstate.stack[self.stack_size - 2] try: assert (top0.edu_start == top1.edu_end + 1) assert (top0.is_validate and top1.is_validate) except: import ipdb ipdb.set_trace() top1.edu_end = top0.edu_end top1.nuclear = nuclear top1.label = label top0.clear() cstate.stack[self.stack_size - 1] = top0 cstate.stack[self.stack_size - 2] = top1 cstate.pre_action.set('RD', nuclear, label) cstate.done_mark() def pop_root(self, cstate): assert self.stack_size == 1 and self.next_index == self.edu_size cstate.stack_size = 0 cstate.next_index = self.edu_size self.copy_state(cstate) top0 = cstate.stack[self.stack_size - 1] # assert(top0.edu_start == 0 and top0.edu_end + 1 == self.edu_size) assert (top0.edu_start == 0) assert (top0.is_validate) top0.clear() cstate.stack[self.stack_size - 1] = top0 cstate.pre_action.set('PR', '', '') cstate.done_mark() #cstate = CState #ac = CAction def move(self, cstate, ac): cstate.is_start = False if ac.is_shift(): self.shift(cstate) elif ac.is_reduce(): self.reduce(cstate, ac.nuclear, ac.label) elif ac.is_finish(): self.pop_root(cstate) else: raise Exception('Error Action!') return cstate def get_result(self): result = CResult() state = self while (not state.pre_state.is_start): ac = state.pre_action st = state.pre_state if (ac.is_reduce()): assert (st.stack_size >= 2) right_node = st.stack[st.stack_size - 1] left_node = st.stack[st.stack_size - 2] left_subtree = SubTree() right_subtree = SubTree() left_subtree.edu_start = left_node.edu_start left_subtree.edu_end = left_node.edu_end right_subtree.edu_start = right_node.edu_start right_subtree.edu_end = right_node.edu_end if ac.nuclear == 'NN': left_subtree.nuclear = NUCLEAR right_subtree.nuclear = NUCLEAR left_subtree.relation = ac.label right_subtree.relation = ac.label elif ac.nuclear == 'SN': left_subtree.nuclear = SATELLITE right_subtree.nuclear = NUCLEAR left_subtree.relation = ac.label right_subtree.relation = SPAN elif ac.nuclear == 'NS': left_subtree.nuclear = NUCLEAR right_subtree.nuclear = SATELLITE left_subtree.relation = SPAN right_subtree.relation = ac.label result.subtrees.insert(0, right_subtree) result.subtrees.insert(0, left_subtree) state = state.pre_state return result def allow_shift(self): if self.next_index == self.edu_size: return False return True def allow_reduce(self): if self.stack_size >= 2: return True return False def allow_pop_root(self): if self.next_index == self.edu_size and self.stack_size == 1: return True return False def get_candidate_actions(self, vocab): mask = np.array([False] * vocab.gold_action_alpha.size()) if self.allow_reduce(): mask = mask | vocab.mask_reduce if self.is_end(): mask = mask | vocab.mask_no_action if self.allow_shift(): mask = mask | vocab.mask_shift if self.allow_pop_root(): mask = mask | vocab.mask_pop_root return ~mask def prepare_index(self): if self.stack_size > 0: self.atom_feat.s0 = self.stack[self.stack_size - 1] else: self.atom_feat.s0 = None if self.stack_size > 1: self.atom_feat.s1 = self.stack[self.stack_size - 2] else: self.atom_feat.s1 = None if self.stack_size > 2: self.atom_feat.s2 = self.stack[self.stack_size - 3] else: self.atom_feat.s2 = None if self.next_index >= 0 and self.next_index < self.edu_size: self.atom_feat.q0 = self.next_index else: self.atom_feat.q0 = None return self.atom_feat