def _get_extractive_summary(self, reference_input_ids, reference_sentence_indicator, gumbel_output):
        tokenizer = self.tokenizers[reference_input_ids.device.index]
        ref_max = reference_sentence_indicator.max()
        if ref_max >= gumbel_output.size(1):
            pad = torch.zeros(reference_sentence_indicator.size(0), ref_max + 1 - gumbel_output.size(1)).cuda()
            gumbel_output = torch.cat((gumbel_output, pad), -1)

        attention_mask = utils.convert_attention_mask(reference_sentence_indicator, gumbel_output).long().detach()
        extractive_summary_ids = reference_input_ids*attention_mask + (1-attention_mask)*tokenizer.pad_token_id

        extractive_summary = tokenizer.batch_decode(extractive_summary_ids, skip_special_tokens=True)
#        print('CLEAN:', extractive_summary)

        with tokenizer.as_target_tokenizer():
            labels = tokenizer(extractive_summary, max_length=200, padding="max_length", truncation=True)
        labels["input_ids"] = [
            [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
        ]
        return torch.tensor(labels['input_ids']).cuda(), torch.tensor(labels['attention_mask']).cuda()
Ejemplo n.º 2
0
    def forward(
        self,
        input_ids=None,
        #           real_input_ids=None,
        attention_mask=None,
        sentence_indicator=None,
        sentence_labels=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1]
                if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2]
                if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]
        hidden_states_non_pad = attention_mask.unsqueeze(-1) * hidden_states

        # extract salient sentences
        if self.config.sequential_extraction:
            gumbel_output, all_sentence_logits = self.selection_loop(
                hidden_states, sentence_indicator, sentence_labels)
        else:
            gumbel_output, sentence_logits = self.single_extraction(
                hidden_states, sentence_indicator, sentence_labels)

        new_attention_mask = utils.convert_attention_mask(
            sentence_indicator, gumbel_output)
        masked_hidden_states = new_attention_mask.unsqueeze(-1) * hidden_states

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(
                    self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(
                    self.decoder.first_device)

        if self.training:
            attention_mask = self.attention_dropout(attention_mask)
            hidden_states = hidden_states * attention_mask.unsqueeze(-1)
            extracted_sentence_encoding = self.encoder(
                input_ids=input_ids * new_attention_mask.long(),
                attention_mask=new_attention_mask)

#        if not self.training:
#            if real_input_ids is None:
#                real_input_ids = input_ids
#            extracted_sentence_encoding = self.encoder(input_ids=real_input_ids*new_attention_mask.long(), attention_mask=new_attention_mask)
#            hidden_states = extracted_sentence_encoding[0]
#            attention_mask = new_attention_mask

# Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states
            if self.training else masked_hidden_states,
            encoder_attention_mask=attention_mask
            if self.training else new_attention_mask,
            head_mask=decoder_head_mask,
            encoder_head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            self.sentence_classifier = self.sentence_classifier.to(
                self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)),
                            labels.view(-1))
            sim_loss_fct = nn.CosineSimilarity()
            pooled_hidden_states = hidden_states_non_pad.mean(1)  #detach()?
            pooled_encoded_summary = masked_hidden_states.mean(
                1) if not self.training else (
                    extracted_sentence_encoding[0] *
                    (new_attention_mask.unsqueeze(-1))).mean(1)

            loss -= 2 * (sim_loss_fct(pooled_hidden_states,
                                      pooled_encoded_summary)).mean()

            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs
            return ((loss, ) + output) if loss is not None else output

        return ExtractorAbstractorOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
            extracted_attentions=new_attention_mask,
            gumbel_output=None if self.training else gumbel_output)
