def _get_extractive_summary(self, reference_input_ids, reference_sentence_indicator, gumbel_output): tokenizer = self.tokenizers[reference_input_ids.device.index] ref_max = reference_sentence_indicator.max() if ref_max >= gumbel_output.size(1): pad = torch.zeros(reference_sentence_indicator.size(0), ref_max + 1 - gumbel_output.size(1)).cuda() gumbel_output = torch.cat((gumbel_output, pad), -1) attention_mask = utils.convert_attention_mask(reference_sentence_indicator, gumbel_output).long().detach() extractive_summary_ids = reference_input_ids*attention_mask + (1-attention_mask)*tokenizer.pad_token_id extractive_summary = tokenizer.batch_decode(extractive_summary_ids, skip_special_tokens=True) # print('CLEAN:', extractive_summary) with tokenizer.as_target_tokenizer(): labels = tokenizer(extractive_summary, max_length=200, padding="max_length", truncation=True) labels["input_ids"] = [ [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] ] return torch.tensor(labels['input_ids']).cuda(), torch.tensor(labels['attention_mask']).cuda()
def forward( self, input_ids=None, # real_input_ids=None, attention_mask=None, sentence_indicator=None, sentence_labels=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) if encoder_outputs is None: # Convert encoder inputs in embeddings if needed encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) hidden_states = encoder_outputs[0] hidden_states_non_pad = attention_mask.unsqueeze(-1) * hidden_states # extract salient sentences if self.config.sequential_extraction: gumbel_output, all_sentence_logits = self.selection_loop( hidden_states, sentence_indicator, sentence_labels) else: gumbel_output, sentence_logits = self.single_extraction( hidden_states, sentence_indicator, sentence_labels) new_attention_mask = utils.convert_attention_mask( sentence_indicator, gumbel_output) masked_hidden_states = new_attention_mask.unsqueeze(-1) * hidden_states if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) # If decoding with past key value states, only the last tokens # should be given as an input if past_key_values is not None: assert labels is None, "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) hidden_states = hidden_states.to(self.decoder.first_device) if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids.to( self.decoder.first_device) if attention_mask is not None: attention_mask = attention_mask.to(self.decoder.first_device) if decoder_attention_mask is not None: decoder_attention_mask = decoder_attention_mask.to( self.decoder.first_device) if self.training: attention_mask = self.attention_dropout(attention_mask) hidden_states = hidden_states * attention_mask.unsqueeze(-1) extracted_sentence_encoding = self.encoder( input_ids=input_ids * new_attention_mask.long(), attention_mask=new_attention_mask) # if not self.training: # if real_input_ids is None: # real_input_ids = input_ids # extracted_sentence_encoding = self.encoder(input_ids=real_input_ids*new_attention_mask.long(), attention_mask=new_attention_mask) # hidden_states = extracted_sentence_encoding[0] # attention_mask = new_attention_mask # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=hidden_states if self.training else masked_hidden_states, encoder_attention_mask=attention_mask if self.training else new_attention_mask, head_mask=decoder_head_mask, encoder_head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = decoder_outputs[0] # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.encoder.first_device) self.lm_head = self.lm_head.to(self.encoder.first_device) self.sentence_classifier = self.sentence_classifier.to( self.encoder.first_device) sequence_output = sequence_output.to(self.lm_head.weight.device) if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) lm_logits = self.lm_head(sequence_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) sim_loss_fct = nn.CosineSimilarity() pooled_hidden_states = hidden_states_non_pad.mean(1) #detach()? pooled_encoded_summary = masked_hidden_states.mean( 1) if not self.training else ( extracted_sentence_encoding[0] * (new_attention_mask.unsqueeze(-1))).mean(1) loss -= 2 * (sim_loss_fct(pooled_hidden_states, pooled_encoded_summary)).mean() # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs return ((loss, ) + output) if loss is not None else output return ExtractorAbstractorOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, extracted_attentions=new_attention_mask, gumbel_output=None if self.training else gumbel_output)
def forward( self, input_ids=None, attention_mask=None, sentence_indicator=None, sentence_labels=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) if encoder_outputs is None: # Convert encoder inputs in embeddings if needed encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) hidden_states = encoder_outputs[0] # extract salient sentences if self.config.sequential_extraction: gumbel_output, all_sentence_logits = self.selection_loop( hidden_states, sentence_indicator, sentence_labels) else: gumbel_output, sentence_logits = self.single_extraction( hidden_states, sentence_indicator, sentence_labels) new_attention_mask = utils.convert_attention_mask( sentence_indicator, gumbel_output) masked_hidden_states = new_attention_mask.unsqueeze(-1) * hidden_states if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) if self.training: reconstruction_decoder_input_ids = self._shift_right(input_ids) # If decoding with past key value states, only the last tokens # should be given as an input if past_key_values is not None: assert labels is None, "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) hidden_states = hidden_states.to(self.decoder.first_device) if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids.to( self.decoder.first_device) if attention_mask is not None: attention_mask = attention_mask.to(self.decoder.first_device) if decoder_attention_mask is not None: decoder_attention_mask = decoder_attention_mask.to( self.decoder.first_device) # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=masked_hidden_states, encoder_attention_mask=new_attention_mask, head_mask=decoder_head_mask, encoder_head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.training: summary = self.greedy_decode(input_ids, masked_hidden_states, new_attention_mask) encoded_summary = self.get_encoder()( summary, attention_mask=(summary != 0).long()) reconstruction_decoder_output = self.decoder( input_ids=reconstruction_decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, encoder_head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = reconstruction_decoder_output[ 0] if self.training else decoder_outputs[0] # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.encoder.first_device) self.lm_head = self.lm_head.to(self.encoder.first_device) self.sentence_classifier = self.sentence_classifier.to( self.encoder.first_device) sequence_output = sequence_output.to(self.lm_head.weight.device) if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) lm_logits = self.lm_head(sequence_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-100) labels = input_ids * attention_mask + (-100) * (1 - attention_mask) if self.training: loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) sim_loss_fct = nn.CosineSimilarity() pooled_hidden_states = hidden_states.mean( 1) if self.config.mean_pool_similarity else torch.max( hidden_states, 1)[0] pooled_encoded_summary = encoded_summary[0].mean( 1) if self.config.mean_pool_similarity else torch.max( encoded_summary[0], 1)[0] # pooled_encoded_summary = masked_hidden_states.mean(1) loss -= (sim_loss_fct(pooled_hidden_states, pooled_encoded_summary)).mean() else: loss = torch.tensor(0.).cuda() # sentence_loss_fct = nn.BCEWithLogitsLoss() # loss = 0 # if self.config.sequential_extraction: # sentence_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) # for i, logits in enumerate(all_sentence_logits): # loss += sentence_loss_fct(logits, sentence_labels[:, i]) # else: # sentence_label_one_hot = utils.convert_one_hot(sentence_labels, sentence_logits.size(1)).float().detach() # loss += 2 * -torch.mean(torch.sum( # sentence_label_one_hot * torch.log_softmax(sentence_logits.squeeze(-1), dim=-1), # dim=-1)) # loss += 2*sentence_loss_fct(sentence_logits.squeeze(-1)[sentence_mask], sentence_label_one_hot[sentence_mask]) # loss += 2*loss_fct(sentence_logits.view(-1, sentence_logits.size(-1)), sentence_label_one_hot.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs return ((loss, ) + output) if loss is not None else output return ExtractorAbstractorOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, extracted_attentions=new_attention_mask, gumbel_output=None if self.training else gumbel_output)
def forward( self, input_ids=None, attention_mask=None, sentence_indicator=None, pmi_features=None, sentence_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, inputs_embeds=None, head_mask=None, encoder_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, encoder_head_mask=encoder_head_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] if self.config.sequential_extraction: gumbel_output, all_sentence_logits = self.selection_loop( hidden_states, sentence_indicator, sentence_labels, pmi_features=pmi_features) else: gumbel_output, sentence_logits = self.single_extraction( hidden_states, sentence_indicator, sentence_labels) new_attention_mask = utils.convert_attention_mask( sentence_indicator, gumbel_output).long() new_input_ids = input_ids * new_attention_mask + self.config.pad_token_id * ( 1 - new_attention_mask) new_hidden_states = self.encoder(new_input_ids, attention_mask=new_attention_mask)[0] masked_hidden_states = new_attention_mask.unsqueeze(-1) * hidden_states return ExtractorModelOutput( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, input_ids=input_ids, new_attention_mask=new_attention_mask, masked_hidden_states=masked_hidden_states, new_hidden_states=new_hidden_states, gumbel_output=gumbel_output)
def forward( self, input_ids=None, reference_input_ids=None, shuffled_input_ids=None, pmi_features=None, attention_mask=None, sentence_indicator=None, reference_sentence_indicator=None, shuffled_sentence_indicator=None, sentence_labels=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) if encoder_outputs is None: # Convert encoder inputs in embeddings if needed encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, ExtractorModelOutput): encoder_outputs = ExtractorModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) if self.training or not isinstance(encoder_outputs, ExtractorModelOutput): hidden_states = encoder_outputs[0] #hidden_states_non_pad = attention_mask.unsqueeze(-1)*hidden_states tokenizer = self.tokenizers[hidden_states.device.index] # extract salient sentences if self.config.sequential_extraction: gumbel_output, gumbel_output1, all_sentence_logits = self.selection_loop(hidden_states, sentence_indicator, sentence_labels, pmi_features) else: gumbel_output, sentence_logits = self.single_extraction(hidden_states, sentence_indicator, sentence_labels) new_attention_mask = utils.convert_attention_mask(shuffled_sentence_indicator if self.training else sentence_indicator, gumbel_output) original_selected_attention_mask = utils.convert_attention_mask(sentence_indicator, gumbel_output) # masked_hidden_states = new_attention_mask.unsqueeze(-1) * detached_hidden_states masked_hidden_states = original_selected_attention_mask.unsqueeze(-1) * hidden_states non_masked_hidden_states = (1-original_selected_attention_mask).unsqueeze(-1)*hidden_states selected_input_ids = input_ids * original_selected_attention_mask + (1-original_selected_attention_mask)*tokenizer.pad_token_id # encoded_hidden_states = self.encoder(selected_input_ids.long(), attention_mask=original_selected_attention_mask.long())[0] new_attention_mask = new_attention_mask.long() new_input_ids = (shuffled_input_ids if self.training else input_ids) * new_attention_mask + tokenizer.pad_token_id * (1 - new_attention_mask) # print("Shuffled", tokenizer.batch_decode(new_input_ids, skip_special_tokens=True)) new_hidden_states = self.encoder(new_input_ids, attention_mask=new_attention_mask)[0] else: new_attention_mask = encoder_outputs.new_attention_mask new_hidden_states = encoder_outputs.new_hidden_states masked_hidden_states = encoder_outputs.masked_hidden_states gumbel_output = encoder_outputs.gumbel_output if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) if self.training: labels, label_attention_mask = self._get_extractive_summary(reference_input_ids, reference_sentence_indicator, gumbel_output1) if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) # If decoding with past key value states, only the last tokens # should be given as an input if past_key_values is not None: assert labels is None, "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) hidden_states = hidden_states.to(self.decoder.first_device) if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) if attention_mask is not None: attention_mask = attention_mask.to(self.decoder.first_device) if decoder_attention_mask is not None: decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=new_hidden_states,#masked_hidden_states, encoder_attention_mask=new_attention_mask, head_mask=decoder_head_mask, encoder_head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = decoder_outputs[0] # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.encoder.first_device) self.lm_head = self.lm_head.to(self.encoder.first_device) self.sentence_classifier = self.sentence_classifier.to(self.encoder.first_device) sequence_output = sequence_output.to(self.lm_head.weight.device) if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim ** -0.5) lm_logits = self.lm_head(sequence_output) loss = None if labels is not None: # if self.training: # gumbel = F.gumbel_softmax(lm_logits, hard=True, dim=-1) # indices = torch.arange(gumbel.size(-1)).view(1, 1, -1).expand(gumbel.size(0), gumbel.size(1), -1).cuda() # summary = (gumbel*indices).long().sum(-1) # # encoded_summary = self.get_encoder()(summary, attention_mask=label_attention_mask) loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) sim_loss_fct = nn.CosineSimilarity() pooled_hidden_states = hidden_states.mean(1) #detach()? pooled_encoded_summary = masked_hidden_states.mean(1) # pooled_encoded_summary = encoded_hidden_states.mean(1) pooled_non_masked_hidden_states = non_masked_hidden_states.mean(1) # pooled_encoded_summary = new_hidden_states.mean(1) #pooled_encoded_summary = encoded_summary[0].mean(1) if self.training else masked_hidden_states.mean(1) if self.config.use_max_margin_sim_loss: sim_loss = self.config.max_margin - sim_loss_fct(pooled_hidden_states, pooled_encoded_summary) + sim_loss_fct(pooled_hidden_states, pooled_non_masked_hidden_states) loss += sim_loss.mean() else: loss -= (sim_loss_fct(pooled_hidden_states, pooled_encoded_summary)).mean() if not return_dict: output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs return ((loss,) + output) if loss is not None else output return ExtractorAbstractorOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, extracted_attentions=new_attention_mask, gumbel_output=None if self.training else gumbel_output )