예제 #1
0
    def gold_feed(self, predicate_hidden, input_sense_logits, argument_hidden,
                  input_arc_tag_logits, predicate_representation,
                  embedded_candidate_preds, graph_mask, sense_mask,
                  predicate_mask, soft_tags, soft_index):

        soft_tags = soft_tags.clamp(min=1e-5, max=1 - 1e-5) * graph_mask
        soft_index = soft_index.clamp(min=1e-5, max=1 - 1e-5) * sense_mask
        learn_rate = self.learn_rate.clamp(min=self.minimum_rate)
        # (batch_size, sequence_length, node_dim)
        predicate_emb = (
            (soft_index.unsqueeze(2).matmul(embedded_candidate_preds)).squeeze(
                2) + predicate_representation) * predicate_mask.unsqueeze(-1)
        score_nodes, score_edges, grad_to_predicate_emb, grad_to_arc_tag_probs = self.graph_scorer(
            predicate_emb, soft_tags * graph_mask, predicate_hidden,
            argument_hidden, graph_mask)

        grad_to_sense_probs = embedded_candidate_preds.matmul(
            grad_to_predicate_emb.unsqueeze(-1)).squeeze(-1)

        grad_to_arc_tag_probs = grad_to_arc_tag_probs + input_arc_tag_logits
        grad_to_arc_tag_probs = grad_to_arc_tag_probs * soft_tags
        grad_to_arc_tag_probs = grad_to_arc_tag_probs - grad_to_arc_tag_probs.sum(
            -1, keepdim=True) * soft_tags

        input_arc_tag_logits = input_arc_tag_logits * (
            1 - learn_rate) + learn_rate * grad_to_arc_tag_probs * graph_mask

        grad_to_sense_probs = grad_to_sense_probs + input_sense_logits
        grad_to_sense_probs = grad_to_sense_probs * soft_index
        grad_to_sense_probs = grad_to_sense_probs - grad_to_sense_probs.sum(
            -1, keepdim=True) * soft_index
        input_sense_logits = input_sense_logits * (
            1 - learn_rate) + learn_rate * grad_to_sense_probs * sense_mask

        arc_tag_logits = input_arc_tag_logits
        sense_logits = input_sense_logits
        if self.training:
            arc_tag_logits = arc_tag_logits + self.gumbel_t * (_sample_gumbel(
                arc_tag_logits.size(), out=arc_tag_logits.new()) - soft_tags)
            sense_logits = sense_logits + self.gumbel_t * (_sample_gumbel(
                sense_logits.size(), out=sense_logits.new()) - soft_index)

        arc_tag_probs_soft = torch.nn.functional.softmax(arc_tag_logits,
                                                         dim=-1)

        sense_probs_soft = torch.nn.functional.softmax(sense_logits, dim=-1)

        arc_tag_probs = hard(
            arc_tag_probs_soft,
            graph_mask) if self.stright_through else arc_tag_probs_soft
        sense_probs = hard(
            sense_probs_soft,
            sense_mask) if self.stright_through else sense_probs_soft

        return input_arc_tag_logits, arc_tag_probs, input_sense_logits, sense_probs, score_nodes, score_edges