Ejemplo n.º 3
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        sentence_indicator=None,
        sentence_labels=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1]
                if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2]
                if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]

        # extract salient sentences
        if self.config.sequential_extraction:
            gumbel_output, all_sentence_logits = self.selection_loop(
                hidden_states, sentence_indicator, sentence_labels)
        else:
            gumbel_output, sentence_logits = self.single_extraction(
                hidden_states, sentence_indicator, sentence_labels)

        new_attention_mask = utils.convert_attention_mask(
            sentence_indicator, gumbel_output)
        masked_hidden_states = new_attention_mask.unsqueeze(-1) * hidden_states

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        if self.training:
            reconstruction_decoder_input_ids = self._shift_right(input_ids)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(
                    self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(
                    self.decoder.first_device)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=masked_hidden_states,
            encoder_attention_mask=new_attention_mask,
            head_mask=decoder_head_mask,
            encoder_head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if self.training:
            summary = self.greedy_decode(input_ids, masked_hidden_states,
                                         new_attention_mask)
            encoded_summary = self.get_encoder()(
                summary, attention_mask=(summary != 0).long())

            reconstruction_decoder_output = self.decoder(
                input_ids=reconstruction_decoder_input_ids,
                attention_mask=decoder_attention_mask,
                inputs_embeds=decoder_inputs_embeds,
                past_key_values=past_key_values,
                encoder_hidden_states=hidden_states,
                encoder_attention_mask=attention_mask,
                head_mask=decoder_head_mask,
                encoder_head_mask=head_mask,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

        sequence_output = reconstruction_decoder_output[
            0] if self.training else decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            self.sentence_classifier = self.sentence_classifier.to(
                self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            labels = input_ids * attention_mask + (-100) * (1 - attention_mask)

            if self.training:
                loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)),
                                labels.view(-1))

                sim_loss_fct = nn.CosineSimilarity()

                pooled_hidden_states = hidden_states.mean(
                    1) if self.config.mean_pool_similarity else torch.max(
                        hidden_states, 1)[0]
                pooled_encoded_summary = encoded_summary[0].mean(
                    1) if self.config.mean_pool_similarity else torch.max(
                        encoded_summary[0], 1)[0]

                #                pooled_encoded_summary = masked_hidden_states.mean(1)
                loss -= (sim_loss_fct(pooled_hidden_states,
                                      pooled_encoded_summary)).mean()
            else:
                loss = torch.tensor(0.).cuda()

#            sentence_loss_fct = nn.BCEWithLogitsLoss()
#            loss = 0

#            if self.config.sequential_extraction:
#                sentence_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
#                for i, logits in enumerate(all_sentence_logits):
#                    loss += sentence_loss_fct(logits, sentence_labels[:, i])
#            else:
#                sentence_label_one_hot = utils.convert_one_hot(sentence_labels, sentence_logits.size(1)).float().detach()
#                loss += 2 * -torch.mean(torch.sum(
#                    sentence_label_one_hot * torch.log_softmax(sentence_logits.squeeze(-1), dim=-1),
#                    dim=-1))
#               loss += 2*sentence_loss_fct(sentence_logits.squeeze(-1)[sentence_mask], sentence_label_one_hot[sentence_mask])
#               loss += 2*loss_fct(sentence_logits.view(-1, sentence_logits.size(-1)), sentence_label_one_hot.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs
            return ((loss, ) + output) if loss is not None else output

        return ExtractorAbstractorOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
            extracted_attentions=new_attention_mask,
            gumbel_output=None if self.training else gumbel_output)
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        sentence_indicator=None,
        pmi_features=None,
        sentence_labels=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        inputs_embeds=None,
        head_mask=None,
        encoder_head_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            encoder_head_mask=encoder_head_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]

        if self.config.sequential_extraction:
            gumbel_output, all_sentence_logits = self.selection_loop(
                hidden_states,
                sentence_indicator,
                sentence_labels,
                pmi_features=pmi_features)
        else:
            gumbel_output, sentence_logits = self.single_extraction(
                hidden_states, sentence_indicator, sentence_labels)

        new_attention_mask = utils.convert_attention_mask(
            sentence_indicator, gumbel_output).long()
        new_input_ids = input_ids * new_attention_mask + self.config.pad_token_id * (
            1 - new_attention_mask)

        new_hidden_states = self.encoder(new_input_ids,
                                         attention_mask=new_attention_mask)[0]
        masked_hidden_states = new_attention_mask.unsqueeze(-1) * hidden_states

        return ExtractorModelOutput(
            last_hidden_state=outputs.last_hidden_state,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
            input_ids=input_ids,
            new_attention_mask=new_attention_mask,
            masked_hidden_states=masked_hidden_states,
            new_hidden_states=new_hidden_states,
            gumbel_output=gumbel_output)
    def forward(
            self,
            input_ids=None,
            reference_input_ids=None,
            shuffled_input_ids=None,
            pmi_features=None,
            attention_mask=None,
            sentence_indicator=None,
            reference_sentence_indicator=None,
            shuffled_sentence_indicator=None,
            sentence_labels=None,
            decoder_input_ids=None,
            decoder_attention_mask=None,
            head_mask=None,
            decoder_head_mask=None,
            encoder_outputs=None,
            past_key_values=None,
            inputs_embeds=None,
            decoder_inputs_embeds=None,
            labels=None,
            use_cache=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):

        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, ExtractorModelOutput):
            encoder_outputs = ExtractorModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        if self.training or not isinstance(encoder_outputs, ExtractorModelOutput):
            hidden_states = encoder_outputs[0]
            #hidden_states_non_pad = attention_mask.unsqueeze(-1)*hidden_states
            tokenizer = self.tokenizers[hidden_states.device.index]

            # extract salient sentences
            if self.config.sequential_extraction:
                gumbel_output, gumbel_output1, all_sentence_logits = self.selection_loop(hidden_states, sentence_indicator, sentence_labels, pmi_features)
            else:
                gumbel_output, sentence_logits = self.single_extraction(hidden_states, sentence_indicator, sentence_labels)

            new_attention_mask = utils.convert_attention_mask(shuffled_sentence_indicator if self.training else sentence_indicator, gumbel_output)
            original_selected_attention_mask = utils.convert_attention_mask(sentence_indicator, gumbel_output)
