class TransitionParser(Model): def __init__(self, vocab: Vocabulary, hidden_dim: int, action_dim: int, ratio_dim: int, num_layers: int, word_dim: int = 0, text_field_embedder: TextFieldEmbedder = None, mces_metric: Metric = None, recurrent_dropout_probability: float = 0.0, layer_dropout_probability: float = 0.0, same_dropout_mask_per_instance: bool = True, input_dropout: float = 0.0, lemma_text_field_embedder: TextFieldEmbedder = None, pos_tag_embedding: Embedding = None, deprel_embedding: Embedding = None, bios_embedding: Embedding = None, lexcat_embedding: Embedding = None, ss_embedding: Embedding = None, ss2_embedding: Embedding = None, action_embedding: Embedding = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None ) -> None: super(TransitionParser, self).__init__(vocab, regularizer) self._primary_labeled_correct = 0 self._primary_unlabeled_correct = 0 self._primary_total_edges_predicted = 0 self._primary_total_edges_actual = 0 self._primary_exact_labeled_correct = 0 self._primary_exact_unlabeled_correct = 0 self._remote_labeled_correct = 0 self._remote_unlabeled_correct = 0 self._remote_total_edges_predicted = 0 self._remote_total_edges_actual = 0 self._remote_exact_labeled_correct = 0 self._remote_exact_unlabeled_correct = 0 self._total_sentences = 0 self.num_actions = vocab.get_vocab_size('actions') self.text_field_embedder = text_field_embedder self.lemma_text_field_embedder = lemma_text_field_embedder self._pos_tag_embedding = pos_tag_embedding self._deprel_embedding = deprel_embedding self._bios_embedding = bios_embedding self._lexcat_embedding = lexcat_embedding self._ss_embedding = ss_embedding self._ss2_embedding = ss2_embedding self._mces_metric = mces_metric node_dim = 0 if self.text_field_embedder: node_dim += word_dim for embedding in pos_tag_embedding, deprel_embedding, bios_embedding, lexcat_embedding, ss_embedding, \ ss2_embedding: if embedding: node_dim += embedding.output_dim self.node_dim = node_dim self.word_dim = word_dim self.hidden_dim = hidden_dim self.ratio_dim = ratio_dim self.action_dim = action_dim self.action_embedding = action_embedding if action_embedding is None: self.action_embedding = Embedding(num_embeddings=self.num_actions, embedding_dim=self.action_dim, trainable=False) # syntactic composition self.p_comp = torch.nn.Linear(self.hidden_dim * 5 + self.ratio_dim, node_dim) # parser state to hidden self.p_s2h = torch.nn.Linear(self.hidden_dim * 3 + self.ratio_dim, self.hidden_dim) # hidden to action self.p_act = torch.nn.Linear(self.hidden_dim + self.ratio_dim, self.num_actions) self.update_concept_node = torch.nn.Linear(self.hidden_dim + self.ratio_dim, node_dim) self.pempty_buffer_emb = torch.nn.Parameter(torch.randn(self.hidden_dim)) self.proot_stack_emb = torch.nn.Parameter(torch.randn(node_dim)) self.pempty_action_emb = torch.nn.Parameter(torch.randn(self.hidden_dim)) self.pempty_stack_emb = torch.nn.Parameter(torch.randn(self.hidden_dim)) self._input_dropout = Dropout(input_dropout) self.buffer = StackRnn(input_size=node_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.stack = StackRnn(input_size=node_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.action_stack = StackRnn(input_size=self.action_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) initializer(self) def expand_arc_with_descendants(self, arc_indices, total_node_num, len_tokens): ###step 1: construct graph graph = {} for token_idx in range(total_node_num): graph[token_idx] = {"in_degree": 0, "head_list": []} # construct graph given directed_arc_indices and arc_tags # key: id_of_point # value: a list of tuples -> [(id_of_head1, label),(id_of_head2, label),...] for arc in arc_indices: graph[arc[0]]["head_list"].append((arc[1], arc[2])) graph[arc[1]]["in_degree"] += 1 # i:head_point j:child_point› top_down_graph = [[] for i in range(total_node_num)] # N real point, 1 root point, concept_node_expect_root step2_top_down_graph = [[] for i in range(total_node_num)] topological_stack = [] for i in range(total_node_num): if graph[i]["in_degree"] == 0: topological_stack.append(i) for head_tuple_of_point_i in graph[i]["head_list"]: head = head_tuple_of_point_i[0] top_down_graph[head].append(i) step2_top_down_graph[head].append(i) ###step 2: construct topological order topological_order = [] # step2_top_down_graph=top_down_graph[:] while len(topological_stack) != 0: stack_0_node = topological_stack.pop() topological_order.append(stack_0_node) for i in graph: if stack_0_node in step2_top_down_graph[i]: step2_top_down_graph[i].remove(stack_0_node) graph[i]["in_degree"] -= 1 if graph[i]["in_degree"] == 0 and \ i not in topological_stack and \ i not in topological_order: topological_stack.append(i) ###step 3: expand arc indices using the nodes indices ordered by topological way expand_node_dict = {} for node_idx in range(total_node_num): expand_node_dict[node_idx] = top_down_graph[node_idx][:] for node_idx in topological_order: if len(expand_node_dict[node_idx]) == 0: # no childs continue expand_childs = expand_node_dict[node_idx][:] for child in expand_node_dict[node_idx]: expand_childs += expand_node_dict[child] expand_node_dict[node_idx] = expand_childs ###step 4: delete duplicate and concept node token_filter = set(list(i for i in range(len_tokens))) for node_idx in expand_node_dict: expand_node_dict[node_idx] = set(expand_node_dict[node_idx]) & token_filter ###step 5: expand arc indices using expand_node_dict arc_descendants = [] for arc_info in arc_indices: arc_info_0 = arc_info[0] if arc_info[0] < len_tokens else \ '-'.join([str(i) for i in sorted(expand_node_dict[arc_info[0]])]) arc_info_1 = arc_info[1] if arc_info[1] < len_tokens else \ '-'.join([str(i) for i in sorted(expand_node_dict[arc_info[1]])]) arc_descendants.append((arc_info_0, arc_info_1, arc_info[2])) return arc_descendants def _greedy_decode(self, batch_size: int, sent_len: List[int], embedded_text_input: torch.Tensor, oracle_actions: Optional[List[List[int]]] = None, ) -> Dict[str, Any]: self.buffer.reset_stack(batch_size) self.stack.reset_stack(batch_size) self.action_stack.reset_stack(batch_size) # We will keep track of all the losses we accumulate during parsing. # If some decision is unambiguous because it's the only thing valid given # the parser state, we will not model it. We only model what is ambiguous. losses = [[] for _ in range(batch_size)] ratio_factor_losses = [[] for _ in range(batch_size)] edge_list = [[] for _ in range(batch_size)] total_node_num = [0 for _ in range(batch_size)] action_list = [[] for _ in range(batch_size)] ret_top_node = [[] for _ in range(batch_size)] ret_concept_node = [[] for _ in range(batch_size)] # push the tokens onto the buffer (tokens is in reverse order) for token_idx in range(max(sent_len)): for sent_idx in range(batch_size): if sent_len[sent_idx] > token_idx: self.buffer.push(sent_idx, input=embedded_text_input[sent_idx][sent_len[sent_idx] - 1 - token_idx], extra={'token': sent_len[sent_idx] - token_idx - 1}) # init stack using proot_emb, considering batch for sent_idx in range(batch_size): self.stack.push(sent_idx, input=self.proot_stack_emb, extra={'token': sent_len[sent_idx]}) ret_top_node[sent_idx] = [sent_len[sent_idx]] action_id = { action_: [self.vocab.get_token_index(a, namespace='actions') for a in self.vocab.get_token_to_index_vocabulary('actions').keys() if a.startswith(action_)] for action_ in ["SHIFT", "REDUCE", "NODE", "REMOTE-NODE", "LEFT-EDGE", "RIGHT-EDGE", "LEFT-REMOTE", "RIGHT-REMOTE", "SWAP", "FINISH"] } # compute probability of each of the actions and choose an action # either from the oracle or if there is no oracle, based on the model trans_not_fin = True action_tag_for_terminate = [False] * batch_size action_sequence_length = [0] * batch_size concept_node = {} for sent_idx in range(batch_size): concept_node[sent_idx] = [sent_len[sent_idx]] while trans_not_fin: trans_not_fin = False for sent_idx in range(batch_size): if (len(concept_node[sent_idx]) > 50 * sent_len[sent_idx] or action_sequence_length[sent_idx] > 50 * sent_len[sent_idx]) and oracle_actions is None: continue total_node_num[sent_idx] = sent_len[sent_idx] + len(concept_node[sent_idx]) # if self.buffer.get_len(sent_idx) != 0: if action_tag_for_terminate[sent_idx] == False: trans_not_fin = True valid_actions = [] # given the buffer and stack, conclude the valid action list if self.buffer.get_len(sent_idx) == 0: valid_actions += action_id['FINISH'] if self.buffer.get_len(sent_idx) > 0: valid_actions += action_id['SHIFT'] try: if self.stack.get_len(sent_idx) > 0: valid_actions += action_id['REDUCE'] valid_actions += action_id['NODE'] valid_actions += action_id['REMOTE-NODE'] except: pass if self.stack.get_len(sent_idx) > 1: valid_actions += action_id['SWAP'] valid_actions += action_id['LEFT-EDGE'] valid_actions += action_id['RIGHT-EDGE'] valid_actions += action_id['LEFT-REMOTE'] valid_actions += action_id['RIGHT-REMOTE'] log_probs = None action = valid_actions[0] ratio_factor = torch.tensor([len(concept_node[sent_idx]) / (1.0 * sent_len[sent_idx])], device=self.pempty_action_emb.device) if len(valid_actions) > 1: stack_emb = self.stack.get_output(sent_idx) buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \ else self.buffer.get_output(sent_idx) action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \ else self.action_stack.get_output(sent_idx) p_t = torch.cat([buffer_emb, stack_emb, action_emb]) p_t = torch.cat([p_t, ratio_factor]) h = torch.tanh(self.p_s2h(p_t)) h = torch.cat([h, ratio_factor]) logits = self.p_act(h)[torch.tensor(valid_actions, dtype=torch.long, device=h.device)] valid_action_tbl = {a: i for i, a in enumerate(valid_actions)} log_probs = torch.log_softmax(logits, dim=0) action_idx = torch.max(log_probs, 0)[1].item() action = valid_actions[action_idx] if oracle_actions is not None: action = oracle_actions[sent_idx].pop(0) # push action into action_stack self.action_stack.push(sent_idx, input=self.action_embedding( torch.tensor(action, device=embedded_text_input.device)), extra={ 'token': self.vocab.get_token_from_index(action, namespace='actions')}) action_list[sent_idx].append(self.vocab.get_token_from_index(action, namespace='actions')) if log_probs is not None: # append the action-specific loss loss = log_probs[valid_action_tbl[action]] if not torch.isnan(loss): losses[sent_idx].append(loss) # generate concept node, recursive way if action in action_id["NODE"] + action_id["REMOTE-NODE"]: concept_node_token = len(concept_node[sent_idx]) + sent_len[sent_idx] concept_node[sent_idx].append(concept_node_token) stack_emb = self.stack.get_output(sent_idx) stack_emb = torch.cat([stack_emb, ratio_factor]) comp_rep = torch.tanh(self.update_concept_node(stack_emb)) node_input = comp_rep self.buffer.push(sent_idx, input=node_input, extra={'token': concept_node_token}) total_node_num[sent_idx] = sent_len[sent_idx] + len(concept_node[sent_idx]) if action in action_id["NODE"] + action_id["REMOTE-NODE"] + action_id["LEFT-EDGE"] \ + action_id["RIGHT-EDGE"] + action_id["LEFT-REMOTE"] + action_id["RIGHT-REMOTE"]: if action in action_id["NODE"] + action_id["REMOTE-NODE"]: head = self.buffer.get_stack(sent_idx)[-1] modifier = self.stack.get_stack(sent_idx)[-1] elif action in action_id["LEFT-EDGE"] + action_id["LEFT-REMOTE"]: head = self.stack.get_stack(sent_idx)[-1] modifier = self.stack.get_stack(sent_idx)[-2] else: head = self.stack.get_stack(sent_idx)[-2] modifier = self.stack.get_stack(sent_idx)[-1] (head_rep, head_tok) = (head['stack_rnn_output'], head['token']) (mod_rep, mod_tok) = (modifier['stack_rnn_output'], modifier['token']) if oracle_actions is None: edge_list[sent_idx].append((mod_tok, head_tok, self.vocab.get_token_from_index(action, namespace='actions') .split(':', maxsplit=1)[1])) # # compute composed representation action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \ else self.action_stack.get_output(sent_idx) stack_emb = self.pempty_stack_emb if self.stack.get_len(sent_idx) == 0 \ else self.stack.get_output(sent_idx) buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \ else self.buffer.get_output(sent_idx) comp_rep = torch.cat([head_rep, mod_rep, action_emb, buffer_emb, stack_emb, ratio_factor]) comp_rep = torch.tanh(self.p_comp(comp_rep)) if action in action_id["NODE"] + action_id["REMOTE-NODE"]: self.buffer.pop(sent_idx) self.buffer.push(sent_idx, input=comp_rep, extra={'token': head_tok}) elif action in action_id["LEFT-EDGE"] + action_id["LEFT-REMOTE"]: self.stack.pop(sent_idx) self.stack.push(sent_idx, input=comp_rep, extra={'token': head_tok}) # RIGHT-EDGE or RIGHT-REMOTE else: stack_0_rep = self.stack.get_stack(sent_idx)[-1]['stack_rnn_input'] self.stack.pop(sent_idx) self.stack.pop(sent_idx) self.stack.push(sent_idx, input=comp_rep, extra={'token': head_tok}) self.stack.push(sent_idx, input=stack_0_rep, extra={'token': mod_tok}) # ["SHIFT", "REDUCE", "NODE", "REMOTE-NODE", "LEFT-EDGE", "RIGHT-EDGE", "LEFT-REMOTE", "RIGHT-REMOTE", "SWAP","FINISH"] # Execute the action to update the parser state if action in action_id["REDUCE"]: self.stack.pop(sent_idx) elif action in action_id["SHIFT"]: buffer_top = self.buffer.pop(sent_idx) self.stack.push(sent_idx, input=buffer_top['stack_rnn_input'], extra={'token': buffer_top['token']}) elif action in action_id["SWAP"]: stack_penult = self.stack.pop_penult(sent_idx) self.buffer.push(sent_idx, input=stack_penult['stack_rnn_input'], extra={'token': stack_penult['token']}) elif action in action_id["FINISH"]: action_tag_for_terminate[sent_idx] = True ratio_factor_losses[sent_idx] = ratio_factor action_sequence_length[sent_idx] += 1 # categorical cross-entropy non_empty_losses = [torch.sum(torch.stack(cur_loss)) for cur_loss in losses if len(cur_loss) > 0] _loss_CCE = -torch.sum(torch.stack(non_empty_losses)) / sum([len(cur_loss) for cur_loss in losses]) \ if non_empty_losses else torch.Tensor([float('NaN')]) _loss = _loss_CCE ret = { 'loss': _loss, 'losses': losses, } # extract concept node list in batchmode for sent_idx in range(batch_size): ret_concept_node[sent_idx] = concept_node[sent_idx] ret["total_node_num"] = total_node_num if oracle_actions is None: ret['edge_list'] = edge_list ret['action_sequence'] = action_list ret['top_node'] = ret_top_node ret["concept_node"] = ret_concept_node return ret # Returns an expression of the loss for the sequence of actions. # (that is, the oracle_actions if present or the predicted sequence otherwise) def forward(self, tokens: Dict[str, torch.LongTensor], metadata: List[Dict[str, Any]], gold_actions: Dict[str, torch.LongTensor] = None, lemmas: Dict[str, torch.LongTensor] = None, pos_tags: torch.LongTensor = None, arc_tags: torch.LongTensor = None, deprels: torch.LongTensor = None, bios: torch.LongTensor = None, lexcat: torch.LongTensor = None, ss: torch.LongTensor = None, ss2: torch.LongTensor = None, ) -> Dict[str, torch.LongTensor]: batch_size = len(metadata) sent_len = [len(d['tokens']) for d in metadata] meta_info = [d['meta_info'] for d in metadata] oracle_actions = None if gold_actions is not None: oracle_actions = [d['gold_actions'] for d in metadata] oracle_actions = [[self.vocab.get_token_index(s, namespace='actions') for s in l] for l in oracle_actions] embeds = [embedder(field) for field, embedder in ((tokens, self.text_field_embedder), (pos_tags, self._pos_tag_embedding), (deprels, self._deprel_embedding), (bios, self._bios_embedding), (lexcat, self._lexcat_embedding), (ss, self._ss_embedding), (ss2, self._ss2_embedding)) if field is not None and embedder is not None] embedded_text_input = torch.cat(embeds, -1) if len(embeds) > 1 else embeds[0] embedded_text_input = self._input_dropout(embedded_text_input) if self.training: ret_train = self._greedy_decode(batch_size=batch_size, sent_len=sent_len, embedded_text_input=embedded_text_input, oracle_actions=oracle_actions) _loss = ret_train['loss'] output_dict = {'loss': _loss} return output_dict training_mode = self.training self.eval() with torch.no_grad(): ret_eval = self._greedy_decode(batch_size=batch_size, sent_len=sent_len, embedded_text_input=embedded_text_input) self.train(training_mode) edge_list = ret_eval['edge_list'] top_node_list = ret_eval['top_node'] _loss = ret_eval['loss'] # prediction-mode output_dict = { 'tokens': [d['tokens'] for d in metadata], 'loss': _loss, 'edge_list': edge_list, 'meta_info': meta_info, 'top_node': top_node_list, 'concept_node': ret_eval['concept_node'], 'tokens_range': [d['tokens_range'] for d in metadata] } # prediction-mode if gold_actions is not None: gold_mrps = [x["gold_mrps"] for x in metadata] predicted_mrps = [] for sent_idx in range(batch_size): if len(output_dict['edge_list'][sent_idx]) <= 5 * len(output_dict['tokens'][sent_idx]): predicted_mrps.append(ucca_trans_outputs_into_mrp({ 'tokens': output_dict['tokens'][sent_idx], 'edge_list': output_dict['edge_list'][sent_idx], 'meta_info': output_dict['meta_info'][sent_idx], 'top_node': output_dict['top_node'][sent_idx], 'concept_node': output_dict['concept_node'][sent_idx], 'tokens_range': output_dict['tokens_range'][sent_idx], })) self._mces_metric(predicted_mrps, gold_mrps) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._mces_metric is not None and not self.training: all_metrics.update(self._mces_metric.get_metric(reset=reset)) return all_metrics
class TransitionParser(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, word_dim: int, hidden_dim: int, action_dim: int, num_layers: int, mces_metric: Metric = None, recurrent_dropout_probability: float = 0.0, layer_dropout_probability: float = 0.0, same_dropout_mask_per_instance: bool = True, input_dropout: float = 0.0, lemma_text_field_embedder: TextFieldEmbedder = None, pos_tag_embedding: Embedding = None, action_embedding: Embedding = None, frame_tagger_encoder: Seq2SeqEncoder = None, pos_tagger_encoder: Seq2SeqEncoder = None, node_label_tagger_encoder: Seq2SeqEncoder = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(TransitionParser, self).__init__(vocab, regularizer) self._unlabeled_correct = 0 self._labeled_correct = 0 self._total_edges_predicted = 0 self._total_edges_actual = 0 self._exact_unlabeled_correct = 0 self._exact_labeled_correct = 0 self._total_sentences = 0 self.num_actions = vocab.get_vocab_size('actions') self.text_field_embedder = text_field_embedder self.pos_tag_embedding = pos_tag_embedding self._mces_metric = mces_metric self.action_embedding = action_embedding if action_embedding is None: self.action_embedding = Embedding(num_embeddings=self.num_actions, embedding_dim=action_dim, trainable=False) # syntactic composition self.p_comp = torch.nn.Linear(hidden_dim * 4, word_dim) # parser state to hidden self.p_s2h = torch.nn.Linear(hidden_dim * 4, hidden_dim) # hidden to action self.p_act = torch.nn.Linear(hidden_dim, self.num_actions) self.pempty_buffer_emb = torch.nn.Parameter(torch.randn(hidden_dim)) self.proot_stack_emb = torch.nn.Parameter(torch.randn(word_dim)) self.pempty_action_emb = torch.nn.Parameter(torch.randn(hidden_dim)) self.pempty_deque_emb = torch.nn.Parameter(torch.randn(hidden_dim)) self._input_dropout = Dropout(input_dropout) self.frame_tagger_encoder = frame_tagger_encoder self.pos_tagger_encoder = pos_tagger_encoder self.node_label_tagger_encoder = node_label_tagger_encoder self.buffer = StackRnn( input_size=word_dim, hidden_size=hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.stack = StackRnn( input_size=word_dim, hidden_size=hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.deque = StackRnn( input_size=word_dim, hidden_size=hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.action_stack = StackRnn( input_size=action_dim, hidden_size=hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.frame_tagger = SimpleTagger( vocab=vocab, text_field_embedder=text_field_embedder, encoder=self.frame_tagger_encoder, label_namespace='frame') self.pos_tagger = SimpleTagger(vocab=vocab, text_field_embedder=text_field_embedder, encoder=self.pos_tagger_encoder, label_namespace='pos_tag') self.node_label_tagger = SimpleTagger( vocab=vocab, text_field_embedder=text_field_embedder, encoder=self.node_label_tagger_encoder, label_namespace='node_label') initializer(self) def _greedy_decode( self, batch_size: int, sent_len: List[int], embedded_text_input: torch.Tensor, oracle_actions: Optional[List[List[int]]] = None ) -> Dict[str, Any]: self.buffer.reset_stack(batch_size) self.stack.reset_stack(batch_size) self.deque.reset_stack(batch_size) self.action_stack.reset_stack(batch_size) # We will keep track of all the losses we accumulate during parsing. # If some decision is unambiguous because it's the only thing valid given # the parser state, we will not model it. We only model what is ambiguous. losses = [[] for _ in range(batch_size)] edge_list = [[] for _ in range(batch_size)] # push the tokens onto the buffer (tokens is in reverse order) for token_idx in range(max(sent_len)): for sent_idx in range(batch_size): if sent_len[sent_idx] > token_idx: self.buffer.push( sent_idx, input=embedded_text_input[sent_idx][sent_len[sent_idx] - 1 - token_idx], extra={'token': sent_len[sent_idx] - token_idx}) # init stack using proot_emb, considering batch for sent_idx in range(batch_size): self.stack.push(sent_idx, input=self.proot_stack_emb, extra={'token': 0}) action_id = { action_: [ self.vocab.get_token_index(a, namespace='actions') for a in self.vocab.get_token_to_index_vocabulary('actions').keys() if a.startswith(action_) ] for action_ in ["LR", "LP", "RS", "RP", "NS", "NR", "NP"] } # compute probability of each of the actions and choose an action # either from the oracle or if there is no oracle, based on the model trans_not_fin = True while trans_not_fin: trans_not_fin = False for sent_idx in range(batch_size): if self.buffer.get_len(sent_idx) != 0: trans_not_fin = True valid_actions = [] # given the buffer and stack, conclude the valid action list if self.stack.get_len( sent_idx) > 1 and self.buffer.get_len( sent_idx) > 0: valid_actions += action_id['LR'] valid_actions += action_id['LP'] valid_actions += action_id['RP'] if self.buffer.get_len(sent_idx) > 0: valid_actions += action_id['NS'] valid_actions += action_id['RS'] # ROOT,NULL if self.stack.get_len(sent_idx) > 1: valid_actions += action_id['NR'] valid_actions += action_id['NP'] log_probs = None action = valid_actions[0] if len(valid_actions) > 1: stack_emb = self.stack.get_output(sent_idx) buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \ else self.buffer.get_output(sent_idx) action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \ else self.action_stack.get_output(sent_idx) deque_emb = self.pempty_deque_emb if self.deque.get_len(sent_idx) == 0 \ else self.deque.get_output(sent_idx) p_t = torch.cat( [buffer_emb, stack_emb, action_emb, deque_emb]) h = torch.tanh(self.p_s2h(p_t)) logits = self.p_act(h)[torch.tensor(valid_actions, dtype=torch.long, device=h.device)] valid_action_tbl = { a: i for i, a in enumerate(valid_actions) } log_probs = torch.log_softmax(logits, dim=0) action_idx = torch.max(log_probs, 0)[1].item() action = valid_actions[action_idx] if oracle_actions is not None: action = oracle_actions[sent_idx].pop(0) # push action into action_stack self.action_stack.push( sent_idx, input=self.action_embedding( torch.tensor(action, device=embedded_text_input.device)), extra={ 'token': self.vocab.get_token_from_index( action, namespace='actions') }) if log_probs is not None: # append the action-specific loss losses[sent_idx].append( log_probs[valid_action_tbl[action]]) if action in action_id["LR"] or action in action_id["LP"] or \ action in action_id["RS"] or action in action_id["RP"]: # figure out which is the head and which is the modifier if action in action_id["RS"] or action in action_id[ "RP"]: head = self.stack.get_stack(sent_idx)[-1] modifier = self.buffer.get_stack(sent_idx)[-1] else: head = self.buffer.get_stack(sent_idx)[-1] modifier = self.stack.get_stack(sent_idx)[-1] (head_rep, head_tok) = (head['stack_rnn_output'], head['token']) (mod_rep, mod_tok) = (modifier['stack_rnn_output'], modifier['token']) if oracle_actions is None: edge_list[sent_idx].append( (mod_tok, head_tok, self.vocab.get_token_from_index( action, namespace='actions').split( ':', maxsplit=1)[1])) # Execute the action to update the parser state # reduce if action in action_id["LR"] or action in action_id["NR"]: self.stack.pop(sent_idx) # pass elif action in action_id["LP"] or action in action_id[ "NP"] or action in action_id["RP"]: stack_top = self.stack.pop(sent_idx) self.deque.push(sent_idx, input=stack_top['stack_rnn_input'], extra={'token': stack_top['token']}) # shift elif action in action_id["RS"] or action in action_id["NS"]: while self.deque.get_len(sent_idx) > 0: deque_top = self.deque.pop(sent_idx) self.stack.push( sent_idx, input=deque_top['stack_rnn_input'], extra={'token': deque_top['token']}) buffer_top = self.buffer.pop(sent_idx) self.stack.push(sent_idx, input=buffer_top['stack_rnn_input'], extra={'token': buffer_top['token']}) _loss = -torch.sum( torch.stack([torch.sum(torch.stack(cur_loss)) for cur_loss in losses if len(cur_loss) > 0])) / \ sum([len(cur_loss) for cur_loss in losses]) ret = { 'loss': _loss, 'losses': losses, } if oracle_actions is None: ret['edge_list'] = edge_list return ret # Returns an expression of the loss for the sequence of actions. # (that is, the oracle_actions if present or the predicted sequence otherwise) def forward( self, tokens: Dict[str, torch.LongTensor], metadata: List[Dict[str, Any]], gold_actions: Dict[str, torch.LongTensor] = None, lemmas: Dict[str, torch.LongTensor] = None, mrp_pos_tags: torch.LongTensor = None, frame: torch.LongTensor = None, pos_tag: torch.LongTensor = None, node_label: torch.LongTensor = None, arc_tags: torch.LongTensor = None, ) -> Dict[str, torch.LongTensor]: batch_size = len(metadata) sent_len = [len(d['tokens']) for d in metadata] meta_info = [d['meta_info'] for d in metadata] oracle_actions = None if gold_actions is not None: oracle_actions = [d['gold_actions'] for d in metadata] oracle_actions = [[ self.vocab.get_token_index(s, namespace='actions') for s in l ] for l in oracle_actions] embedded_text_input = self.text_field_embedder(tokens) embedded_text_input = self._input_dropout(embedded_text_input) if self.training: ret_train = self._greedy_decode( batch_size=batch_size, sent_len=sent_len, embedded_text_input=embedded_text_input, oracle_actions=oracle_actions) frame_tagger_train_outputs = self.frame_tagger(tokens=tokens, tags=frame) frame_tagger_train_outputs = self.frame_tagger.decode( frame_tagger_train_outputs) pos_tagger_train_outputs = self.pos_tagger(tokens=tokens, tags=pos_tag) pos_tagger_train_outputs = self.pos_tagger.decode( pos_tagger_train_outputs) node_label_tagger_train_outputs = self.node_label_tagger( tokens=tokens, tags=node_label) node_label_tagger_train_outputs = self.node_label_tagger.decode( node_label_tagger_train_outputs) _loss = ret_train['loss'] + \ frame_tagger_train_outputs['loss'] + \ pos_tagger_train_outputs['loss'] + \ node_label_tagger_train_outputs['loss'] output_dict = {'loss': _loss} return output_dict training_mode = self.training self.eval() with torch.no_grad(): ret_eval = self._greedy_decode( batch_size=batch_size, sent_len=sent_len, embedded_text_input=embedded_text_input) if frame is not None: frame_tagger_eval_outputs = self.frame_tagger(tokens, tags=frame) else: frame_tagger_eval_outputs = self.frame_tagger(tokens) frame_tagger_eval_outputs = self.frame_tagger.decode( frame_tagger_eval_outputs) if pos_tag is not None: pos_tagger_eval_outputs = self.pos_tagger(tokens, tags=pos_tag) else: pos_tagger_eval_outputs = self.pos_tagger(tokens) pos_tagger_eval_outputs = self.pos_tagger.decode( pos_tagger_eval_outputs) if node_label is not None: node_label_tagger_eval_outputs = self.node_label_tagger( tokens, tags=node_label) else: node_label_tagger_eval_outputs = self.node_label_tagger(tokens) node_label_tagger_eval_outputs = self.node_label_tagger.decode( node_label_tagger_eval_outputs) self.train(training_mode) edge_list = ret_eval['edge_list'] if 'loss' in frame_tagger_eval_outputs and 'loss' in pos_tagger_eval_outputs: _loss = ret_eval['loss'] + \ frame_tagger_eval_outputs['loss'] + \ pos_tagger_eval_outputs['loss'] + \ node_label_tagger_eval_outputs['loss'] else: _loss = ret_eval['loss'] # prediction-mode output_dict = { 'tokens': [d['tokens'] for d in metadata], 'edge_list': edge_list, 'meta_info': meta_info, 'tokens_range': [d['tokens_range'] for d in metadata], 'frame': frame_tagger_eval_outputs["tags"], 'pos_tag': pos_tagger_eval_outputs["tags"], 'node_label': node_label_tagger_eval_outputs["tags"], 'loss': _loss } # prediction-mode # compute the mrp accuracy when gold actions exists if gold_actions is not None: gold_mrps = [x["gold_mrps"] for x in metadata] predicted_mrps = [] for sent_idx in range(batch_size): if len(output_dict['edge_list'][sent_idx]) <= 5 * len( output_dict['tokens'][sent_idx]): predicted_mrps.append( sdp_trans_outputs_into_mrp({ 'tokens': output_dict['tokens'][sent_idx], 'edge_list': output_dict['edge_list'][sent_idx], 'meta_info': output_dict['meta_info'][sent_idx], 'frame': output_dict['frame'][sent_idx], 'pos_tag': output_dict['pos_tag'][sent_idx], "node_label": output_dict['node_label'][sent_idx], 'tokens_range': output_dict['tokens_range'][sent_idx], })) self._mces_metric(predicted_mrps, gold_mrps) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._mces_metric is not None and not self.training: all_metrics.update(self._mces_metric.get_metric(reset=reset)) return all_metrics
class TransitionParser(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, word_dim: int, hidden_dim: int, action_dim: int, concept_label_dim: int, num_layers: int, mces_metric: Metric = None, recurrent_dropout_probability: float = 0.0, layer_dropout_probability: float = 0.0, same_dropout_mask_per_instance: bool = True, input_dropout: float = 0.0, lemma_text_field_embedder: TextFieldEmbedder = None, pos_tag_embedding: Embedding = None, action_embedding: Embedding = None, concept_label_embedding: Embedding = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None) -> None: super(TransitionParser, self).__init__(vocab, regularizer) self._unlabeled_correct = 0 self._labeled_correct = 0 self._total_edges_predicted = 0 self._total_edges_actual = 0 self._exact_unlabeled_correct = 0 self._exact_labeled_correct = 0 self._total_sentences = 0 self.num_actions = vocab.get_vocab_size('actions') self.num_concept_label = vocab.get_vocab_size('concept_label') self.text_field_embedder = text_field_embedder self.lemma_text_field_embedder = lemma_text_field_embedder self._pos_tag_embedding = pos_tag_embedding self._mces_metric = mces_metric self.word_dim = word_dim self.hidden_dim = hidden_dim self.action_dim = action_dim self.concept_label_dim = concept_label_dim self.action_embedding = action_embedding self.concept_label_embedding = concept_label_embedding if concept_label_embedding is None: self.concept_label_embedding = Embedding( num_embeddings=self.num_concept_label, embedding_dim=self.concept_label_dim, trainable=False) if action_embedding is None: self.action_embedding = Embedding(num_embeddings=self.num_actions, embedding_dim=self.action_dim, trainable=False) # syntactic composition self.p_comp = torch.nn.Linear(self.hidden_dim * 6, self.word_dim) # parser state to hidden self.p_s2h = torch.nn.Linear(self.hidden_dim * 4, self.hidden_dim) # hidden to action self.p_act = torch.nn.Linear(self.hidden_dim, self.num_actions) self.start_concept_node = torch.nn.Linear( self.hidden_dim + self.concept_label_dim, self.word_dim) self.end_concept_node = torch.nn.Linear( self.hidden_dim * 2 + self.concept_label_dim, self.word_dim) self.pempty_buffer_emb = torch.nn.Parameter( torch.randn(self.hidden_dim)) self.proot_stack_emb = torch.nn.Parameter(torch.randn(self.word_dim)) self.pempty_action_emb = torch.nn.Parameter( torch.randn(self.hidden_dim)) self.pempty_stack_emb = torch.nn.Parameter(torch.randn( self.hidden_dim)) self.pempty_deque_emb = torch.nn.Parameter(torch.randn( self.hidden_dim)) self._input_dropout = Dropout(input_dropout) self.buffer = StackRnn( input_size=self.word_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.stack = StackRnn( input_size=self.word_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.deque = StackRnn( input_size=self.word_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.action_stack = StackRnn( input_size=self.action_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) initializer(self) def _greedy_decode( self, batch_size: int, sent_len: List[int], embedded_text_input: torch.Tensor, oracle_actions: Optional[List[List[int]]] = None, ) -> Dict[str, Any]: self.buffer.reset_stack(batch_size) self.stack.reset_stack(batch_size) self.deque.reset_stack(batch_size) self.action_stack.reset_stack(batch_size) # We will keep track of all the losses we accumulate during parsing. # If some decision is unambiguous because it's the only thing valid given # the parser state, we will not model it. We only model what is ambiguous. losses = [[] for _ in range(batch_size)] ratio_factor_losses = [[] for _ in range(batch_size)] edge_list = [[] for _ in range(batch_size)] total_node_num = [0 for _ in range(batch_size)] action_list = [[] for _ in range(batch_size)] ret_top_node = [[] for _ in range(batch_size)] ret_concept_node = [[] for _ in range(batch_size)] broken_indices = set() # push the tokens onto the buffer (tokens is in reverse order) for token_idx in range(max(sent_len)): for sent_idx in range(batch_size): if sent_len[sent_idx] > token_idx: if sent_len[ sent_idx] - 1 - token_idx < embedded_text_input.shape[ 1]: self.buffer.push(sent_idx, input=embedded_text_input[sent_idx] [sent_len[sent_idx] - 1 - token_idx], extra={ 'token': sent_len[sent_idx] - token_idx - 1 }) else: broken_indices.add(sent_idx) if len(broken_indices) > 0: print('Broken indices: ' + str(broken_indices)) # init stack using proot_emb, considering batch for sent_idx in range(batch_size): self.stack.push(sent_idx, input=self.proot_stack_emb, extra={'token': 'protection_symbol'}) action_id = { action_: [ self.vocab.get_token_index(a, namespace='actions') for a in self.vocab.get_token_to_index_vocabulary('actions').keys() if a.startswith(action_) ] for action_ in [ "SHIFT", "REDUCE", "LEFT-EDGE", "RIGHT-EDGE", "SELF-EDGE", "DROP", "TOP", "PASS", "START", "END", "FINISH" ] } # compute probability of each of the actions and choose an action # either from the oracle or if there is no oracle, based on the model trans_not_fin = True action_tag_for_terminate = [False] * batch_size action_sequence_length = [0] * batch_size concept_node = {} for sent_idx in range(batch_size): concept_node[sent_idx] = {} while trans_not_fin: trans_not_fin = False for sent_idx in range(batch_size): if (len(concept_node[sent_idx]) > 50 * sent_len[sent_idx] or action_sequence_length[sent_idx] > 50 * sent_len[sent_idx]) and oracle_actions is None: continue total_node_num[sent_idx] = sent_len[sent_idx] + len( concept_node[sent_idx]) # if self.buffer.get_len(sent_idx) != 0: if not (self.buffer.get_len(sent_idx) == 0 and self.stack.get_len(sent_idx) == 1): trans_not_fin = True valid_actions = [] # given the buffer and stack, conclude the valid action list if self.buffer.get_len(sent_idx) == 0: valid_actions += action_id['FINISH'] if self.buffer.get_len(sent_idx) > 0: valid_actions += action_id['SHIFT'] valid_actions += action_id['DROP'] valid_actions += action_id['TOP'] buffer_token = self.buffer.get_stack( sent_idx)[-1]["token"] if buffer_token < sent_len[sent_idx]: valid_actions += action_id['START'] if self.buffer.get_len( sent_idx) > 0 and self.stack.get_len(sent_idx) > 1: valid_actions += action_id['LEFT-EDGE'] valid_actions += action_id['RIGHT-EDGE'] valid_actions += action_id['SELF-EDGE'] if self.stack.get_len(sent_idx) > 1: valid_actions += action_id['REDUCE'] valid_actions += action_id['PASS'] if self.buffer.get_len( sent_idx) > 0 and self.stack.get_len(sent_idx) > 1: concept_node_token = self.stack.get_stack( sent_idx)[-1]['token'] concept_alignment_end_token = self.buffer.get_stack( sent_idx)[-1]["token"] if not (concept_node_token not in concept_node[sent_idx] or concept_alignment_end_token >= sent_len[sent_idx]): valid_actions += action_id['END'] log_probs = None action = valid_actions[0] if len(valid_actions) > 1: stack_emb = self.stack.get_output(sent_idx) buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \ else self.buffer.get_output(sent_idx) action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \ else self.action_stack.get_output(sent_idx) deque_emb = self.pempty_deque_emb if self.deque.get_len(sent_idx) == 0 \ else self.deque.get_output(sent_idx) p_t = torch.cat( (buffer_emb, stack_emb, action_emb, deque_emb)) h = torch.tanh(self.p_s2h(p_t)) logits = self.p_act(h)[torch.tensor(valid_actions, dtype=torch.long, device=h.device)] valid_action_tbl = { a: i for i, a in enumerate(valid_actions) } log_probs = torch.log_softmax(logits, dim=0) action_idx = torch.max(log_probs, 0)[1].item() action = valid_actions[action_idx] if oracle_actions is not None: action = oracle_actions[sent_idx].pop(0) if log_probs is not None: # append the action-specific loss losses[sent_idx].append( log_probs[valid_action_tbl[action]]) # generate concept node, push it into buffer and align it with the second item in buffer if action in action_id["START"]: # get concept label and corresponding embedding concept_node_token = len( concept_node[sent_idx]) + sent_len[sent_idx] stack_emb = self.stack.get_output(sent_idx) concept_node_label_token = self.vocab.get_token_from_index(action, namespace='actions') \ .split('#SPLIT_TAG#', maxsplit=1)[1] concept_node_label = self.vocab.get_token_index( concept_node_label_token, namespace='concept_label') concept_node_label_emb = self.concept_label_embedding( torch.tensor(concept_node_label, device=embedded_text_input.device)) # init concept_alignment_begin = self.buffer.get_stack( sent_idx)[-1]["token"] concept_node[sent_idx][concept_node_token] = {"label": concept_node_label_token, \ "start": concept_alignment_begin} # insert comp_rep into buffer comp_rep = torch.tanh( self.start_concept_node( torch.cat( (stack_emb, concept_node_label_emb)))) self.buffer.push(sent_idx, input=comp_rep, extra={'token': concept_node_token}) # update total_node_num for early-stopping total_node_num[sent_idx] = sent_len[sent_idx] + len( concept_node[sent_idx]) # predice the span end of the node in S0 elif action in action_id["END"]: # get label embedding of concept node concept_node_token = self.stack.get_stack( sent_idx)[-1]['token'] concept_alignment_end_token = self.buffer.get_stack( sent_idx)[-1]["token"] # if concept_node_token not in concept_node[sent_idx] or concept_alignment_end_token>=sent_len[sent_idx]: # continue concept_node_label_token = concept_node[sent_idx][ concept_node_token]["label"] concept_node_label = self.vocab.get_token_index( concept_node_label_token, namespace='concept_label') concept_node_label_emb = self.concept_label_embedding( torch.tensor(concept_node_label, device=embedded_text_input.device)) # update concept info via inserting the span end of concept node concept_node[sent_idx][concept_node_token][ "end"] = concept_alignment_end_token # update node representation using a)begin compositioned embedding b)end embedding stack_emb = self.stack.get_output(sent_idx) buffer_emb = self.buffer.get_output(sent_idx) comp_rep = torch.tanh( self.end_concept_node( torch.cat((stack_emb, buffer_emb, concept_node_label_emb)))) self.stack.pop(sent_idx) self.stack.push(sent_idx, input=comp_rep, extra={'token': concept_node_token}) elif action in action_id["LEFT-EDGE"] + action_id[ "RIGHT-EDGE"] + action_id["SELF-EDGE"]: if action in action_id["LEFT-EDGE"]: head = self.buffer.get_stack(sent_idx)[-1] modifier = self.stack.get_stack(sent_idx)[-1] elif action in action_id["RIGHT-EDGE"]: head = self.stack.get_stack(sent_idx)[-1] modifier = self.buffer.get_stack(sent_idx)[-1] else: head = self.stack.get_stack(sent_idx)[-1] modifier = self.stack.get_stack(sent_idx)[-1] (head_rep, head_tok) = (head['stack_rnn_output'], head['token']) (mod_rep, mod_tok) = (modifier['stack_rnn_output'], modifier['token']) edge_list[sent_idx].append( (mod_tok, head_tok, self.vocab.get_token_from_index( action, namespace='actions').split('#SPLIT_TAG#', maxsplit=1)[1])) # compute composed representation action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \ else self.action_stack.get_output(sent_idx) stack_emb = self.pempty_stack_emb if self.stack.get_len(sent_idx) == 0 \ else self.stack.get_output(sent_idx) buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \ else self.buffer.get_output(sent_idx) deque_emb = self.pempty_deque_emb if self.deque.get_len(sent_idx) == 0 \ else self.deque.get_output(sent_idx) comp_rep = torch.cat( (head_rep, mod_rep, action_emb, buffer_emb, stack_emb, deque_emb)) comp_rep = torch.tanh(self.p_comp(comp_rep)) if action in action_id["LEFT-EDGE"]: self.buffer.pop(sent_idx) self.buffer.push(sent_idx, input=comp_rep, extra={'token': head_tok}) elif action in action_id["RIGHT-EDGE"] + action_id[ "SELF-EDGE"]: self.stack.pop(sent_idx) self.stack.push(sent_idx, input=comp_rep, extra={'token': head_tok}) elif action in action_id["REDUCE"]: self.stack.pop(sent_idx) elif action in action_id["TOP"]: ret_top_node[sent_idx] = self.buffer.get_stack( sent_idx)[-1]["token"] elif action in action_id["DROP"]: self.buffer.pop(sent_idx) while self.deque.get_len(sent_idx) > 0: deque_top = self.deque.pop(sent_idx) self.stack.push( sent_idx, input=deque_top['stack_rnn_input'], extra={'token': deque_top['token']}) elif action in action_id["PASS"]: stack_top = self.stack.pop(sent_idx) self.deque.push(sent_idx, input=stack_top['stack_rnn_input'], extra={'token': stack_top['token']}) elif action in action_id["SHIFT"]: while self.deque.get_len(sent_idx) > 0: deque_top = self.deque.pop(sent_idx) self.stack.push( sent_idx, input=deque_top['stack_rnn_input'], extra={'token': deque_top['token']}) buffer_top = self.buffer.pop(sent_idx) self.stack.push(sent_idx, input=buffer_top['stack_rnn_input'], extra={'token': buffer_top['token']}) # push action into action_stack self.action_stack.push( sent_idx, input=self.action_embedding( torch.tensor(action, device=embedded_text_input.device)), extra={ 'token': self.vocab.get_token_from_index( action, namespace='actions') }) action_list[sent_idx].append( self.vocab.get_token_from_index(action, namespace='actions')) action_sequence_length[sent_idx] += 1 # categorical cross-entropy _loss_CCE = -torch.sum( torch.stack([torch.sum(torch.stack(cur_loss)) for cur_loss in losses if len(cur_loss) > 0])) / \ sum([len(cur_loss) for cur_loss in losses]) _loss = _loss_CCE ret = { 'loss': _loss, 'losses': losses, } # extract concept node list in batchmode for sent_idx in range(batch_size): ret_concept_node[sent_idx] = concept_node[sent_idx] for idx in broken_indices: total_node_num[idx] = 0 edge_list[idx] = [] action_list[idx] = [] ret_top_node[idx] = 0 ret_concept_node[idx] = {} ret["total_node_num"] = total_node_num ret['edge_list'] = edge_list ret['action_sequence'] = action_list ret['top_node'] = ret_top_node ret["concept_node"] = ret_concept_node return ret # Returns an expression of the loss for the sequence of actions. # (that is, the oracle_actions if present or the predicted sequence otherwise) def forward( self, tokens: Dict[str, torch.LongTensor], metadata: List[Dict[str, Any]], gold_actions: Dict[str, torch.LongTensor] = None, lemmas: Dict[str, torch.LongTensor] = None, pos_tags: torch.LongTensor = None, arc_tags: torch.LongTensor = None, concept_label: torch.LongTensor = None, ) -> Dict[str, torch.LongTensor]: batch_size = len(metadata) sent_len = [len(d['tokens']) for d in metadata] meta_tokens = [d['tokens'] for d in metadata] meta_info = [d['meta_info'] for d in metadata] oracle_actions = None if gold_actions is not None: oracle_actions = [d['gold_actions'] for d in metadata] oracle_actions = [[ self.vocab.get_token_index(s, namespace='actions') for s in l ] for l in oracle_actions] embedded_text_input = self.text_field_embedder(tokens) embedded_text_input = self._input_dropout(embedded_text_input) if self.training: ret_train = self._greedy_decode( batch_size=batch_size, sent_len=sent_len, embedded_text_input=embedded_text_input, oracle_actions=oracle_actions) _loss = ret_train['loss'] output_dict = {'loss': _loss} return output_dict training_mode = self.training self.eval() with torch.no_grad(): ret_eval = self._greedy_decode( batch_size=batch_size, sent_len=sent_len, embedded_text_input=embedded_text_input) self.train(training_mode) edge_list = ret_eval['edge_list'] top_node_list = ret_eval['top_node'] _loss = ret_eval['loss'] output_dict = { 'tokens': [d['tokens'] for d in metadata], 'loss': _loss, 'edge_list': edge_list, 'meta_info': meta_info, 'top_node': top_node_list, 'concept_node': ret_eval['concept_node'], 'tokens_range': [d['tokens_range'] for d in metadata] } # prediction-mode # compute the mrp accuracy when gold actions exists if gold_actions is not None: gold_mrps = [x["gold_mrps"] for x in metadata] predicted_mrps = [] for sent_idx in range(batch_size): predicted_mrps.append( eds_trans_outputs_into_mrp({ 'tokens': output_dict['tokens'][sent_idx], 'edge_list': output_dict['edge_list'][sent_idx], 'meta_info': output_dict['meta_info'][sent_idx], 'top_node': output_dict['top_node'][sent_idx], 'concept_node': output_dict['concept_node'][sent_idx], 'tokens_range': output_dict['tokens_range'][sent_idx], })) self._mces_metric(predicted_mrps, gold_mrps) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._mces_metric is not None and not self.training: all_metrics.update(self._mces_metric.get_metric(reset=reset)) return all_metrics
class TransitionParser(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, word_dim: int, hidden_dim: int, action_dim: int, ratio_dim: int, num_layers: int, recurrent_dropout_probability: float = 0.0, layer_dropout_probability: float = 0.0, same_dropout_mask_per_instance: bool = True, input_dropout: float = 0.0, output_null_nodes: bool = True, max_heads: int = None, max_swaps_per_node: int = 3, fix_unconnected_egraph: bool = True, validate_every_n_instances: int = None, action_embedding: Embedding = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None ) -> None: super(TransitionParser, self).__init__(vocab, regularizer) self._total_validation_instances = 0 self.num_actions = vocab.get_vocab_size('actions') self.text_field_embedder = text_field_embedder self.output_null_nodes = output_null_nodes self.max_heads = max_heads self.max_swaps_per_node = max_swaps_per_node self._fix_unconnected_egraph = fix_unconnected_egraph self.num_validation_instances = validate_every_n_instances self._xud_score = XUDScore(collapse=self.output_null_nodes) self.word_dim = word_dim self.hidden_dim = hidden_dim self.ratio_dim = ratio_dim self.action_dim = action_dim self.action_embedding = action_embedding if action_embedding is None: self.action_embedding = Embedding(num_embeddings=self.num_actions, embedding_dim=action_dim, trainable=False) # syntactic composition self.p_comp = torch.nn.Linear(self.hidden_dim * 5 + self.ratio_dim, self.word_dim) # parser state to hidden self.p_s2h = torch.nn.Linear(self.hidden_dim * 3 + self.ratio_dim, self.hidden_dim) # hidden to action self.p_act = torch.nn.Linear(self.hidden_dim + self.ratio_dim, self.num_actions) self.update_null_node = torch.nn.Linear(self.hidden_dim + self.ratio_dim, self.word_dim) self.pempty_buffer_emb = torch.nn.Parameter(torch.randn(self.hidden_dim)) self.proot_stack_emb = torch.nn.Parameter(torch.randn(self.word_dim)) self.pempty_action_emb = torch.nn.Parameter(torch.randn(self.hidden_dim)) self.pempty_stack_emb = torch.nn.Parameter(torch.randn(self.hidden_dim)) self._input_dropout = Dropout(input_dropout) self.buffer = StackRnn(input_size=self.word_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.stack = StackRnn(input_size=self.word_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.action_stack = StackRnn(input_size=self.action_dim, hidden_size=self.hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) initializer(self) def _greedy_decode(self, batch_size: int, sent_len: List[int], embedded_text_input: torch.Tensor, oracle_actions: Optional[List[List[str]]] = None, ) -> Dict[str, Any]: self.buffer.reset_stack(batch_size) self.stack.reset_stack(batch_size) self.action_stack.reset_stack(batch_size) # We will keep track of all the losses we accumulate during parsing. # If some decision is unambiguous because it's the only thing valid given # the parser state, we will not model it. We only model what is ambiguous. losses = [[] for _ in range(batch_size)] ratio_factor_losses = [[] for _ in range(batch_size)] edge_list = [[] for _ in range(batch_size)] total_node_num = [0 for _ in range(batch_size)] action_list = [[] for _ in range(batch_size)] ret_node = [[] for _ in range(batch_size)] root_id = [[] for _ in range(batch_size)] num_of_generated_node= [[] for _ in range(batch_size)] generated_order = [{} for _ in range(batch_size)] head_count = [{} for _ in range(batch_size)] # keep track of the number of heads swap_count = [{} for _ in range(batch_size)] # keep track of the number of swaps reachable = [{} for _ in range(batch_size)] # set of node ids reachable from each node (by id) # push the tokens onto the buffer (tokens is in reverse order) for token_idx in range(max(sent_len)): for sent_idx in range(batch_size): if sent_len[sent_idx] > token_idx: try: self.buffer.push(sent_idx, input=embedded_text_input[sent_idx][sent_len[sent_idx] - 1 - token_idx], extra={'token': sent_len[sent_idx] - token_idx - 1}) except IndexError: raise IndexError(f"{sent_idx} {batch_size} {token_idx} {sent_len[sent_idx]} {embedded_text_input[sent_idx].dim()}") # init stack using proot_emb, considering batch for sent_idx in range(batch_size): root_id[sent_idx] = sent_len[sent_idx] reachable[sent_idx][root_id[sent_idx]] = {root_id[sent_idx]} # only root is reachable from root so far generated_order[sent_idx][root_id[sent_idx]] = 0 self.stack.push(sent_idx, input=self.proot_stack_emb, extra={'token': root_id[sent_idx]}) # compute probability of each of the actions and choose an action # either from the oracle or if there is no oracle, based on the model trans_not_fin = True action_tag_for_terminate = [False] * batch_size action_sequence_length = [0] * batch_size null_node = {} for sent_idx in range(batch_size): null_node[sent_idx] = [] while trans_not_fin: trans_not_fin = False for sent_idx in range(batch_size): if (action_sequence_length[sent_idx] > 10000 * sent_len[sent_idx]) and oracle_actions is None: raise RuntimeError(f"Too many actions for a sentence {sent_idx}. actions: {action_list}") total_node_num[sent_idx] = sent_len[sent_idx] + len(null_node[sent_idx]) if not action_tag_for_terminate[sent_idx]: trans_not_fin = True valid_actions = self.calc_valid_actions(edge_list, generated_order, head_count, swap_count, null_node, root_id, sent_idx, sent_len) log_probs = None action = valid_actions[0] action_idx = self.vocab.get_token_index(action, namespace='actions') ratio_factor = torch.tensor([len(null_node[sent_idx]) / (1.0 * sent_len[sent_idx])], device=self.pempty_action_emb.device) if len(valid_actions) > 1: action, action_idx, log_probs, valid_action_tbl = self.predict_action(ratio_factor, sent_idx, valid_actions) if oracle_actions is not None: action = oracle_actions[sent_idx].pop(0) action_idx = self.vocab.get_token_index(action, namespace='actions') # push action into action_stack self.action_stack.push(sent_idx, input=self.action_embedding( torch.tensor(action_idx, device=embedded_text_input.device)), extra={ 'token': action}) action_list[sent_idx].append(action) #print(f'Sent ID: {sent_idx}, action {action}') try: UNK_ID = self.vocab.get_token_index('@@UNKNOWN@@') if log_probs is not None and not (UNK_ID and action_idx == UNK_ID): loss = log_probs[valid_action_tbl[action_idx]] if not torch.isnan(loss): losses[sent_idx].append(loss) except KeyError: raise KeyError(f'action: {action}, valid actions: {valid_actions}') self.exec_action(action, action_sequence_length, action_tag_for_terminate, edge_list, generated_order, head_count, swap_count, null_node, num_of_generated_node, oracle_actions, ratio_factor, ratio_factor_losses, sent_idx, sent_len, total_node_num, reachable) if oracle_actions is None and self._fix_unconnected_egraph: # Fix edge_list if we are not training. If training, it's empty (no output) for sent_idx in range(batch_size): self.fix_unconnected_egraph(edge_list[sent_idx], reachable[sent_idx], root_id[sent_idx], generated_order[sent_idx]) # categorical cross-entropy _loss_CCE = -torch.sum( torch.stack([torch.sum(torch.stack(cur_loss)) for cur_loss in losses if len(cur_loss) > 0])) / \ sum([len(cur_loss) for cur_loss in losses]) _loss = _loss_CCE ret = { 'loss': _loss, 'losses': losses, } # extract null node list in batchmode for sent_idx in range(batch_size): ret_node[sent_idx] = null_node[sent_idx] ret["total_node_num"] = total_node_num if oracle_actions is None: ret['edge_list'] = edge_list ret['action_sequence'] = action_list ret["null_node"] = ret_node return ret def calc_valid_actions(self, edge_list, generated_order, head_count, swap_count, null_node, root_id, sent_idx, sent_len): valid_actions = [] # given the buffer and stack, conclude the valid action list if self.buffer.get_len(sent_idx) == 0 and self.stack.get_len(sent_idx) == 1: valid_actions += ['FINISH'] if self.buffer.get_len(sent_idx) > 0: valid_actions += ['SHIFT'] if self.stack.get_len(sent_idx) > 0: s0 = self.stack.get_stack(sent_idx)[-1]['token'] if s0 != root_id[sent_idx] and head_count[sent_idx][s0] > 0: valid_actions += ['REDUCE-0'] if self.output_null_nodes and \ len(null_node[sent_idx]) < sent_len[sent_idx]: # Max number of null nodes is the number of words # valid_actions += ['NODE'] # Support legacy models with "NODE:*" actions that also create an edge: node_possible_actions = [a for a in self.vocab.get_token_to_index_vocabulary('actions').keys() if a.startswith('NODE') and (":" not in a or s0 == root_id[sent_idx] or a.split('NODE:')[1] != "root")] # if len(node_possible_actions) > 1: # logger.warning(f"Possible node actions: {node_possible_actions}. " # f"NODE:* actions are deprecated - train a new model!") valid_actions += node_possible_actions if self.stack.get_len(sent_idx) > 1: s1 = self.stack.get_stack(sent_idx)[-2]['token'] if s1 != root_id[sent_idx] and \ generated_order[sent_idx][s0] > generated_order[sent_idx][s1] and \ swap_count[sent_idx][s1] < self.max_swaps_per_node: valid_actions += ['SWAP'] if s1 != root_id[sent_idx] and head_count[sent_idx][s1] > 0: valid_actions += ['REDUCE-1'] # Hacky code to verify that we do not draw the same edge with the same label twice labels_left_edge = ['root'] # should not be in vocab anyway actually but be safe labels_right_edge = [] for mod_tok, head_tok, label in edge_list[sent_idx]: if (mod_tok, head_tok) == (s1, s0): labels_left_edge.append(label) if (mod_tok, head_tok) == (s0, s1): labels_right_edge.append(label) if s1 != root_id[sent_idx] and (not self.max_heads or head_count[sent_idx][s1] < self.max_heads): left_edge_possible_actions = \ [a for a in self.vocab.get_token_to_index_vocabulary('actions').keys() if a.startswith('LEFT-EDGE') and a.split('LEFT-EDGE:')[1] not in labels_left_edge] valid_actions += left_edge_possible_actions if not self.max_heads or head_count[sent_idx][s0] < self.max_heads: if s1 == root_id[sent_idx]: right_edge_possible_actions = [] if "root" in labels_right_edge else ['RIGHT-EDGE:root'] else: # hack to disable root labels_right_edge += ['root'] right_edge_possible_actions = \ [a for a in self.vocab.get_token_to_index_vocabulary('actions').keys() if a.startswith('RIGHT-EDGE') and a.split('RIGHT-EDGE:')[1] not in labels_right_edge] valid_actions += right_edge_possible_actions # remove unknown actions: vocab_actions = self.vocab.get_token_to_index_vocabulary('actions').keys() valid_actions = [valid_action for valid_action in valid_actions if valid_action in vocab_actions] if not valid_actions: valid_actions = ["FINISH"] return valid_actions def predict_action(self, ratio_factor, sent_idx, valid_actions): stack_emb = self.stack.get_output(sent_idx) buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \ else self.buffer.get_output(sent_idx) action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \ else self.action_stack.get_output(sent_idx) p_t = torch.cat([buffer_emb, stack_emb, action_emb]) p_t = torch.cat([p_t, ratio_factor]) h = torch.tanh(self.p_s2h(p_t)) h = torch.cat([h, ratio_factor]) valid_action_idx = [self.vocab.get_token_index(a, namespace='actions') for a in valid_actions] logits = self.p_act(h)[torch.tensor(valid_action_idx, dtype=torch.long, device=h.device)] valid_action_tbl = {a: i for i, a in enumerate(valid_action_idx)} log_probs = torch.log_softmax(logits, dim=0) action_idx = torch.max(log_probs, 0)[1].item() action_idx = valid_action_idx[action_idx] action = self.vocab.get_token_from_index(action_idx, namespace='actions') return action, action_idx, log_probs, valid_action_tbl def exec_action(self, action, action_sequence_length, action_tag_for_terminate, edge_list, generated_order, head_count, swap_count, null_node, num_of_generated_node, oracle_actions, ratio_factor, ratio_factor_losses, sent_idx, sent_len, total_node_num, reachable: List[Dict[int, Set[int]]]): # generate null node, recursive way if action.startswith("NODE"): # Support legacy models with "NODE:*" actions that also create an edge null_node_token = len(null_node[sent_idx]) + sent_len[sent_idx] + 1 null_node[sent_idx].append(null_node_token) head_count[sent_idx][null_node_token] = swap_count[sent_idx][null_node_token] = 0 reachable[sent_idx][null_node_token] = {null_node_token} stack_emb = self.stack.get_output(sent_idx) stack_emb = torch.cat([stack_emb, ratio_factor]) comp_rep = torch.tanh(self.update_null_node(stack_emb)) node_input = comp_rep self.buffer.push(sent_idx, input=node_input, extra={'token': null_node_token}) total_node_num[sent_idx] = sent_len[sent_idx] + len(null_node[sent_idx]) # Support legacy models with "NODE:*" actions that also create an edge if action.startswith("NODE:") or action.startswith("LEFT-EDGE") or action.startswith("RIGHT-EDGE"): if action.startswith("NODE:"): logger.warning(f"Took action {action}") modifier = self.buffer.get_stack(sent_idx)[-1] head = self.stack.get_stack(sent_idx)[-1] elif action.startswith("LEFT-EDGE"): head = self.stack.get_stack(sent_idx)[-1] modifier = self.stack.get_stack(sent_idx)[-2] else: head = self.stack.get_stack(sent_idx)[-2] modifier = self.stack.get_stack(sent_idx)[-1] (head_rep, head_tok) = (head['stack_rnn_output'], head['token']) (mod_rep, mod_tok) = (modifier['stack_rnn_output'], modifier['token']) if oracle_actions is None: edge_list[sent_idx].append((mod_tok, head_tok, action.split(':', maxsplit=1)[1])) # propagate reachability reachable_from_mod = reachable[sent_idx][mod_tok] for tok, reachable_for_tok_set in reachable[sent_idx].items(): if head_tok in reachable_for_tok_set: reachable_for_tok_set |= reachable_from_mod action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \ else self.action_stack.get_output(sent_idx) stack_emb = self.pempty_stack_emb if self.stack.get_len(sent_idx) == 0 \ else self.stack.get_output(sent_idx) buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \ else self.buffer.get_output(sent_idx) # # compute composed representation comp_rep = torch.cat([head_rep, mod_rep, action_emb, buffer_emb, stack_emb, ratio_factor]) comp_rep = torch.tanh(self.p_comp(comp_rep)) if action.startswith("NODE:"): self.buffer.pop(sent_idx) self.buffer.push(sent_idx, input=comp_rep, extra={'token': mod_tok}) elif action.startswith("LEFT-EDGE"): self.stack.pop(sent_idx) self.stack.push(sent_idx, input=comp_rep, extra={'token': head_tok}) # RIGHT-EDGE else: stack_0_rep = self.stack.get_stack(sent_idx)[-1]['stack_rnn_input'] self.stack.pop(sent_idx) self.stack.pop(sent_idx) self.stack.push(sent_idx, input=comp_rep, extra={'token': head_tok}) self.stack.push(sent_idx, input=stack_0_rep, extra={'token': mod_tok}) head_count[sent_idx][mod_tok] += 1 # Execute the action to update the parser state elif action == "REDUCE-0": self.stack.pop(sent_idx) elif action == "REDUCE-1": stack_0 = self.stack.pop(sent_idx) self.stack.pop(sent_idx) self.stack.push(sent_idx, input=stack_0['stack_rnn_input'], extra={'token': stack_0['token']}) elif action == "SHIFT": buffer = self.buffer.pop(sent_idx) self.stack.push(sent_idx, input=buffer['stack_rnn_input'], extra={'token': buffer['token']}) s0 = self.stack.get_stack(sent_idx)[-1]['token'] if s0 not in generated_order[sent_idx]: num_of_generated_node[sent_idx] = len(generated_order[sent_idx]) generated_order[sent_idx][s0] = num_of_generated_node[sent_idx] head_count[sent_idx][s0] = swap_count[sent_idx][s0] = 0 reachable[sent_idx][s0] = {s0} elif action == "SWAP": stack_penult = self.stack.pop_penult(sent_idx) self.buffer.push(sent_idx, input=stack_penult['stack_rnn_input'], extra={'token': stack_penult['token']}) swap_count[sent_idx][stack_penult['token']] += 1 elif action == "FINISH": action_tag_for_terminate[sent_idx] = True ratio_factor_losses[sent_idx] = ratio_factor action_sequence_length[sent_idx] += 1 def fix_unconnected_egraph(self, edge_list, reachable, root_id, generated_order): """ Detect cycles in edge_list, select arbitrary element, attach to root's dependent. Postprocessing to avoid validation errors (validate.py --level 2): L2 Enhanced unconnected-egraph Since REDUCE-{0,1} has a precondition requiring a head, the only way to get an unconnect graph (which means there are nodes unreachable from the root) is if there are cycles. """ # Find the dependent of root since it will be the head of all orphans pred = None for mod_tok, head_tok, label in edge_list: if head_tok == root_id: pred = mod_tok break orphans = set(generated_order) - reachable[root_id] # All nodes not reachable from the root while orphans: tok = max(orphans, key=lambda x: len(reachable[x])) # Start from nodes with the most descendents: "roots" if pred is None: # print(f"No predicate (dependent of root) found: {edge_list}", file=sys.stderr) edge_list.append((tok, root_id, "root")) pred = tok else: edge_list.append((tok, pred, "orphan")) orphans -= reachable[tok] # All cycle members were taken care of by doing this # Returns an expression of the loss for the sequence of actions. # (that is, the oracle_actions if present or the predicted sequence otherwise) def forward(self, words: Dict[str, torch.LongTensor], metadata: List[Dict[str, Any]], gold_actions: Dict[str, torch.LongTensor] = None, ) -> Dict[str, torch.LongTensor]: batch_size = len(metadata) sent_len = [len(d['words']) for d in metadata] #oracle_actions = None if gold_actions is not None: oracle_actions = deepcopy([d['gold_actions'] for d in metadata]) embedded_text_input = self.text_field_embedder(words) embedded_text_input = self._input_dropout(embedded_text_input) if self.training: try: ret_train = self._greedy_decode(batch_size=batch_size, sent_len=sent_len, embedded_text_input=embedded_text_input, oracle_actions=oracle_actions) except IndexError: raise IndexError(f"{[d['words'] for d in metadata]}") _loss = ret_train['loss'] output_dict = {'loss': _loss} return output_dict else: #reset if self.num_validation_instances and (not self._total_validation_instances or self._total_validation_instances == self.num_validation_instances): self._total_validation_instances = batch_size else: self._total_validation_instances += batch_size #print(f'{self._total_validation_instances}/{self.num_validation_instances}') training_mode = self.training self.eval() with torch.no_grad(): ret_eval = self._greedy_decode(batch_size=batch_size, sent_len=sent_len, embedded_text_input=embedded_text_input) self.train(training_mode) edge_list = ret_eval['edge_list'] null_node = ret_eval['null_node'] _loss = ret_eval['loss'] # prediction-mode output_dict = { 'edge_list': edge_list, 'null_node': null_node, 'loss': _loss } for k in ["id", "form", "lemma", "upostag", "xpostag", "feats", "head", "deprel", "misc"]: output_dict[k] = [[token_metadata[k] for token_metadata in sentence_metadata['annotation']] for sentence_metadata in metadata] output_dict["multiwords"] = [sentence_metadata['multiwords'] for sentence_metadata in metadata] output_dict["sent_id"] = [sentence_metadata['sent_id'] for sentence_metadata in metadata] output_dict["text"] = [sentence_metadata['text'] for sentence_metadata in metadata] # validation mode # compute the accuracy when gold actions exist if gold_actions is not None: predicted_graphs = [] gold_graphs_conllu = [] for sent_idx in range(batch_size): predicted_graphs.append(eud_trans_outputs_into_conllu({ k:output_dict[k][sent_idx] for k in ["id", "form", "lemma", "upostag", "xpostag", "feats", "head", "deprel", "misc", "edge_list", "null_node", "multiwords", "text", "sent_id"] }, self.output_null_nodes)) gold_annotation = [{key: output_dict[key][sent_idx] for key in ("sent_id", "text") if output_dict[key][sent_idx]}] for annotation in metadata[sent_idx]['annotation']: gold_annotation.append(annotation) gold_graphs_conllu.append(annotation_to_conllu(gold_annotation, self.output_null_nodes)) predicted_graphs_conllu = [line for lines in predicted_graphs for line in lines] gold_graphs_conllu = [line for lines in gold_graphs_conllu for line in lines] self._xud_score(predicted_graphs_conllu, gold_graphs_conllu) return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._xud_score is not None and not self.training: if self.num_validation_instances and self._total_validation_instances == self.num_validation_instances: all_metrics.update(self._xud_score.get_metric(reset=reset)) elif not self.num_validation_instances: all_metrics.update(self._xud_score.get_metric(reset=True)) return all_metrics
class TransitionParserAmr(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, word_dim: int, hidden_dim: int, action_dim: int, entity_dim: int, rel_dim: int, num_layers: int, recurrent_dropout_probability: float = 0.0, layer_dropout_probability: float = 0.0, same_dropout_mask_per_instance: bool = True, input_dropout: float = 0.0, dropout: float = 0.0, pos_tag_embedding: Embedding = None, action_text_field_embedder: TextFieldEmbedder = None, initializer: InitializerApplicator = InitializerApplicator(), regularizer: Optional[RegularizerApplicator] = None, eval_on_training: bool = True, sep_act_type_para: bool = False) -> None: super(TransitionParserAmr, self).__init__(vocab, regularizer) self._smatch_scorer = SmatchScorer() self._mces_scorer = MCESScore(output_type='gscprf', cores=0, trace=0) self.eval_on_training = eval_on_training self.sep_act_type_para = sep_act_type_para self._NEWNODE_TYPE_MAX = 20 self._NODE_LEN_RATIO_MAX = 20.0 self._ACTION_TYPE = [ 'SHIFT', 'CONFIRM', 'REDUCE', 'MERGE', 'ENTITY', 'NEWNODE', 'DROP', 'CACHE', 'LEFT', 'RIGHT' ] self._ACTION_TYPE_IDX = { a: i for (i, a) in enumerate(self._ACTION_TYPE) } self.text_field_embedder = text_field_embedder self.pos_tag_embedding = pos_tag_embedding self.action_text_field_embedder = action_text_field_embedder node_dim = word_dim if pos_tag_embedding: node_dim += pos_tag_embedding.output_dim self.num_actions = vocab.get_vocab_size('actions') self.num_newnodes = vocab.get_vocab_size('newnodes') self.num_relations = vocab.get_vocab_size('relations') self.num_entities = vocab.get_vocab_size('entities') self.action_embedding = Embedding(num_embeddings=self.num_actions, embedding_dim=action_dim) self.newnode_embedding = Embedding(num_embeddings=self.num_newnodes, embedding_dim=node_dim) self.rel_embedding = Embedding(num_embeddings=self.num_relations, embedding_dim=rel_dim) self.entity_embedding = Embedding(num_embeddings=self.num_entities, embedding_dim=entity_dim) # merge (stack, buffer, action_stack, deque) self.merge = FeedForward(input_dim=hidden_dim * 4, num_layers=1, hidden_dims=hidden_dim, activations=Activation.by_name('relu')(), dropout=dropout) # merge (parent, rel, child) -> parent self.merge_parent = FeedForward( input_dim=hidden_dim * 2 + rel_dim, num_layers=1, hidden_dims=node_dim, activations=Activation.by_name('relu')(), dropout=dropout) # merge (parent, rel, child) -> child self.merge_child = FeedForward( input_dim=hidden_dim * 2 + rel_dim, num_layers=1, hidden_dims=node_dim, activations=Activation.by_name('relu')(), dropout=dropout) # merge (A, B) -> AB self.merge_token = FeedForward( input_dim=hidden_dim * 2, num_layers=1, hidden_dims=node_dim, activations=Activation.by_name('relu')(), dropout=dropout) # merge (AB, entity_label) -> X self.merge_entity = FeedForward( input_dim=hidden_dim + entity_dim, num_layers=1, hidden_dims=node_dim, activations=Activation.by_name('relu')(), dropout=dropout) # Q / A value scorer if sep_act_type_para: self.action_type_scorer = torch.nn.Linear(hidden_dim, len(self._ACTION_TYPE)) action_cnt = { action_: len([ self.vocab.get_token_index(a, namespace='actions') for a in self.vocab.get_token_to_index_vocabulary('actions').keys() if a.startswith(action_) ]) for action_ in self._ACTION_TYPE } self.scorers = [] for a in self._ACTION_TYPE: m = torch.nn.Linear(hidden_dim, action_cnt[a]) self.__setattr__(f'scorer_{a}', m) self.scorers.append(m) else: self.scorer = torch.nn.Linear(hidden_dim, self.num_actions) # X -> confirm (X) self.confirm_layer = FeedForward( input_dim=hidden_dim, num_layers=1, hidden_dims=node_dim, activations=Activation.by_name('relu')(), dropout=dropout) self.pempty_buffer_emb = torch.nn.Parameter(torch.randn(word_dim)) self.proot_stack_emb = torch.nn.Parameter(torch.randn(word_dim)) self.proot_action_emb = torch.nn.Parameter(torch.randn(action_dim)) self.proot_deque_emb = torch.nn.Parameter(torch.randn(word_dim)) self._input_dropout = Dropout(input_dropout) self.buffer = StackRnn( input_size=word_dim, hidden_size=hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.stack = StackRnn( input_size=word_dim, hidden_size=hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.deque = StackRnn( input_size=word_dim, hidden_size=hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) self.action_stack = StackRnn( input_size=action_dim, hidden_size=hidden_dim, num_layers=num_layers, recurrent_dropout_probability=recurrent_dropout_probability, layer_dropout_probability=layer_dropout_probability, same_dropout_mask_per_instance=same_dropout_mask_per_instance) initializer(self) def _greedy_decode( self, batch_size: int, sent_len: List[int], embedded_text_input: torch.Tensor, metadata: List[Dict[str, Any]], oracle_actions: Optional[List[List[int]]] = None ) -> Dict[str, Any]: self.buffer.reset_stack(batch_size) self.stack.reset_stack(batch_size) self.deque.reset_stack(batch_size) self.action_stack.reset_stack(batch_size) # We will keep track of all the losses we accumulate during parsing. # If some decision is unambiguous because it's the only thing valid given # the parser state, we will not model it. We only model what is ambiguous. losses = [[] for _ in range(batch_size)] node_labels = [[] for _ in range(batch_size)] node_types = [[] for _ in range(batch_size)] existing_edges = [{} for _ in range(batch_size)] id_cnt = [sent_len[i] + 1 for i in range(batch_size)] action_strs = [[] for _ in range(batch_size)] origin_tokens = [metadata[i]['tokens'] for i in range(batch_size)] for sent_idx in range(batch_size): node_labels[sent_idx].append('@@ROOT@@') node_types[sent_idx].append('ROOT') for i in range(sent_len[sent_idx]): node_labels[sent_idx].append(origin_tokens[sent_idx][i]) node_types[sent_idx].append('TokenNode') # push the tokens onto the buffer (tokens is in reverse order) for sent_idx in range(batch_size): self.buffer.push(sent_idx, input=self.pempty_buffer_emb, extra={ 'token': 0, 'type': -1 }) for token_idx in range(max(sent_len)): for sent_idx in range(batch_size): if sent_len[sent_idx] > token_idx: self.buffer.push( sent_idx, input=embedded_text_input[sent_idx][sent_len[sent_idx] - 1 - token_idx], extra={ 'token': sent_len[sent_idx] - token_idx, 'type': 0 }) # init stack using proot_emb, considering batch for sent_idx in range(batch_size): self.stack.push(sent_idx, input=self.proot_stack_emb, extra={ 'token': 0, 'type': -1 }) # init deque using proot_emb, considering batch for sent_idx in range(batch_size): self.deque.push(sent_idx, input=self.proot_deque_emb, extra={ 'token': 0, 'type': -1 }) action_id = { action_: [ self.vocab.get_token_index(a, namespace='actions') for a in self.vocab.get_token_to_index_vocabulary('actions').keys() if a.startswith(action_) ] for action_ in self._ACTION_TYPE } action_idx_to_param_idx = {} for a in self._ACTION_TYPE: for (i, x) in enumerate(action_id[a]): action_idx_to_param_idx[x] = i origin_token_to_confirm_action = {} for a in action_id['CONFIRM']: t = self.vocab.get_token_from_index( a, namespace='actions').split('@@:@@')[1] if t not in origin_token_to_confirm_action: origin_token_to_confirm_action[t] = [] origin_token_to_confirm_action[t].append(a) # init stack using proot_emb, considering batch for sent_idx in range(batch_size): self.action_stack.push(sent_idx, input=self.proot_action_emb, extra={ 'token': 0, 'type': -1 }) # compute probability of each of the actions and choose an action # either from the oracle or if there is no oracle, based on the model trans_not_fin = True while trans_not_fin: trans_not_fin = False for sent_idx in range(batch_size): try: valid_action_types = set() valid_actions = [] # given the buffer and stack, conclude the valid action list if self.buffer.get_len(sent_idx) > 1 \ and self.buffer.get_stack(sent_idx)[-1]['type'] > 1: valid_actions += action_id['SHIFT'] valid_action_types.add(self._ACTION_TYPE_IDX['SHIFT']) if self.buffer.get_len(sent_idx) > 1 \ and self.buffer.get_stack(sent_idx)[-1]['type'] < 2: tk = node_labels[sent_idx][self.buffer.get_stack( sent_idx)[-1]['token']] if self.buffer.get_stack(sent_idx)[-1]['type'] == 1: tk = tk.replace('@@_@@', '_') if tk in origin_token_to_confirm_action: valid_actions += origin_token_to_confirm_action[tk] valid_action_types.add( self._ACTION_TYPE_IDX['CONFIRM']) if self.buffer.get_len(sent_idx) > 2 \ and self.buffer.get_stack(sent_idx)[-1]['type'] < 2 \ and self.buffer.get_stack(sent_idx)[-2]['type'] == 0: valid_actions += action_id['MERGE'] valid_action_types.add(self._ACTION_TYPE_IDX['MERGE']) if self.buffer.get_len(sent_idx) > 1 \ and self.buffer.get_stack(sent_idx)[-1]['type'] < 2: valid_actions += action_id['ENTITY'] valid_action_types.add(self._ACTION_TYPE_IDX['ENTITY']) if self.stack.get_len(sent_idx) > 1 \ and self.stack.get_stack(sent_idx)[-1]['type'] > 1: valid_actions += action_id['REDUCE'] valid_action_types.add(self._ACTION_TYPE_IDX['REDUCE']) if self.buffer.get_len(sent_idx) > 1 \ and self.buffer.get_stack(sent_idx)[-1]['type'] == 0: valid_actions += action_id['DROP'] valid_action_types.add(self._ACTION_TYPE_IDX['DROP']) if self.buffer.get_len(sent_idx) > 1 \ and self.stack.get_len(sent_idx) > 1: valid_actions += action_id['CACHE'] valid_action_types.add(self._ACTION_TYPE_IDX['CACHE']) if self.buffer.get_len(sent_idx) > 1 \ and self.buffer.get_stack(sent_idx)[-1]['type'] > 1 \ and self.buffer.get_stack(sent_idx)[-1]['type'] < self._NEWNODE_TYPE_MAX \ and id_cnt[sent_idx] / sent_len[sent_idx] < self._NODE_LEN_RATIO_MAX: valid_actions += action_id['NEWNODE'] valid_action_types.add( self._ACTION_TYPE_IDX['NEWNODE']) if self.stack.get_len(sent_idx) > 1 \ and self.stack.get_stack(sent_idx)[-1]['type'] > 1 \ and self.buffer.get_len(sent_idx) > 0 \ and self.buffer.get_stack(sent_idx)[-1]['type'] > 1: u = self.stack.get_stack(sent_idx)[-1]['token'] v = self.buffer.get_stack(sent_idx)[-1]['token'] for aid in ['LEFT', 'RIGHT']: u, v = v, u for a in action_id[aid]: rel = self.vocab.get_token_from_index( a, namespace='actions').split('@@:@@')[1] if (u not in existing_edges[sent_idx] or (rel, v) not in existing_edges[sent_idx][u]) \ and \ (True): # allowing different arcs with same tag coming from one node? if u in existing_edges[sent_idx]: v_cnt = 0 for eg in existing_edges[sent_idx][u]: if eg[1] == v: v_cnt += 1 if v_cnt >= 3: continue valid_actions.append(a) valid_action_types.add( self._ACTION_TYPE_IDX[aid]) if self.stack.get_len(sent_idx) > 1 \ and self.stack.get_stack(sent_idx)[-1]['type'] > 1 \ and self.buffer.get_len(sent_idx) == 1 \ and self.buffer.get_stack(sent_idx)[-1]['type'] == -1: v = self.stack.get_stack(sent_idx)[-1]['token'] u = self.buffer.get_stack(sent_idx)[-1]['token'] rel = '_ROOT_' if (u not in existing_edges[sent_idx] or (rel, v) not in existing_edges[sent_idx][u]): valid_actions.append( self.vocab.get_token_index( 'LEFT@@:@@_ROOT_', namespace='actions')) valid_action_types.add( self._ACTION_TYPE_IDX['LEFT']) valid_action_types = list(valid_action_types) assert ((len(valid_actions) == 0) == ( self.stack.get_len(sent_idx) == 1 and self.buffer.get_len(sent_idx) == 1)) if len(valid_actions) == 0: continue trans_not_fin = True if oracle_actions is not None: valid_action_types = list(range(len( self._ACTION_TYPE))) if self.sep_act_type_para: action_type = valid_action_types[0] action_type_log_probs = None h = None if len(valid_action_types) > 1: stack_emb = self.stack.get_output(sent_idx) buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \ else self.buffer.get_output(sent_idx) action_emb = self.action_stack.get_output(sent_idx) deque_emb = self.deque.get_output(sent_idx) p_t = torch.cat( [buffer_emb, stack_emb, action_emb, deque_emb]) h = self.merge(p_t) logits = self.action_type_scorer(h)[torch.tensor( valid_action_types, dtype=torch.long, device=h.device)] valid_action_type_tbl = { a: i for i, a in enumerate(valid_action_types) } action_type_log_probs = torch.log_softmax(logits, dim=0) action_type_idx = torch.max( action_type_log_probs, 0)[1].item() action_type = valid_action_types[action_type_idx] if oracle_actions is not None: action_type = -1 for a in range(len(self._ACTION_TYPE)): if oracle_actions[sent_idx][0] in action_id[ self._ACTION_TYPE[a]]: action_type = a assert (action_type >= 0 and action_type < len(self._ACTION_TYPE)) if action_type_log_probs is not None: losses[sent_idx].append(action_type_log_probs[ valid_action_type_tbl[action_type]]) else: losses[sent_idx].append( torch.tensor( 0.0, dtype=self.action_type_scorer.weight.dtype, device=self.action_type_scorer.weight. device)) valid_actions = [ x for x in valid_actions if x in action_id[self._ACTION_TYPE[action_type]] ] log_probs = None action = valid_actions[0] if len(valid_actions) > 1: if h is None: stack_emb = self.stack.get_output(sent_idx) buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \ else self.buffer.get_output(sent_idx) action_emb = self.action_stack.get_output( sent_idx) deque_emb = self.deque.get_output(sent_idx) p_t = torch.cat([ buffer_emb, stack_emb, action_emb, deque_emb ]) h = self.merge(p_t) valid_actions_param = [ action_idx_to_param_idx[x] for x in valid_actions ] logits = self.scorers[action_type](h)[torch.tensor( valid_actions_param, dtype=torch.long, device=h.device)] valid_action_tbl = { a: i for i, a in enumerate(valid_actions_param) } log_probs = torch.log_softmax(logits, dim=0) action_idx = torch.max(log_probs, 0)[1].item() action = valid_actions[action_idx] if oracle_actions is not None: action = oracle_actions[sent_idx].pop(0) if log_probs is not None: # append the action-specific loss losses[sent_idx].append(log_probs[valid_action_tbl[ action_idx_to_param_idx[action]]]) else: losses[sent_idx].append( torch.tensor( 0.0, dtype=self.action_type_scorer.weight.dtype, device=self.action_type_scorer.weight. device)) else: log_probs = None action = valid_actions[0] if len(valid_actions) > 1: stack_emb = self.stack.get_output(sent_idx) buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \ else self.buffer.get_output(sent_idx) action_emb = self.action_stack.get_output(sent_idx) deque_emb = self.deque.get_output(sent_idx) p_t = torch.cat( [buffer_emb, stack_emb, action_emb, deque_emb]) h = self.merge(p_t) logits = self.scorer(h)[torch.tensor( valid_actions, dtype=torch.long, device=h.device)] valid_action_tbl = { a: i for i, a in enumerate(valid_actions) } log_probs = torch.log_softmax(logits, dim=0) action_idx = torch.max(log_probs, 0)[1].item() action = valid_actions[action_idx] if oracle_actions is not None: action = oracle_actions[sent_idx].pop(0) if log_probs is not None: # append the action-specific loss losses[sent_idx].append( log_probs[valid_action_tbl[action]]) else: losses[sent_idx].append( torch.tensor(0.0, dtype=self.scorer.weight.dtype, device=self.scorer.weight.device)) # push action into action_stack self.action_stack.push( sent_idx, input=self.action_embedding( torch.tensor(action, device=embedded_text_input.device)), extra={ 'token': self.vocab.get_token_from_index( action, namespace='actions') }) action_str = self.vocab.get_token_from_index( action, namespace='actions') action_strs[sent_idx].append(action_str) if action in action_id['SHIFT']: while self.deque.get_len(sent_idx) > 1: e = self.deque.pop(sent_idx) self.stack.push( sent_idx, input=e['stack_rnn_input'], extra={ k: e[k] for k in e.keys() if not k.startswith('stack_rnn_') }) e = self.buffer.pop(sent_idx) self.stack.push(sent_idx, input=e['stack_rnn_input'], extra={ k: e[k] for k in e.keys() if not k.startswith('stack_rnn_') }) elif action in action_id['CONFIRM']: e = self.buffer.pop(sent_idx) concept = self.confirm_layer(e['stack_rnn_output']) self.buffer.push(sent_idx, input=concept, extra={ 'token': id_cnt[sent_idx], 'type': 2 }) node_labels[sent_idx].append( action_str.split('@@:@@')[-1]) node_types[sent_idx].append('ConceptNode') id_cnt[sent_idx] += 1 elif action in action_id['REDUCE']: self.stack.pop(sent_idx) elif action in action_id['MERGE']: token_a = self.buffer.pop(sent_idx) token_b = self.buffer.pop(sent_idx) token_ab = self.merge_token( torch.cat([ token_a['stack_rnn_output'], token_b['stack_rnn_output'] ])) token_id = token_a['token'] if token_a['type'] == 0: node_labels[sent_idx].append( node_labels[sent_idx][token_a['token']]) node_types[sent_idx].append('EntityNode') token_id = id_cnt[sent_idx] id_cnt[sent_idx] += 1 node_labels[sent_idx][ token_id] += '@@_@@' + node_labels[sent_idx][ token_b['token']] self.buffer.push(sent_idx, input=token_ab, extra={ 'token': token_id, 'type': 1 }) elif action in action_id['ENTITY']: entity_name = action_str.split('@@:@@')[-1] buffer_top_id = self.buffer.get_stack( sent_idx)[-1]['token'] entity = self.entity_embedding( torch.tensor(self.vocab.get_token_index( action_str.split('@@:@@')[1], namespace='entities'), device=embedded_text_input.device)) e = self.buffer.pop(sent_idx) entity = self.merge_entity( torch.cat([e['stack_rnn_output'], entity])) self.buffer.push(sent_idx, input=entity, extra={ 'token': id_cnt[sent_idx], 'type': 2 }) node_labels[sent_idx].append( action_str.split('@@:@@')[-1]) node_types[sent_idx].append('ConceptNode') id_cnt[sent_idx] += 1 if entity_name == 'date-entity': datestr = ' '.join( map( unquote, node_labels[sent_idx] [buffer_top_id].split('@@_@@'))) entry, flags = parse_date(datestr) date_node_id = id_cnt[sent_idx] - 1 for relation, flag in zip(['year', 'month', 'day'], flags): if flag: value = getattr(entry, relation) node_labels[sent_idx].append(str(value)) node_types[sent_idx].append( 'AttributeNode') id_cnt[sent_idx] += 1 if date_node_id not in existing_edges[ sent_idx]: existing_edges[sent_idx][ date_node_id] = [] existing_edges[sent_idx][ date_node_id].append( (relation, id_cnt[sent_idx] - 1)) if entity_name not in [ 'date-entity', 'capitalism', '2', 'contrast-01', '1', 'compare-01' ]: if entity_name != 'name': node_labels[sent_idx].append('name') node_types[sent_idx].append('ConceptNode') id_cnt[sent_idx] += 1 if id_cnt[sent_idx] - 2 not in existing_edges[ sent_idx]: existing_edges[sent_idx][id_cnt[sent_idx] - 2] = [] existing_edges[sent_idx][id_cnt[sent_idx] - 2].append(( 'name', id_cnt[sent_idx] - 1)) name_node_id = id_cnt[sent_idx] - 1 for (i, opi) in enumerate( map( unquote, node_labels[sent_idx] [buffer_top_id].split('@@_@@'))): node_labels[sent_idx].append(opi) node_types[sent_idx].append('AttributeNode') id_cnt[sent_idx] += 1 if name_node_id not in existing_edges[ sent_idx]: existing_edges[sent_idx][name_node_id] = [] existing_edges[sent_idx][name_node_id].append( (f"op{i + 1}", id_cnt[sent_idx] - 1)) elif action in action_id['NEWNODE']: node = self.newnode_embedding( torch.tensor(self.vocab.get_token_index( action_str.split('@@:@@')[1], namespace='newnodes'), device=embedded_text_input.device)) self.buffer.get_stack( sent_idx)[-1]['type'] += self._NEWNODE_TYPE_MAX self.buffer.push( sent_idx, input=node, extra={ 'token': id_cnt[sent_idx], 'type': self.buffer.get_stack(sent_idx)[-1]['type'] - self._NEWNODE_TYPE_MAX + 1 }) node_labels[sent_idx].append( action_str.split('@@:@@')[-1]) node_types[sent_idx].append('ConceptNode') id_cnt[sent_idx] += 1 elif action in action_id['DROP']: self.buffer.pop(sent_idx) elif action in action_id['CACHE']: e = self.stack.pop(sent_idx) self.deque.push(sent_idx, input=e['stack_rnn_input'], extra={ k: e[k] for k in e.keys() if not k.startswith('stack_rnn_') }) elif action in action_id['LEFT']: parent = self.buffer.pop(sent_idx) child = self.stack.pop(sent_idx) rel = self.rel_embedding( torch.tensor(self.vocab.get_token_index( action_str.split('@@:@@')[1], namespace='relations'), device=embedded_text_input.device)) parent_rep = self.merge_parent( torch.cat([ parent['stack_rnn_output'], rel, child['stack_rnn_output'] ])) child_rep = self.merge_child( torch.cat([ parent['stack_rnn_output'], rel, child['stack_rnn_output'] ])) self.buffer.push(sent_idx, input=parent_rep, extra={ k: parent[k] for k in parent.keys() if not k.startswith('stack_rnn_') }) self.stack.push(sent_idx, input=child_rep, extra={ k: child[k] for k in child.keys() if not k.startswith('stack_rnn_') }) if parent['token'] not in existing_edges[sent_idx]: existing_edges[sent_idx][parent['token']] = [] existing_edges[sent_idx][parent['token']].append( (action_str.split('@@:@@')[1], child['token'])) elif action in action_id['RIGHT']: child = self.buffer.pop(sent_idx) parent = self.stack.pop(sent_idx) rel = self.rel_embedding( torch.tensor(self.vocab.get_token_index( action_str.split('@@:@@')[1], namespace='relations'), device=embedded_text_input.device)) parent_rep = self.merge_parent( torch.cat([ parent['stack_rnn_output'], rel, child['stack_rnn_output'] ])) child_rep = self.merge_child( torch.cat([ parent['stack_rnn_output'], rel, child['stack_rnn_output'] ])) self.buffer.push(sent_idx, input=child_rep, extra={ k: child[k] for k in child.keys() if not k.startswith('stack_rnn_') }) self.stack.push(sent_idx, input=parent_rep, extra={ k: parent[k] for k in parent.keys() if not k.startswith('stack_rnn_') }) if parent['token'] not in existing_edges[sent_idx]: existing_edges[sent_idx][parent['token']] = [] existing_edges[sent_idx][parent['token']].append( (action_str.split('@@:@@')[1], child['token'])) else: raise ValueError(f'Illegal action: \"{action}\"') except BaseException as e: print(e) print(metadata[sent_idx]['id']) raise e _loss = -torch.sum( torch.stack([torch.sum(torch.stack(cur_loss)) for cur_loss in losses if len(cur_loss) > 0])) / \ sum([len(cur_loss) for cur_loss in losses]) ret = { 'loss': _loss, 'losses': losses, } if oracle_actions is None: ret['existing_edges'] = existing_edges ret['node_labels'] = node_labels ret['node_types'] = node_types ret['id_cnt'] = id_cnt ret['action_strs'] = action_strs return ret # Returns an expression of the loss for the sequence of actions. # (that is, the oracle_actions if present or the predicted sequence otherwise) def forward( self, tokens: Dict[str, torch.LongTensor], metadata: List[Dict[str, Any]], gold_actions: Dict[str, torch.LongTensor] = None, gold_newnodes: Dict[str, torch.LongTensor] = None, gold_entities: Dict[str, torch.LongTensor] = None, gold_relations: Dict[str, torch.LongTensor] = None, lemmas: Dict[str, torch.LongTensor] = None, pos_tags: torch.LongTensor = None, arc_tags: torch.LongTensor = None, ) -> Dict[str, torch.LongTensor]: batch_size = len(metadata) sent_len = [len(d['tokens']) for d in metadata] oracle_actions = None if gold_actions is not None: oracle_actions = [d['gold_actions'] for d in metadata] oracle_actions = [[ self.vocab.get_token_index(s, namespace='actions') for s in l ] for l in oracle_actions] embedded_text_input = self.text_field_embedder(tokens) if pos_tags is not None and self.pos_tag_embedding is not None: embedded_pos_tags = self.pos_tag_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) embedded_text_input = self._input_dropout(embedded_text_input) if self.training: ret_train = self._greedy_decode( batch_size=batch_size, sent_len=sent_len, embedded_text_input=embedded_text_input, metadata=metadata, oracle_actions=oracle_actions) if not self.eval_on_training: return {'loss': ret_train['loss']} training_mode = self.training self.eval() with torch.no_grad(): ret_eval = self._greedy_decode( batch_size=batch_size, sent_len=sent_len, metadata=metadata, embedded_text_input=embedded_text_input) self.train(training_mode) action_strs = ret_eval['action_strs'] existing_edges = ret_eval['existing_edges'] id_cnt = ret_eval['id_cnt'] node_labels = ret_eval['node_labels'] node_types = ret_eval['node_types'] _loss = ret_train['loss'] if self.training else ret_eval['loss'] amr_dicts = [] for sent_idx in range(batch_size): amr_dict = extract_mrp_dict( existing_edges=existing_edges[sent_idx], sent_len=sent_len[sent_idx], id_cnt=id_cnt[sent_idx], node_labels=node_labels[sent_idx], node_types=node_types[sent_idx], metadata=metadata[sent_idx]) amr_dicts.append(amr_dict) golds_mrp = [d['mrp'] for d in metadata] if 'mrp' in metadata[0] else [] golds_amr = [d['amr'] for d in metadata] if 'amr' in metadata[0] else [] if golds_mrp: self._mces_scorer(predictions=amr_dicts, golds=golds_mrp) elif golds_amr: amr_strs = [ mrp_dict_to_amr_str(amr_dict, all_nodes=True) for amr_dict in amr_dicts ] for amr_str, gold_amr in zip(amr_strs, golds_amr): self._smatch_scorer.update(amr_str, gold_amr) output_dict = { 'tokens': [d['tokens'] for d in metadata], 'loss': _loss, 'sent_len': sent_len, 'metadata': metadata } if 'gold_actions' in metadata[0]: output_dict['oracle_actions'] = [ d['gold_actions'] for d in metadata ] if not self.training: output_dict['existing_edges'] = existing_edges output_dict['node_labels'] = node_labels output_dict['node_types'] = node_types output_dict['id_cnt'] = id_cnt return output_dict def get_metrics(self, reset: bool = False) -> Dict[str, float]: if self._smatch_scorer.total_gold_num > 0: p, r, f1 = self._smatch_scorer.get_prf() if reset: self._smatch_scorer.reset() return { 'P': p, 'R': r, 'F1': f1, } else: return self._mces_scorer.get_metric(reset=reset)