def get_candidate_actions(self, vocab): candidate_actions = [] if self.is_end(): ac = Action(CODE.NO_ACTION, -1) candidate_actions.append(ac) if self.allow_shift(): ac = Action(CODE.SHIFT, -1) candidate_actions.append(ac) if self.allow_arc_left(): for idx in range(0, vocab.rel_size): ac = Action(CODE.ARC_LEFT, idx) if ac in vocab._id2ac: candidate_actions.append(ac) if self.allow_arc_right(): for idx in range(0, vocab.rel_size): ac = Action(CODE.ARC_RIGHT, idx) if ac in vocab._id2ac: candidate_actions.append(ac) if self.allow_pop_root(): ac = Action(CODE.POP_ROOT, vocab.ROOT) candidate_actions.append(ac) return candidate_actions
def clear(self): self._next_index = 0 self._stack_size = 0 self._word_size = 0 self._is_gold = True self._is_start = True self._pre_action = Action(CODE.NO_ACTION, -1) self.done_mark()
def __init__(self): self._stack = [-3] * max_length self._stack_size = 0 self._rel = [-3] * max_length self._head = [-3] * max_length self._have_parent = [-1] * max_length self._next_index = 0 self._word_size = 0 self._is_start = True self._is_gold = True self._inst = None self._atom_feat = AtomFeat() self._pre_action = Action(CODE.NO_ACTION, -1)
def __init__(self): self._stack = [] for idx in range(max_length): self._stack.append(Node()) self._stack_size = 0 self._edu_size = 0 self._next_index = 0 self._word_size = 0 self._is_start = True self._is_gold = True self._inst = None self._pre_state = None self._atom_feat = AtomFeat() self._pre_action = Action(CODE.NO_ACTION)
def create_action_table(self, all_actions): self._id2ac.append(Action(CODE.NO_ACTION, -1)) ac_counter = Counter() for actions in all_actions: for ac in actions: ac_counter[ac] += 1 for ac, count in ac_counter.most_common(): self._id2ac.append(ac) reverse = lambda x: dict(zip(x, range(len(x)))) self._ac2id = reverse(self._id2ac) if len(self._ac2id) != len(self._id2ac): print("serious bug: actions dumplicated, please check!") print("action num: ", len(self._ac2id)) print("action: ", end=' ') self.mask_shift = np.array([False] * self.ac_size) self.mask_arc_left = np.array([False] * self.ac_size) self.mask_arc_right = np.array([False] * self.ac_size) self.mask_arc_label = np.array([False] * self.ac_size) self.mask_pop_root = np.array([False] * self.ac_size) self.mask_no_action = np.array([False] * self.ac_size) for (idx, ac) in enumerate(self._id2ac): if ac.is_shift(): self.mask_shift[idx] = True if ac.is_arc_left(): self.mask_arc_left[idx] = True if ac.is_arc_right(): self.mask_arc_right[idx] = True if ac.is_arc_label(): self.mask_arc_label[idx] = True if ac.is_finish(): self.mask_pop_root[idx] = True if ac.is_none(): self.mask_no_action[idx] = True print(ac.str(self), end=', ') print()
def create_action_table(self, all_actions): self._id2ac.append(Action(CODE.NO_ACTION, -1)) ac_counter = Counter() for actions in all_actions: for ac in actions: ac_counter[ac] += 1 for ac, count in ac_counter.most_common(): self._id2ac.append(ac) reverse = lambda x: dict(zip(x, range(len(x)))) self._ac2id = reverse(self._id2ac) if len(self._ac2id) != len(self._id2ac): print("serious bug: actions dumplicated, please check!") print("action num: ", len(self._ac2id)) print("action: ", end=' ') for ac in self._id2ac: print(ac.str(self), end=', ') print()
def get_gold_action(self, vocab): gold_action = Action(CODE.NO_ACTION, -1) if self._stack_size == 0: gold_action.set(CODE.SHIFT, -1) elif self._stack_size == 1: if self._next_index == self._word_size: gold_action.set(CODE.POP_ROOT, vocab.ROOT) else: gold_action.set(CODE.SHIFT, -1) elif self._stack_size > 1: # arc top0 = self._stack[self._stack_size - 1] top1 = self._stack[self._stack_size - 2] assert top0 < self._word_size and top1 < self._word_size if top0 == self._inst.heads[top1]: # top1 <- top0 gold_action.set(CODE.ARC_LEFT, vocab._rel2id[self._inst.rels[top1]]) elif top1 == self._inst.heads[top0]: # top1 -> top0, # if top0 have right child, shift. have_right_child = False for idx in range(self._next_index, self._word_size): if self._inst.heads[idx] == top0: have_right_child = True break if have_right_child: gold_action.set(CODE.SHIFT, -1) else: gold_action.set(CODE.ARC_RIGHT, vocab._rel2id[self._inst.rels[top0]]) else: # can not arc gold_action.set(CODE.SHIFT, -1) return gold_action
class State: def __init__(self): self._stack = [-3] * max_length self._stack_size = 0 self._rel = [-3] * max_length self._head = [-3] * max_length self._have_parent = [-1] * max_length self._next_index = 0 self._word_size = 0 self._is_start = True self._is_gold = True self._inst = None self._atom_feat = AtomFeat() self._pre_state = None self._pre_action = Action(CODE.NO_ACTION, -1) def ready(self, sentence, vocab): self._inst = Instance(sentence, vocab) self._word_size = len(self._inst.words) def clear(self): self._next_index = 0 self._stack_size = 0 self._word_size = 0 self._is_gold = True self._is_start = True self._pre_action = Action(CODE.NO_ACTION, -1) self._pre_state = None self.done_mark() def done_mark(self): self._stack[self._stack_size] = -2 self._head[self._next_index] = -2 self._rel[self._next_index] = -2 self._have_parent[self._next_index] = -2 def allow_shift(self): if self._next_index < self._word_size: return True else: return False def allow_arc_left(self): if self._stack_size > 1: return True else: return False def allow_arc_right(self): if self._stack_size > 1: return True else: return False def allow_pop_root(self): if self._stack_size == 1 and self._next_index == self._word_size: return True else: return False def shift(self, next_state): assert self._next_index < self._word_size next_state._next_index = self._next_index + 1 next_state._stack_size = self._stack_size + 1 self.copy_state(next_state) next_state._stack[next_state._stack_size - 1] = self._next_index next_state._have_parent[self._next_index] = 0 next_state.done_mark() next_state._pre_action.set(CODE.SHIFT, -1) def arc_left(self, next_state, dep): assert self._stack_size > 1 next_state._next_index = self._next_index next_state._stack_size = self._stack_size - 1 self.copy_state(next_state) top0 = self._stack[self._stack_size - 1] top1 = self._stack[self._stack_size - 2] next_state._stack[next_state._stack_size - 1] = top0 next_state._head[top1] = top0 next_state._have_parent[top1] = 1 next_state._rel[top1] = dep next_state.done_mark() next_state._pre_action.set(CODE.ARC_LEFT, dep) def arc_right(self, next_state, dep): assert self._stack_size > 1 next_state._next_index = self._next_index next_state._stack_size = self._stack_size - 1 self.copy_state(next_state) top0 = self._stack[self._stack_size - 1] top1 = self._stack[self._stack_size - 2] #next_state._stack[next_state._stack_size - 1] = top0 next_state._head[top0] = top1 next_state._have_parent[top0] = 1 next_state.done_mark() next_state._rel[top0] = dep next_state._pre_action.set(CODE.ARC_RIGHT, dep) def pop_root(self, next_state, dep): assert self._stack_size == 1 and self._next_index == self._word_size next_state._next_index = self._word_size next_state._stack_size = 0 self.copy_state(next_state) top0 = self._stack[self._stack_size - 1] next_state._head[top0] = -1 next_state._have_parent[top0] = 1 next_state._rel[top0] = dep next_state.done_mark() next_state._pre_action.set(CODE.POP_ROOT, dep) def move(self, next_state, action): next_state._is_start = False next_state._is_gold = False if action.is_shift(): self.shift(next_state) elif action.is_arc_left(): self.arc_left(next_state, action.label) elif action.is_arc_right(): self.arc_right(next_state, action.label) elif action.is_finish(): self.pop_root(next_state, action.label) else: print(" error state ") def get_candidate_actions(self, vocab): candidate_actions = [] if self.is_end(): ac = Action(CODE.NO_ACTION, -1) candidate_actions.append(ac) if self.allow_shift(): ac = Action(CODE.SHIFT, -1) candidate_actions.append(ac) if self.allow_arc_left(): for idx in range(0, vocab.rel_size): ac = Action(CODE.ARC_LEFT, idx) if ac in vocab._id2ac: candidate_actions.append(ac) if self.allow_arc_right(): for idx in range(0, vocab.rel_size): ac = Action(CODE.ARC_RIGHT, idx) if ac in vocab._id2ac: candidate_actions.append(ac) if self.allow_pop_root(): ac = Action(CODE.POP_ROOT, vocab.ROOT) candidate_actions.append(ac) return candidate_actions def copy_state(self, next_state): next_state._pre_state = self next_state._inst = self._inst next_state._word_size = self._word_size next_state._stack[0:self._stack_size] = deepcopy( self._stack[0:self._stack_size]) next_state._rel[0:self._next_index] = deepcopy( self._rel[0:self._next_index]) next_state._head[0:self._next_index] = deepcopy( self._head[0:self._next_index]) next_state._have_parent[0:self._next_index] = deepcopy( self._have_parent[0:self._next_index]) def is_end(self): if self._pre_action.is_finish(): return True else: return False def get_gold_action(self, vocab): gold_action = Action(CODE.NO_ACTION, -1) if self._stack_size == 0: gold_action.set(CODE.SHIFT, -1) elif self._stack_size == 1: if self._next_index == self._word_size: gold_action.set(CODE.POP_ROOT, vocab.ROOT) else: gold_action.set(CODE.SHIFT, -1) elif self._stack_size > 1: # arc top0 = self._stack[self._stack_size - 1] top1 = self._stack[self._stack_size - 2] assert top0 < self._word_size and top1 < self._word_size if top0 == self._inst.heads[top1]: # top1 <- top0 gold_action.set(CODE.ARC_LEFT, vocab._rel2id[self._inst.rels[top1]]) elif top1 == self._inst.heads[top0]: # top1 -> top0, # if top0 have right child, shift. have_right_child = False for idx in range(self._next_index, self._word_size): if self._inst.heads[idx] == top0: have_right_child = True break if have_right_child: gold_action.set(CODE.SHIFT, -1) else: gold_action.set(CODE.ARC_RIGHT, vocab._rel2id[self._inst.rels[top0]]) else: # can not arc gold_action.set(CODE.SHIFT, -1) return gold_action def get_result(self, vocab): result = [] result.append( Dependency(0, vocab._root_form, vocab._root, 0, vocab._root)) for idx in range(0, self._word_size): assert self._have_parent[idx] == 1 relation = vocab.id2rel(self._rel[idx]) head = self._head[idx] word = self._inst.words[idx] tag = self._inst.tags[idx] result.append(Dependency(idx + 1, word, tag, head + 1, relation)) return result def prepare_atom_feat(self, encoder_output, offset, bucket, vocab): if self._stack_size > 0: self._atom_feat.s0 = self._stack[self._stack_size - 1] else: self._atom_feat.s0 = -1 if self._stack_size > 1: self._atom_feat.s1 = self._stack[self._stack_size - 2] else: self._atom_feat.s1 = -1 if self._stack_size > 2: self._atom_feat.s2 = self._stack[self._stack_size - 3] else: self._atom_feat.s2 = -1 if self._next_index >= 0 and self._next_index < self._word_size: self._atom_feat.q0 = self._next_index if self._atom_feat.s0 == -1: hidden_s0 = bucket else: hidden_s0 = encoder_output[offset][self._atom_feat.s0].unsqueeze(0) if self._atom_feat.s1 == -1: hidden_s1 = bucket else: hidden_s1 = encoder_output[offset][self._atom_feat.s1].unsqueeze(0) if self._atom_feat.s2 == -1: hidden_s2 = bucket else: hidden_s2 = encoder_output[offset][self._atom_feat.s2].unsqueeze(0) if self._atom_feat.q0 == -1: hidden_q0 = bucket else: hidden_q0 = encoder_output[offset][self._atom_feat.q0].unsqueeze(0) self.hidden_state = torch.cat( (hidden_s0, hidden_s1, hidden_s2, hidden_q0), 1)
def parseTree(self, tree_str): buffer = tree_str.strip().split(" ") buffer_size = len(buffer) step = 0 subtree_stack = [] # edu index op_stack = [] relation_stack = [] action_stack = [] while True: assert step <= buffer_size if step == buffer_size: break if buffer[step] == "(": op_stack.append(buffer[step]) relation_stack.append(buffer[step + 1]) action_stack.append(buffer[step + 2]) if buffer[step + 1] == 'leaf' and buffer[step + 2] == 't': start = int(buffer[step + 3]) end = int(buffer[step + 4]) step += 2 step += 3 elif buffer[step] == ")": action = action_stack[-1] if action == 't': for sent_type in self.sent_types: assert len(sent_type) == 3 if start >= sent_type[0] and end <= sent_type[1]: e = EDU(start, end, sent_type[2]) edu_start = len(self.EDUs) edu_end = len(self.EDUs) subtree_stack.append([edu_start, edu_end]) self.EDUs.append(e) assert relation_stack[-1] == "leaf" ac = Action(CODE.SHIFT, -1, -1, relation_stack[-1]) self.gold_actions.append(ac) break elif action == 'l' or action == 'r' or action == 'c': if action == 'l': nuclear = NUCLEAR.NS if action == 'r': nuclear = NUCLEAR.SN if action == 'c': nuclear = NUCLEAR.NN code = CODE.REDUCE ac = Action(code, nuclear, -1, relation_stack[-1]) self.gold_actions.append(ac) assert len(subtree_stack) >= 2 l_index = subtree_stack[-2] r_index = subtree_stack[-1] assert l_index[1] + 1 == r_index[0] left_subtree = SubTree(nullkey, nullkey, l_index[0], l_index[1]) right_subtree = SubTree(nullkey, nullkey, r_index[0], r_index[1]) if action == "l": #NS left_subtree.nuclear = nuclear_str left_subtree.relation = span_str right_subtree.nuclear = satellite_str right_subtree.relation = ac.label_str if action == "r": #SN left_subtree.nuclear = satellite_str left_subtree.relation = ac.label_str right_subtree.nuclear = nuclear_str right_subtree.relation = span_str if action == "c": #NN left_subtree.nuclear = nuclear_str left_subtree.relation = ac.label_str right_subtree.nuclear = nuclear_str right_subtree.relation = ac.label_str self.result.subtrees.append(left_subtree) self.result.subtrees.append(right_subtree) l_index[1] = r_index[1] subtree_stack.pop() relation_stack.pop() op_stack.pop() action_stack.pop() step += 1 ac = Action(CODE.POP_ROOT) self.gold_actions.append(ac) assert len(subtree_stack) == 1 root = subtree_stack[0] assert root[0] == 0 and root[1] == len(self.EDUs) - 1 subtree_stack.pop() # pop root #### check stack, all stack empty assert op_stack == [] and relation_stack == [] and action_stack == [] and subtree_stack == [] #### check edu index for idx in range(len(self.EDUs)): edu = self.EDUs[idx] assert edu.start >= 0 and edu.end < len(self.total_words) assert edu.start <= edu.end if idx < len(self.EDUs) - 1: assert edu.end + 1 == self.EDUs[idx + 1].start #### initialize edu word and tag sum = 0 for edu in self.EDUs: for idx in range(edu.start, edu.end + 1): if self.total_tags[idx] != nullkey: edu.words.append(self.total_words[idx]) edu.tags.append(self.total_tags[idx]) sum += len(edu.words) assert sum == len(self.words) #### check subtree for subtree in self.result.subtrees: assert subtree.relation != nullkey and subtree.nuclear != nullkey
class State: def __init__(self): self._stack = [-3] * max_length self._stack_size = 0 self._rel = [-3] * max_length self._head = [-3] * max_length self._have_parent = [-1] * max_length self._next_index = 0 self._word_size = 0 self._is_start = True self._is_gold = True self._inst = None self._atom_feat = AtomFeat() self._pre_action = Action(CODE.NO_ACTION, -1) def ready(self, sentence, vocab): self._inst = Instance(sentence, vocab) self._word_size = len(self._inst.words) def clear(self): self._next_index = 0 self._stack_size = 0 self._word_size = 0 self._is_gold = True self._is_start = True self._pre_action = Action(CODE.NO_ACTION, -1) self.done_mark() def done_mark(self): self._stack[self._stack_size] = -2 self._head[self._next_index] = -2 self._rel[self._next_index] = -2 self._have_parent[self._next_index] = -2 def allow_shift(self): if self._next_index < self._word_size: return True else: return False def allow_arc_left(self): if self._stack_size > 1: return True else: return False def allow_arc_right(self): if self._stack_size > 1: return True else: return False def allow_pop_root(self): if self._stack_size == 1 and self._next_index == self._word_size: return True else: return False def allow_arc_label(self): if self._pre_action.is_arc_left() or self._pre_action.is_arc_right(): return True else: return False def shift(self, next_state): assert self._next_index < self._word_size next_state._next_index = self._next_index + 1 next_state._stack_size = self._stack_size + 1 self.copy_state(next_state) next_state._stack[next_state._stack_size - 1] = self._next_index next_state._have_parent[self._next_index] = 0 next_state.done_mark() next_state._pre_action.set(CODE.SHIFT, -1) def arc_left(self, next_state): assert self._stack_size > 1 next_state._next_index = self._next_index next_state._stack_size = self._stack_size self.copy_state(next_state) next_state.done_mark() next_state._pre_action.set(CODE.ARC_LEFT, -1) def arc_right(self, next_state): assert self._stack_size > 1 next_state._next_index = self._next_index next_state._stack_size = self._stack_size self.copy_state(next_state) next_state.done_mark() next_state._pre_action.set(CODE.ARC_RIGHT, -1) def arc_label(self, next_state, dep): assert self._stack_size > 1 next_state._next_index = self._next_index next_state._stack_size = self._stack_size - 1 self.copy_state(next_state) top0 = self._stack[self._stack_size - 1] top1 = self._stack[self._stack_size - 2] if (self._pre_action.is_arc_left()): next_state._stack[next_state._stack_size - 1] = top0 next_state._head[top1] = top0 next_state._have_parent[top1] = 1 next_state._rel[top1] = dep else: next_state._head[top0] = top1 next_state._have_parent[top0] = 1 next_state._rel[top0] = dep next_state.done_mark() next_state._pre_action.set(CODE.ARC_LABEL, dep) def pop_root(self, next_state, dep): assert self._stack_size == 1 and self._next_index == self._word_size next_state._next_index = self._word_size next_state._stack_size = 0 self.copy_state(next_state) top0 = self._stack[self._stack_size - 1] next_state._head[top0] = -1 next_state._have_parent[top0] = 1 next_state._rel[top0] = dep next_state.done_mark() next_state._pre_action.set(CODE.POP_ROOT, dep) def move(self, next_state, action): next_state._is_start = False next_state._is_gold = False if action.is_shift(): self.shift(next_state) elif action.is_arc_left(): self.arc_left(next_state) elif action.is_arc_right(): self.arc_right(next_state) elif action.is_arc_label(): self.arc_label(next_state, action.label) elif action.is_finish(): self.pop_root(next_state, action.label) else: print(" error state ") def get_candidate_actions(self, vocab): mask = np.array([False] * vocab.ac_size) if self.allow_arc_label(): mask = mask | vocab.mask_arc_label return ~mask if self.allow_arc_left(): mask = mask | vocab.mask_arc_left if self.allow_arc_right(): mask = mask | vocab.mask_arc_right if self.is_end(): mask = mask | vocab.mask_no_action if self.allow_shift(): mask = mask | vocab.mask_shift if self.allow_pop_root(): mask = mask | vocab.mask_pop_root return ~mask def copy_state(self, next_state): next_state._inst = self._inst next_state._word_size = self._word_size next_state._stack[0:self._stack_size] = ( self._stack[0:self._stack_size]) next_state._rel[0:self._next_index] = (self._rel[0:self._next_index]) next_state._head[0:self._next_index] = (self._head[0:self._next_index]) next_state._have_parent[0:self._next_index] = ( self._have_parent[0:self._next_index]) def is_end(self): if self._pre_action.is_finish(): return True else: return False def get_gold_action(self, vocab): gold_action = Action(CODE.NO_ACTION, -1) if self._stack_size == 0: gold_action.set(CODE.SHIFT, -1) elif self._stack_size == 1: if self._next_index == self._word_size: gold_action.set(CODE.POP_ROOT, vocab.ROOT) else: gold_action.set(CODE.SHIFT, -1) elif self._pre_action.is_arc_left() or self._pre_action.is_arc_right( ): # arc label assert self._stack_size > 1 top0 = self._stack[self._stack_size - 1] top1 = self._stack[self._stack_size - 2] if self._pre_action.is_arc_left(): gold_action.set(CODE.ARC_LABEL, vocab._rel2id[self._inst.rels[top1]]) elif self._pre_action.is_arc_right(): gold_action.set(CODE.ARC_LABEL, vocab._rel2id[self._inst.rels[top0]]) elif self._stack_size > 1: # arc top0 = self._stack[self._stack_size - 1] top1 = self._stack[self._stack_size - 2] assert top0 < self._word_size and top1 < self._word_size if top0 == self._inst.heads[top1]: # top1 <- top0 gold_action.set(CODE.ARC_LEFT, -1) elif top1 == self._inst.heads[top0]: # top1 -> top0, # if top0 have right child, shift. have_right_child = False for idx in range(self._next_index, self._word_size): if self._inst.heads[idx] == top0: have_right_child = True break if have_right_child: gold_action.set(CODE.SHIFT, -1) else: gold_action.set(CODE.ARC_RIGHT, -1) else: # can not arc gold_action.set(CODE.SHIFT, -1) return gold_action def get_result(self, vocab): result = [] result.append( Dependency(0, vocab._root_form, vocab._root, 0, vocab._root)) for idx in range(0, self._word_size): assert self._have_parent[idx] == 1 relation = vocab.id2rel(self._rel[idx]) head = self._head[idx] word = self._inst.words[idx] tag = self._inst.tags[idx] result.append(Dependency(idx + 1, word, tag, head + 1, relation)) return result def prepare_index(self): if self._stack_size > 0: self._atom_feat.s0 = self._stack[self._stack_size - 1] else: self._atom_feat.s0 = self._word_size if self._stack_size > 1: self._atom_feat.s1 = self._stack[self._stack_size - 2] else: self._atom_feat.s1 = self._word_size if self._stack_size > 2: self._atom_feat.s2 = self._stack[self._stack_size - 3] else: self._atom_feat.s2 = self._word_size if self._next_index >= 0 and self._next_index < self._word_size: self._atom_feat.q0 = self._next_index else: self._atom_feat.q0 = self._word_size if self._pre_action.is_arc_left() or self._pre_action.is_arc_right(): self._atom_feat.arc = True else: self._atom_feat.arc = False return self._atom_feat.index()
class State: def __init__(self): self._stack = [] for idx in range(max_length): self._stack.append(Node()) self._stack_size = 0 self._edu_size = 0 self._next_index = 0 self._word_size = 0 self._is_start = True self._is_gold = True self._inst = None self._pre_state = None self._atom_feat = AtomFeat() self._pre_action = Action(CODE.NO_ACTION) def ready(self, doc): self._inst = doc self._edu_size = len(self._inst.EDUs) def clear(self): self._next_index = 0 self._stack_size = 0 self._inst = None self._is_gold = True self._is_start = True self._pre_state = None self._pre_action = Action(CODE.NO_ACTION) self.done_mark() def done_mark(self): self._stack[self._stack_size].clear() def allow_shift(self): if self._next_index >= self._edu_size: return False else: return True def allow_pop_root(self): if self._stack_size == 1 and self._next_index == self._edu_size: return True else: return False def allow_reduce(self): if self._stack_size >= 2: return True else: return False def shift(self, next_state): assert self._next_index < self._edu_size next_state._stack_size = self._stack_size + 1 next_state._next_index = self._next_index + 1 self.copy_state(next_state) top = next_state._stack[next_state._stack_size - 1] top.clear() top.is_validate = True top.edu_start = self._next_index top.edu_end = self._next_index next_state.done_mark() next_state._pre_action.set(CODE.SHIFT) def reduce(self, next_state, nuclear, label): next_state._stack_size = self._stack_size - 1 next_state._next_index = self._next_index self.copy_state(next_state) top0 = next_state._stack[self._stack_size - 1] top1 = next_state._stack[self._stack_size - 2] assert top0.is_validate == True and top1.is_validate == True assert top0.edu_start == top1.edu_end + 1 top1.edu_end = top0.edu_end top1.nuclear = nuclear top1.label = label top0.clear() next_state.done_mark() next_state._pre_action.set(CODE.REDUCE, nuclear=nuclear, label=label) def pop_root(self, next_state): assert self._stack_size == 1 and self._next_index == self._edu_size next_state._next_index = self._edu_size next_state._stack_size = 0 self.copy_state(next_state) top0 = next_state._stack[self._stack_size - 1] assert top0.is_validate == True assert top0.edu_start == 0 and top0.edu_end + 1 == len(self._inst.EDUs) top0.clear() next_state.done_mark() next_state._pre_action.set(CODE.POP_ROOT) def move(self, next_state, action): next_state._is_start = False next_state._is_gold = False if action.is_shift(): self.shift(next_state) elif action.is_reduce(): self.reduce(next_state, action.nuclear, action.label) elif action.is_finish(): self.pop_root(next_state) else: print(" error state ") def get_candidate_actions(self, vocab): mask = np.array([False]*vocab.ac_size) if self.allow_reduce(): mask = mask | vocab.mask_reduce if self.is_end(): mask = mask | vocab.mask_no_action if self.allow_shift(): mask = mask | vocab.mask_shift if self.allow_pop_root(): mask = mask | vocab.mask_pop_root return ~mask def copy_state(self, next_state): next_state._stack[0:self._stack_size] = deepcopy(self._stack[0:self._stack_size]) next_state._edu_size = self._edu_size next_state._inst = self._inst next_state._pre_state = self def is_end(self): if self._pre_action.is_finish(): return True else: return False def get_result(self, vocab): result = Result() state_iter = self while not state_iter._pre_state._is_start: action = state_iter._pre_action pre_state = state_iter._pre_state if action.is_reduce(): assert pre_state._stack_size >= 2 right_node = pre_state._stack[pre_state._stack_size - 1] left_node = pre_state._stack[pre_state._stack_size - 2] left_subtree = SubTree() right_subtree = SubTree() left_subtree.edu_start = left_node.edu_start left_subtree.edu_end = left_node.edu_end right_subtree.edu_start = right_node.edu_start right_subtree.edu_end = right_node.edu_end if action.nuclear == NUCLEAR.NN: left_subtree.nuclear = nuclear_str right_subtree.nuclear = nuclear_str left_subtree.relation = vocab._id2rel[action.label] right_subtree.relation = vocab._id2rel[action.label] elif action.nuclear == NUCLEAR.SN: left_subtree.nuclear = satellite_str right_subtree.nuclear = nuclear_str left_subtree.relation = vocab._id2rel[action.label] right_subtree.relation = span_str elif action.nuclear == NUCLEAR.NS: left_subtree.nuclear = nuclear_str right_subtree.nuclear = satellite_str left_subtree.relation = span_str right_subtree.relation = vocab._id2rel[action.label] result.subtrees.insert(0, right_subtree) result.subtrees.insert(0, left_subtree) state_iter = state_iter._pre_state return result def prepare_index(self): if self._stack_size > 0: self._atom_feat.s0 = self._stack[self._stack_size - 1] else: self._atom_feat.s0 = None if self._stack_size > 1: self._atom_feat.s1 = self._stack[self._stack_size - 2] else: self._atom_feat.s1 = None if self._stack_size > 2: self._atom_feat.s2 = self._stack[self._stack_size - 3] else: self._atom_feat.s2 = None if self._next_index >= 0 and self._next_index < self._edu_size: self._atom_feat.q0 = self._next_index else: self._atom_feat.q0 = None return self._atom_feat