class TransitionParser(Model): def __init__(self, vocab: Vocabulary, hidden_dim: int, action_dim: int, ratio_dim: int, num_layers: int, word_dim: int = 0, text_field_embedder: TextFieldEmbedder = None, mces_metric: Metric = None, recurrent_dropout_probability: float = 0.0, layer_dropout_probability: float = 0.0, same_dropout_mask_per_instance: bool = True, input_dropout: float = 0.0, lemma_text_field_embedder: TextFieldEmbedder = None, pos_tag_embedding: Embedding = None, deprel_embedding: Embedding = None, bios_embedding: Embedding = None, lexcat_embedding: Embedding = None, ss_embedding: Embedding = None, ss2_embedding: Embedding = None, action_embedding: Embedding = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None ) -> None: super(TransitionParser, self).__init__(vocab, regularizer) self._primary_labeled_correct = 0 self._primary_unlabeled_correct = 0 self._primary_total_edges_predicted = 0 self._primary_total_edges_actual = 0 self._primary_exact_labeled_correct = 0 self._primary_exact_unlabeled_correct = 0 self._remote_labeled_correct = 0 self._remote_unlabeled_correct = 0 self._remote_total_edges_predicted = 0 self._remote_total_edges_actual = 0 self._remote_exact_labeled_correct = 0 self._remote_exact_unlabeled_correct = 0 self._total_sentences = 0 self.num_actions = vocab.get_vocab_size('actions') self.text_field_embedder = text_field_embedder self.lemma_text_field_embedder = lemma_text_field_embedder self._pos_tag_embedding = pos_tag_embedding self._deprel_embedding = deprel_embedding self._bios_embedding = bios_embedding self._lexcat_embedding = lexcat_embedding self._ss_embedding = ss_embedding self._ss2_embedding = ss2_embedding self._mces_metric = mces_metric node_dim = 0 if self.text_field_embedder: node_dim += word_dim for embedding in pos_tag_embedding, deprel_embedding, bios_embedding, lexcat_embedding, ss_embedding, \ ss2_embedding: if embedding: node_dim += embedding.output_dim self.node_dim = node_dim self.word_dim = word_dim self.hidden_dim = hidden_dim self.ratio_dim = ratio_dim self.action_dim = action_dim self.action_embedding = action_embedding if action_embedding is None: self.action_embedding = Embedding(num_embeddings=self.num_actions, embedding_dim=self.action_dim, trainable=False) # syntactic composition self.p_comp = torch.nn.Linear(self.hidden_dim * 5 + self.ratio_dim, node_dim) # parser state to hidden self.p_s2h = torch.nn.Linear(self.hidden_dim * 3 + self.ratio_dim, self.hidden_dim) # hidden to action self.p_act = torch.nn.Linear(self.hidden_dim + self.ratio_dim, self.num_actions) self.update_concept_node = torch.nn.Linear(self.hidden_dim + self.ratio_dim, node_dim) self.pempty_buffer_emb = torch.nn.Parameter(torch.randn(self.hidden_dim)) self.proot_stack_emb = torch.nn.Parameter(torch.randn(node_dim)) self.pempty_action_emb = torch.nn.Parameter(torch.randn(self.hidden_dim)) self.pempty_stack_emb = torch.nn.Parameter(torch.randn(self.hidden_dim)) self._input_dropout = Dropout(input_dropout) self.buffer = StackRnn(input_size=node_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.stack = StackRnn(input_size=node_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.action_stack = StackRnn(input_size=self.action_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) initializer(self) def expand_arc_with_descendants(self, arc_indices, total_node_num, len_tokens): ###step 1: construct graph graph = {} for token_idx in range(total_node_num): graph[token_idx] = {"in_degree": 0, "head_list": []} # construct graph given directed_arc_indices and arc_tags # key: id_of_point # value: a list of tuples -> [(id_of_head1, label),(id_of_head2, label),...] for arc in arc_indices: graph[arc[0]]["head_list"].append((arc[1], arc[2])) graph[arc[1]]["in_degree"] += 1 # i:head_point j:child_point› top_down_graph = [[] for i in range(total_node_num)] # N real point, 1 root point, concept_node_expect_root step2_top_down_graph = [[] for i in range(total_node_num)] topological_stack = [] for i in range(total_node_num): if graph[i]["in_degree"] == 0: topological_stack.append(i) for head_tuple_of_point_i in graph[i]["head_list"]: head = head_tuple_of_point_i[0] top_down_graph[head].append(i) step2_top_down_graph[head].append(i) ###step 2: construct topological order topological_order = [] # step2_top_down_graph=top_down_graph[:] while len(topological_stack) != 0: stack_0_node = topological_stack.pop() topological_order.append(stack_0_node) for i in graph: if stack_0_node in step2_top_down_graph[i]: step2_top_down_graph[i].remove(stack_0_node) graph[i]["in_degree"] -= 1 if graph[i]["in_degree"] == 0 and \ i not in topological_stack and \ i not in topological_order: topological_stack.append(i) ###step 3: expand arc indices using the nodes indices ordered by topological way expand_node_dict = {} for node_idx in range(total_node_num): expand_node_dict[node_idx] = top_down_graph[node_idx][:] for node_idx in topological_order: if len(expand_node_dict[node_idx]) == 0: # no childs continue expand_childs = expand_node_dict[node_idx][:] for child in expand_node_dict[node_idx]: expand_childs += expand_node_dict[child] expand_node_dict[node_idx] = expand_childs ###step 4: delete duplicate and concept node token_filter = set(list(i for i in range(len_tokens))) for node_idx in expand_node_dict: expand_node_dict[node_idx] = set(expand_node_dict[node_idx]) & token_filter ###step 5: expand arc indices using expand_node_dict arc_descendants = [] for arc_info in arc_indices: arc_info_0 = arc_info[0] if arc_info[0] < len_tokens else \ '-'.join([str(i) for i in sorted(expand_node_dict[arc_info[0]])]) arc_info_1 = arc_info[1] if arc_info[1] < len_tokens else \ '-'.join([str(i) for i in sorted(expand_node_dict[arc_info[1]])]) arc_descendants.append((arc_info_0, arc_info_1, arc_info[2])) return arc_descendants def _greedy_decode(self, batch_size: int, sent_len: List[int], embedded_text_input: torch.Tensor, oracle_actions: Optional[List[List[int]]] = None, ) -> Dict[str, Any]: self.buffer.reset_stack(batch_size) self.stack.reset_stack(batch_size) self.action_stack.reset_stack(batch_size) # We will keep track of all the losses we accumulate during parsing. # If some decision is unambiguous because it's the only thing valid given # the parser state, we will not model it. We only model what is ambiguous. losses = [[] for _ in range(batch_size)] ratio_factor_losses = [[] for _ in range(batch_size)] edge_list = [[] for _ in range(batch_size)] total_node_num = [0 for _ in range(batch_size)] action_list = [[] for _ in range(batch_size)] ret_top_node = [[] for _ in range(batch_size)] ret_concept_node = [[] for _ in range(batch_size)] # push the tokens onto the buffer (tokens is in reverse order) for token_idx in range(max(sent_len)): for sent_idx in range(batch_size): if sent_len[sent_idx] > token_idx: self.buffer.push(sent_idx, input=embedded_text_input[sent_idx][sent_len[sent_idx] - 1 - token_idx], extra={'token': sent_len[sent_idx] - token_idx - 1}) # init stack using proot_emb, considering batch for sent_idx in range(batch_size): self.stack.push(sent_idx, input=self.proot_stack_emb, extra={'token': sent_len[sent_idx]}) ret_top_node[sent_idx] = [sent_len[sent_idx]] action_id = { action_: [self.vocab.get_token_index(a, namespace='actions') for a in self.vocab.get_token_to_index_vocabulary('actions').keys() if a.startswith(action_)] for action_ in ["SHIFT", "REDUCE", "NODE", "REMOTE-NODE", "LEFT-EDGE", "RIGHT-EDGE", "LEFT-REMOTE", "RIGHT-REMOTE", "SWAP", "FINISH"] } # compute probability of each of the actions and choose an action # either from the oracle or if there is no oracle, based on the model trans_not_fin = True action_tag_for_terminate = [False] * batch_size action_sequence_length = [0] * batch_size concept_node = {} for sent_idx in range(batch_size): concept_node[sent_idx] = [sent_len[sent_idx]] while trans_not_fin: trans_not_fin = False for sent_idx in range(batch_size): if (len(concept_node[sent_idx]) > 50 * sent_len[sent_idx] or action_sequence_length[sent_idx] > 50 * sent_len[sent_idx]) and oracle_actions is None: continue total_node_num[sent_idx] = sent_len[sent_idx] + len(concept_node[sent_idx]) # if self.buffer.get_len(sent_idx) != 0: if action_tag_for_terminate[sent_idx] == False: trans_not_fin = True valid_actions = [] # given the buffer and stack, conclude the valid action list if self.buffer.get_len(sent_idx) == 0: valid_actions += action_id['FINISH'] if self.buffer.get_len(sent_idx) > 0: valid_actions += action_id['SHIFT'] try: if self.stack.get_len(sent_idx) > 0: valid_actions += action_id['REDUCE'] valid_actions += action_id['NODE'] valid_actions += action_id['REMOTE-NODE'] except: pass if self.stack.get_len(sent_idx) > 1: valid_actions += action_id['SWAP'] valid_actions += action_id['LEFT-EDGE'] valid_actions += action_id['RIGHT-EDGE'] valid_actions += action_id['LEFT-REMOTE'] valid_actions += action_id['RIGHT-REMOTE'] log_probs = None action = valid_actions[0] ratio_factor = torch.tensor([len(concept_node[sent_idx]) / (1.0 * sent_len[sent_idx])], device=self.pempty_action_emb.device) if len(valid_actions) > 1: stack_emb = self.stack.get_output(sent_idx) buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \ else self.buffer.get_output(sent_idx) action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \ else self.action_stack.get_output(sent_idx) p_t = torch.cat([buffer_emb, stack_emb, action_emb]) p_t = torch.cat([p_t, ratio_factor]) h = torch.tanh(self.p_s2h(p_t)) h = torch.cat([h, ratio_factor]) logits = self.p_act(h)[torch.tensor(valid_actions, dtype=torch.long, device=h.device)] valid_action_tbl = {a: i for i, a in enumerate(valid_actions)} log_probs = torch.log_softmax(logits, dim=0) action_idx = torch.max(log_probs, 0)[1].item() action = valid_actions[action_idx] if oracle_actions is not None: action = oracle_actions[sent_idx].pop(0) # push action into action_stack self.action_stack.push(sent_idx, input=self.action_embedding( torch.tensor(action, device=embedded_text_input.device)), extra={ 'token': self.vocab.get_token_from_index(action, namespace='actions')}) action_list[sent_idx].append(self.vocab.get_token_from_index(action, namespace='actions')) if log_probs is not None: # append the action-specific loss loss = log_probs[valid_action_tbl[action]] if not torch.isnan(loss): losses[sent_idx].append(loss) # generate concept node, recursive way if action in action_id["NODE"] + action_id["REMOTE-NODE"]: concept_node_token = len(concept_node[sent_idx]) + sent_len[sent_idx] concept_node[sent_idx].append(concept_node_token) stack_emb = self.stack.get_output(sent_idx) stack_emb = torch.cat([stack_emb, ratio_factor]) comp_rep = torch.tanh(self.update_concept_node(stack_emb)) node_input = comp_rep self.buffer.push(sent_idx, input=node_input, extra={'token': concept_node_token}) total_node_num[sent_idx] = sent_len[sent_idx] + len(concept_node[sent_idx]) if action in action_id["NODE"] + action_id["REMOTE-NODE"] + action_id["LEFT-EDGE"] \ + action_id["RIGHT-EDGE"] + action_id["LEFT-REMOTE"] + action_id["RIGHT-REMOTE"]: if action in action_id["NODE"] + action_id["REMOTE-NODE"]: head = self.buffer.get_stack(sent_idx)[-1] modifier = self.stack.get_stack(sent_idx)[-1] elif action in action_id["LEFT-EDGE"] + action_id["LEFT-REMOTE"]: head = self.stack.get_stack(sent_idx)[-1] modifier = self.stack.get_stack(sent_idx)[-2] else: head = self.stack.get_stack(sent_idx)[-2] modifier = self.stack.get_stack(sent_idx)[-1] (head_rep, head_tok) = (head['stack_rnn_output'], head['token']) (mod_rep, mod_tok) = (modifier['stack_rnn_output'], modifier['token']) if oracle_actions is None: edge_list[sent_idx].append((mod_tok, head_tok, self.vocab.get_token_from_index(action, namespace='actions') .split(':', maxsplit=1)[1])) # # compute composed representation action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \ else self.action_stack.get_output(sent_idx) stack_emb = self.pempty_stack_emb if self.stack.get_len(sent_idx) == 0 \ else self.stack.get_output(sent_idx) buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \ else self.buffer.get_output(sent_idx) comp_rep = torch.cat([head_rep, mod_rep, action_emb, buffer_emb, stack_emb, ratio_factor]) comp_rep = torch.tanh(self.p_comp(comp_rep)) if action in action_id["NODE"] + action_id["REMOTE-NODE"]: self.buffer.pop(sent_idx) self.buffer.push(sent_idx, input=comp_rep, extra={'token': head_tok}) elif action in action_id["LEFT-EDGE"] + action_id["LEFT-REMOTE"]: self.stack.pop(sent_idx) self.stack.push(sent_idx, input=comp_rep, extra={'token': head_tok}) # RIGHT-EDGE or RIGHT-REMOTE else: stack_0_rep = self.stack.get_stack(sent_idx)[-1]['stack_rnn_input'] self.stack.pop(sent_idx) self.stack.pop(sent_idx) self.stack.push(sent_idx, input=comp_rep, extra={'token': head_tok}) self.stack.push(sent_idx, input=stack_0_rep, extra={'token': mod_tok}) # ["SHIFT", "REDUCE", "NODE", "REMOTE-NODE", "LEFT-EDGE", "RIGHT-EDGE", "LEFT-REMOTE", "RIGHT-REMOTE", "SWAP","FINISH"] # Execute the action to update the parser state if action in action_id["REDUCE"]: self.stack.pop(sent_idx) elif action in action_id["SHIFT"]: buffer_top = self.buffer.pop(sent_idx) self.stack.push(sent_idx, input=buffer_top['stack_rnn_input'], extra={'token': buffer_top['token']}) elif action in action_id["SWAP"]: stack_penult = self.stack.pop_penult(sent_idx) self.buffer.push(sent_idx, input=stack_penult['stack_rnn_input'], extra={'token': stack_penult['token']}) elif action in action_id["FINISH"]: action_tag_for_terminate[sent_idx] = True ratio_factor_losses[sent_idx] = ratio_factor action_sequence_length[sent_idx] += 1 # categorical cross-entropy non_empty_losses = [torch.sum(torch.stack(cur_loss)) for cur_loss in losses if len(cur_loss) > 0] _loss_CCE = -torch.sum(torch.stack(non_empty_losses)) / sum([len(cur_loss) for cur_loss in losses]) \ if non_empty_losses else torch.Tensor([float('NaN')]) _loss = _loss_CCE ret = { 'loss': _loss, 'losses': losses, } # extract concept node list in batchmode for sent_idx in range(batch_size): ret_concept_node[sent_idx] = concept_node[sent_idx] ret["total_node_num"] = total_node_num if oracle_actions is None: ret['edge_list'] = edge_list ret['action_sequence'] = action_list ret['top_node'] = ret_top_node ret["concept_node"] = ret_concept_node return ret # Returns an expression of the loss for the sequence of actions. # (that is, the oracle_actions if present or the predicted sequence otherwise) def forward(self, tokens: Dict[str, torch.LongTensor], metadata: List[Dict[str, Any]], gold_actions: Dict[str, torch.LongTensor] = None, lemmas: Dict[str, torch.LongTensor] = None, pos_tags: torch.LongTensor = None, arc_tags: torch.LongTensor = None, deprels: torch.LongTensor = None, bios: torch.LongTensor = None, lexcat: torch.LongTensor = None, ss: torch.LongTensor = None, ss2: torch.LongTensor = None, ) -> Dict[str, torch.LongTensor]: batch_size = len(metadata) sent_len = [len(d['tokens']) for d in metadata] meta_info = [d['meta_info'] for d in metadata] oracle_actions = None if gold_actions is not None: oracle_actions = [d['gold_actions'] for d in metadata] oracle_actions = [[self.vocab.get_token_index(s, namespace='actions') for s in l] for l in oracle_actions] embeds = [embedder(field) for field, embedder in ((tokens, self.text_field_embedder), (pos_tags, self._pos_tag_embedding), (deprels, self._deprel_embedding), (bios, self._bios_embedding), (lexcat, self._lexcat_embedding), (ss, self._ss_embedding), (ss2, self._ss2_embedding)) if field is not None and embedder is not None] embedded_text_input = torch.cat(embeds, -1) if len(embeds) > 1 else embeds[0] embedded_text_input = self._input_dropout(embedded_text_input) if self.training: ret_train = self._greedy_decode(batch_size=batch_size, sent_len=sent_len, embedded_text_input=embedded_text_input, oracle_actions=oracle_actions) _loss = ret_train['loss'] output_dict = {'loss': _loss} return output_dict training_mode = self.training self.eval() with torch.no_grad(): ret_eval = self._greedy_decode(batch_size=batch_size, sent_len=sent_len, embedded_text_input=embedded_text_input) self.train(training_mode) edge_list = ret_eval['edge_list'] top_node_list = ret_eval['top_node'] _loss = ret_eval['loss'] # prediction-mode output_dict = { 'tokens': [d['tokens'] for d in metadata], 'loss': _loss, 'edge_list': edge_list, 'meta_info': meta_info, 'top_node': top_node_list, 'concept_node': ret_eval['concept_node'], 'tokens_range': [d['tokens_range'] for d in metadata] } # prediction-mode if gold_actions is not None: gold_mrps = [x["gold_mrps"] for x in metadata] predicted_mrps = [] for sent_idx in range(batch_size): if len(output_dict['edge_list'][sent_idx]) <= 5 * len(output_dict['tokens'][sent_idx]): predicted_mrps.append(ucca_trans_outputs_into_mrp({ 'tokens': output_dict['tokens'][sent_idx], 'edge_list': output_dict['edge_list'][sent_idx], 'meta_info': output_dict['meta_info'][sent_idx], 'top_node': output_dict['top_node'][sent_idx], 'concept_node': output_dict['concept_node'][sent_idx], 'tokens_range': output_dict['tokens_range'][sent_idx], })) self._mces_metric(predicted_mrps, gold_mrps) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._mces_metric is not None and not self.training: all_metrics.update(self._mces_metric.get_metric(reset=reset)) return all_metrics
class TransitionParser(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, word_dim: int, hidden_dim: int, action_dim: int, ratio_dim: int, num_layers: int, recurrent_dropout_probability: float = 0.0, layer_dropout_probability: float = 0.0, same_dropout_mask_per_instance: bool = True, input_dropout: float = 0.0, output_null_nodes: bool = True, max_heads: int = None, max_swaps_per_node: int = 3, fix_unconnected_egraph: bool = True, validate_every_n_instances: int = None, action_embedding: Embedding = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None ) -> None: super(TransitionParser, self).__init__(vocab, regularizer) self._total_validation_instances = 0 self.num_actions = vocab.get_vocab_size('actions') self.text_field_embedder = text_field_embedder self.output_null_nodes = output_null_nodes self.max_heads = max_heads self.max_swaps_per_node = max_swaps_per_node self._fix_unconnected_egraph = fix_unconnected_egraph self.num_validation_instances = validate_every_n_instances self._xud_score = XUDScore(collapse=self.output_null_nodes) self.word_dim = word_dim self.hidden_dim = hidden_dim self.ratio_dim = ratio_dim self.action_dim = action_dim self.action_embedding = action_embedding if action_embedding is None: self.action_embedding = Embedding(num_embeddings=self.num_actions, embedding_dim=action_dim, trainable=False) # syntactic composition self.p_comp = torch.nn.Linear(self.hidden_dim * 5 + self.ratio_dim, self.word_dim) # parser state to hidden self.p_s2h = torch.nn.Linear(self.hidden_dim * 3 + self.ratio_dim, self.hidden_dim) # hidden to action self.p_act = torch.nn.Linear(self.hidden_dim + self.ratio_dim, self.num_actions) self.update_null_node = torch.nn.Linear(self.hidden_dim + self.ratio_dim, self.word_dim) self.pempty_buffer_emb = torch.nn.Parameter(torch.randn(self.hidden_dim)) self.proot_stack_emb = torch.nn.Parameter(torch.randn(self.word_dim)) self.pempty_action_emb = torch.nn.Parameter(torch.randn(self.hidden_dim)) self.pempty_stack_emb = torch.nn.Parameter(torch.randn(self.hidden_dim)) self._input_dropout = Dropout(input_dropout) self.buffer = StackRnn(input_size=self.word_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.stack = StackRnn(input_size=self.word_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.action_stack = StackRnn(input_size=self.action_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) initializer(self) def _greedy_decode(self, batch_size: int, sent_len: List[int], embedded_text_input: torch.Tensor, oracle_actions: Optional[List[List[str]]] = None, ) -> Dict[str, Any]: self.buffer.reset_stack(batch_size) self.stack.reset_stack(batch_size) self.action_stack.reset_stack(batch_size) # We will keep track of all the losses we accumulate during parsing. # If some decision is unambiguous because it's the only thing valid given # the parser state, we will not model it. We only model what is ambiguous. losses = [[] for _ in range(batch_size)] ratio_factor_losses = [[] for _ in range(batch_size)] edge_list = [[] for _ in range(batch_size)] total_node_num = [0 for _ in range(batch_size)] action_list = [[] for _ in range(batch_size)] ret_node = [[] for _ in range(batch_size)] root_id = [[] for _ in range(batch_size)] num_of_generated_node= [[] for _ in range(batch_size)] generated_order = [{} for _ in range(batch_size)] head_count = [{} for _ in range(batch_size)] # keep track of the number of heads swap_count = [{} for _ in range(batch_size)] # keep track of the number of swaps reachable = [{} for _ in range(batch_size)] # set of node ids reachable from each node (by id) # push the tokens onto the buffer (tokens is in reverse order) for token_idx in range(max(sent_len)): for sent_idx in range(batch_size): if sent_len[sent_idx] > token_idx: try: self.buffer.push(sent_idx, input=embedded_text_input[sent_idx][sent_len[sent_idx] - 1 - token_idx], extra={'token': sent_len[sent_idx] - token_idx - 1}) except IndexError: raise IndexError(f"{sent_idx} {batch_size} {token_idx} {sent_len[sent_idx]} {embedded_text_input[sent_idx].dim()}") # init stack using proot_emb, considering batch for sent_idx in range(batch_size): root_id[sent_idx] = sent_len[sent_idx] reachable[sent_idx][root_id[sent_idx]] = {root_id[sent_idx]} # only root is reachable from root so far generated_order[sent_idx][root_id[sent_idx]] = 0 self.stack.push(sent_idx, input=self.proot_stack_emb, extra={'token': root_id[sent_idx]}) # compute probability of each of the actions and choose an action # either from the oracle or if there is no oracle, based on the model trans_not_fin = True action_tag_for_terminate = [False] * batch_size action_sequence_length = [0] * batch_size null_node = {} for sent_idx in range(batch_size): null_node[sent_idx] = [] while trans_not_fin: trans_not_fin = False for sent_idx in range(batch_size): if (action_sequence_length[sent_idx] > 10000 * sent_len[sent_idx]) and oracle_actions is None: raise RuntimeError(f"Too many actions for a sentence {sent_idx}. actions: {action_list}") total_node_num[sent_idx] = sent_len[sent_idx] + len(null_node[sent_idx]) if not action_tag_for_terminate[sent_idx]: trans_not_fin = True valid_actions = self.calc_valid_actions(edge_list, generated_order, head_count, swap_count, null_node, root_id, sent_idx, sent_len) log_probs = None action = valid_actions[0] action_idx = self.vocab.get_token_index(action, namespace='actions') ratio_factor = torch.tensor([len(null_node[sent_idx]) / (1.0 * sent_len[sent_idx])], device=self.pempty_action_emb.device) if len(valid_actions) > 1: action, action_idx, log_probs, valid_action_tbl = self.predict_action(ratio_factor, sent_idx, valid_actions) if oracle_actions is not None: action = oracle_actions[sent_idx].pop(0) action_idx = self.vocab.get_token_index(action, namespace='actions') # push action into action_stack self.action_stack.push(sent_idx, input=self.action_embedding( torch.tensor(action_idx, device=embedded_text_input.device)), extra={ 'token': action}) action_list[sent_idx].append(action) #print(f'Sent ID: {sent_idx}, action {action}') try: UNK_ID = self.vocab.get_token_index('@@UNKNOWN@@') if log_probs is not None and not (UNK_ID and action_idx == UNK_ID): loss = log_probs[valid_action_tbl[action_idx]] if not torch.isnan(loss): losses[sent_idx].append(loss) except KeyError: raise KeyError(f'action: {action}, valid actions: {valid_actions}') self.exec_action(action, action_sequence_length, action_tag_for_terminate, edge_list, generated_order, head_count, swap_count, null_node, num_of_generated_node, oracle_actions, ratio_factor, ratio_factor_losses, sent_idx, sent_len, total_node_num, reachable) if oracle_actions is None and self._fix_unconnected_egraph: # Fix edge_list if we are not training. If training, it's empty (no output) for sent_idx in range(batch_size): self.fix_unconnected_egraph(edge_list[sent_idx], reachable[sent_idx], root_id[sent_idx], generated_order[sent_idx]) # categorical cross-entropy _loss_CCE = -torch.sum( torch.stack([torch.sum(torch.stack(cur_loss)) for cur_loss in losses if len(cur_loss) > 0])) / \ sum([len(cur_loss) for cur_loss in losses]) _loss = _loss_CCE ret = { 'loss': _loss, 'losses': losses, } # extract null node list in batchmode for sent_idx in range(batch_size): ret_node[sent_idx] = null_node[sent_idx] ret["total_node_num"] = total_node_num if oracle_actions is None: ret['edge_list'] = edge_list ret['action_sequence'] = action_list ret["null_node"] = ret_node return ret def calc_valid_actions(self, edge_list, generated_order, head_count, swap_count, null_node, root_id, sent_idx, sent_len): valid_actions = [] # given the buffer and stack, conclude the valid action list if self.buffer.get_len(sent_idx) == 0 and self.stack.get_len(sent_idx) == 1: valid_actions += ['FINISH'] if self.buffer.get_len(sent_idx) > 0: valid_actions += ['SHIFT'] if self.stack.get_len(sent_idx) > 0: s0 = self.stack.get_stack(sent_idx)[-1]['token'] if s0 != root_id[sent_idx] and head_count[sent_idx][s0] > 0: valid_actions += ['REDUCE-0'] if self.output_null_nodes and \ len(null_node[sent_idx]) < sent_len[sent_idx]: # Max number of null nodes is the number of words # valid_actions += ['NODE'] # Support legacy models with "NODE:*" actions that also create an edge: node_possible_actions = [a for a in self.vocab.get_token_to_index_vocabulary('actions').keys() if a.startswith('NODE') and (":" not in a or s0 == root_id[sent_idx] or a.split('NODE:')[1] != "root")] # if len(node_possible_actions) > 1: # logger.warning(f"Possible node actions: {node_possible_actions}. " # f"NODE:* actions are deprecated - train a new model!") valid_actions += node_possible_actions if self.stack.get_len(sent_idx) > 1: s1 = self.stack.get_stack(sent_idx)[-2]['token'] if s1 != root_id[sent_idx] and \ generated_order[sent_idx][s0] > generated_order[sent_idx][s1] and \ swap_count[sent_idx][s1] < self.max_swaps_per_node: valid_actions += ['SWAP'] if s1 != root_id[sent_idx] and head_count[sent_idx][s1] > 0: valid_actions += ['REDUCE-1'] # Hacky code to verify that we do not draw the same edge with the same label twice labels_left_edge = ['root'] # should not be in vocab anyway actually but be safe labels_right_edge = [] for mod_tok, head_tok, label in edge_list[sent_idx]: if (mod_tok, head_tok) == (s1, s0): labels_left_edge.append(label) if (mod_tok, head_tok) == (s0, s1): labels_right_edge.append(label) if s1 != root_id[sent_idx] and (not self.max_heads or head_count[sent_idx][s1] < self.max_heads): left_edge_possible_actions = \ [a for a in self.vocab.get_token_to_index_vocabulary('actions').keys() if a.startswith('LEFT-EDGE') and a.split('LEFT-EDGE:')[1] not in labels_left_edge] valid_actions += left_edge_possible_actions if not self.max_heads or head_count[sent_idx][s0] < self.max_heads: if s1 == root_id[sent_idx]: right_edge_possible_actions = [] if "root" in labels_right_edge else ['RIGHT-EDGE:root'] else: # hack to disable root labels_right_edge += ['root'] right_edge_possible_actions = \ [a for a in self.vocab.get_token_to_index_vocabulary('actions').keys() if a.startswith('RIGHT-EDGE') and a.split('RIGHT-EDGE:')[1] not in labels_right_edge] valid_actions += right_edge_possible_actions # remove unknown actions: vocab_actions = self.vocab.get_token_to_index_vocabulary('actions').keys() valid_actions = [valid_action for valid_action in valid_actions if valid_action in vocab_actions] if not valid_actions: valid_actions = ["FINISH"] return valid_actions def predict_action(self, ratio_factor, sent_idx, valid_actions): stack_emb = self.stack.get_output(sent_idx) buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \ else self.buffer.get_output(sent_idx) action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \ else self.action_stack.get_output(sent_idx) p_t = torch.cat([buffer_emb, stack_emb, action_emb]) p_t = torch.cat([p_t, ratio_factor]) h = torch.tanh(self.p_s2h(p_t)) h = torch.cat([h, ratio_factor]) valid_action_idx = [self.vocab.get_token_index(a, namespace='actions') for a in valid_actions] logits = self.p_act(h)[torch.tensor(valid_action_idx, dtype=torch.long, device=h.device)] valid_action_tbl = {a: i for i, a in enumerate(valid_action_idx)} log_probs = torch.log_softmax(logits, dim=0) action_idx = torch.max(log_probs, 0)[1].item() action_idx = valid_action_idx[action_idx] action = self.vocab.get_token_from_index(action_idx, namespace='actions') return action, action_idx, log_probs, valid_action_tbl def exec_action(self, action, action_sequence_length, action_tag_for_terminate, edge_list, generated_order, head_count, swap_count, null_node, num_of_generated_node, oracle_actions, ratio_factor, ratio_factor_losses, sent_idx, sent_len, total_node_num, reachable: List[Dict[int, Set[int]]]): # generate null node, recursive way if action.startswith("NODE"): # Support legacy models with "NODE:*" actions that also create an edge null_node_token = len(null_node[sent_idx]) + sent_len[sent_idx] + 1 null_node[sent_idx].append(null_node_token) head_count[sent_idx][null_node_token] = swap_count[sent_idx][null_node_token] = 0 reachable[sent_idx][null_node_token] = {null_node_token} stack_emb = self.stack.get_output(sent_idx) stack_emb = torch.cat([stack_emb, ratio_factor]) comp_rep = torch.tanh(self.update_null_node(stack_emb)) node_input = comp_rep self.buffer.push(sent_idx, input=node_input, extra={'token': null_node_token}) total_node_num[sent_idx] = sent_len[sent_idx] + len(null_node[sent_idx]) # Support legacy models with "NODE:*" actions that also create an edge if action.startswith("NODE:") or action.startswith("LEFT-EDGE") or action.startswith("RIGHT-EDGE"): if action.startswith("NODE:"): logger.warning(f"Took action {action}") modifier = self.buffer.get_stack(sent_idx)[-1] head = self.stack.get_stack(sent_idx)[-1] elif action.startswith("LEFT-EDGE"): head = self.stack.get_stack(sent_idx)[-1] modifier = self.stack.get_stack(sent_idx)[-2] else: head = self.stack.get_stack(sent_idx)[-2] modifier = self.stack.get_stack(sent_idx)[-1] (head_rep, head_tok) = (head['stack_rnn_output'], head['token']) (mod_rep, mod_tok) = (modifier['stack_rnn_output'], modifier['token']) if oracle_actions is None: edge_list[sent_idx].append((mod_tok, head_tok, action.split(':', maxsplit=1)[1])) # propagate reachability reachable_from_mod = reachable[sent_idx][mod_tok] for tok, reachable_for_tok_set in reachable[sent_idx].items(): if head_tok in reachable_for_tok_set: reachable_for_tok_set |= reachable_from_mod action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \ else self.action_stack.get_output(sent_idx) stack_emb = self.pempty_stack_emb if self.stack.get_len(sent_idx) == 0 \ else self.stack.get_output(sent_idx) buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \ else self.buffer.get_output(sent_idx) # # compute composed representation comp_rep = torch.cat([head_rep, mod_rep, action_emb, buffer_emb, stack_emb, ratio_factor]) comp_rep = torch.tanh(self.p_comp(comp_rep)) if action.startswith("NODE:"): self.buffer.pop(sent_idx) self.buffer.push(sent_idx, input=comp_rep, extra={'token': mod_tok}) elif action.startswith("LEFT-EDGE"): self.stack.pop(sent_idx) self.stack.push(sent_idx, input=comp_rep, extra={'token': head_tok}) # RIGHT-EDGE else: stack_0_rep = self.stack.get_stack(sent_idx)[-1]['stack_rnn_input'] self.stack.pop(sent_idx) self.stack.pop(sent_idx) self.stack.push(sent_idx, input=comp_rep, extra={'token': head_tok}) self.stack.push(sent_idx, input=stack_0_rep, extra={'token': mod_tok}) head_count[sent_idx][mod_tok] += 1 # Execute the action to update the parser state elif action == "REDUCE-0": self.stack.pop(sent_idx) elif action == "REDUCE-1": stack_0 = self.stack.pop(sent_idx) self.stack.pop(sent_idx) self.stack.push(sent_idx, input=stack_0['stack_rnn_input'], extra={'token': stack_0['token']}) elif action == "SHIFT": buffer = self.buffer.pop(sent_idx) self.stack.push(sent_idx, input=buffer['stack_rnn_input'], extra={'token': buffer['token']}) s0 = self.stack.get_stack(sent_idx)[-1]['token'] if s0 not in generated_order[sent_idx]: num_of_generated_node[sent_idx] = len(generated_order[sent_idx]) generated_order[sent_idx][s0] = num_of_generated_node[sent_idx] head_count[sent_idx][s0] = swap_count[sent_idx][s0] = 0 reachable[sent_idx][s0] = {s0} elif action == "SWAP": stack_penult = self.stack.pop_penult(sent_idx) self.buffer.push(sent_idx, input=stack_penult['stack_rnn_input'], extra={'token': stack_penult['token']}) swap_count[sent_idx][stack_penult['token']] += 1 elif action == "FINISH": action_tag_for_terminate[sent_idx] = True ratio_factor_losses[sent_idx] = ratio_factor action_sequence_length[sent_idx] += 1 def fix_unconnected_egraph(self, edge_list, reachable, root_id, generated_order): """ Detect cycles in edge_list, select arbitrary element, attach to root's dependent. Postprocessing to avoid validation errors (validate.py --level 2): L2 Enhanced unconnected-egraph Since REDUCE-{0,1} has a precondition requiring a head, the only way to get an unconnect graph (which means there are nodes unreachable from the root) is if there are cycles. """ # Find the dependent of root since it will be the head of all orphans pred = None for mod_tok, head_tok, label in edge_list: if head_tok == root_id: pred = mod_tok break orphans = set(generated_order) - reachable[root_id] # All nodes not reachable from the root while orphans: tok = max(orphans, key=lambda x: len(reachable[x])) # Start from nodes with the most descendents: "roots" if pred is None: # print(f"No predicate (dependent of root) found: {edge_list}", file=sys.stderr) edge_list.append((tok, root_id, "root")) pred = tok else: edge_list.append((tok, pred, "orphan")) orphans -= reachable[tok] # All cycle members were taken care of by doing this # Returns an expression of the loss for the sequence of actions. # (that is, the oracle_actions if present or the predicted sequence otherwise) def forward(self, words: Dict[str, torch.LongTensor], metadata: List[Dict[str, Any]], gold_actions: Dict[str, torch.LongTensor] = None, ) -> Dict[str, torch.LongTensor]: batch_size = len(metadata) sent_len = [len(d['words']) for d in metadata] #oracle_actions = None if gold_actions is not None: oracle_actions = deepcopy([d['gold_actions'] for d in metadata]) embedded_text_input = self.text_field_embedder(words) embedded_text_input = self._input_dropout(embedded_text_input) if self.training: try: ret_train = self._greedy_decode(batch_size=batch_size, sent_len=sent_len, embedded_text_input=embedded_text_input, oracle_actions=oracle_actions) except IndexError: raise IndexError(f"{[d['words'] for d in metadata]}") _loss = ret_train['loss'] output_dict = {'loss': _loss} return output_dict else: #reset if self.num_validation_instances and (not self._total_validation_instances or self._total_validation_instances == self.num_validation_instances): self._total_validation_instances = batch_size else: self._total_validation_instances += batch_size #print(f'{self._total_validation_instances}/{self.num_validation_instances}') training_mode = self.training self.eval() with torch.no_grad(): ret_eval = self._greedy_decode(batch_size=batch_size, sent_len=sent_len, embedded_text_input=embedded_text_input) self.train(training_mode) edge_list = ret_eval['edge_list'] null_node = ret_eval['null_node'] _loss = ret_eval['loss'] # prediction-mode output_dict = { 'edge_list': edge_list, 'null_node': null_node, 'loss': _loss } for k in ["id", "form", "lemma", "upostag", "xpostag", "feats", "head", "deprel", "misc"]: output_dict[k] = [[token_metadata[k] for token_metadata in sentence_metadata['annotation']] for sentence_metadata in metadata] output_dict["multiwords"] = [sentence_metadata['multiwords'] for sentence_metadata in metadata] output_dict["sent_id"] = [sentence_metadata['sent_id'] for sentence_metadata in metadata] output_dict["text"] = [sentence_metadata['text'] for sentence_metadata in metadata] # validation mode # compute the accuracy when gold actions exist if gold_actions is not None: predicted_graphs = [] gold_graphs_conllu = [] for sent_idx in range(batch_size): predicted_graphs.append(eud_trans_outputs_into_conllu({ k:output_dict[k][sent_idx] for k in ["id", "form", "lemma", "upostag", "xpostag", "feats", "head", "deprel", "misc", "edge_list", "null_node", "multiwords", "text", "sent_id"] }, self.output_null_nodes)) gold_annotation = [{key: output_dict[key][sent_idx] for key in ("sent_id", "text") if output_dict[key][sent_idx]}] for annotation in metadata[sent_idx]['annotation']: gold_annotation.append(annotation) gold_graphs_conllu.append(annotation_to_conllu(gold_annotation, self.output_null_nodes)) predicted_graphs_conllu = [line for lines in predicted_graphs for line in lines] gold_graphs_conllu = [line for lines in gold_graphs_conllu for line in lines] self._xud_score(predicted_graphs_conllu, gold_graphs_conllu) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._xud_score is not None and not self.training: if self.num_validation_instances and self._total_validation_instances == self.num_validation_instances: all_metrics.update(self._xud_score.get_metric(reset=reset)) elif not self.num_validation_instances: all_metrics.update(self._xud_score.get_metric(reset=True)) return all_metrics