Пример #1
0
class TransitionParser(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 hidden_dim: int,
                 action_dim: int,
                 ratio_dim: int,
                 num_layers: int,
                 word_dim: int = 0,
                 text_field_embedder: TextFieldEmbedder = None,
                 mces_metric: Metric = None,
                 recurrent_dropout_probability: float = 0.0,
                 layer_dropout_probability: float = 0.0,
                 same_dropout_mask_per_instance: bool = True,
                 input_dropout: float = 0.0,
                 lemma_text_field_embedder: TextFieldEmbedder = None,
                 pos_tag_embedding: Embedding = None,
                 deprel_embedding: Embedding = None,
                 bios_embedding: Embedding = None,
                 lexcat_embedding: Embedding = None,
                 ss_embedding: Embedding = None,
                 ss2_embedding: Embedding = None,
                 action_embedding: Embedding = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None
                 ) -> None:

        super(TransitionParser, self).__init__(vocab, regularizer)

        self._primary_labeled_correct = 0
        self._primary_unlabeled_correct = 0
        self._primary_total_edges_predicted = 0
        self._primary_total_edges_actual = 0
        self._primary_exact_labeled_correct = 0
        self._primary_exact_unlabeled_correct = 0

        self._remote_labeled_correct = 0
        self._remote_unlabeled_correct = 0
        self._remote_total_edges_predicted = 0
        self._remote_total_edges_actual = 0
        self._remote_exact_labeled_correct = 0
        self._remote_exact_unlabeled_correct = 0

        self._total_sentences = 0

        self.num_actions = vocab.get_vocab_size('actions')
        self.text_field_embedder = text_field_embedder
        self.lemma_text_field_embedder = lemma_text_field_embedder
        self._pos_tag_embedding = pos_tag_embedding
        self._deprel_embedding = deprel_embedding
        self._bios_embedding = bios_embedding
        self._lexcat_embedding = lexcat_embedding
        self._ss_embedding = ss_embedding
        self._ss2_embedding = ss2_embedding
        self._mces_metric = mces_metric

        node_dim = 0
        if self.text_field_embedder:
            node_dim += word_dim
        for embedding in pos_tag_embedding, deprel_embedding, bios_embedding, lexcat_embedding, ss_embedding, \
                         ss2_embedding:
            if embedding:
                node_dim += embedding.output_dim
        self.node_dim = node_dim
        self.word_dim = word_dim
        self.hidden_dim = hidden_dim
        self.ratio_dim = ratio_dim
        self.action_dim = action_dim

        self.action_embedding = action_embedding

        if action_embedding is None:
            self.action_embedding = Embedding(num_embeddings=self.num_actions,
                                              embedding_dim=self.action_dim,
                                              trainable=False)

        # syntactic composition
        self.p_comp = torch.nn.Linear(self.hidden_dim * 5 + self.ratio_dim, node_dim)
        # parser state to hidden
        self.p_s2h = torch.nn.Linear(self.hidden_dim * 3 + self.ratio_dim, self.hidden_dim)
        # hidden to action
        self.p_act = torch.nn.Linear(self.hidden_dim + self.ratio_dim, self.num_actions)

        self.update_concept_node = torch.nn.Linear(self.hidden_dim + self.ratio_dim, node_dim)

        self.pempty_buffer_emb = torch.nn.Parameter(torch.randn(self.hidden_dim))
        self.proot_stack_emb = torch.nn.Parameter(torch.randn(node_dim))
        self.pempty_action_emb = torch.nn.Parameter(torch.randn(self.hidden_dim))
        self.pempty_stack_emb = torch.nn.Parameter(torch.randn(self.hidden_dim))

        self._input_dropout = Dropout(input_dropout)

        self.buffer = StackRnn(input_size=node_dim,
                               hidden_size=self.hidden_dim,
                               num_layers=num_layers,
                               recurrent_dropout_probability=recurrent_dropout_probability,
                               layer_dropout_probability=layer_dropout_probability,
                               same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        self.stack = StackRnn(input_size=node_dim,
                              hidden_size=self.hidden_dim,
                              num_layers=num_layers,
                              recurrent_dropout_probability=recurrent_dropout_probability,
                              layer_dropout_probability=layer_dropout_probability,
                              same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        self.action_stack = StackRnn(input_size=self.action_dim,
                                     hidden_size=self.hidden_dim,
                                     num_layers=num_layers,
                                     recurrent_dropout_probability=recurrent_dropout_probability,
                                     layer_dropout_probability=layer_dropout_probability,
                                     same_dropout_mask_per_instance=same_dropout_mask_per_instance)
        initializer(self)

    def expand_arc_with_descendants(self, arc_indices, total_node_num, len_tokens):

        ###step 1: construct graph
        graph = {}
        for token_idx in range(total_node_num):
            graph[token_idx] = {"in_degree": 0, "head_list": []}

        # construct graph given directed_arc_indices and arc_tags
        # key: id_of_point
        # value: a list of tuples -> [(id_of_head1, label),(id_of_head2, label),...]
        for arc in arc_indices:
            graph[arc[0]]["head_list"].append((arc[1], arc[2]))
            graph[arc[1]]["in_degree"] += 1

        # i:head_point j:child_point›
        top_down_graph = [[] for i in range(total_node_num)]  # N real point, 1 root point, concept_node_expect_root
        step2_top_down_graph = [[] for i in range(total_node_num)]

        topological_stack = []
        for i in range(total_node_num):
            if graph[i]["in_degree"] == 0:
                topological_stack.append(i)
            for head_tuple_of_point_i in graph[i]["head_list"]:
                head = head_tuple_of_point_i[0]
                top_down_graph[head].append(i)
                step2_top_down_graph[head].append(i)

        ###step 2: construct topological order
        topological_order = []
        # step2_top_down_graph=top_down_graph[:]
        while len(topological_stack) != 0:
            stack_0_node = topological_stack.pop()
            topological_order.append(stack_0_node)

            for i in graph:
                if stack_0_node in step2_top_down_graph[i]:
                    step2_top_down_graph[i].remove(stack_0_node)
                    graph[i]["in_degree"] -= 1
                    if graph[i]["in_degree"] == 0 and \
                            i not in topological_stack and \
                            i not in topological_order:
                        topological_stack.append(i)

        ###step 3: expand arc indices using the nodes indices ordered by topological way
        expand_node_dict = {}
        for node_idx in range(total_node_num):
            expand_node_dict[node_idx] = top_down_graph[node_idx][:]

        for node_idx in topological_order:
            if len(expand_node_dict[node_idx]) == 0:  # no childs
                continue
            expand_childs = expand_node_dict[node_idx][:]
            for child in expand_node_dict[node_idx]:
                expand_childs += expand_node_dict[child]
            expand_node_dict[node_idx] = expand_childs

        ###step 4: delete duplicate and concept node
        token_filter = set(list(i for i in range(len_tokens)))
        for node_idx in expand_node_dict:
            expand_node_dict[node_idx] = set(expand_node_dict[node_idx]) & token_filter

        ###step 5: expand arc indices using expand_node_dict
        arc_descendants = []
        for arc_info in arc_indices:
            arc_info_0 = arc_info[0] if arc_info[0] < len_tokens else \
                '-'.join([str(i) for i in sorted(expand_node_dict[arc_info[0]])])

            arc_info_1 = arc_info[1] if arc_info[1] < len_tokens else \
                '-'.join([str(i) for i in sorted(expand_node_dict[arc_info[1]])])

            arc_descendants.append((arc_info_0, arc_info_1, arc_info[2]))

        return arc_descendants

    def _greedy_decode(self,
                       batch_size: int,
                       sent_len: List[int],
                       embedded_text_input: torch.Tensor,
                       oracle_actions: Optional[List[List[int]]] = None,
                       ) -> Dict[str, Any]:

        self.buffer.reset_stack(batch_size)
        self.stack.reset_stack(batch_size)
        self.action_stack.reset_stack(batch_size)

        # We will keep track of all the losses we accumulate during parsing.
        # If some decision is unambiguous because it's the only thing valid given
        # the parser state, we will not model it. We only model what is ambiguous.
        losses = [[] for _ in range(batch_size)]
        ratio_factor_losses = [[] for _ in range(batch_size)]
        edge_list = [[] for _ in range(batch_size)]
        total_node_num = [0 for _ in range(batch_size)]
        action_list = [[] for _ in range(batch_size)]
        ret_top_node = [[] for _ in range(batch_size)]
        ret_concept_node = [[] for _ in range(batch_size)]
        # push the tokens onto the buffer (tokens is in reverse order)
        for token_idx in range(max(sent_len)):
            for sent_idx in range(batch_size):
                if sent_len[sent_idx] > token_idx:
                    self.buffer.push(sent_idx,
                                     input=embedded_text_input[sent_idx][sent_len[sent_idx] - 1 - token_idx],
                                     extra={'token': sent_len[sent_idx] - token_idx - 1})

        # init stack using proot_emb, considering batch
        for sent_idx in range(batch_size):
            self.stack.push(sent_idx,
                            input=self.proot_stack_emb,
                            extra={'token': sent_len[sent_idx]})
            ret_top_node[sent_idx] = [sent_len[sent_idx]]

        action_id = {
            action_: [self.vocab.get_token_index(a, namespace='actions') for a in
                      self.vocab.get_token_to_index_vocabulary('actions').keys() if a.startswith(action_)]
            for action_ in
            ["SHIFT", "REDUCE", "NODE", "REMOTE-NODE", "LEFT-EDGE", "RIGHT-EDGE", "LEFT-REMOTE", "RIGHT-REMOTE", "SWAP",
             "FINISH"]
        }

        # compute probability of each of the actions and choose an action
        # either from the oracle or if there is no oracle, based on the model
        trans_not_fin = True

        action_tag_for_terminate = [False] * batch_size
        action_sequence_length = [0] * batch_size

        concept_node = {}
        for sent_idx in range(batch_size):
            concept_node[sent_idx] = [sent_len[sent_idx]]

        while trans_not_fin:
            trans_not_fin = False
            for sent_idx in range(batch_size):
                if (len(concept_node[sent_idx]) > 50 * sent_len[sent_idx] or action_sequence_length[sent_idx] > 50 *
                    sent_len[sent_idx]) and oracle_actions is None:
                    continue
                total_node_num[sent_idx] = sent_len[sent_idx] + len(concept_node[sent_idx])
                # if self.buffer.get_len(sent_idx) != 0:
                if action_tag_for_terminate[sent_idx] == False:
                    trans_not_fin = True
                    valid_actions = []
                    # given the buffer and stack, conclude the valid action list
                    if self.buffer.get_len(sent_idx) == 0:
                        valid_actions += action_id['FINISH']

                    if self.buffer.get_len(sent_idx) > 0:
                        valid_actions += action_id['SHIFT']

                    try:
                        if self.stack.get_len(sent_idx) > 0:
                            valid_actions += action_id['REDUCE']
                            valid_actions += action_id['NODE']
                            valid_actions += action_id['REMOTE-NODE']
                    except:
                        pass
                    if self.stack.get_len(sent_idx) > 1:
                        valid_actions += action_id['SWAP']
                        valid_actions += action_id['LEFT-EDGE']
                        valid_actions += action_id['RIGHT-EDGE']
                        valid_actions += action_id['LEFT-REMOTE']
                        valid_actions += action_id['RIGHT-REMOTE']

                    log_probs = None
                    action = valid_actions[0]
                    ratio_factor = torch.tensor([len(concept_node[sent_idx]) / (1.0 * sent_len[sent_idx])],
                                                device=self.pempty_action_emb.device)

                    if len(valid_actions) > 1:
                        stack_emb = self.stack.get_output(sent_idx)
                        buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \
                            else self.buffer.get_output(sent_idx)

                        action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \
                            else self.action_stack.get_output(sent_idx)

                        p_t = torch.cat([buffer_emb, stack_emb, action_emb])
                        p_t = torch.cat([p_t, ratio_factor])

                        h = torch.tanh(self.p_s2h(p_t))
                        h = torch.cat([h, ratio_factor])

                        logits = self.p_act(h)[torch.tensor(valid_actions, dtype=torch.long, device=h.device)]
                        valid_action_tbl = {a: i for i, a in enumerate(valid_actions)}
                        log_probs = torch.log_softmax(logits, dim=0)

                        action_idx = torch.max(log_probs, 0)[1].item()
                        action = valid_actions[action_idx]

                    if oracle_actions is not None:
                        action = oracle_actions[sent_idx].pop(0)

                    # push action into action_stack
                    self.action_stack.push(sent_idx,
                                           input=self.action_embedding(
                                               torch.tensor(action, device=embedded_text_input.device)),
                                           extra={
                                               'token': self.vocab.get_token_from_index(action, namespace='actions')})
                    action_list[sent_idx].append(self.vocab.get_token_from_index(action, namespace='actions'))

                    if log_probs is not None:
                        # append the action-specific loss
                        loss = log_probs[valid_action_tbl[action]]
                        if not torch.isnan(loss):
                            losses[sent_idx].append(loss)

                    # generate concept node, recursive way
                    if action in action_id["NODE"] + action_id["REMOTE-NODE"]:
                        concept_node_token = len(concept_node[sent_idx]) + sent_len[sent_idx]
                        concept_node[sent_idx].append(concept_node_token)

                        stack_emb = self.stack.get_output(sent_idx)

                        stack_emb = torch.cat([stack_emb, ratio_factor])
                        comp_rep = torch.tanh(self.update_concept_node(stack_emb))

                        node_input = comp_rep

                        self.buffer.push(sent_idx,
                                         input=node_input,
                                         extra={'token': concept_node_token})

                        total_node_num[sent_idx] = sent_len[sent_idx] + len(concept_node[sent_idx])

                    if action in action_id["NODE"] + action_id["REMOTE-NODE"] + action_id["LEFT-EDGE"] \
                            + action_id["RIGHT-EDGE"] + action_id["LEFT-REMOTE"] + action_id["RIGHT-REMOTE"]:

                        if action in action_id["NODE"] + action_id["REMOTE-NODE"]:
                            head = self.buffer.get_stack(sent_idx)[-1]
                            modifier = self.stack.get_stack(sent_idx)[-1]

                        elif action in action_id["LEFT-EDGE"] + action_id["LEFT-REMOTE"]:
                            head = self.stack.get_stack(sent_idx)[-1]
                            modifier = self.stack.get_stack(sent_idx)[-2]
                        else:
                            head = self.stack.get_stack(sent_idx)[-2]
                            modifier = self.stack.get_stack(sent_idx)[-1]

                        (head_rep, head_tok) = (head['stack_rnn_output'], head['token'])
                        (mod_rep, mod_tok) = (modifier['stack_rnn_output'], modifier['token'])

                        if oracle_actions is None:
                            edge_list[sent_idx].append((mod_tok,
                                                        head_tok,
                                                        self.vocab.get_token_from_index(action, namespace='actions')
                                                        .split(':', maxsplit=1)[1]))

                        # # compute composed representation

                        action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \
                            else self.action_stack.get_output(sent_idx)

                        stack_emb = self.pempty_stack_emb if self.stack.get_len(sent_idx) == 0 \
                            else self.stack.get_output(sent_idx)

                        buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \
                            else self.buffer.get_output(sent_idx)

                        comp_rep = torch.cat([head_rep, mod_rep, action_emb, buffer_emb, stack_emb, ratio_factor])
                        comp_rep = torch.tanh(self.p_comp(comp_rep))

                        if action in action_id["NODE"] + action_id["REMOTE-NODE"]:
                            self.buffer.pop(sent_idx)
                            self.buffer.push(sent_idx,
                                             input=comp_rep,
                                             extra={'token': head_tok})


                        elif action in action_id["LEFT-EDGE"] + action_id["LEFT-REMOTE"]:
                            self.stack.pop(sent_idx)
                            self.stack.push(sent_idx,
                                            input=comp_rep,
                                            extra={'token': head_tok})

                        # RIGHT-EDGE or RIGHT-REMOTE
                        else:
                            stack_0_rep = self.stack.get_stack(sent_idx)[-1]['stack_rnn_input']
                            self.stack.pop(sent_idx)
                            self.stack.pop(sent_idx)

                            self.stack.push(sent_idx,
                                            input=comp_rep,
                                            extra={'token': head_tok})

                            self.stack.push(sent_idx,
                                            input=stack_0_rep,
                                            extra={'token': mod_tok})

                    # ["SHIFT", "REDUCE", "NODE", "REMOTE-NODE", "LEFT-EDGE", "RIGHT-EDGE", "LEFT-REMOTE", "RIGHT-REMOTE", "SWAP","FINISH"]
                    # Execute the action to update the parser state
                    if action in action_id["REDUCE"]:
                        self.stack.pop(sent_idx)

                    elif action in action_id["SHIFT"]:
                        buffer_top = self.buffer.pop(sent_idx)
                        self.stack.push(sent_idx,
                                        input=buffer_top['stack_rnn_input'],
                                        extra={'token': buffer_top['token']})

                    elif action in action_id["SWAP"]:
                        stack_penult = self.stack.pop_penult(sent_idx)
                        self.buffer.push(sent_idx,
                                         input=stack_penult['stack_rnn_input'],
                                         extra={'token': stack_penult['token']})

                    elif action in action_id["FINISH"]:
                        action_tag_for_terminate[sent_idx] = True
                        ratio_factor_losses[sent_idx] = ratio_factor

                    action_sequence_length[sent_idx] += 1

        # categorical cross-entropy
        non_empty_losses = [torch.sum(torch.stack(cur_loss)) for cur_loss in losses if len(cur_loss) > 0]
        _loss_CCE = -torch.sum(torch.stack(non_empty_losses)) / sum([len(cur_loss) for cur_loss in losses]) \
            if non_empty_losses else torch.Tensor([float('NaN')])

        _loss = _loss_CCE

        ret = {
            'loss': _loss,
            'losses': losses,
        }

        # extract concept node list in batchmode
        for sent_idx in range(batch_size):
            ret_concept_node[sent_idx] = concept_node[sent_idx]

        ret["total_node_num"] = total_node_num

        if oracle_actions is None:
            ret['edge_list'] = edge_list
        ret['action_sequence'] = action_list
        ret['top_node'] = ret_top_node
        ret["concept_node"] = ret_concept_node

        return ret

    # Returns an expression of the loss for the sequence of actions.
    # (that is, the oracle_actions if present or the predicted sequence otherwise)
    def forward(self,
                tokens: Dict[str, torch.LongTensor],
                metadata: List[Dict[str, Any]],
                gold_actions: Dict[str, torch.LongTensor] = None,
                lemmas: Dict[str, torch.LongTensor] = None,
                pos_tags: torch.LongTensor = None,
                arc_tags: torch.LongTensor = None,
                deprels: torch.LongTensor = None,
                bios: torch.LongTensor = None,
                lexcat: torch.LongTensor = None,
                ss: torch.LongTensor = None,
                ss2: torch.LongTensor = None,
                ) -> Dict[str, torch.LongTensor]:

        batch_size = len(metadata)
        sent_len = [len(d['tokens']) for d in metadata]
        meta_info = [d['meta_info'] for d in metadata]

        oracle_actions = None
        if gold_actions is not None:
            oracle_actions = [d['gold_actions'] for d in metadata]
            oracle_actions = [[self.vocab.get_token_index(s, namespace='actions') for s in l] for l in oracle_actions]

        embeds = [embedder(field) for field, embedder in ((tokens, self.text_field_embedder),
                                                          (pos_tags, self._pos_tag_embedding),
                                                          (deprels, self._deprel_embedding),
                                                          (bios, self._bios_embedding),
                                                          (lexcat, self._lexcat_embedding),
                                                          (ss, self._ss_embedding),
                                                          (ss2, self._ss2_embedding))
                  if field is not None and embedder is not None]
        embedded_text_input = torch.cat(embeds, -1) if len(embeds) > 1 else embeds[0]
        embedded_text_input = self._input_dropout(embedded_text_input)

        if self.training:
            ret_train = self._greedy_decode(batch_size=batch_size,
                                            sent_len=sent_len,
                                            embedded_text_input=embedded_text_input,
                                            oracle_actions=oracle_actions)

            _loss = ret_train['loss']
            output_dict = {'loss': _loss}
            return output_dict

        training_mode = self.training
        self.eval()
        with torch.no_grad():
            ret_eval = self._greedy_decode(batch_size=batch_size,
                                           sent_len=sent_len,
                                           embedded_text_input=embedded_text_input)
        self.train(training_mode)

        edge_list = ret_eval['edge_list']
        top_node_list = ret_eval['top_node']
        _loss = ret_eval['loss']

        # prediction-mode

        output_dict = {
            'tokens': [d['tokens'] for d in metadata],
            'loss': _loss,
            'edge_list': edge_list,
            'meta_info': meta_info,
            'top_node': top_node_list,
            'concept_node': ret_eval['concept_node'],
            'tokens_range': [d['tokens_range'] for d in metadata]
        }

        # prediction-mode
        if gold_actions is not None:
            gold_mrps = [x["gold_mrps"] for x in metadata]
            predicted_mrps = []

            for sent_idx in range(batch_size):
                if len(output_dict['edge_list'][sent_idx]) <= 5 * len(output_dict['tokens'][sent_idx]):
                    predicted_mrps.append(ucca_trans_outputs_into_mrp({
                        'tokens': output_dict['tokens'][sent_idx],
                        'edge_list': output_dict['edge_list'][sent_idx],
                        'meta_info': output_dict['meta_info'][sent_idx],
                        'top_node': output_dict['top_node'][sent_idx],
                        'concept_node': output_dict['concept_node'][sent_idx],
                        'tokens_range': output_dict['tokens_range'][sent_idx],
                    }))

            self._mces_metric(predicted_mrps, gold_mrps)

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._mces_metric is not None and not self.training:
            all_metrics.update(self._mces_metric.get_metric(reset=reset))
        return all_metrics
Пример #2
0
class TransitionParser(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 word_dim: int,
                 hidden_dim: int,
                 action_dim: int,
                 num_layers: int,
                 mces_metric: Metric = None,
                 recurrent_dropout_probability: float = 0.0,
                 layer_dropout_probability: float = 0.0,
                 same_dropout_mask_per_instance: bool = True,
                 input_dropout: float = 0.0,
                 lemma_text_field_embedder: TextFieldEmbedder = None,
                 pos_tag_embedding: Embedding = None,
                 action_embedding: Embedding = None,
                 frame_tagger_encoder: Seq2SeqEncoder = None,
                 pos_tagger_encoder: Seq2SeqEncoder = None,
                 node_label_tagger_encoder: Seq2SeqEncoder = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:

        super(TransitionParser, self).__init__(vocab, regularizer)

        self._unlabeled_correct = 0
        self._labeled_correct = 0
        self._total_edges_predicted = 0
        self._total_edges_actual = 0
        self._exact_unlabeled_correct = 0
        self._exact_labeled_correct = 0
        self._total_sentences = 0

        self.num_actions = vocab.get_vocab_size('actions')
        self.text_field_embedder = text_field_embedder
        self.pos_tag_embedding = pos_tag_embedding
        self._mces_metric = mces_metric

        self.action_embedding = action_embedding

        if action_embedding is None:
            self.action_embedding = Embedding(num_embeddings=self.num_actions,
                                              embedding_dim=action_dim,
                                              trainable=False)
        # syntactic composition
        self.p_comp = torch.nn.Linear(hidden_dim * 4, word_dim)
        # parser state to hidden
        self.p_s2h = torch.nn.Linear(hidden_dim * 4, hidden_dim)
        # hidden to action

        self.p_act = torch.nn.Linear(hidden_dim, self.num_actions)
        self.pempty_buffer_emb = torch.nn.Parameter(torch.randn(hidden_dim))
        self.proot_stack_emb = torch.nn.Parameter(torch.randn(word_dim))
        self.pempty_action_emb = torch.nn.Parameter(torch.randn(hidden_dim))
        self.pempty_deque_emb = torch.nn.Parameter(torch.randn(hidden_dim))

        self._input_dropout = Dropout(input_dropout)

        self.frame_tagger_encoder = frame_tagger_encoder
        self.pos_tagger_encoder = pos_tagger_encoder
        self.node_label_tagger_encoder = node_label_tagger_encoder

        self.buffer = StackRnn(
            input_size=word_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            recurrent_dropout_probability=recurrent_dropout_probability,
            layer_dropout_probability=layer_dropout_probability,
            same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        self.stack = StackRnn(
            input_size=word_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            recurrent_dropout_probability=recurrent_dropout_probability,
            layer_dropout_probability=layer_dropout_probability,
            same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        self.deque = StackRnn(
            input_size=word_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            recurrent_dropout_probability=recurrent_dropout_probability,
            layer_dropout_probability=layer_dropout_probability,
            same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        self.action_stack = StackRnn(
            input_size=action_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            recurrent_dropout_probability=recurrent_dropout_probability,
            layer_dropout_probability=layer_dropout_probability,
            same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        self.frame_tagger = SimpleTagger(
            vocab=vocab,
            text_field_embedder=text_field_embedder,
            encoder=self.frame_tagger_encoder,
            label_namespace='frame')

        self.pos_tagger = SimpleTagger(vocab=vocab,
                                       text_field_embedder=text_field_embedder,
                                       encoder=self.pos_tagger_encoder,
                                       label_namespace='pos_tag')

        self.node_label_tagger = SimpleTagger(
            vocab=vocab,
            text_field_embedder=text_field_embedder,
            encoder=self.node_label_tagger_encoder,
            label_namespace='node_label')

        initializer(self)

    def _greedy_decode(
            self,
            batch_size: int,
            sent_len: List[int],
            embedded_text_input: torch.Tensor,
            oracle_actions: Optional[List[List[int]]] = None
    ) -> Dict[str, Any]:

        self.buffer.reset_stack(batch_size)
        self.stack.reset_stack(batch_size)
        self.deque.reset_stack(batch_size)
        self.action_stack.reset_stack(batch_size)

        # We will keep track of all the losses we accumulate during parsing.
        # If some decision is unambiguous because it's the only thing valid given
        # the parser state, we will not model it. We only model what is ambiguous.
        losses = [[] for _ in range(batch_size)]
        edge_list = [[] for _ in range(batch_size)]

        # push the tokens onto the buffer (tokens is in reverse order)
        for token_idx in range(max(sent_len)):
            for sent_idx in range(batch_size):
                if sent_len[sent_idx] > token_idx:
                    self.buffer.push(
                        sent_idx,
                        input=embedded_text_input[sent_idx][sent_len[sent_idx]
                                                            - 1 - token_idx],
                        extra={'token': sent_len[sent_idx] - token_idx})

        # init stack using proot_emb, considering batch
        for sent_idx in range(batch_size):
            self.stack.push(sent_idx,
                            input=self.proot_stack_emb,
                            extra={'token': 0})

        action_id = {
            action_: [
                self.vocab.get_token_index(a, namespace='actions') for a in
                self.vocab.get_token_to_index_vocabulary('actions').keys()
                if a.startswith(action_)
            ]
            for action_ in ["LR", "LP", "RS", "RP", "NS", "NR", "NP"]
        }

        # compute probability of each of the actions and choose an action
        # either from the oracle or if there is no oracle, based on the model
        trans_not_fin = True
        while trans_not_fin:
            trans_not_fin = False
            for sent_idx in range(batch_size):
                if self.buffer.get_len(sent_idx) != 0:
                    trans_not_fin = True
                    valid_actions = []

                    # given the buffer and stack, conclude the valid action list
                    if self.stack.get_len(
                            sent_idx) > 1 and self.buffer.get_len(
                                sent_idx) > 0:
                        valid_actions += action_id['LR']
                        valid_actions += action_id['LP']
                        valid_actions += action_id['RP']

                    if self.buffer.get_len(sent_idx) > 0:
                        valid_actions += action_id['NS']
                        valid_actions += action_id['RS']  # ROOT,NULL

                    if self.stack.get_len(sent_idx) > 1:
                        valid_actions += action_id['NR']
                        valid_actions += action_id['NP']

                    log_probs = None
                    action = valid_actions[0]
                    if len(valid_actions) > 1:
                        stack_emb = self.stack.get_output(sent_idx)
                        buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \
                            else self.buffer.get_output(sent_idx)

                        action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \
                            else self.action_stack.get_output(sent_idx)

                        deque_emb = self.pempty_deque_emb if self.deque.get_len(sent_idx) == 0 \
                            else self.deque.get_output(sent_idx)

                        p_t = torch.cat(
                            [buffer_emb, stack_emb, action_emb, deque_emb])
                        h = torch.tanh(self.p_s2h(p_t))
                        logits = self.p_act(h)[torch.tensor(valid_actions,
                                                            dtype=torch.long,
                                                            device=h.device)]
                        valid_action_tbl = {
                            a: i
                            for i, a in enumerate(valid_actions)
                        }
                        log_probs = torch.log_softmax(logits, dim=0)

                        action_idx = torch.max(log_probs, 0)[1].item()
                        action = valid_actions[action_idx]

                    if oracle_actions is not None:
                        action = oracle_actions[sent_idx].pop(0)

                    # push action into action_stack
                    self.action_stack.push(
                        sent_idx,
                        input=self.action_embedding(
                            torch.tensor(action,
                                         device=embedded_text_input.device)),
                        extra={
                            'token':
                            self.vocab.get_token_from_index(
                                action, namespace='actions')
                        })

                    if log_probs is not None:
                        # append the action-specific loss
                        losses[sent_idx].append(
                            log_probs[valid_action_tbl[action]])

                    if action in action_id["LR"] or action in action_id["LP"] or \
                            action in action_id["RS"] or action in action_id["RP"]:
                        # figure out which is the head and which is the modifier
                        if action in action_id["RS"] or action in action_id[
                                "RP"]:
                            head = self.stack.get_stack(sent_idx)[-1]
                            modifier = self.buffer.get_stack(sent_idx)[-1]
                        else:
                            head = self.buffer.get_stack(sent_idx)[-1]
                            modifier = self.stack.get_stack(sent_idx)[-1]

                        (head_rep, head_tok) = (head['stack_rnn_output'],
                                                head['token'])
                        (mod_rep, mod_tok) = (modifier['stack_rnn_output'],
                                              modifier['token'])

                        if oracle_actions is None:
                            edge_list[sent_idx].append(
                                (mod_tok, head_tok,
                                 self.vocab.get_token_from_index(
                                     action, namespace='actions').split(
                                         ':', maxsplit=1)[1]))

                    # Execute the action to update the parser state
                    # reduce
                    if action in action_id["LR"] or action in action_id["NR"]:
                        self.stack.pop(sent_idx)
                    # pass
                    elif action in action_id["LP"] or action in action_id[
                            "NP"] or action in action_id["RP"]:
                        stack_top = self.stack.pop(sent_idx)
                        self.deque.push(sent_idx,
                                        input=stack_top['stack_rnn_input'],
                                        extra={'token': stack_top['token']})
                    # shift
                    elif action in action_id["RS"] or action in action_id["NS"]:
                        while self.deque.get_len(sent_idx) > 0:
                            deque_top = self.deque.pop(sent_idx)
                            self.stack.push(
                                sent_idx,
                                input=deque_top['stack_rnn_input'],
                                extra={'token': deque_top['token']})

                        buffer_top = self.buffer.pop(sent_idx)
                        self.stack.push(sent_idx,
                                        input=buffer_top['stack_rnn_input'],
                                        extra={'token': buffer_top['token']})

        _loss = -torch.sum(
            torch.stack([torch.sum(torch.stack(cur_loss)) for cur_loss in losses if len(cur_loss) > 0])) / \
                sum([len(cur_loss) for cur_loss in losses])
        ret = {
            'loss': _loss,
            'losses': losses,
        }
        if oracle_actions is None:
            ret['edge_list'] = edge_list
        return ret

    # Returns an expression of the loss for the sequence of actions.
    # (that is, the oracle_actions if present or the predicted sequence otherwise)
    def forward(
        self,
        tokens: Dict[str, torch.LongTensor],
        metadata: List[Dict[str, Any]],
        gold_actions: Dict[str, torch.LongTensor] = None,
        lemmas: Dict[str, torch.LongTensor] = None,
        mrp_pos_tags: torch.LongTensor = None,
        frame: torch.LongTensor = None,
        pos_tag: torch.LongTensor = None,
        node_label: torch.LongTensor = None,
        arc_tags: torch.LongTensor = None,
    ) -> Dict[str, torch.LongTensor]:

        batch_size = len(metadata)
        sent_len = [len(d['tokens']) for d in metadata]
        meta_info = [d['meta_info'] for d in metadata]

        oracle_actions = None
        if gold_actions is not None:
            oracle_actions = [d['gold_actions'] for d in metadata]
            oracle_actions = [[
                self.vocab.get_token_index(s, namespace='actions') for s in l
            ] for l in oracle_actions]

        embedded_text_input = self.text_field_embedder(tokens)
        embedded_text_input = self._input_dropout(embedded_text_input)

        if self.training:
            ret_train = self._greedy_decode(
                batch_size=batch_size,
                sent_len=sent_len,
                embedded_text_input=embedded_text_input,
                oracle_actions=oracle_actions)

            frame_tagger_train_outputs = self.frame_tagger(tokens=tokens,
                                                           tags=frame)
            frame_tagger_train_outputs = self.frame_tagger.decode(
                frame_tagger_train_outputs)

            pos_tagger_train_outputs = self.pos_tagger(tokens=tokens,
                                                       tags=pos_tag)
            pos_tagger_train_outputs = self.pos_tagger.decode(
                pos_tagger_train_outputs)

            node_label_tagger_train_outputs = self.node_label_tagger(
                tokens=tokens, tags=node_label)
            node_label_tagger_train_outputs = self.node_label_tagger.decode(
                node_label_tagger_train_outputs)

            _loss = ret_train['loss'] + \
                    frame_tagger_train_outputs['loss'] + \
                    pos_tagger_train_outputs['loss'] + \
                    node_label_tagger_train_outputs['loss']
            output_dict = {'loss': _loss}
            return output_dict

        training_mode = self.training
        self.eval()
        with torch.no_grad():
            ret_eval = self._greedy_decode(
                batch_size=batch_size,
                sent_len=sent_len,
                embedded_text_input=embedded_text_input)
            if frame is not None:
                frame_tagger_eval_outputs = self.frame_tagger(tokens,
                                                              tags=frame)
            else:
                frame_tagger_eval_outputs = self.frame_tagger(tokens)
            frame_tagger_eval_outputs = self.frame_tagger.decode(
                frame_tagger_eval_outputs)

            if pos_tag is not None:
                pos_tagger_eval_outputs = self.pos_tagger(tokens, tags=pos_tag)
            else:
                pos_tagger_eval_outputs = self.pos_tagger(tokens)
            pos_tagger_eval_outputs = self.pos_tagger.decode(
                pos_tagger_eval_outputs)

            if node_label is not None:
                node_label_tagger_eval_outputs = self.node_label_tagger(
                    tokens, tags=node_label)
            else:
                node_label_tagger_eval_outputs = self.node_label_tagger(tokens)
            node_label_tagger_eval_outputs = self.node_label_tagger.decode(
                node_label_tagger_eval_outputs)

        self.train(training_mode)

        edge_list = ret_eval['edge_list']

        if 'loss' in frame_tagger_eval_outputs and 'loss' in pos_tagger_eval_outputs:
            _loss = ret_eval['loss'] + \
                    frame_tagger_eval_outputs['loss'] + \
                    pos_tagger_eval_outputs['loss'] + \
                    node_label_tagger_eval_outputs['loss']
        else:
            _loss = ret_eval['loss']

        # prediction-mode
        output_dict = {
            'tokens': [d['tokens'] for d in metadata],
            'edge_list': edge_list,
            'meta_info': meta_info,
            'tokens_range': [d['tokens_range'] for d in metadata],
            'frame': frame_tagger_eval_outputs["tags"],
            'pos_tag': pos_tagger_eval_outputs["tags"],
            'node_label': node_label_tagger_eval_outputs["tags"],
            'loss': _loss
        }

        # prediction-mode
        # compute the mrp accuracy when gold actions exists
        if gold_actions is not None:
            gold_mrps = [x["gold_mrps"] for x in metadata]
            predicted_mrps = []

            for sent_idx in range(batch_size):
                if len(output_dict['edge_list'][sent_idx]) <= 5 * len(
                        output_dict['tokens'][sent_idx]):
                    predicted_mrps.append(
                        sdp_trans_outputs_into_mrp({
                            'tokens':
                            output_dict['tokens'][sent_idx],
                            'edge_list':
                            output_dict['edge_list'][sent_idx],
                            'meta_info':
                            output_dict['meta_info'][sent_idx],
                            'frame':
                            output_dict['frame'][sent_idx],
                            'pos_tag':
                            output_dict['pos_tag'][sent_idx],
                            "node_label":
                            output_dict['node_label'][sent_idx],
                            'tokens_range':
                            output_dict['tokens_range'][sent_idx],
                        }))

            self._mces_metric(predicted_mrps, gold_mrps)

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._mces_metric is not None and not self.training:
            all_metrics.update(self._mces_metric.get_metric(reset=reset))
        return all_metrics
class TransitionParser(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 word_dim: int,
                 hidden_dim: int,
                 action_dim: int,
                 concept_label_dim: int,
                 num_layers: int,
                 mces_metric: Metric = None,
                 recurrent_dropout_probability: float = 0.0,
                 layer_dropout_probability: float = 0.0,
                 same_dropout_mask_per_instance: bool = True,
                 input_dropout: float = 0.0,
                 lemma_text_field_embedder: TextFieldEmbedder = None,
                 pos_tag_embedding: Embedding = None,
                 action_embedding: Embedding = None,
                 concept_label_embedding: Embedding = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:

        super(TransitionParser, self).__init__(vocab, regularizer)

        self._unlabeled_correct = 0
        self._labeled_correct = 0
        self._total_edges_predicted = 0
        self._total_edges_actual = 0
        self._exact_unlabeled_correct = 0
        self._exact_labeled_correct = 0
        self._total_sentences = 0

        self.num_actions = vocab.get_vocab_size('actions')
        self.num_concept_label = vocab.get_vocab_size('concept_label')
        self.text_field_embedder = text_field_embedder
        self.lemma_text_field_embedder = lemma_text_field_embedder
        self._pos_tag_embedding = pos_tag_embedding
        self._mces_metric = mces_metric

        self.word_dim = word_dim
        self.hidden_dim = hidden_dim
        self.action_dim = action_dim
        self.concept_label_dim = concept_label_dim
        self.action_embedding = action_embedding
        self.concept_label_embedding = concept_label_embedding

        if concept_label_embedding is None:
            self.concept_label_embedding = Embedding(
                num_embeddings=self.num_concept_label,
                embedding_dim=self.concept_label_dim,
                trainable=False)
        if action_embedding is None:
            self.action_embedding = Embedding(num_embeddings=self.num_actions,
                                              embedding_dim=self.action_dim,
                                              trainable=False)

        # syntactic composition
        self.p_comp = torch.nn.Linear(self.hidden_dim * 6, self.word_dim)
        # parser state to hidden
        self.p_s2h = torch.nn.Linear(self.hidden_dim * 4, self.hidden_dim)
        # hidden to action
        self.p_act = torch.nn.Linear(self.hidden_dim, self.num_actions)

        self.start_concept_node = torch.nn.Linear(
            self.hidden_dim + self.concept_label_dim, self.word_dim)
        self.end_concept_node = torch.nn.Linear(
            self.hidden_dim * 2 + self.concept_label_dim, self.word_dim)

        self.pempty_buffer_emb = torch.nn.Parameter(
            torch.randn(self.hidden_dim))
        self.proot_stack_emb = torch.nn.Parameter(torch.randn(self.word_dim))
        self.pempty_action_emb = torch.nn.Parameter(
            torch.randn(self.hidden_dim))
        self.pempty_stack_emb = torch.nn.Parameter(torch.randn(
            self.hidden_dim))
        self.pempty_deque_emb = torch.nn.Parameter(torch.randn(
            self.hidden_dim))

        self._input_dropout = Dropout(input_dropout)

        self.buffer = StackRnn(
            input_size=self.word_dim,
            hidden_size=self.hidden_dim,
            num_layers=num_layers,
            recurrent_dropout_probability=recurrent_dropout_probability,
            layer_dropout_probability=layer_dropout_probability,
            same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        self.stack = StackRnn(
            input_size=self.word_dim,
            hidden_size=self.hidden_dim,
            num_layers=num_layers,
            recurrent_dropout_probability=recurrent_dropout_probability,
            layer_dropout_probability=layer_dropout_probability,
            same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        self.deque = StackRnn(
            input_size=self.word_dim,
            hidden_size=self.hidden_dim,
            num_layers=num_layers,
            recurrent_dropout_probability=recurrent_dropout_probability,
            layer_dropout_probability=layer_dropout_probability,
            same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        self.action_stack = StackRnn(
            input_size=self.action_dim,
            hidden_size=self.hidden_dim,
            num_layers=num_layers,
            recurrent_dropout_probability=recurrent_dropout_probability,
            layer_dropout_probability=layer_dropout_probability,
            same_dropout_mask_per_instance=same_dropout_mask_per_instance)
        initializer(self)

    def _greedy_decode(
        self,
        batch_size: int,
        sent_len: List[int],
        embedded_text_input: torch.Tensor,
        oracle_actions: Optional[List[List[int]]] = None,
    ) -> Dict[str, Any]:

        self.buffer.reset_stack(batch_size)
        self.stack.reset_stack(batch_size)
        self.deque.reset_stack(batch_size)
        self.action_stack.reset_stack(batch_size)

        # We will keep track of all the losses we accumulate during parsing.
        # If some decision is unambiguous because it's the only thing valid given
        # the parser state, we will not model it. We only model what is ambiguous.
        losses = [[] for _ in range(batch_size)]
        ratio_factor_losses = [[] for _ in range(batch_size)]
        edge_list = [[] for _ in range(batch_size)]
        total_node_num = [0 for _ in range(batch_size)]
        action_list = [[] for _ in range(batch_size)]
        ret_top_node = [[] for _ in range(batch_size)]
        ret_concept_node = [[] for _ in range(batch_size)]

        broken_indices = set()

        # push the tokens onto the buffer (tokens is in reverse order)
        for token_idx in range(max(sent_len)):
            for sent_idx in range(batch_size):
                if sent_len[sent_idx] > token_idx:
                    if sent_len[
                            sent_idx] - 1 - token_idx < embedded_text_input.shape[
                                1]:
                        self.buffer.push(sent_idx,
                                         input=embedded_text_input[sent_idx]
                                         [sent_len[sent_idx] - 1 - token_idx],
                                         extra={
                                             'token':
                                             sent_len[sent_idx] - token_idx - 1
                                         })
                    else:
                        broken_indices.add(sent_idx)
        if len(broken_indices) > 0:
            print('Broken indices: ' + str(broken_indices))

        # init stack using proot_emb, considering batch
        for sent_idx in range(batch_size):
            self.stack.push(sent_idx,
                            input=self.proot_stack_emb,
                            extra={'token': 'protection_symbol'})

        action_id = {
            action_: [
                self.vocab.get_token_index(a, namespace='actions') for a in
                self.vocab.get_token_to_index_vocabulary('actions').keys()
                if a.startswith(action_)
            ]
            for action_ in [
                "SHIFT", "REDUCE", "LEFT-EDGE", "RIGHT-EDGE", "SELF-EDGE",
                "DROP", "TOP", "PASS", "START", "END", "FINISH"
            ]
        }

        # compute probability of each of the actions and choose an action
        # either from the oracle or if there is no oracle, based on the model
        trans_not_fin = True

        action_tag_for_terminate = [False] * batch_size
        action_sequence_length = [0] * batch_size

        concept_node = {}
        for sent_idx in range(batch_size):
            concept_node[sent_idx] = {}

        while trans_not_fin:
            trans_not_fin = False
            for sent_idx in range(batch_size):

                if (len(concept_node[sent_idx]) > 50 * sent_len[sent_idx]
                        or action_sequence_length[sent_idx] >
                        50 * sent_len[sent_idx]) and oracle_actions is None:
                    continue
                total_node_num[sent_idx] = sent_len[sent_idx] + len(
                    concept_node[sent_idx])
                # if self.buffer.get_len(sent_idx) != 0:

                if not (self.buffer.get_len(sent_idx) == 0
                        and self.stack.get_len(sent_idx) == 1):
                    trans_not_fin = True
                    valid_actions = []
                    # given the buffer and stack, conclude the valid action list
                    if self.buffer.get_len(sent_idx) == 0:
                        valid_actions += action_id['FINISH']

                    if self.buffer.get_len(sent_idx) > 0:
                        valid_actions += action_id['SHIFT']
                        valid_actions += action_id['DROP']
                        valid_actions += action_id['TOP']

                        buffer_token = self.buffer.get_stack(
                            sent_idx)[-1]["token"]
                        if buffer_token < sent_len[sent_idx]:
                            valid_actions += action_id['START']

                    if self.buffer.get_len(
                            sent_idx) > 0 and self.stack.get_len(sent_idx) > 1:
                        valid_actions += action_id['LEFT-EDGE']
                        valid_actions += action_id['RIGHT-EDGE']
                        valid_actions += action_id['SELF-EDGE']

                    if self.stack.get_len(sent_idx) > 1:
                        valid_actions += action_id['REDUCE']
                        valid_actions += action_id['PASS']

                    if self.buffer.get_len(
                            sent_idx) > 0 and self.stack.get_len(sent_idx) > 1:
                        concept_node_token = self.stack.get_stack(
                            sent_idx)[-1]['token']
                        concept_alignment_end_token = self.buffer.get_stack(
                            sent_idx)[-1]["token"]
                        if not (concept_node_token
                                not in concept_node[sent_idx]
                                or concept_alignment_end_token >=
                                sent_len[sent_idx]):
                            valid_actions += action_id['END']

                    log_probs = None
                    action = valid_actions[0]

                    if len(valid_actions) > 1:
                        stack_emb = self.stack.get_output(sent_idx)
                        buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \
                            else self.buffer.get_output(sent_idx)

                        action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \
                            else self.action_stack.get_output(sent_idx)

                        deque_emb = self.pempty_deque_emb if self.deque.get_len(sent_idx) == 0 \
                            else self.deque.get_output(sent_idx)

                        p_t = torch.cat(
                            (buffer_emb, stack_emb, action_emb, deque_emb))
                        h = torch.tanh(self.p_s2h(p_t))
                        logits = self.p_act(h)[torch.tensor(valid_actions,
                                                            dtype=torch.long,
                                                            device=h.device)]
                        valid_action_tbl = {
                            a: i
                            for i, a in enumerate(valid_actions)
                        }
                        log_probs = torch.log_softmax(logits, dim=0)

                        action_idx = torch.max(log_probs, 0)[1].item()
                        action = valid_actions[action_idx]

                    if oracle_actions is not None:
                        action = oracle_actions[sent_idx].pop(0)

                    if log_probs is not None:
                        # append the action-specific loss
                        losses[sent_idx].append(
                            log_probs[valid_action_tbl[action]])

                    # generate concept node, push it into buffer and align it with the second item in buffer
                    if action in action_id["START"]:

                        # get concept label and corresponding embedding
                        concept_node_token = len(
                            concept_node[sent_idx]) + sent_len[sent_idx]
                        stack_emb = self.stack.get_output(sent_idx)
                        concept_node_label_token = self.vocab.get_token_from_index(action, namespace='actions') \
                            .split('#SPLIT_TAG#', maxsplit=1)[1]
                        concept_node_label = self.vocab.get_token_index(
                            concept_node_label_token,
                            namespace='concept_label')
                        concept_node_label_emb = self.concept_label_embedding(
                            torch.tensor(concept_node_label,
                                         device=embedded_text_input.device))

                        # init
                        concept_alignment_begin = self.buffer.get_stack(
                            sent_idx)[-1]["token"]
                        concept_node[sent_idx][concept_node_token] = {"label": concept_node_label_token, \
                                                                      "start": concept_alignment_begin}

                        # insert comp_rep into buffer
                        comp_rep = torch.tanh(
                            self.start_concept_node(
                                torch.cat(
                                    (stack_emb, concept_node_label_emb))))
                        self.buffer.push(sent_idx,
                                         input=comp_rep,
                                         extra={'token': concept_node_token})

                        # update total_node_num for early-stopping
                        total_node_num[sent_idx] = sent_len[sent_idx] + len(
                            concept_node[sent_idx])

                    # predice the span end of the node in S0
                    elif action in action_id["END"]:

                        # get label embedding of concept node
                        concept_node_token = self.stack.get_stack(
                            sent_idx)[-1]['token']
                        concept_alignment_end_token = self.buffer.get_stack(
                            sent_idx)[-1]["token"]

                        # if concept_node_token not in concept_node[sent_idx] or concept_alignment_end_token>=sent_len[sent_idx]:
                        #     continue

                        concept_node_label_token = concept_node[sent_idx][
                            concept_node_token]["label"]
                        concept_node_label = self.vocab.get_token_index(
                            concept_node_label_token,
                            namespace='concept_label')
                        concept_node_label_emb = self.concept_label_embedding(
                            torch.tensor(concept_node_label,
                                         device=embedded_text_input.device))

                        # update concept info via inserting the span end of concept node
                        concept_node[sent_idx][concept_node_token][
                            "end"] = concept_alignment_end_token

                        # update node representation using a)begin compositioned embedding b)end embedding
                        stack_emb = self.stack.get_output(sent_idx)
                        buffer_emb = self.buffer.get_output(sent_idx)

                        comp_rep = torch.tanh(
                            self.end_concept_node(
                                torch.cat((stack_emb, buffer_emb,
                                           concept_node_label_emb))))
                        self.stack.pop(sent_idx)
                        self.stack.push(sent_idx,
                                        input=comp_rep,
                                        extra={'token': concept_node_token})

                    elif action in action_id["LEFT-EDGE"] + action_id[
                            "RIGHT-EDGE"] + action_id["SELF-EDGE"]:

                        if action in action_id["LEFT-EDGE"]:
                            head = self.buffer.get_stack(sent_idx)[-1]
                            modifier = self.stack.get_stack(sent_idx)[-1]
                        elif action in action_id["RIGHT-EDGE"]:
                            head = self.stack.get_stack(sent_idx)[-1]
                            modifier = self.buffer.get_stack(sent_idx)[-1]
                        else:
                            head = self.stack.get_stack(sent_idx)[-1]
                            modifier = self.stack.get_stack(sent_idx)[-1]

                        (head_rep, head_tok) = (head['stack_rnn_output'],
                                                head['token'])
                        (mod_rep, mod_tok) = (modifier['stack_rnn_output'],
                                              modifier['token'])

                        edge_list[sent_idx].append(
                            (mod_tok, head_tok,
                             self.vocab.get_token_from_index(
                                 action,
                                 namespace='actions').split('#SPLIT_TAG#',
                                                            maxsplit=1)[1]))

                        # compute composed representation
                        action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \
                            else self.action_stack.get_output(sent_idx)

                        stack_emb = self.pempty_stack_emb if self.stack.get_len(sent_idx) == 0 \
                            else self.stack.get_output(sent_idx)

                        buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \
                            else self.buffer.get_output(sent_idx)

                        deque_emb = self.pempty_deque_emb if self.deque.get_len(sent_idx) == 0 \
                            else self.deque.get_output(sent_idx)

                        comp_rep = torch.cat(
                            (head_rep, mod_rep, action_emb, buffer_emb,
                             stack_emb, deque_emb))
                        comp_rep = torch.tanh(self.p_comp(comp_rep))

                        if action in action_id["LEFT-EDGE"]:
                            self.buffer.pop(sent_idx)
                            self.buffer.push(sent_idx,
                                             input=comp_rep,
                                             extra={'token': head_tok})

                        elif action in action_id["RIGHT-EDGE"] + action_id[
                                "SELF-EDGE"]:
                            self.stack.pop(sent_idx)
                            self.stack.push(sent_idx,
                                            input=comp_rep,
                                            extra={'token': head_tok})

                    elif action in action_id["REDUCE"]:
                        self.stack.pop(sent_idx)

                    elif action in action_id["TOP"]:
                        ret_top_node[sent_idx] = self.buffer.get_stack(
                            sent_idx)[-1]["token"]

                    elif action in action_id["DROP"]:
                        self.buffer.pop(sent_idx)
                        while self.deque.get_len(sent_idx) > 0:
                            deque_top = self.deque.pop(sent_idx)
                            self.stack.push(
                                sent_idx,
                                input=deque_top['stack_rnn_input'],
                                extra={'token': deque_top['token']})

                    elif action in action_id["PASS"]:
                        stack_top = self.stack.pop(sent_idx)
                        self.deque.push(sent_idx,
                                        input=stack_top['stack_rnn_input'],
                                        extra={'token': stack_top['token']})

                    elif action in action_id["SHIFT"]:
                        while self.deque.get_len(sent_idx) > 0:
                            deque_top = self.deque.pop(sent_idx)
                            self.stack.push(
                                sent_idx,
                                input=deque_top['stack_rnn_input'],
                                extra={'token': deque_top['token']})

                        buffer_top = self.buffer.pop(sent_idx)
                        self.stack.push(sent_idx,
                                        input=buffer_top['stack_rnn_input'],
                                        extra={'token': buffer_top['token']})

                    # push action into action_stack
                    self.action_stack.push(
                        sent_idx,
                        input=self.action_embedding(
                            torch.tensor(action,
                                         device=embedded_text_input.device)),
                        extra={
                            'token':
                            self.vocab.get_token_from_index(
                                action, namespace='actions')
                        })

                    action_list[sent_idx].append(
                        self.vocab.get_token_from_index(action,
                                                        namespace='actions'))

                    action_sequence_length[sent_idx] += 1

        # categorical cross-entropy
        _loss_CCE = -torch.sum(
            torch.stack([torch.sum(torch.stack(cur_loss)) for cur_loss in losses if len(cur_loss) > 0])) / \
                    sum([len(cur_loss) for cur_loss in losses])

        _loss = _loss_CCE

        ret = {
            'loss': _loss,
            'losses': losses,
        }

        # extract concept node list in batchmode
        for sent_idx in range(batch_size):
            ret_concept_node[sent_idx] = concept_node[sent_idx]

        for idx in broken_indices:
            total_node_num[idx] = 0
            edge_list[idx] = []
            action_list[idx] = []
            ret_top_node[idx] = 0
            ret_concept_node[idx] = {}

        ret["total_node_num"] = total_node_num
        ret['edge_list'] = edge_list
        ret['action_sequence'] = action_list
        ret['top_node'] = ret_top_node
        ret["concept_node"] = ret_concept_node

        return ret

    # Returns an expression of the loss for the sequence of actions.
    # (that is, the oracle_actions if present or the predicted sequence otherwise)
    def forward(
        self,
        tokens: Dict[str, torch.LongTensor],
        metadata: List[Dict[str, Any]],
        gold_actions: Dict[str, torch.LongTensor] = None,
        lemmas: Dict[str, torch.LongTensor] = None,
        pos_tags: torch.LongTensor = None,
        arc_tags: torch.LongTensor = None,
        concept_label: torch.LongTensor = None,
    ) -> Dict[str, torch.LongTensor]:

        batch_size = len(metadata)
        sent_len = [len(d['tokens']) for d in metadata]
        meta_tokens = [d['tokens'] for d in metadata]
        meta_info = [d['meta_info'] for d in metadata]

        oracle_actions = None
        if gold_actions is not None:
            oracle_actions = [d['gold_actions'] for d in metadata]
            oracle_actions = [[
                self.vocab.get_token_index(s, namespace='actions') for s in l
            ] for l in oracle_actions]

        embedded_text_input = self.text_field_embedder(tokens)
        embedded_text_input = self._input_dropout(embedded_text_input)

        if self.training:
            ret_train = self._greedy_decode(
                batch_size=batch_size,
                sent_len=sent_len,
                embedded_text_input=embedded_text_input,
                oracle_actions=oracle_actions)

            _loss = ret_train['loss']
            output_dict = {'loss': _loss}
            return output_dict

        training_mode = self.training

        self.eval()
        with torch.no_grad():
            ret_eval = self._greedy_decode(
                batch_size=batch_size,
                sent_len=sent_len,
                embedded_text_input=embedded_text_input)
        self.train(training_mode)

        edge_list = ret_eval['edge_list']
        top_node_list = ret_eval['top_node']
        _loss = ret_eval['loss']

        output_dict = {
            'tokens': [d['tokens'] for d in metadata],
            'loss': _loss,
            'edge_list': edge_list,
            'meta_info': meta_info,
            'top_node': top_node_list,
            'concept_node': ret_eval['concept_node'],
            'tokens_range': [d['tokens_range'] for d in metadata]
        }

        # prediction-mode
        # compute the mrp accuracy when gold actions exists
        if gold_actions is not None:
            gold_mrps = [x["gold_mrps"] for x in metadata]
            predicted_mrps = []

            for sent_idx in range(batch_size):
                predicted_mrps.append(
                    eds_trans_outputs_into_mrp({
                        'tokens':
                        output_dict['tokens'][sent_idx],
                        'edge_list':
                        output_dict['edge_list'][sent_idx],
                        'meta_info':
                        output_dict['meta_info'][sent_idx],
                        'top_node':
                        output_dict['top_node'][sent_idx],
                        'concept_node':
                        output_dict['concept_node'][sent_idx],
                        'tokens_range':
                        output_dict['tokens_range'][sent_idx],
                    }))

            self._mces_metric(predicted_mrps, gold_mrps)

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._mces_metric is not None and not self.training:
            all_metrics.update(self._mces_metric.get_metric(reset=reset))
        return all_metrics
Пример #4
0
class TransitionParser(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 word_dim: int,
                 hidden_dim: int,
                 action_dim: int,
                 ratio_dim: int,
                 num_layers: int,
                 recurrent_dropout_probability: float = 0.0,
                 layer_dropout_probability: float = 0.0,
                 same_dropout_mask_per_instance: bool = True,
                 input_dropout: float = 0.0,
                 output_null_nodes: bool = True,
                 max_heads: int = None,
                 max_swaps_per_node: int = 3,
                 fix_unconnected_egraph: bool = True,
                 validate_every_n_instances: int = None,
                 action_embedding: Embedding = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None
                 ) -> None:

        super(TransitionParser, self).__init__(vocab, regularizer)

        self._total_validation_instances = 0

        self.num_actions = vocab.get_vocab_size('actions')
        self.text_field_embedder = text_field_embedder
        self.output_null_nodes = output_null_nodes
        self.max_heads = max_heads
        self.max_swaps_per_node = max_swaps_per_node
        self._fix_unconnected_egraph = fix_unconnected_egraph
        self.num_validation_instances = validate_every_n_instances
        self._xud_score = XUDScore(collapse=self.output_null_nodes)


        self.word_dim = word_dim
        self.hidden_dim = hidden_dim
        self.ratio_dim = ratio_dim
        self.action_dim = action_dim

        self.action_embedding = action_embedding

        if action_embedding is None:
            self.action_embedding = Embedding(num_embeddings=self.num_actions,
                                              embedding_dim=action_dim,
                                              trainable=False)


        # syntactic composition
        self.p_comp = torch.nn.Linear(self.hidden_dim * 5 + self.ratio_dim, self.word_dim)
        # parser state to hidden
        self.p_s2h = torch.nn.Linear(self.hidden_dim * 3 + self.ratio_dim, self.hidden_dim)
        # hidden to action
        self.p_act = torch.nn.Linear(self.hidden_dim + self.ratio_dim, self.num_actions)

        self.update_null_node = torch.nn.Linear(self.hidden_dim + self.ratio_dim, self.word_dim)

        self.pempty_buffer_emb = torch.nn.Parameter(torch.randn(self.hidden_dim))
        self.proot_stack_emb = torch.nn.Parameter(torch.randn(self.word_dim))
        self.pempty_action_emb = torch.nn.Parameter(torch.randn(self.hidden_dim))
        self.pempty_stack_emb = torch.nn.Parameter(torch.randn(self.hidden_dim))

        self._input_dropout = Dropout(input_dropout)

        self.buffer = StackRnn(input_size=self.word_dim,
                        hidden_size=self.hidden_dim,
                        num_layers=num_layers,
                        recurrent_dropout_probability=recurrent_dropout_probability,
                        layer_dropout_probability=layer_dropout_probability,
                        same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        self.stack = StackRnn(input_size=self.word_dim,
                        hidden_size=self.hidden_dim,
                        num_layers=num_layers,
                        recurrent_dropout_probability=recurrent_dropout_probability,
                        layer_dropout_probability=layer_dropout_probability,
                        same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        self.action_stack = StackRnn(input_size=self.action_dim,
                        hidden_size=self.hidden_dim,
                        num_layers=num_layers,
                        recurrent_dropout_probability=recurrent_dropout_probability,
                        layer_dropout_probability=layer_dropout_probability,
                        same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        initializer(self)

    def _greedy_decode(self,
                        batch_size: int,
                        sent_len: List[int],
                        embedded_text_input: torch.Tensor,
                        oracle_actions: Optional[List[List[str]]] = None,
                        ) -> Dict[str, Any]:

        self.buffer.reset_stack(batch_size)
        self.stack.reset_stack(batch_size)
        self.action_stack.reset_stack(batch_size)

        # We will keep track of all the losses we accumulate during parsing.
        # If some decision is unambiguous because it's the only thing valid given
        # the parser state, we will not model it. We only model what is ambiguous.

        losses = [[] for _ in range(batch_size)]
        ratio_factor_losses = [[] for _ in range(batch_size)]
        edge_list = [[] for _ in range(batch_size)]
        total_node_num = [0 for _ in range(batch_size)]
        action_list = [[] for _ in range(batch_size)]
        ret_node = [[] for _ in range(batch_size)]
        root_id = [[] for _ in range(batch_size)]
        num_of_generated_node= [[] for _ in range(batch_size)]
        generated_order = [{} for _ in range(batch_size)]
        head_count = [{} for _ in range(batch_size)] # keep track of the number of heads
        swap_count = [{} for _ in range(batch_size)] # keep track of the number of swaps
        reachable = [{} for _ in range(batch_size)]  # set of node ids reachable from each node (by id)
        # push the tokens onto the buffer (tokens is in reverse order)
        for token_idx in range(max(sent_len)):
            for sent_idx in range(batch_size):
                if sent_len[sent_idx] > token_idx:
                    try:
                        self.buffer.push(sent_idx,
                                        input=embedded_text_input[sent_idx][sent_len[sent_idx] - 1 - token_idx],
                                        extra={'token': sent_len[sent_idx] - token_idx - 1})
                    except IndexError:
                        raise IndexError(f"{sent_idx} {batch_size} {token_idx} {sent_len[sent_idx]} {embedded_text_input[sent_idx].dim()}")

        # init stack using proot_emb, considering batch
        for sent_idx in range(batch_size):
            root_id[sent_idx] = sent_len[sent_idx]
            reachable[sent_idx][root_id[sent_idx]] = {root_id[sent_idx]}  # only root is reachable from root so far
            generated_order[sent_idx][root_id[sent_idx]] = 0
            self.stack.push(sent_idx,
                            input=self.proot_stack_emb,
                            extra={'token': root_id[sent_idx]})

        # compute probability of each of the actions and choose an action
        # either from the oracle or if there is no oracle, based on the model
        trans_not_fin = True

        action_tag_for_terminate = [False] * batch_size
        action_sequence_length = [0] * batch_size

        null_node = {}
        for sent_idx in range(batch_size):
            null_node[sent_idx] = []

        while trans_not_fin:
            trans_not_fin = False
            for sent_idx in range(batch_size):
                if (action_sequence_length[sent_idx] > 10000 *
                        sent_len[sent_idx]) and oracle_actions is None:
                    raise RuntimeError(f"Too many actions for a sentence {sent_idx}. actions: {action_list}")
                total_node_num[sent_idx] = sent_len[sent_idx] + len(null_node[sent_idx])
                if not action_tag_for_terminate[sent_idx]:
                    trans_not_fin = True
                    valid_actions = self.calc_valid_actions(edge_list, generated_order, head_count, swap_count,
                                                            null_node, root_id, sent_idx, sent_len)

                    log_probs = None
                    action = valid_actions[0]
                    action_idx = self.vocab.get_token_index(action, namespace='actions')
                    ratio_factor = torch.tensor([len(null_node[sent_idx]) / (1.0 * sent_len[sent_idx])],
                                                device=self.pempty_action_emb.device)

                    if len(valid_actions) > 1:
                        action, action_idx, log_probs, valid_action_tbl = self.predict_action(ratio_factor, sent_idx,
                                                                                              valid_actions)

                    if oracle_actions is not None:
                        action = oracle_actions[sent_idx].pop(0)
                        action_idx = self.vocab.get_token_index(action, namespace='actions')

                    # push action into action_stack
                    self.action_stack.push(sent_idx,
                            input=self.action_embedding(
                                    torch.tensor(action_idx, device=embedded_text_input.device)),
                            extra={
                                    'token': action})
                    action_list[sent_idx].append(action)
                    #print(f'Sent ID: {sent_idx}, action {action}')

                    try:
                        UNK_ID = self.vocab.get_token_index('@@UNKNOWN@@')
                        if log_probs is not None and not (UNK_ID and action_idx == UNK_ID):
                            loss = log_probs[valid_action_tbl[action_idx]]
                            if not torch.isnan(loss):
                                losses[sent_idx].append(loss)
                    except KeyError:
                        raise KeyError(f'action: {action}, valid actions: {valid_actions}')

                    self.exec_action(action, action_sequence_length, action_tag_for_terminate, edge_list,
                                     generated_order, head_count, swap_count, null_node, num_of_generated_node, oracle_actions,
                                     ratio_factor, ratio_factor_losses, sent_idx, sent_len, total_node_num,
                                     reachable)

        if oracle_actions is None and self._fix_unconnected_egraph:
            # Fix edge_list if we are not training. If training, it's empty (no output)
            for sent_idx in range(batch_size):
                self.fix_unconnected_egraph(edge_list[sent_idx], reachable[sent_idx], root_id[sent_idx],
                                            generated_order[sent_idx])

        # categorical cross-entropy
        _loss_CCE = -torch.sum(
                        torch.stack([torch.sum(torch.stack(cur_loss)) for cur_loss in losses if len(cur_loss) > 0])) / \
                                        sum([len(cur_loss) for cur_loss in losses])

        _loss = _loss_CCE

        ret = {
                'loss': _loss,
                'losses': losses,
                }

        # extract null node list in batchmode
        for sent_idx in range(batch_size):
            ret_node[sent_idx] = null_node[sent_idx]

        ret["total_node_num"] = total_node_num

        if oracle_actions is None:
            ret['edge_list'] = edge_list
        ret['action_sequence'] = action_list
        ret["null_node"] = ret_node

        return ret

    def calc_valid_actions(self, edge_list, generated_order, head_count, swap_count, null_node, root_id, sent_idx, sent_len):
        valid_actions = []
        # given the buffer and stack, conclude the valid action list
        if self.buffer.get_len(sent_idx) == 0 and self.stack.get_len(sent_idx) == 1:
            valid_actions += ['FINISH']
        if self.buffer.get_len(sent_idx) > 0:
            valid_actions += ['SHIFT']
        if self.stack.get_len(sent_idx) > 0:
            s0 = self.stack.get_stack(sent_idx)[-1]['token']
            if s0 != root_id[sent_idx] and head_count[sent_idx][s0] > 0:
                valid_actions += ['REDUCE-0']
            if self.output_null_nodes and \
                    len(null_node[sent_idx]) < sent_len[sent_idx]:  # Max number of null nodes is the number of words
                # valid_actions += ['NODE']
                # Support legacy models with "NODE:*" actions that also create an edge:
                node_possible_actions = [a for a in self.vocab.get_token_to_index_vocabulary('actions').keys()
                                         if a.startswith('NODE')
                                         and (":" not in a or s0 == root_id[sent_idx] or a.split('NODE:')[1] != "root")]
                # if len(node_possible_actions) > 1:
                #     logger.warning(f"Possible node actions: {node_possible_actions}. "
                #                    f"NODE:* actions are deprecated - train a new model!")
                valid_actions += node_possible_actions

            if self.stack.get_len(sent_idx) > 1:
                s1 = self.stack.get_stack(sent_idx)[-2]['token']
                if s1 != root_id[sent_idx] and \
                        generated_order[sent_idx][s0] > generated_order[sent_idx][s1] and \
                        swap_count[sent_idx][s1] < self.max_swaps_per_node:
                    valid_actions += ['SWAP']
                if s1 != root_id[sent_idx] and head_count[sent_idx][s1] > 0:
                    valid_actions += ['REDUCE-1']

                # Hacky code to verify that we do not draw the same edge with the same label twice
                labels_left_edge = ['root']  # should not be in vocab anyway actually but be safe
                labels_right_edge = []
                for mod_tok, head_tok, label in edge_list[sent_idx]:
                    if (mod_tok, head_tok) == (s1, s0):
                        labels_left_edge.append(label)
                    if (mod_tok, head_tok) == (s0, s1):
                        labels_right_edge.append(label)
                if s1 != root_id[sent_idx] and (not self.max_heads or
                                                head_count[sent_idx][s1] < self.max_heads):
                    left_edge_possible_actions = \
                        [a for a in self.vocab.get_token_to_index_vocabulary('actions').keys()
                         if a.startswith('LEFT-EDGE') and
                         a.split('LEFT-EDGE:')[1] not in labels_left_edge]
                    valid_actions += left_edge_possible_actions

                if not self.max_heads or head_count[sent_idx][s0] < self.max_heads:
                    if s1 == root_id[sent_idx]:
                        right_edge_possible_actions = [] if "root" in labels_right_edge else ['RIGHT-EDGE:root']
                    else:
                        # hack to disable root
                        labels_right_edge += ['root']
                        right_edge_possible_actions = \
                            [a for a in self.vocab.get_token_to_index_vocabulary('actions').keys()
                             if a.startswith('RIGHT-EDGE')
                             and a.split('RIGHT-EDGE:')[1] not in labels_right_edge]
                    valid_actions += right_edge_possible_actions
        # remove unknown actions:
        vocab_actions = self.vocab.get_token_to_index_vocabulary('actions').keys()
        valid_actions = [valid_action for valid_action in valid_actions if valid_action in vocab_actions]
        if not valid_actions:
            valid_actions = ["FINISH"]
        return valid_actions

    def predict_action(self, ratio_factor, sent_idx, valid_actions):
        stack_emb = self.stack.get_output(sent_idx)
        buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \
            else self.buffer.get_output(sent_idx)
        action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \
            else self.action_stack.get_output(sent_idx)
        p_t = torch.cat([buffer_emb, stack_emb, action_emb])
        p_t = torch.cat([p_t, ratio_factor])
        h = torch.tanh(self.p_s2h(p_t))
        h = torch.cat([h, ratio_factor])
        valid_action_idx = [self.vocab.get_token_index(a, namespace='actions') for a in valid_actions]
        logits = self.p_act(h)[torch.tensor(valid_action_idx, dtype=torch.long, device=h.device)]
        valid_action_tbl = {a: i for i, a in enumerate(valid_action_idx)}
        log_probs = torch.log_softmax(logits, dim=0)
        action_idx = torch.max(log_probs, 0)[1].item()
        action_idx = valid_action_idx[action_idx]
        action = self.vocab.get_token_from_index(action_idx, namespace='actions')
        return action, action_idx, log_probs, valid_action_tbl

    def exec_action(self, action, action_sequence_length, action_tag_for_terminate, edge_list, generated_order,
                    head_count, swap_count, null_node, num_of_generated_node, oracle_actions, ratio_factor,
                    ratio_factor_losses, sent_idx, sent_len, total_node_num, reachable: List[Dict[int, Set[int]]]):
        # generate null node, recursive way
        if action.startswith("NODE"):  # Support legacy models with "NODE:*" actions that also create an edge
            null_node_token = len(null_node[sent_idx]) + sent_len[sent_idx] + 1
            null_node[sent_idx].append(null_node_token)
            head_count[sent_idx][null_node_token] = swap_count[sent_idx][null_node_token] = 0
            reachable[sent_idx][null_node_token] = {null_node_token}

            stack_emb = self.stack.get_output(sent_idx)

            stack_emb = torch.cat([stack_emb, ratio_factor])
            comp_rep = torch.tanh(self.update_null_node(stack_emb))

            node_input = comp_rep

            self.buffer.push(sent_idx,
                             input=node_input,
                             extra={'token': null_node_token})

            total_node_num[sent_idx] = sent_len[sent_idx] + len(null_node[sent_idx])
        # Support legacy models with "NODE:*" actions that also create an edge
        if action.startswith("NODE:") or action.startswith("LEFT-EDGE") or action.startswith("RIGHT-EDGE"):
            if action.startswith("NODE:"):
                logger.warning(f"Took action {action}")
                modifier = self.buffer.get_stack(sent_idx)[-1]
                head = self.stack.get_stack(sent_idx)[-1]
            elif action.startswith("LEFT-EDGE"):
                head = self.stack.get_stack(sent_idx)[-1]
                modifier = self.stack.get_stack(sent_idx)[-2]
            else:
                head = self.stack.get_stack(sent_idx)[-2]
                modifier = self.stack.get_stack(sent_idx)[-1]

            (head_rep, head_tok) = (head['stack_rnn_output'], head['token'])
            (mod_rep, mod_tok) = (modifier['stack_rnn_output'], modifier['token'])

            if oracle_actions is None:
                edge_list[sent_idx].append((mod_tok,
                                            head_tok,
                                            action.split(':', maxsplit=1)[1]))
            # propagate reachability
            reachable_from_mod = reachable[sent_idx][mod_tok]
            for tok, reachable_for_tok_set in reachable[sent_idx].items():
                if head_tok in reachable_for_tok_set:
                    reachable_for_tok_set |= reachable_from_mod

            action_emb = self.pempty_action_emb if self.action_stack.get_len(sent_idx) == 0 \
                else self.action_stack.get_output(sent_idx)

            stack_emb = self.pempty_stack_emb if self.stack.get_len(sent_idx) == 0 \
                else self.stack.get_output(sent_idx)

            buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \
                else self.buffer.get_output(sent_idx)

            # # compute composed representation
            comp_rep = torch.cat([head_rep, mod_rep, action_emb, buffer_emb, stack_emb, ratio_factor])
            comp_rep = torch.tanh(self.p_comp(comp_rep))

            if action.startswith("NODE:"):
                self.buffer.pop(sent_idx)
                self.buffer.push(sent_idx,
                                 input=comp_rep,
                                 extra={'token': mod_tok})
            elif action.startswith("LEFT-EDGE"):
                self.stack.pop(sent_idx)
                self.stack.push(sent_idx,
                                input=comp_rep,
                                extra={'token': head_tok})
            # RIGHT-EDGE
            else:
                stack_0_rep = self.stack.get_stack(sent_idx)[-1]['stack_rnn_input']
                self.stack.pop(sent_idx)
                self.stack.pop(sent_idx)

                self.stack.push(sent_idx,
                                input=comp_rep,
                                extra={'token': head_tok})

                self.stack.push(sent_idx,
                                input=stack_0_rep,
                                extra={'token': mod_tok})
            head_count[sent_idx][mod_tok] += 1

        # Execute the action to update the parser state
        elif action == "REDUCE-0":
            self.stack.pop(sent_idx)

        elif action == "REDUCE-1":
            stack_0 = self.stack.pop(sent_idx)
            self.stack.pop(sent_idx)
            self.stack.push(sent_idx,
                            input=stack_0['stack_rnn_input'],
                            extra={'token': stack_0['token']})

        elif action == "SHIFT":
            buffer = self.buffer.pop(sent_idx)
            self.stack.push(sent_idx,
                            input=buffer['stack_rnn_input'],
                            extra={'token': buffer['token']})
            s0 = self.stack.get_stack(sent_idx)[-1]['token']
            if s0 not in generated_order[sent_idx]:
                num_of_generated_node[sent_idx] = len(generated_order[sent_idx])
                generated_order[sent_idx][s0] = num_of_generated_node[sent_idx]
                head_count[sent_idx][s0] = swap_count[sent_idx][s0] = 0
                reachable[sent_idx][s0] = {s0}

        elif action == "SWAP":
            stack_penult = self.stack.pop_penult(sent_idx)
            self.buffer.push(sent_idx,
                             input=stack_penult['stack_rnn_input'],
                             extra={'token': stack_penult['token']})
            swap_count[sent_idx][stack_penult['token']] += 1

        elif action == "FINISH":
            action_tag_for_terminate[sent_idx] = True
            ratio_factor_losses[sent_idx] = ratio_factor
        action_sequence_length[sent_idx] += 1

    def fix_unconnected_egraph(self, edge_list, reachable, root_id, generated_order):
        """
        Detect cycles in edge_list, select arbitrary element, attach to root's dependent.
        Postprocessing to avoid validation errors (validate.py --level 2): L2 Enhanced unconnected-egraph
        Since REDUCE-{0,1} has a precondition requiring a head, the only way to get an unconnect graph
        (which means there are nodes unreachable from the root) is if there are cycles.
        """
        # Find the dependent of root since it will be the head of all orphans
        pred = None
        for mod_tok, head_tok, label in edge_list:
            if head_tok == root_id:
                pred = mod_tok
                break
        orphans = set(generated_order) - reachable[root_id]  # All nodes not reachable from the root
        while orphans:
            tok = max(orphans, key=lambda x: len(reachable[x]))  # Start from nodes with the most descendents: "roots"
            if pred is None:
                # print(f"No predicate (dependent of root) found: {edge_list}", file=sys.stderr)
                edge_list.append((tok, root_id, "root"))
                pred = tok
            else:
                edge_list.append((tok, pred, "orphan"))
            orphans -= reachable[tok]  # All cycle members were taken care of by doing this

    # Returns an expression of the loss for the sequence of actions.
    # (that is, the oracle_actions if present or the predicted sequence otherwise)
    def forward(self,
                words: Dict[str, torch.LongTensor],
                metadata: List[Dict[str, Any]],
                gold_actions: Dict[str, torch.LongTensor] = None,
                ) -> Dict[str, torch.LongTensor]:

        batch_size = len(metadata)
        sent_len = [len(d['words']) for d in metadata]

        #oracle_actions = None
        if gold_actions is not None:
            oracle_actions  = deepcopy([d['gold_actions'] for d in metadata])

        embedded_text_input = self.text_field_embedder(words)
        embedded_text_input = self._input_dropout(embedded_text_input)

        if self.training:
            try:
                ret_train = self._greedy_decode(batch_size=batch_size,
                                                sent_len=sent_len,
                                                embedded_text_input=embedded_text_input,
                                                oracle_actions=oracle_actions)
            except IndexError:
                raise IndexError(f"{[d['words'] for d in metadata]}")

            _loss = ret_train['loss']
            output_dict = {'loss': _loss}
            return output_dict
        else:
            #reset
            if self.num_validation_instances and (not self._total_validation_instances or self._total_validation_instances == self.num_validation_instances):
                self._total_validation_instances = batch_size
            else:
                self._total_validation_instances += batch_size

        #print(f'{self._total_validation_instances}/{self.num_validation_instances}')
        training_mode = self.training
        self.eval()
        with torch.no_grad():
            ret_eval = self._greedy_decode(batch_size=batch_size,
                                           sent_len=sent_len,
                                           embedded_text_input=embedded_text_input)

        self.train(training_mode)

        edge_list = ret_eval['edge_list']
        null_node = ret_eval['null_node']

        _loss = ret_eval['loss']

        # prediction-mode
        output_dict = {
                'edge_list': edge_list,
                'null_node': null_node,
                'loss': _loss
        }

        for k in ["id", "form", "lemma", "upostag", "xpostag", "feats", "head",
                        "deprel", "misc"]:
                output_dict[k] = [[token_metadata[k] for token_metadata in sentence_metadata['annotation']] for sentence_metadata in metadata]

        output_dict["multiwords"] = [sentence_metadata['multiwords'] for sentence_metadata in metadata]
        output_dict["sent_id"] = [sentence_metadata['sent_id'] for sentence_metadata in metadata]
        output_dict["text"] = [sentence_metadata['text'] for sentence_metadata in metadata]
        # validation mode
        # compute the accuracy when gold actions exist
        if gold_actions is not None:
            predicted_graphs = []
            gold_graphs_conllu = []

            for sent_idx in range(batch_size):
                predicted_graphs.append(eud_trans_outputs_into_conllu({
                        k:output_dict[k][sent_idx] for k in ["id", "form", "lemma", "upostag", "xpostag", "feats", "head",
                                "deprel", "misc", "edge_list", "null_node", "multiwords", "text", "sent_id"]
                }, self.output_null_nodes))
                gold_annotation = [{key: output_dict[key][sent_idx] for key in ("sent_id", "text")
                    if output_dict[key][sent_idx]}]
                for annotation in metadata[sent_idx]['annotation']:
                    gold_annotation.append(annotation)
                gold_graphs_conllu.append(annotation_to_conllu(gold_annotation, self.output_null_nodes))

            predicted_graphs_conllu = [line for lines in predicted_graphs for line in lines]
            gold_graphs_conllu = [line for lines in gold_graphs_conllu for line in lines]

            self._xud_score(predicted_graphs_conllu,
                            gold_graphs_conllu)

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._xud_score is not None and not self.training:
            if self.num_validation_instances and self._total_validation_instances == self.num_validation_instances:
                all_metrics.update(self._xud_score.get_metric(reset=reset))
            elif not self.num_validation_instances:
                all_metrics.update(self._xud_score.get_metric(reset=True))
        return all_metrics
Пример #5
0
class TransitionParserAmr(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 word_dim: int,
                 hidden_dim: int,
                 action_dim: int,
                 entity_dim: int,
                 rel_dim: int,
                 num_layers: int,
                 recurrent_dropout_probability: float = 0.0,
                 layer_dropout_probability: float = 0.0,
                 same_dropout_mask_per_instance: bool = True,
                 input_dropout: float = 0.0,
                 dropout: float = 0.0,
                 pos_tag_embedding: Embedding = None,
                 action_text_field_embedder: TextFieldEmbedder = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 eval_on_training: bool = True,
                 sep_act_type_para: bool = False) -> None:

        super(TransitionParserAmr, self).__init__(vocab, regularizer)

        self._smatch_scorer = SmatchScorer()
        self._mces_scorer = MCESScore(output_type='gscprf', cores=0, trace=0)
        self.eval_on_training = eval_on_training
        self.sep_act_type_para = sep_act_type_para
        self._NEWNODE_TYPE_MAX = 20
        self._NODE_LEN_RATIO_MAX = 20.0
        self._ACTION_TYPE = [
            'SHIFT', 'CONFIRM', 'REDUCE', 'MERGE', 'ENTITY', 'NEWNODE', 'DROP',
            'CACHE', 'LEFT', 'RIGHT'
        ]
        self._ACTION_TYPE_IDX = {
            a: i
            for (i, a) in enumerate(self._ACTION_TYPE)
        }

        self.text_field_embedder = text_field_embedder
        self.pos_tag_embedding = pos_tag_embedding
        self.action_text_field_embedder = action_text_field_embedder

        node_dim = word_dim
        if pos_tag_embedding:
            node_dim += pos_tag_embedding.output_dim

        self.num_actions = vocab.get_vocab_size('actions')
        self.num_newnodes = vocab.get_vocab_size('newnodes')
        self.num_relations = vocab.get_vocab_size('relations')
        self.num_entities = vocab.get_vocab_size('entities')

        self.action_embedding = Embedding(num_embeddings=self.num_actions,
                                          embedding_dim=action_dim)
        self.newnode_embedding = Embedding(num_embeddings=self.num_newnodes,
                                           embedding_dim=node_dim)
        self.rel_embedding = Embedding(num_embeddings=self.num_relations,
                                       embedding_dim=rel_dim)
        self.entity_embedding = Embedding(num_embeddings=self.num_entities,
                                          embedding_dim=entity_dim)

        # merge (stack, buffer, action_stack, deque)
        self.merge = FeedForward(input_dim=hidden_dim * 4,
                                 num_layers=1,
                                 hidden_dims=hidden_dim,
                                 activations=Activation.by_name('relu')(),
                                 dropout=dropout)
        # merge (parent, rel, child) -> parent
        self.merge_parent = FeedForward(
            input_dim=hidden_dim * 2 + rel_dim,
            num_layers=1,
            hidden_dims=node_dim,
            activations=Activation.by_name('relu')(),
            dropout=dropout)
        # merge (parent, rel, child) -> child
        self.merge_child = FeedForward(
            input_dim=hidden_dim * 2 + rel_dim,
            num_layers=1,
            hidden_dims=node_dim,
            activations=Activation.by_name('relu')(),
            dropout=dropout)
        # merge (A, B) -> AB
        self.merge_token = FeedForward(
            input_dim=hidden_dim * 2,
            num_layers=1,
            hidden_dims=node_dim,
            activations=Activation.by_name('relu')(),
            dropout=dropout)
        # merge (AB, entity_label) -> X
        self.merge_entity = FeedForward(
            input_dim=hidden_dim + entity_dim,
            num_layers=1,
            hidden_dims=node_dim,
            activations=Activation.by_name('relu')(),
            dropout=dropout)
        # Q / A value scorer
        if sep_act_type_para:
            self.action_type_scorer = torch.nn.Linear(hidden_dim,
                                                      len(self._ACTION_TYPE))
            action_cnt = {
                action_: len([
                    self.vocab.get_token_index(a, namespace='actions') for a in
                    self.vocab.get_token_to_index_vocabulary('actions').keys()
                    if a.startswith(action_)
                ])
                for action_ in self._ACTION_TYPE
            }
            self.scorers = []
            for a in self._ACTION_TYPE:
                m = torch.nn.Linear(hidden_dim, action_cnt[a])
                self.__setattr__(f'scorer_{a}', m)
                self.scorers.append(m)
        else:
            self.scorer = torch.nn.Linear(hidden_dim, self.num_actions)
        # X -> confirm (X)
        self.confirm_layer = FeedForward(
            input_dim=hidden_dim,
            num_layers=1,
            hidden_dims=node_dim,
            activations=Activation.by_name('relu')(),
            dropout=dropout)

        self.pempty_buffer_emb = torch.nn.Parameter(torch.randn(word_dim))
        self.proot_stack_emb = torch.nn.Parameter(torch.randn(word_dim))
        self.proot_action_emb = torch.nn.Parameter(torch.randn(action_dim))
        self.proot_deque_emb = torch.nn.Parameter(torch.randn(word_dim))

        self._input_dropout = Dropout(input_dropout)

        self.buffer = StackRnn(
            input_size=word_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            recurrent_dropout_probability=recurrent_dropout_probability,
            layer_dropout_probability=layer_dropout_probability,
            same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        self.stack = StackRnn(
            input_size=word_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            recurrent_dropout_probability=recurrent_dropout_probability,
            layer_dropout_probability=layer_dropout_probability,
            same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        self.deque = StackRnn(
            input_size=word_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            recurrent_dropout_probability=recurrent_dropout_probability,
            layer_dropout_probability=layer_dropout_probability,
            same_dropout_mask_per_instance=same_dropout_mask_per_instance)

        self.action_stack = StackRnn(
            input_size=action_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            recurrent_dropout_probability=recurrent_dropout_probability,
            layer_dropout_probability=layer_dropout_probability,
            same_dropout_mask_per_instance=same_dropout_mask_per_instance)
        initializer(self)

    def _greedy_decode(
            self,
            batch_size: int,
            sent_len: List[int],
            embedded_text_input: torch.Tensor,
            metadata: List[Dict[str, Any]],
            oracle_actions: Optional[List[List[int]]] = None
    ) -> Dict[str, Any]:

        self.buffer.reset_stack(batch_size)
        self.stack.reset_stack(batch_size)
        self.deque.reset_stack(batch_size)
        self.action_stack.reset_stack(batch_size)

        # We will keep track of all the losses we accumulate during parsing.
        # If some decision is unambiguous because it's the only thing valid given
        # the parser state, we will not model it. We only model what is ambiguous.
        losses = [[] for _ in range(batch_size)]
        node_labels = [[] for _ in range(batch_size)]
        node_types = [[] for _ in range(batch_size)]
        existing_edges = [{} for _ in range(batch_size)]
        id_cnt = [sent_len[i] + 1 for i in range(batch_size)]
        action_strs = [[] for _ in range(batch_size)]
        origin_tokens = [metadata[i]['tokens'] for i in range(batch_size)]

        for sent_idx in range(batch_size):
            node_labels[sent_idx].append('@@ROOT@@')
            node_types[sent_idx].append('ROOT')
            for i in range(sent_len[sent_idx]):
                node_labels[sent_idx].append(origin_tokens[sent_idx][i])
                node_types[sent_idx].append('TokenNode')
        # push the tokens onto the buffer (tokens is in reverse order)
        for sent_idx in range(batch_size):
            self.buffer.push(sent_idx,
                             input=self.pempty_buffer_emb,
                             extra={
                                 'token': 0,
                                 'type': -1
                             })
        for token_idx in range(max(sent_len)):
            for sent_idx in range(batch_size):
                if sent_len[sent_idx] > token_idx:
                    self.buffer.push(
                        sent_idx,
                        input=embedded_text_input[sent_idx][sent_len[sent_idx]
                                                            - 1 - token_idx],
                        extra={
                            'token': sent_len[sent_idx] - token_idx,
                            'type': 0
                        })

        # init stack using proot_emb, considering batch
        for sent_idx in range(batch_size):
            self.stack.push(sent_idx,
                            input=self.proot_stack_emb,
                            extra={
                                'token': 0,
                                'type': -1
                            })

        # init deque using proot_emb, considering batch
        for sent_idx in range(batch_size):
            self.deque.push(sent_idx,
                            input=self.proot_deque_emb,
                            extra={
                                'token': 0,
                                'type': -1
                            })

        action_id = {
            action_: [
                self.vocab.get_token_index(a, namespace='actions') for a in
                self.vocab.get_token_to_index_vocabulary('actions').keys()
                if a.startswith(action_)
            ]
            for action_ in self._ACTION_TYPE
        }
        action_idx_to_param_idx = {}
        for a in self._ACTION_TYPE:
            for (i, x) in enumerate(action_id[a]):
                action_idx_to_param_idx[x] = i
        origin_token_to_confirm_action = {}
        for a in action_id['CONFIRM']:
            t = self.vocab.get_token_from_index(
                a, namespace='actions').split('@@:@@')[1]
            if t not in origin_token_to_confirm_action:
                origin_token_to_confirm_action[t] = []
            origin_token_to_confirm_action[t].append(a)
        # init stack using proot_emb, considering batch
        for sent_idx in range(batch_size):
            self.action_stack.push(sent_idx,
                                   input=self.proot_action_emb,
                                   extra={
                                       'token': 0,
                                       'type': -1
                                   })

        # compute probability of each of the actions and choose an action
        # either from the oracle or if there is no oracle, based on the model
        trans_not_fin = True
        while trans_not_fin:
            trans_not_fin = False
            for sent_idx in range(batch_size):
                try:
                    valid_action_types = set()
                    valid_actions = []

                    # given the buffer and stack, conclude the valid action list
                    if self.buffer.get_len(sent_idx) > 1 \
                            and self.buffer.get_stack(sent_idx)[-1]['type'] > 1:
                        valid_actions += action_id['SHIFT']
                        valid_action_types.add(self._ACTION_TYPE_IDX['SHIFT'])
                    if self.buffer.get_len(sent_idx) > 1 \
                            and self.buffer.get_stack(sent_idx)[-1]['type'] < 2:
                        tk = node_labels[sent_idx][self.buffer.get_stack(
                            sent_idx)[-1]['token']]
                        if self.buffer.get_stack(sent_idx)[-1]['type'] == 1:
                            tk = tk.replace('@@_@@', '_')
                        if tk in origin_token_to_confirm_action:
                            valid_actions += origin_token_to_confirm_action[tk]
                            valid_action_types.add(
                                self._ACTION_TYPE_IDX['CONFIRM'])
                    if self.buffer.get_len(sent_idx) > 2 \
                            and self.buffer.get_stack(sent_idx)[-1]['type'] < 2 \
                            and self.buffer.get_stack(sent_idx)[-2]['type'] == 0:
                        valid_actions += action_id['MERGE']
                        valid_action_types.add(self._ACTION_TYPE_IDX['MERGE'])
                    if self.buffer.get_len(sent_idx) > 1 \
                            and self.buffer.get_stack(sent_idx)[-1]['type'] < 2:
                        valid_actions += action_id['ENTITY']
                        valid_action_types.add(self._ACTION_TYPE_IDX['ENTITY'])
                    if self.stack.get_len(sent_idx) > 1 \
                            and self.stack.get_stack(sent_idx)[-1]['type'] > 1:
                        valid_actions += action_id['REDUCE']
                        valid_action_types.add(self._ACTION_TYPE_IDX['REDUCE'])
                    if self.buffer.get_len(sent_idx) > 1 \
                            and self.buffer.get_stack(sent_idx)[-1]['type'] == 0:
                        valid_actions += action_id['DROP']
                        valid_action_types.add(self._ACTION_TYPE_IDX['DROP'])
                    if self.buffer.get_len(sent_idx) > 1 \
                            and self.stack.get_len(sent_idx) > 1:
                        valid_actions += action_id['CACHE']
                        valid_action_types.add(self._ACTION_TYPE_IDX['CACHE'])
                    if self.buffer.get_len(sent_idx) > 1 \
                            and self.buffer.get_stack(sent_idx)[-1]['type'] > 1 \
                            and self.buffer.get_stack(sent_idx)[-1]['type'] < self._NEWNODE_TYPE_MAX \
                            and id_cnt[sent_idx] / sent_len[sent_idx] < self._NODE_LEN_RATIO_MAX:
                        valid_actions += action_id['NEWNODE']
                        valid_action_types.add(
                            self._ACTION_TYPE_IDX['NEWNODE'])
                    if self.stack.get_len(sent_idx) > 1 \
                            and self.stack.get_stack(sent_idx)[-1]['type'] > 1 \
                            and self.buffer.get_len(sent_idx) > 0 \
                            and self.buffer.get_stack(sent_idx)[-1]['type'] > 1:
                        u = self.stack.get_stack(sent_idx)[-1]['token']
                        v = self.buffer.get_stack(sent_idx)[-1]['token']
                        for aid in ['LEFT', 'RIGHT']:
                            u, v = v, u
                            for a in action_id[aid]:
                                rel = self.vocab.get_token_from_index(
                                    a, namespace='actions').split('@@:@@')[1]
                                if (u not in existing_edges[sent_idx]
                                    or (rel, v) not in existing_edges[sent_idx][u]) \
                                        and \
                                        (True):  # allowing different arcs with same tag coming from one node?
                                    if u in existing_edges[sent_idx]:
                                        v_cnt = 0
                                        for eg in existing_edges[sent_idx][u]:
                                            if eg[1] == v:
                                                v_cnt += 1
                                        if v_cnt >= 3:
                                            continue
                                    valid_actions.append(a)
                                    valid_action_types.add(
                                        self._ACTION_TYPE_IDX[aid])
                    if self.stack.get_len(sent_idx) > 1 \
                            and self.stack.get_stack(sent_idx)[-1]['type'] > 1 \
                            and self.buffer.get_len(sent_idx) == 1 \
                            and self.buffer.get_stack(sent_idx)[-1]['type'] == -1:
                        v = self.stack.get_stack(sent_idx)[-1]['token']
                        u = self.buffer.get_stack(sent_idx)[-1]['token']
                        rel = '_ROOT_'
                        if (u not in existing_edges[sent_idx] or
                            (rel, v) not in existing_edges[sent_idx][u]):
                            valid_actions.append(
                                self.vocab.get_token_index(
                                    'LEFT@@:@@_ROOT_', namespace='actions'))
                            valid_action_types.add(
                                self._ACTION_TYPE_IDX['LEFT'])
                    valid_action_types = list(valid_action_types)
                    assert ((len(valid_actions) == 0) == (
                        self.stack.get_len(sent_idx) == 1
                        and self.buffer.get_len(sent_idx) == 1))
                    if len(valid_actions) == 0:
                        continue
                    trans_not_fin = True

                    if oracle_actions is not None:
                        valid_action_types = list(range(len(
                            self._ACTION_TYPE)))

                    if self.sep_act_type_para:
                        action_type = valid_action_types[0]
                        action_type_log_probs = None
                        h = None
                        if len(valid_action_types) > 1:
                            stack_emb = self.stack.get_output(sent_idx)
                            buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \
                                else self.buffer.get_output(sent_idx)
                            action_emb = self.action_stack.get_output(sent_idx)
                            deque_emb = self.deque.get_output(sent_idx)

                            p_t = torch.cat(
                                [buffer_emb, stack_emb, action_emb, deque_emb])
                            h = self.merge(p_t)

                            logits = self.action_type_scorer(h)[torch.tensor(
                                valid_action_types,
                                dtype=torch.long,
                                device=h.device)]
                            valid_action_type_tbl = {
                                a: i
                                for i, a in enumerate(valid_action_types)
                            }
                            action_type_log_probs = torch.log_softmax(logits,
                                                                      dim=0)

                            action_type_idx = torch.max(
                                action_type_log_probs, 0)[1].item()
                            action_type = valid_action_types[action_type_idx]

                        if oracle_actions is not None:
                            action_type = -1
                            for a in range(len(self._ACTION_TYPE)):
                                if oracle_actions[sent_idx][0] in action_id[
                                        self._ACTION_TYPE[a]]:
                                    action_type = a
                            assert (action_type >= 0
                                    and action_type < len(self._ACTION_TYPE))

                        if action_type_log_probs is not None:
                            losses[sent_idx].append(action_type_log_probs[
                                valid_action_type_tbl[action_type]])
                        else:
                            losses[sent_idx].append(
                                torch.tensor(
                                    0.0,
                                    dtype=self.action_type_scorer.weight.dtype,
                                    device=self.action_type_scorer.weight.
                                    device))

                        valid_actions = [
                            x for x in valid_actions
                            if x in action_id[self._ACTION_TYPE[action_type]]
                        ]

                        log_probs = None
                        action = valid_actions[0]
                        if len(valid_actions) > 1:
                            if h is None:
                                stack_emb = self.stack.get_output(sent_idx)
                                buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \
                                    else self.buffer.get_output(sent_idx)
                                action_emb = self.action_stack.get_output(
                                    sent_idx)
                                deque_emb = self.deque.get_output(sent_idx)

                                p_t = torch.cat([
                                    buffer_emb, stack_emb, action_emb,
                                    deque_emb
                                ])
                                h = self.merge(p_t)

                            valid_actions_param = [
                                action_idx_to_param_idx[x]
                                for x in valid_actions
                            ]
                            logits = self.scorers[action_type](h)[torch.tensor(
                                valid_actions_param,
                                dtype=torch.long,
                                device=h.device)]
                            valid_action_tbl = {
                                a: i
                                for i, a in enumerate(valid_actions_param)
                            }
                            log_probs = torch.log_softmax(logits, dim=0)

                            action_idx = torch.max(log_probs, 0)[1].item()
                            action = valid_actions[action_idx]

                        if oracle_actions is not None:
                            action = oracle_actions[sent_idx].pop(0)

                        if log_probs is not None:
                            # append the action-specific loss
                            losses[sent_idx].append(log_probs[valid_action_tbl[
                                action_idx_to_param_idx[action]]])
                        else:
                            losses[sent_idx].append(
                                torch.tensor(
                                    0.0,
                                    dtype=self.action_type_scorer.weight.dtype,
                                    device=self.action_type_scorer.weight.
                                    device))
                    else:
                        log_probs = None
                        action = valid_actions[0]
                        if len(valid_actions) > 1:
                            stack_emb = self.stack.get_output(sent_idx)
                            buffer_emb = self.pempty_buffer_emb if self.buffer.get_len(sent_idx) == 0 \
                                else self.buffer.get_output(sent_idx)
                            action_emb = self.action_stack.get_output(sent_idx)
                            deque_emb = self.deque.get_output(sent_idx)

                            p_t = torch.cat(
                                [buffer_emb, stack_emb, action_emb, deque_emb])
                            h = self.merge(p_t)

                            logits = self.scorer(h)[torch.tensor(
                                valid_actions,
                                dtype=torch.long,
                                device=h.device)]
                            valid_action_tbl = {
                                a: i
                                for i, a in enumerate(valid_actions)
                            }
                            log_probs = torch.log_softmax(logits, dim=0)

                            action_idx = torch.max(log_probs, 0)[1].item()
                            action = valid_actions[action_idx]

                        if oracle_actions is not None:
                            action = oracle_actions[sent_idx].pop(0)

                        if log_probs is not None:
                            # append the action-specific loss
                            losses[sent_idx].append(
                                log_probs[valid_action_tbl[action]])
                        else:
                            losses[sent_idx].append(
                                torch.tensor(0.0,
                                             dtype=self.scorer.weight.dtype,
                                             device=self.scorer.weight.device))

                    # push action into action_stack
                    self.action_stack.push(
                        sent_idx,
                        input=self.action_embedding(
                            torch.tensor(action,
                                         device=embedded_text_input.device)),
                        extra={
                            'token':
                            self.vocab.get_token_from_index(
                                action, namespace='actions')
                        })

                    action_str = self.vocab.get_token_from_index(
                        action, namespace='actions')

                    action_strs[sent_idx].append(action_str)

                    if action in action_id['SHIFT']:
                        while self.deque.get_len(sent_idx) > 1:
                            e = self.deque.pop(sent_idx)
                            self.stack.push(
                                sent_idx,
                                input=e['stack_rnn_input'],
                                extra={
                                    k: e[k]
                                    for k in e.keys()
                                    if not k.startswith('stack_rnn_')
                                })
                        e = self.buffer.pop(sent_idx)
                        self.stack.push(sent_idx,
                                        input=e['stack_rnn_input'],
                                        extra={
                                            k: e[k]
                                            for k in e.keys()
                                            if not k.startswith('stack_rnn_')
                                        })
                    elif action in action_id['CONFIRM']:
                        e = self.buffer.pop(sent_idx)
                        concept = self.confirm_layer(e['stack_rnn_output'])
                        self.buffer.push(sent_idx,
                                         input=concept,
                                         extra={
                                             'token': id_cnt[sent_idx],
                                             'type': 2
                                         })
                        node_labels[sent_idx].append(
                            action_str.split('@@:@@')[-1])
                        node_types[sent_idx].append('ConceptNode')
                        id_cnt[sent_idx] += 1
                    elif action in action_id['REDUCE']:
                        self.stack.pop(sent_idx)
                    elif action in action_id['MERGE']:
                        token_a = self.buffer.pop(sent_idx)
                        token_b = self.buffer.pop(sent_idx)
                        token_ab = self.merge_token(
                            torch.cat([
                                token_a['stack_rnn_output'],
                                token_b['stack_rnn_output']
                            ]))
                        token_id = token_a['token']
                        if token_a['type'] == 0:
                            node_labels[sent_idx].append(
                                node_labels[sent_idx][token_a['token']])
                            node_types[sent_idx].append('EntityNode')
                            token_id = id_cnt[sent_idx]
                            id_cnt[sent_idx] += 1
                        node_labels[sent_idx][
                            token_id] += '@@_@@' + node_labels[sent_idx][
                                token_b['token']]
                        self.buffer.push(sent_idx,
                                         input=token_ab,
                                         extra={
                                             'token': token_id,
                                             'type': 1
                                         })
                    elif action in action_id['ENTITY']:
                        entity_name = action_str.split('@@:@@')[-1]
                        buffer_top_id = self.buffer.get_stack(
                            sent_idx)[-1]['token']
                        entity = self.entity_embedding(
                            torch.tensor(self.vocab.get_token_index(
                                action_str.split('@@:@@')[1],
                                namespace='entities'),
                                         device=embedded_text_input.device))
                        e = self.buffer.pop(sent_idx)
                        entity = self.merge_entity(
                            torch.cat([e['stack_rnn_output'], entity]))
                        self.buffer.push(sent_idx,
                                         input=entity,
                                         extra={
                                             'token': id_cnt[sent_idx],
                                             'type': 2
                                         })
                        node_labels[sent_idx].append(
                            action_str.split('@@:@@')[-1])
                        node_types[sent_idx].append('ConceptNode')
                        id_cnt[sent_idx] += 1
                        if entity_name == 'date-entity':
                            datestr = ' '.join(
                                map(
                                    unquote, node_labels[sent_idx]
                                    [buffer_top_id].split('@@_@@')))
                            entry, flags = parse_date(datestr)
                            date_node_id = id_cnt[sent_idx] - 1
                            for relation, flag in zip(['year', 'month', 'day'],
                                                      flags):
                                if flag:
                                    value = getattr(entry, relation)
                                    node_labels[sent_idx].append(str(value))
                                    node_types[sent_idx].append(
                                        'AttributeNode')
                                    id_cnt[sent_idx] += 1
                                    if date_node_id not in existing_edges[
                                            sent_idx]:
                                        existing_edges[sent_idx][
                                            date_node_id] = []
                                    existing_edges[sent_idx][
                                        date_node_id].append(
                                            (relation, id_cnt[sent_idx] - 1))
                        if entity_name not in [
                                'date-entity', 'capitalism', '2',
                                'contrast-01', '1', 'compare-01'
                        ]:
                            if entity_name != 'name':
                                node_labels[sent_idx].append('name')
                                node_types[sent_idx].append('ConceptNode')
                                id_cnt[sent_idx] += 1
                                if id_cnt[sent_idx] - 2 not in existing_edges[
                                        sent_idx]:
                                    existing_edges[sent_idx][id_cnt[sent_idx] -
                                                             2] = []
                                existing_edges[sent_idx][id_cnt[sent_idx] -
                                                         2].append((
                                                             'name',
                                                             id_cnt[sent_idx] -
                                                             1))
                            name_node_id = id_cnt[sent_idx] - 1
                            for (i, opi) in enumerate(
                                    map(
                                        unquote, node_labels[sent_idx]
                                        [buffer_top_id].split('@@_@@'))):
                                node_labels[sent_idx].append(opi)
                                node_types[sent_idx].append('AttributeNode')
                                id_cnt[sent_idx] += 1
                                if name_node_id not in existing_edges[
                                        sent_idx]:
                                    existing_edges[sent_idx][name_node_id] = []
                                existing_edges[sent_idx][name_node_id].append(
                                    (f"op{i + 1}", id_cnt[sent_idx] - 1))
                    elif action in action_id['NEWNODE']:
                        node = self.newnode_embedding(
                            torch.tensor(self.vocab.get_token_index(
                                action_str.split('@@:@@')[1],
                                namespace='newnodes'),
                                         device=embedded_text_input.device))
                        self.buffer.get_stack(
                            sent_idx)[-1]['type'] += self._NEWNODE_TYPE_MAX
                        self.buffer.push(
                            sent_idx,
                            input=node,
                            extra={
                                'token':
                                id_cnt[sent_idx],
                                'type':
                                self.buffer.get_stack(sent_idx)[-1]['type'] -
                                self._NEWNODE_TYPE_MAX + 1
                            })
                        node_labels[sent_idx].append(
                            action_str.split('@@:@@')[-1])
                        node_types[sent_idx].append('ConceptNode')
                        id_cnt[sent_idx] += 1
                    elif action in action_id['DROP']:
                        self.buffer.pop(sent_idx)
                    elif action in action_id['CACHE']:
                        e = self.stack.pop(sent_idx)
                        self.deque.push(sent_idx,
                                        input=e['stack_rnn_input'],
                                        extra={
                                            k: e[k]
                                            for k in e.keys()
                                            if not k.startswith('stack_rnn_')
                                        })
                    elif action in action_id['LEFT']:
                        parent = self.buffer.pop(sent_idx)
                        child = self.stack.pop(sent_idx)
                        rel = self.rel_embedding(
                            torch.tensor(self.vocab.get_token_index(
                                action_str.split('@@:@@')[1],
                                namespace='relations'),
                                         device=embedded_text_input.device))
                        parent_rep = self.merge_parent(
                            torch.cat([
                                parent['stack_rnn_output'], rel,
                                child['stack_rnn_output']
                            ]))
                        child_rep = self.merge_child(
                            torch.cat([
                                parent['stack_rnn_output'], rel,
                                child['stack_rnn_output']
                            ]))
                        self.buffer.push(sent_idx,
                                         input=parent_rep,
                                         extra={
                                             k: parent[k]
                                             for k in parent.keys()
                                             if not k.startswith('stack_rnn_')
                                         })
                        self.stack.push(sent_idx,
                                        input=child_rep,
                                        extra={
                                            k: child[k]
                                            for k in child.keys()
                                            if not k.startswith('stack_rnn_')
                                        })
                        if parent['token'] not in existing_edges[sent_idx]:
                            existing_edges[sent_idx][parent['token']] = []
                        existing_edges[sent_idx][parent['token']].append(
                            (action_str.split('@@:@@')[1], child['token']))
                    elif action in action_id['RIGHT']:
                        child = self.buffer.pop(sent_idx)
                        parent = self.stack.pop(sent_idx)
                        rel = self.rel_embedding(
                            torch.tensor(self.vocab.get_token_index(
                                action_str.split('@@:@@')[1],
                                namespace='relations'),
                                         device=embedded_text_input.device))
                        parent_rep = self.merge_parent(
                            torch.cat([
                                parent['stack_rnn_output'], rel,
                                child['stack_rnn_output']
                            ]))
                        child_rep = self.merge_child(
                            torch.cat([
                                parent['stack_rnn_output'], rel,
                                child['stack_rnn_output']
                            ]))
                        self.buffer.push(sent_idx,
                                         input=child_rep,
                                         extra={
                                             k: child[k]
                                             for k in child.keys()
                                             if not k.startswith('stack_rnn_')
                                         })
                        self.stack.push(sent_idx,
                                        input=parent_rep,
                                        extra={
                                            k: parent[k]
                                            for k in parent.keys()
                                            if not k.startswith('stack_rnn_')
                                        })
                        if parent['token'] not in existing_edges[sent_idx]:
                            existing_edges[sent_idx][parent['token']] = []
                        existing_edges[sent_idx][parent['token']].append(
                            (action_str.split('@@:@@')[1], child['token']))
                    else:
                        raise ValueError(f'Illegal action: \"{action}\"')
                except BaseException as e:
                    print(e)
                    print(metadata[sent_idx]['id'])
                    raise e

        _loss = -torch.sum(
            torch.stack([torch.sum(torch.stack(cur_loss)) for cur_loss in losses if len(cur_loss) > 0])) / \
                sum([len(cur_loss) for cur_loss in losses])
        ret = {
            'loss': _loss,
            'losses': losses,
        }
        if oracle_actions is None:
            ret['existing_edges'] = existing_edges
            ret['node_labels'] = node_labels
            ret['node_types'] = node_types
            ret['id_cnt'] = id_cnt
            ret['action_strs'] = action_strs
        return ret

    # Returns an expression of the loss for the sequence of actions.
    # (that is, the oracle_actions if present or the predicted sequence otherwise)
    def forward(
        self,
        tokens: Dict[str, torch.LongTensor],
        metadata: List[Dict[str, Any]],
        gold_actions: Dict[str, torch.LongTensor] = None,
        gold_newnodes: Dict[str, torch.LongTensor] = None,
        gold_entities: Dict[str, torch.LongTensor] = None,
        gold_relations: Dict[str, torch.LongTensor] = None,
        lemmas: Dict[str, torch.LongTensor] = None,
        pos_tags: torch.LongTensor = None,
        arc_tags: torch.LongTensor = None,
    ) -> Dict[str, torch.LongTensor]:

        batch_size = len(metadata)
        sent_len = [len(d['tokens']) for d in metadata]

        oracle_actions = None
        if gold_actions is not None:
            oracle_actions = [d['gold_actions'] for d in metadata]
            oracle_actions = [[
                self.vocab.get_token_index(s, namespace='actions') for s in l
            ] for l in oracle_actions]

        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
        embedded_text_input = self._input_dropout(embedded_text_input)

        if self.training:
            ret_train = self._greedy_decode(
                batch_size=batch_size,
                sent_len=sent_len,
                embedded_text_input=embedded_text_input,
                metadata=metadata,
                oracle_actions=oracle_actions)
            if not self.eval_on_training:
                return {'loss': ret_train['loss']}
        training_mode = self.training
        self.eval()
        with torch.no_grad():
            ret_eval = self._greedy_decode(
                batch_size=batch_size,
                sent_len=sent_len,
                metadata=metadata,
                embedded_text_input=embedded_text_input)
        self.train(training_mode)

        action_strs = ret_eval['action_strs']

        existing_edges = ret_eval['existing_edges']
        id_cnt = ret_eval['id_cnt']
        node_labels = ret_eval['node_labels']
        node_types = ret_eval['node_types']
        _loss = ret_train['loss'] if self.training else ret_eval['loss']

        amr_dicts = []
        for sent_idx in range(batch_size):
            amr_dict = extract_mrp_dict(
                existing_edges=existing_edges[sent_idx],
                sent_len=sent_len[sent_idx],
                id_cnt=id_cnt[sent_idx],
                node_labels=node_labels[sent_idx],
                node_types=node_types[sent_idx],
                metadata=metadata[sent_idx])
            amr_dicts.append(amr_dict)

        golds_mrp = [d['mrp']
                     for d in metadata] if 'mrp' in metadata[0] else []
        golds_amr = [d['amr']
                     for d in metadata] if 'amr' in metadata[0] else []
        if golds_mrp:
            self._mces_scorer(predictions=amr_dicts, golds=golds_mrp)
        elif golds_amr:
            amr_strs = [
                mrp_dict_to_amr_str(amr_dict, all_nodes=True)
                for amr_dict in amr_dicts
            ]
            for amr_str, gold_amr in zip(amr_strs, golds_amr):
                self._smatch_scorer.update(amr_str, gold_amr)

        output_dict = {
            'tokens': [d['tokens'] for d in metadata],
            'loss': _loss,
            'sent_len': sent_len,
            'metadata': metadata
        }
        if 'gold_actions' in metadata[0]:
            output_dict['oracle_actions'] = [
                d['gold_actions'] for d in metadata
            ]
        if not self.training:
            output_dict['existing_edges'] = existing_edges
            output_dict['node_labels'] = node_labels
            output_dict['node_types'] = node_types
            output_dict['id_cnt'] = id_cnt

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        if self._smatch_scorer.total_gold_num > 0:
            p, r, f1 = self._smatch_scorer.get_prf()
            if reset:
                self._smatch_scorer.reset()
            return {
                'P': p,
                'R': r,
                'F1': f1,
            }
        else:
            return self._mces_scorer.get_metric(reset=reset)