def single_extraction(self, hidden_states, sentence_indicator, sentence_labels): # extract salient sentences sentences = [] for i in range(sentence_indicator.max() + 1): mask = (sentence_indicator == i).long().cuda() sentences.append( torch.sum(hidden_states * mask.unsqueeze(-1), dim=1) / (mask.sum(dim=1).view(-1, 1) + 1e-12)) sentences = torch.stack(sentences, dim=1) sentence_logits = self.sentence_classifier(sentences) sentence_logits = utils.mask_sentences(sentence_logits, sentence_indicator) if self.training: if self.config.teacher_forcing: gumbel_output = utils.convert_one_hot(sentence_labels, sentence_logits.size(1)) else: gumbel_output = utils.gumbel_softmax_topk( sentence_logits.squeeze(-1), self.config.extraction_k) else: gumbel_output = torch.topk(sentence_logits.squeeze(-1), self.config.extraction_k, dim=-1)[1] gumbel_output = utils.convert_one_hot(gumbel_output, sentence_logits.size(1)) return gumbel_output, sentence_logits
def forward( self, input_ids=None, attention_mask=None, sentence_indicator=None, sentence_labels=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) if encoder_outputs is None: # Convert encoder inputs in embeddings if needed encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) hidden_states = encoder_outputs[0] # extract salient sentences if self.config.sequential_extraction: gumbel_output, all_sentence_logits = self.selection_loop( hidden_states, sentence_indicator, sentence_labels) else: gumbel_output, sentence_logits = self.single_extraction( hidden_states, sentence_indicator, sentence_labels) new_attention_mask = utils.convert_attention_mask( sentence_indicator, gumbel_output) masked_hidden_states = new_attention_mask.unsqueeze(-1) * hidden_states if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) # If decoding with past key value states, only the last tokens # should be given as an input if past_key_values is not None: assert labels is None, "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) hidden_states = hidden_states.to(self.decoder.first_device) if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids.to( self.decoder.first_device) if attention_mask is not None: attention_mask = attention_mask.to(self.decoder.first_device) if decoder_attention_mask is not None: decoder_attention_mask = decoder_attention_mask.to( self.decoder.first_device) # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=masked_hidden_states, encoder_attention_mask=new_attention_mask, head_mask=decoder_head_mask, encoder_head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = decoder_outputs[0] # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.encoder.first_device) self.lm_head = self.lm_head.to(self.encoder.first_device) self.sentence_classifier = self.sentence_classifier.to( self.encoder.first_device) sequence_output = sequence_output.to(self.lm_head.weight.device) if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) lm_logits = self.lm_head(sequence_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # sentence_loss_fct = nn.BCEWithLogitsLoss() # loss = 0 if self.config.sequential_extraction: sim_loss = nn.CosineSimilarity() pooled_embedding = hidden_states.mean(1) pooled_extractive = masked_hidden_states.mean(1) loss -= sim_loss(pooled_embedding, pooled_extractive).mean() ## sentence_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) # for i, logits in enumerate(all_sentence_logits): # loss += sentence_loss_fct(logits, sentence_labels[:, i]) else: sentence_label_one_hot = utils.convert_one_hot( sentence_labels, sentence_logits.size(1)).float().detach() loss += 2 * -torch.mean( torch.sum(sentence_label_one_hot * torch.log_softmax( sentence_logits.squeeze(-1), dim=-1), dim=-1)) # loss += 2*sentence_loss_fct(sentence_logits.squeeze(-1)[sentence_mask], sentence_label_one_hot[sentence_mask]) # loss += 2*loss_fct(sentence_logits.view(-1, sentence_logits.size(-1)), sentence_label_one_hot.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs return ((loss, ) + output) if loss is not None else output return ExtractorAbstractorOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, extracted_attentions=new_attention_mask, gumbel_output=None if self.training else gumbel_output)
def load_dataset(self, dataset_filepaths, parameters, annotator): ''' dataset_filepaths : dictionary with keys 'train', 'valid', 'test' ''' start_time = time.time() print('Load dataset... ', end='', flush=True) if parameters['do_split']: dataset_filepaths = self._do_split(parameters) all_pretrained_tokens = [] if parameters['token_pretrained_embedding_filepath'] != '': all_pretrained_tokens = utils_nlp.load_tokens_from_pretrained_token_embeddings( parameters) if self.verbose: print("len(all_pretrained_tokens): {0}".format( len(all_pretrained_tokens))) # Load pretraining dataset to ensure that index to label is compatible to the pretrained model, # and that token embeddings that are learned in the pretrained model are loaded properly. all_tokens_in_pretraining_dataset = [] self.UNK_TOKEN_INDEX = 0 self.PADDING_TOKEN_INDEX = 1 self.tokens_mapped_to_unk = [] self.UNK = '_UNK_' self.PAD = '_PAD_' self.unique_labels = [] labels = {} tokens = {} token_count = {} label_count = {} self.max_tokens = -1 # Look for max length for dataset_type in ['train', 'valid', 'test']: max_tokens = self._find_max_length( dataset_filepaths.get(dataset_type, None), annotator, force_preprocessing=parameters['do_split']) if parameters['max_length_sentence'] == -1: self.max_tokens = max(self.max_tokens, max_tokens) else: if self.max_tokens == -1: self.max_tokens = max_tokens self.max_tokens = min(parameters['max_length_sentence'], self.max_tokens) for dataset_type in ['train', 'valid', 'test']: labels[dataset_type], tokens[dataset_type], token_count[ dataset_type], label_count[dataset_type] = self._parse_dataset( dataset_filepaths.get(dataset_type, None), annotator, force_preprocessing=parameters['do_split'], limit=self.max_tokens) if self.verbose: print("dataset_type: {0}".format(dataset_type)) if self.verbose: print("len(token_count[dataset_type]): {0}".format( len(token_count[dataset_type]))) token_count['all'] = {} for token in list(token_count['train'].keys()) + list( token_count['valid'].keys()) + list( token_count['test'].keys()): token_count['all'][token] = token_count['train'].get( token, 0) + token_count['valid'].get( token, 0) + token_count['test'].get(token, 0) for dataset_type in dataset_filepaths.keys(): if self.verbose: print("dataset_type: {0}".format(dataset_type)) if self.verbose: print("len(token_count[dataset_type]): {0}".format( len(token_count[dataset_type]))) label_count['all'] = {} for character in list(label_count['train'].keys()) + list( label_count['valid'].keys()) + list( label_count['test'].keys()): label_count['all'][character] = label_count['train'].get( character, 0) + label_count['valid'].get( character, 0) + label_count['test'].get(character, 0) token_count['all'] = utils.order_dictionary(token_count['all'], 'value_key', reverse=True) label_count['all'] = utils.order_dictionary(label_count['all'], 'key', reverse=False) token_to_index = {} token_to_index[self.UNK] = self.UNK_TOKEN_INDEX token_to_index[self.PAD] = self.PADDING_TOKEN_INDEX iteration_number = 0 number_of_unknown_tokens = 0 if self.verbose: print("parameters['remap_unknown_tokens_to_unk']: {0}".format( parameters['remap_unknown_tokens_to_unk'])) if self.verbose: print("len(token_count['train'].keys()): {0}".format( len(token_count['train'].keys()))) for token, count in token_count['all'].items(): if iteration_number == self.UNK_TOKEN_INDEX: iteration_number += 1 if iteration_number == self.PADDING_TOKEN_INDEX: iteration_number += 1 if parameters['remap_unknown_tokens_to_unk'] and ( token_count['train'].get(token, 0) == 0 or parameters['load_only_pretrained_token_embeddings'] ) and not utils_nlp.is_token_in_pretrained_embeddings( token, all_pretrained_tokens, parameters ) and token not in all_tokens_in_pretraining_dataset: if self.verbose: print("token: {0}".format(token)) if self.verbose: print("token.lower(): {0}".format(token.lower())) if self.verbose: print("re.sub('\d', '0', token.lower()): {0}".format( re.sub('\d', '0', token.lower()))) token_to_index[token] = self.UNK_TOKEN_INDEX number_of_unknown_tokens += 1 self.tokens_mapped_to_unk.append(token) else: token_to_index[token] = iteration_number iteration_number += 1 if self.verbose: print("number_of_unknown_tokens: {0}".format( number_of_unknown_tokens)) infrequent_token_indices = [] for token, count in token_count['train'].items(): if 0 < count <= parameters['remap_to_unk_count_threshold']: infrequent_token_indices.append(token_to_index[token]) if self.verbose: print("len(token_count['train']): {0}".format( len(token_count['train']))) if self.verbose: print("len(infrequent_token_indices): {0}".format( len(infrequent_token_indices))) label_to_index = {} iteration_number = 0 for label, count in label_count['all'].items(): label_to_index[label] = iteration_number iteration_number += 1 self.unique_labels.append(label) if self.verbose: print('self.unique_labels: {0}'.format(self.unique_labels)) if self.verbose: print('token_count[\'train\'][0:10]: {0}'.format( list(token_count['train'].items())[0:10])) token_to_index = utils.order_dictionary(token_to_index, 'value', reverse=False) if self.verbose: print('token_to_index: {0}'.format(token_to_index)) index_to_token = utils.reverse_dictionary(token_to_index) if parameters['remap_unknown_tokens_to_unk'] == 1: index_to_token[self.UNK_TOKEN_INDEX] = self.UNK index_to_token[self.PADDING_TOKEN_INDEX] = self.PAD if self.verbose: print('index_to_token: {0}'.format(index_to_token)) if self.verbose: print('label_count[\'train\']: {0}'.format(label_count['train'])) label_to_index = utils.order_dictionary(label_to_index, 'value', reverse=False) if self.verbose: print('label_to_index: {0}'.format(label_to_index)) index_to_label = utils.reverse_dictionary(label_to_index) if self.verbose: print('index_to_label: {0}'.format(index_to_label)) if self.verbose: print('labels[\'train\'][0:10]: {0}'.format(labels['train'][0:10])) if self.verbose: print('tokens[\'train\'][0:10]: {0}'.format(tokens['train'][0:10])) # Map tokens and labels to their indices token_indices = {} label_indices = {} token_lengths = {} token_indices_padded = {} for dataset_type in dataset_filepaths.keys(): token_indices[dataset_type] = [] token_lengths[dataset_type] = [] token_indices_padded[dataset_type] = [] # Tokens for token_sequence in tokens[dataset_type]: token_indices[dataset_type].append( [token_to_index[token] for token in token_sequence]) token_lengths[dataset_type].append(len(token_sequence)) # Labels label_indices[dataset_type] = [] for label in labels[dataset_type]: label_indices[dataset_type].append(label_to_index[label]) # Pad tokens for dataset_type in dataset_filepaths.keys(): token_indices_padded[dataset_type] = [] token_indices_padded[dataset_type] = [ utils.pad_list(temp_token_indices, self.max_tokens, self.PADDING_TOKEN_INDEX) for temp_token_indices in token_indices[dataset_type] ] if self.verbose: print('token_lengths[\'train\'][0:10]: {0}'.format( token_lengths['train'][0:10])) if self.verbose: print('token_indices[\'train\'][0][0:10]: {0}'.format( token_indices['train'][0][0:10])) if self.verbose: print('token_indices_padded[\'train\'][0][0:10]: {0}'.format( token_indices_padded['train'][0][0:10])) if self.verbose: print('label_indices[\'train\'][0:10]: {0}'.format( label_indices['train'][0:10])) self.token_to_index = token_to_index self.index_to_token = index_to_token self.token_indices = token_indices self.label_indices = label_indices self.token_indices_padded = token_indices_padded self.token_lengths = token_lengths self.tokens = tokens self.labels = labels self.index_to_label = index_to_label self.label_to_index = label_to_index if self.verbose: print("len(self.token_to_index): {0}".format( len(self.token_to_index))) if self.verbose: print("len(self.index_to_token): {0}".format( len(self.index_to_token))) self.number_of_classes = max(self.index_to_label.keys()) + 1 self.vocabulary_size = max(self.index_to_token.keys()) + 1 if self.verbose: print("self.number_of_classes: {0}".format(self.number_of_classes)) if self.verbose: print("self.vocabulary_size: {0}".format(self.vocabulary_size)) self.infrequent_token_indices = infrequent_token_indices # Binarize label label_vector_indices = {} for dataset_type, labels in label_indices.items(): label_vector_indices[dataset_type] = [] for label in labels: label_vector_indices[dataset_type].append( utils.convert_one_hot(label, self.number_of_classes)) self.label_vector_indices = label_vector_indices elapsed_time = time.time() - start_time print('done ({0:.2f} seconds)'.format(elapsed_time))