예제 #2
0
    def one_iteration(self,
                      predicate_rep,
                      input_sense_logits,
                      argument_rep,
                      input_arc_tag_logits,
                      embedded_candidate_preds,
                      graph_mask,
                      sense_mask,
                      predicate_mask,
                      arc_tag_probs,
                      sense_probs,
                      soft_tags,
                      soft_index,
                      step_size=1,
                      sense_rep = None):

        # shape (batch_size, predicates_len, node_dim)
        if sense_rep is not None:
            predicate_emb = (self._dropout((sense_probs.unsqueeze(2).matmul(embedded_candidate_preds)).squeeze(
                2) + sense_rep)* predicate_mask.unsqueeze(-1))
        else:
            predicate_emb = self._dropout((sense_probs.unsqueeze(2).matmul(embedded_candidate_preds)).squeeze(
                2) * predicate_mask.unsqueeze(-1))

        score_nodes, score_edges, grad_to_predicate_emb, grad_to_arc_tag_probs = self.graph_scorer(predicate_emb,
                                                                                                   arc_tag_probs * graph_mask,
                                                                                                   predicate_rep,
                                                                                                   argument_rep,
                                                                                                   graph_mask)
        grad_to_predicate_emb =  self._dropout(grad_to_predicate_emb)

        grad_to_sense_probs = embedded_candidate_preds.matmul(grad_to_predicate_emb.unsqueeze(-1)).squeeze(-1)

        arc_tag_logits = input_arc_tag_logits + grad_to_arc_tag_probs
        sense_logits = input_sense_logits + grad_to_sense_probs


        if self.training and self.gumbel_t:
            arc_tag_logits = arc_tag_logits + self.gumbel_t * _sample_gumbel(arc_tag_logits.size(),
                                                                             out=arc_tag_logits.new())
            sense_logits = sense_logits  + self.sense_gumbel_t * _sample_gumbel(sense_logits.size(),
                                                                         out=sense_logits.new())
        if self.training and self.subtract_gold and soft_tags is not None:
            arc_tag_logits = arc_tag_logits + self.subtract_gold * (- soft_tags)
            sense_logits = sense_logits + self.subtract_gold * (- soft_index)

        arc_tag_probs, sense_probs = self.decode(arc_tag_logits,sense_logits,graph_mask,sense_mask,soft_tags,soft_index,arc_tag_probs,sense_probs,step_size)

        score_nodes = None if score_nodes is None else score_nodes.sum(-1,keepdim=True)
        score_edges = None if score_edges is None else score_edges.sum(-1,keepdim=True)
        return arc_tag_logits, arc_tag_probs, sense_logits, sense_probs, score_nodes, score_edges