#            masked_hidden_states = new_attention_mask.unsqueeze(-1) * detached_hidden_states
            masked_hidden_states = original_selected_attention_mask.unsqueeze(-1) * hidden_states
            non_masked_hidden_states = (1-original_selected_attention_mask).unsqueeze(-1)*hidden_states

            
            selected_input_ids = input_ids * original_selected_attention_mask + (1-original_selected_attention_mask)*tokenizer.pad_token_id

#            encoded_hidden_states = self.encoder(selected_input_ids.long(), attention_mask=original_selected_attention_mask.long())[0]
            new_attention_mask = new_attention_mask.long()
            new_input_ids = (shuffled_input_ids if self.training else input_ids) * new_attention_mask + tokenizer.pad_token_id * (1 - new_attention_mask)
            
#            print("Shuffled", tokenizer.batch_decode(new_input_ids, skip_special_tokens=True))
            new_hidden_states = self.encoder(new_input_ids, attention_mask=new_attention_mask)[0]
        else:
            new_attention_mask = encoder_outputs.new_attention_mask
            new_hidden_states = encoder_outputs.new_hidden_states
            masked_hidden_states = encoder_outputs.masked_hidden_states
            gumbel_output = encoder_outputs.gumbel_output

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if self.training:
            labels, label_attention_mask = self._get_extractive_summary(reference_input_ids, reference_sentence_indicator, gumbel_output1)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=new_hidden_states,#masked_hidden_states,
            encoder_attention_mask=new_attention_mask,
            head_mask=decoder_head_mask,
            encoder_head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            self.sentence_classifier = self.sentence_classifier.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim ** -0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            # if self.training:
            #     gumbel = F.gumbel_softmax(lm_logits, hard=True, dim=-1)
            #     indices = torch.arange(gumbel.size(-1)).view(1, 1, -1).expand(gumbel.size(0), gumbel.size(1), -1).cuda()
            #     summary = (gumbel*indices).long().sum(-1)
            #
            #     encoded_summary = self.get_encoder()(summary, attention_mask=label_attention_mask)
                        
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            sim_loss_fct = nn.CosineSimilarity()
            pooled_hidden_states = hidden_states.mean(1) #detach()?
            pooled_encoded_summary = masked_hidden_states.mean(1)
#            pooled_encoded_summary = encoded_hidden_states.mean(1)
            pooled_non_masked_hidden_states = non_masked_hidden_states.mean(1)
#            pooled_encoded_summary = new_hidden_states.mean(1)
            #pooled_encoded_summary = encoded_summary[0].mean(1) if self.training else masked_hidden_states.mean(1)
            if self.config.use_max_margin_sim_loss:
                sim_loss = self.config.max_margin - sim_loss_fct(pooled_hidden_states, pooled_encoded_summary) + sim_loss_fct(pooled_hidden_states, pooled_non_masked_hidden_states)
                loss += sim_loss.mean()
            else:
                loss -= (sim_loss_fct(pooled_hidden_states, pooled_encoded_summary)).mean()

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return ExtractorAbstractorOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
            extracted_attentions=new_attention_mask,
            gumbel_output=None if self.training else gumbel_output
        )