예제 #3
0
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            pos_tags: torch.LongTensor,
            dep_tags: torch.LongTensor,
            predicate_candidates: torch.LongTensor = None,
            epoch: int = None,
            predicate_indexes: torch.LongTensor = None,
            sense_indexes: torch.LongTensor = None,
            predicates: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None,
            arc_tags: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``.
        verb_indicator: torch.LongTensor, required.
            An integer ``SequenceFeatureField`` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        pos_tags : ``torch.LongTensor``, optional, (default = None).
            The output of a ``SequenceLabelField`` containing POS tags.
        arc_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape ``(batch_size, sequence_length, sequence_length)``.
        pred_candidates : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape ``(batch_size, predicates_len, batch_max_senses)``.

        predicate_indexes:  shape (batch_size, predicates_len)

        Returns
        -------
        An output dictionary.
        """
        #  torch.cuda.empty_cache()

        #   if self.refine_epoch > -1 and epoch is not None and epoch >  self.refine_epoch:
        #       self.freeze_initial()
        # shape (batch_size, sequence_length, predicates_len)
        if arc_tags is not None:
            arc_tags = arc_tags.long()
        #print('arc_tags',arc_tags.size(),arc_tags)
        # shape (batch_size, sequence_length, embedding_dim)
        embedded_text_input = self.text_field_embedder(tokens)

        # shape (batch_size, predicates_len, batch_max_senses , pred_dim)
        embedded_candidate_preds = self._pred_embedding(predicate_candidates)
        #print ('predicate_candidates',predicate_candidates.size(), predicate_candidates)

        # shape (batch_size, predicates_len, batch_max_senses )
        sense_mask = (predicate_candidates > 0).float()

        # shape (batch_size, predicates_len)
        predicate_indexes = predicate_indexes.long()
        #print ('predicate_indexes', predicate_indexes.size(), predicate_indexes)

        embedded_pos_tags = self._pos_tag_embedding(pos_tags)
        #embedded_dep_tags = self._dep_tag_embedding(dep_tags)
        embedded_text_input = torch.cat(
            [embedded_text_input, embedded_pos_tags], -1)

        embedded_text_input = self._input_dropout(embedded_text_input)

        # shape (batch_size, sequence_length)
        mask = get_text_field_mask(tokens)
        # print ('mask', mask.size(), mask)
        batch_size, sequence_length = mask.size()

        float_mask = mask.float()

        # shape (batch_size, predicates_len)
        predicate_mask = (predicate_indexes > -1).float()
        # shape (batch_size, sequence_length, predicates_len, 1)
        graph_mask = (predicate_mask.unsqueeze(1) *
                      float_mask.unsqueeze(2)).unsqueeze(-1)

        # shape (batch_size, sequence_length, hidden_dim)
        if isinstance(self.encoder, FeedForward):
            encoded_text = self._dropout(self.encoder(embedded_text_input))
        else:
            encoded_text = self._dropout(
                self.encoder(embedded_text_input, mask))

        #print('encoded_text', encoded_text.size(), encoded_text)

        padding_for_predicate = torch.zeros(
            size=[batch_size, 1, encoded_text.size(-1)],
            device=encoded_text.device)

        # shape (batch_size, predicates_len, hidden_dim)
        encoded_text_for_predicate = torch.cat(
            [padding_for_predicate, encoded_text], dim=1)

        #    print ("paded encoded_text_for_predicate",encoded_text_for_predicate.size())
        #print("encoded_text_for_predicate", encoded_text_for_predicate.size())

        #      print("predicate_indexes", predicate_indexes.size())
        index_size = list(predicate_indexes.size()) + [encoded_text.size(-1)]

        #print("index_size", index_size)
        #print("predicate_indexes", predicate_indexes.size(),predicate_indexes)
        effective_predicate_indexes = (predicate_indexes.unsqueeze(-1) +
                                       1).expand(index_size)
        #print('effective_predicate_indexes', effective_predicate_indexes.size(),effective_predicate_indexes)
        encoded_text_for_predicate = encoded_text_for_predicate.gather(
            dim=1, index=effective_predicate_indexes)
        #print('encoded_text_for_predicate',encoded_text_for_predicate.size(),encoded_text_for_predicate)

        #    print ("selected encoded_text_for_predicate",encoded_text_for_predicate.size())
        # shape (batch_size, sequence_length, arc_representation_dim)
        #arg_arc_representation = self._dropout(self.arg_arc_feedforward(encoded_text))

        # shape (batch_size, predicates_len, arc_representation_dim)
        #pred_arc_representation = self._dropout(self.pred_arc_feedforward(encoded_text_for_predicate))

        # shape (batch_size, capsule_dim, sequence_length, predicates_len)
        #arc_logits = self.arc_attention(arg_arc_representation,
        #                                pred_arc_representation)#.unsqueeze(-1)  # + (1-predicate_mask)*1e9

        # shape (batch_size, sequence_length, tag_representation_dim)
        arg_tag_representation = self._dropout(
            self.arg_tag_feedforward(encoded_text))

        # shape (batch_size, predicates_len, arc_representation_dim)
        pred_tag_representation = self._dropout(
            self.pred_tag_feedforward(encoded_text_for_predicate))

        # shape (batch_size, num_tags * capsule_dim, sequence_length, predicates_len)
        arc_tag_logits = self.tag_bilinear(arg_tag_representation,
                                           pred_tag_representation)

        # Switch to (batch_size, predicates_len, refine_representation_dim)
        predicate_representation = self._dropout(
            self.predicte_feedforward(encoded_text_for_predicate))

        # (batch_size, predicates_len, max_sense)
        sense_logits = embedded_candidate_preds.matmul(
            predicate_representation.unsqueeze(-1)).squeeze(-1)
        #    if self.training is False and False:
        #arc_logits = arc_logits + (1 - predicate_mask.unsqueeze(1).unsqueeze(1)) * 1e9
        #sense_logits = sense_logits - (1 - sense_mask) * 1e9

        # Switch to (batch_size, sequence_length, predicates_len, num_tags, capsule_dim)
        arc_tag_logits = arc_tag_logits.permute(0, 2, 3, 1).view(
            batch_size, sequence_length, -1,
            self.vocab.get_vocab_size("arc_types") + 1, self.capsule_dim)
        # Switch to (batch_size, sequence_length, predicates_len, 1, capsule_dim)
        #arc_logits = arc_logits.permute(0, 2, 3, 1).view(batch_size, sequence_length, -1, 1, self.capsule_dim)
        # Switch to (batch_size, sequence_length, predicates_len, num_tags + 1, capsule_dim)
        #arc_tag_logits = torch.cat([arc_logits, arc_tag_logits], dim=-2).contiguous()
        # Switch to (batch_size, sequence_length, predicates_len, num_tags + 1) via capsule_net
        if self.base_average == False:
            if self.using_global == False:
                arc_tag_logits = self.capsule_net_layer(
                    arc_tag_logits, self.iter_num)
            else:
                arc_tag_logits = self.capsule_net_layer_with_massage_passing(
                    arc_tag_logits, self.iter_num, self.passing_type)
        else:
            arc_tag_logits = torch.mean(arc_tag_logits, -1)  #baseline option
        #arc_tag_logits = torch.cat([arc_logits, arc_tag_logits], dim=-1).contiguous()

        output_dict = {
            "tokens": [meta["tokens"] for meta in metadata],
        }

        if arc_tags is not None:
            soft_tags = torch.zeros(size=arc_tag_logits.size(),
                                    device=arc_tag_logits.device)
            soft_tags.scatter_(3, arc_tags.unsqueeze(3) + 1, 1) * graph_mask

            #    print ("sense_logits",sense_logits.size(),sense_logits)
            #    print ("sense_indexes",sense_indexes.size(),sense_indexes)
            soft_index = torch.zeros(size=sense_logits.size(),
                                     device=sense_logits.device)
            soft_index.scatter_(2, sense_indexes.unsqueeze(2), 1) * sense_mask

        # We stack scores here because the f1 measure expects a
        # distribution, rather than a single value.
        #     arc_tag_probs = torch.cat([one_minus_arc_probs, arc_tag_probs*arc_probs], dim=-1)

        if self.training:
            arc_tag_logits = arc_tag_logits + self.gumbel_t * (_sample_gumbel(
                arc_tag_logits.size(), out=arc_tag_logits.new()))
            sense_logits = sense_logits + self.gumbel_t * (_sample_gumbel(
                sense_logits.size(), out=sense_logits.new()))

        arc_tag_probs, sense_probs, arc_tag_probs_second = self._greedy_decode(
            arc_tag_logits, sense_logits)
        if arc_tags is not None:
            loss = self._construct_loss(arc_tag_logits, arc_tag_probs,
                                        arc_tags, soft_tags, sense_logits,
                                        sense_probs, sense_indexes, soft_index,
                                        graph_mask, sense_mask, predicate_mask,
                                        float_mask, arc_tag_probs_second)
            self._labelled_f1(arc_tag_probs,
                              arc_tags + 1,
                              graph_mask.squeeze(-1),
                              sense_probs,
                              predicate_candidates,
                              predicates,
                              linear_scores=arc_tag_probs * arc_tag_logits,
                              n_iteration=1)

            output_dict["loss"] = loss

        output_dict["arc_tag_probs"] = arc_tag_probs
        output_dict["sense_probs"] = sense_probs
        output_dict["arc_tag_logits"] = arc_tag_logits
        output_dict["sense_logits"] = sense_logits

        #output_dict["predicate_representation"] = predicate_representation
        #output_dict["embedded_candidate_preds"] = embedded_candidate_preds
        #output_dict["encoded_text"] = encoded_text
        #output_dict["encoded_text_for_predicate"] = encoded_text_for_predicate
        #output_dict["embedded_text_input"] = embedded_text_input

        return output_dict
예제 #4
0
    def one_iteration(self,
                      predicate_rep,
                      input_sense_logits,
                      argument_rep,
                      input_arc_tag_logits,
                      embedded_candidate_preds,
                      graph_mask,
                      sense_mask,
                      predicate_mask,
                      arc_tag_probs,
                      sense_probs,
                      soft_tags,
                      soft_index,
                      step_size=1,
                      sense_rep=None):

        all_edges = (arc_tag_probs * graph_mask).sum(1)
        all_active_edges = all_edges[:, :, 1:]

        # shape (batch_size, predicates_len, node_dim)
        predicate_emb = self._dropout(
            (sense_probs.unsqueeze(2).matmul(embedded_candidate_preds)
             ).squeeze(2) * predicate_mask.unsqueeze(-1))

        if self.graph_type == 3:
            predicate_representation = self._dropout(
                self.predicte_refiner(
                    torch.cat([predicate_rep, all_active_edges], dim=-1)))
        elif self.graph_type == 2:
            predicate_representation = self._dropout(
                self.predicte_refiner(
                    torch.cat([predicate_emb, predicate_rep], dim=-1)))
        else:
            predicate_representation = self._dropout(
                self.predicte_refiner(
                    torch.cat([predicate_emb, predicate_rep, all_active_edges],
                              dim=-1)))

        # (batch_size, predicates_len, max_sense)
        sense_logits = embedded_candidate_preds.matmul(
            predicate_representation.unsqueeze(-1)).squeeze(-1)

        all_other_edges = (all_edges.unsqueeze(1) - arc_tag_probs) * graph_mask
        all_other_edges = all_other_edges[:, :, :, 1:]

        encoded_arg_enc = self._arc_tag_arg_enc(argument_rep)

        if self.use_predicate_rep:
            encoded_pred_enc = self._arc_tag_pred_enc(predicate_rep)

        if self.graph_type == 1:
            tag_input_date = arc_tag_probs
        elif self.graph_type == 2:
            tag_input_date = arc_tag_probs
        else:
            tag_input_date = torch.cat([arc_tag_probs, all_other_edges],
                                       dim=-1)

        tag_enc = self._arc_tag_tags_enc(tag_input_date)

        if self.use_predicate_rep:
            linear_added = tag_enc + encoded_arg_enc.unsqueeze(2).expand_as(
                tag_enc) + encoded_pred_enc.unsqueeze(1).expand_as(tag_enc)
        else:
            linear_added = tag_enc + encoded_arg_enc.unsqueeze(2).expand_as(
                tag_enc)

        if self.graph_type != 2:
            predicate_emb_enc = self._arc_tag_sense_enc(predicate_emb)
            linear_added = linear_added + predicate_emb_enc.unsqueeze(
                1).expand_as(tag_enc)

        arc_tag_logits = self.arc_tag_refiner(
            self._dropout(self.activation(linear_added)))

        sense_logits = sense_logits + input_sense_logits
        arc_tag_logits = arc_tag_logits + input_arc_tag_logits

        if self.training and self.gumbel_t:
            arc_tag_logits = arc_tag_logits + self.gumbel_t * _sample_gumbel(
                arc_tag_logits.size(), out=arc_tag_logits.new())
            sense_logits = sense_logits + 5 * self.gumbel_t * _sample_gumbel(
                sense_logits.size(), out=sense_logits.new())
        if self.training and self.subtract_gold and soft_tags is not None:
            arc_tag_logits = arc_tag_logits + self.subtract_gold * (-soft_tags)
            sense_logits = sense_logits + self.subtract_gold * (-soft_index)

        arc_tag_probs, sense_probs = self.decode(arc_tag_logits, sense_logits,
                                                 graph_mask, sense_mask,
                                                 soft_tags, soft_index,
                                                 arc_tag_probs, sense_probs,
                                                 step_size)

        return arc_tag_logits, arc_tag_probs, sense_logits, sense_probs, None, None
예제 #5
0
    def one_iteration(self,
                      predicate_hidden,
                      input_sense_logits,
                      argument_hidden,
                      input_arc_tag_logits,
                      predicate_representation,
                      embedded_candidate_preds,
                      graph_mask,
                      sense_mask,
                      predicate_mask,
                      soft_tags,
                      soft_index,
                      arc_tag_probs,
                      sense_probs,
                      old_arc_tag_logits=None,
                      old_sense_logits=None,
                      old_score_nodes=None,
                      old_score_edges=None):

        learn_rate = self.learn_rate.clamp(min=self.minimum_rate)
        #    print ("predicate_representation",predicate_representation.size())
        predicate_emb = (
            (sense_probs.unsqueeze(2).matmul(embedded_candidate_preds)
             ).squeeze(2) +
            predicate_representation) * predicate_mask.unsqueeze(-1)
        score_nodes, score_edges, grad_to_predicate_emb, grad_to_arc_tag_probs = self.graph_scorer(
            predicate_emb, arc_tag_probs * graph_mask, predicate_hidden,
            argument_hidden, graph_mask)
        grad_to_sense_probs = embedded_candidate_preds.matmul(
            grad_to_predicate_emb.unsqueeze(-1)).squeeze(-1)

        grad_to_arc_tag_probs = grad_to_arc_tag_probs + input_arc_tag_logits
        grad_to_arc_tag_probs = grad_to_arc_tag_probs * arc_tag_probs
        grad_to_arc_tag_probs = grad_to_arc_tag_probs - grad_to_arc_tag_probs.sum(
            -1, keepdim=True) * arc_tag_probs

        grad_to_sense_probs = grad_to_sense_probs + input_sense_logits
        grad_to_sense_probs = grad_to_sense_probs * sense_probs
        grad_to_sense_probs = grad_to_sense_probs - grad_to_sense_probs.sum(
            -1, keepdim=True) * sense_probs
        if old_arc_tag_logits is not None:
            input_arc_tag_logits = old_arc_tag_logits * (
                1 -
                learn_rate) + learn_rate * grad_to_arc_tag_probs * graph_mask
            input_sense_logits = old_sense_logits * (
                1 - learn_rate) + learn_rate * grad_to_sense_probs * sense_mask
        else:
            input_arc_tag_logits = input_arc_tag_logits * (
                1 -
                learn_rate) + learn_rate * grad_to_arc_tag_probs * graph_mask
            input_sense_logits = input_sense_logits * (
                1 - learn_rate) + learn_rate * grad_to_sense_probs * sense_mask

        arc_tag_logits = input_arc_tag_logits
        sense_logits = input_sense_logits
        if self.denoise and self.training:
            arc_tag_logits = arc_tag_logits + self.gumbel_t * (_sample_gumbel(
                arc_tag_logits.size(), out=arc_tag_logits.new()) - soft_tags)
            sense_logits = sense_logits + self.gumbel_t * (_sample_gumbel(
                sense_logits.size(), out=sense_logits.new()) - soft_index)

        arc_tag_probs_soft = torch.nn.functional.softmax(arc_tag_logits,
                                                         dim=-1)

        sense_probs_soft = torch.nn.functional.softmax(sense_logits, dim=-1)

        arc_tag_probs = hard(
            arc_tag_probs_soft,
            graph_mask) if self.stright_through else arc_tag_probs_soft
        sense_probs = hard(
            sense_probs_soft,
            sense_mask) if self.stright_through else sense_probs_soft

        return input_arc_tag_logits, arc_tag_probs, input_sense_logits, sense_probs, score_nodes, score_edges
    def forward(
        self,
        predicate_representation,
        extra_representation,
        input_arc_tag_logits,
        input_sense_logits,
        embedded_candidate_preds,
        graph_mask,
        sense_mask,
        predicate_mask,
        soft_tags=None,
        soft_index=None,
        arc_tag_probs_soft=None,
        sense_probs_soft=None,
    ):  # pylint: disable=arguments-differ
        '''

        :param predict_representation:
        :param argument_representation:
        :param input_arc_tag_logits:
        :param arc_tag_probs:
        :param mask:
            A mask of shape (batch_size, sequence_length), denoting unpadded
            elements in the sequence.
        :param arc_tags:
        :param pre_dropout_mask:
        :param embedded_candidate_preds: (batch_size, sequence_length, max_senses, node_dim)
        :return:
        '''
        # shape (batch_size, predicates_len, batch_max_senses )  sense_mask

        # scores = scores + (arc_tag_probs * input_arc_tag_logits).sum(-1, keepdim=True)/scores.size(-1)*graph_mask

        if self.corrupt_input and self.training:
            arc_tag_logits = input_arc_tag_logits + self.gumbel_t * (
                _sample_gumbel(input_arc_tag_logits.size(),
                               out=input_arc_tag_logits.new()) - soft_tags)
            sense_logits = input_sense_logits + self.gumbel_t * (
                _sample_gumbel(input_sense_logits.size(),
                               out=input_sense_logits.new()) - soft_index)

            arc_tag_probs_soft = torch.nn.functional.softmax(arc_tag_logits,
                                                             dim=-1)

            sense_probs_soft = torch.nn.functional.softmax(sense_logits,
                                                           dim=-1)

            arc_tag_probs = hard(
                arc_tag_probs_soft,
                graph_mask) if self.stright_through else arc_tag_probs_soft
            sense_probs = hard(
                sense_probs_soft,
                sense_mask) if self.stright_through else sense_probs_soft

        else:
            if arc_tag_probs_soft is None or sense_probs_soft is None:

                arc_tag_probs_soft = torch.nn.functional.softmax(
                    input_arc_tag_logits, dim=-1)

                sense_probs_soft = torch.nn.functional.softmax(
                    input_sense_logits, dim=-1)

            arc_tag_probs = hard(
                arc_tag_probs_soft,
                graph_mask) if self.stright_through else arc_tag_probs_soft
            sense_probs = hard(
                sense_probs_soft,
                sense_mask) if self.stright_through else sense_probs_soft

        arc_tag_logits_list = [input_arc_tag_logits]
        arc_tag_probs_list = [arc_tag_probs]
        sense_logits_list = [input_sense_logits]
        sense_probs_list = [sense_probs]
        scores_list = []

        old_arc_tag_logits = input_arc_tag_logits
        old_sense_logits = input_sense_logits
        old_scores = None

        if self.detach_type == "all":
            arc_tag_probs = arc_tag_probs.detach()
            sense_probs = sense_probs.detach()
            input_arc_tag_logits = input_arc_tag_logits.detach()
            input_sense_logits = input_sense_logits.detach()
        elif self.detach_type == "probs":
            arc_tag_probs = arc_tag_probs.detach()
            sense_probs = sense_probs.detach()
        elif self.detach_type == "logits":
            input_arc_tag_logits = input_arc_tag_logits.detach()
            input_sense_logits = input_sense_logits.detach()
        else:
            assert self.detach_type == "no", (
                "detach_type is set as " + self.detach_type +
                " need to be one of no, all, probs, logits")

        if self.dropout_local:
            input_sense_logits = self._dropout(input_sense_logits)
            input_arc_tag_logits = self._dropout(input_arc_tag_logits)

        iterations = self.iterations if self.training else self.testing_iterations
        #      arg_intermediates_list = []
        #  predicate_representation = predicate_representation.detach()
        for i in range(iterations):

            arc_tag_logits, arc_tag_probs, sense_logits, sense_probs, scores, grad_to_arc_tag_probs = self.one_iteration(
                predicate_representation, input_sense_logits,
                extra_representation, input_arc_tag_logits,
                embedded_candidate_preds, graph_mask, sense_mask,
                predicate_mask, soft_tags, soft_index, arc_tag_probs,
                sense_probs, old_arc_tag_logits, old_sense_logits, old_scores)

            old_arc_tag_logits = arc_tag_logits
            old_sense_logits = sense_logits
            old_scores = scores

            arc_tag_logits_list.append(arc_tag_logits)
            sense_logits_list.append(sense_logits)
            arc_tag_probs_list.append(arc_tag_probs)
            sense_probs_list.append(sense_probs)
            scores_list.append(scores)

        scores_list.append(None)

        c_arc_tag_logits_list = []
        c_arc_tag_probs_list = []
        c_sense_logits_list = []
        c_sense_probs_list = []
        c_scores_list = []

        if soft_tags is not None:
            gold_arc_tag_logits, gold_arc_tag_probs, gold_sense_logits, gold_sense_probs, gold_scores, grad_to_arc_tag_probs = self.gold_feed(
                predicate_representation, input_sense_logits,
                extra_representation, input_arc_tag_logits,
                embedded_candidate_preds, graph_mask, sense_mask,
                predicate_mask, soft_tags, soft_index)

            gold_results = [
                gold_scores, gold_sense_logits, gold_sense_probs,
                gold_arc_tag_logits, gold_arc_tag_probs
            ]
        else:
            gold_results = [None, None, None, None, None]

        iterations = self.corruption_iterations if self.training else 1
        if self.corruption_rate and soft_tags is not None and self.training:
            for i in range(iterations):
                c_soft_tags = self.corrupt_one_hot(soft_tags, graph_mask)
                c_soft_index = self.corrupt_index(soft_index, sense_mask)
                gold_arc_tag_logits, gold_arc_tag_probs, gold_sense_logits, gold_sense_probs, gold_scores, grad_to_arc_tag_probs = self.gold_feed(
                    predicate_representation, input_sense_logits,
                    extra_representation, input_arc_tag_logits,
                    embedded_candidate_preds, graph_mask, sense_mask,
                    predicate_mask, c_soft_tags, c_soft_index)

                c_arc_tag_logits_list.append(gold_arc_tag_logits)
                c_arc_tag_probs_list.append(gold_arc_tag_probs)
                c_sense_logits_list.append(gold_sense_logits)
                c_sense_probs_list.append(gold_sense_probs)
                c_scores_list.append(gold_scores)

        return (arc_tag_logits_list, arc_tag_probs_list, sense_logits_list, sense_probs_list, scores_list ),\
               (c_arc_tag_logits_list, c_arc_tag_probs_list, c_sense_logits_list, c_sense_probs_list, c_scores_list ), gold_results
    def one_iteration(self,
                      predicate_representation,
                      input_sense_logits,
                      extra_representation,
                      input_arc_tag_logits,
                      embedded_candidate_preds,
                      graph_mask,
                      sense_mask,
                      predicate_mask,
                      soft_tags,
                      soft_index,
                      arc_tag_probs,
                      sense_probs,
                      old_arc_tag_logits,
                      old_sense_logits,
                      old_scores=None):

        learn_rate = self.learn_rate.clamp(min=self.minimum_rate)
        predicate_emb = (
            sense_probs.unsqueeze(2).matmul(embedded_candidate_preds)
        ).squeeze(2) * predicate_mask.unsqueeze(-1) + predicate_representation

        scores, grad_to_predicate_emb, grad_to_arc_tag_probs = self.graph_scorer(
            predicate_emb, extra_representation, arc_tag_probs * graph_mask,
            graph_mask)

        grad_to_sense_probs = embedded_candidate_preds.matmul(
            grad_to_predicate_emb.unsqueeze(-1)).squeeze(-1)

        grad_to_arc_tag_probs = grad_to_arc_tag_probs * arc_tag_probs
        grad_to_arc_tag_probs = grad_to_arc_tag_probs - grad_to_arc_tag_probs.sum(
            -1, keepdim=True) * arc_tag_probs

        grad_to_sense_probs = grad_to_sense_probs * sense_probs
        grad_to_sense_probs = grad_to_sense_probs - grad_to_sense_probs.sum(
            -1, keepdim=True) * sense_probs

        if self.gating and old_scores is not None:
            if self.global_gating:
                delta = (scores - old_scores).sum(3, keepdim=True).sum(
                    2, keepdim=True)
            else:
                delta = (scores - old_scores).sum(-1, keepdim=True)
            update_mask = (delta > 0).float()  #.sigmoid()
            grad_to_arc_tag_probs = grad_to_arc_tag_probs * update_mask

            sense_update_mask = (delta.sum(1) > 0).float()  #.sigmoid()
            grad_to_sense_probs = sense_update_mask * grad_to_sense_probs

        input_arc_tag_logits = old_arc_tag_logits + learn_rate * grad_to_arc_tag_probs * graph_mask
        input_sense_logits = old_sense_logits + learn_rate * grad_to_sense_probs * sense_mask

        arc_tag_logits = input_arc_tag_logits
        sense_logits = input_sense_logits
        if self.denoise and self.training:
            arc_tag_logits = arc_tag_logits + self.gumbel_t * (_sample_gumbel(
                arc_tag_logits.size(), out=arc_tag_logits.new()) - soft_tags)
            sense_logits = sense_logits + self.gumbel_t * (_sample_gumbel(
                sense_logits.size(), out=sense_logits.new()) - soft_index)

        arc_tag_probs_soft = torch.nn.functional.softmax(arc_tag_logits,
                                                         dim=-1)

        sense_probs_soft = torch.nn.functional.softmax(sense_logits, dim=-1)

        arc_tag_probs = hard(
            arc_tag_probs_soft,
            graph_mask) if self.stright_through else arc_tag_probs_soft
        sense_probs = hard(
            sense_probs_soft,
            sense_mask) if self.stright_through else sense_probs_soft

        return input_arc_tag_logits, arc_tag_probs, input_sense_logits, sense_probs, scores, grad_to_arc_tag_probs