def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = (() if output_attentions and self.config.add_cross_attention else None) next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[ i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, ) hidden_states = layer_outputs[0] return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, )
def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False) and self.training: if use_cache: # logger.warn( # "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " # "`use_cache=False`..." # ) use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, )
def forward( self, input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) batch_size = input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size = inputs_embeds.shape[0] else: raise ValueError( "You have to specify either input_ids or inputs_embeds") if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) if position_ids is not None: position_ids = position_ids.view(-1, input_shape[-1]) if past_key_values is None: past_length = 0 past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) # Attention mask. if attention_mask is not None: assert batch_size > 0, "batch_size has to be defined and > 0" attention_mask = attention_mask.view(batch_size, -1) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. attention_mask = attention_mask[:, None, None, :] # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to( dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if encoder_hidden_states is not None: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( ) encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_attention_mask = self.invert_attention_mask( encoder_attention_mask) else: encoder_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # head_mask has shape n_layer x batch x n_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1), ) presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) # Ensure layer_past is on same device as hidden_states (might not be correct) if layer_past is not None: layer_past = tuple( past_state.to(hidden_states.device) for past_state in layer_past) # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) if isinstance(head_mask, torch.Tensor): head_mask = head_mask.to(hidden_states.device) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) if getattr(self.config, "gradient_checkpointing", False) and self.training: if use_cache: use_cache = False def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache, output_attentions) return custom_forward outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, ) else: outputs = block( hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, ) hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1], ) if output_attentions: all_self_attentions = all_self_attentions + ( outputs[2 if use_cache else 1], ) all_cross_attentions = all_cross_attentions + ( outputs[3 if use_cache else 2], ) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: for k, v in self.device_map.items(): if i == v[-1] and "cuda:" + str(k) != self.last_device: hidden_states = hidden_states.to("cuda:" + str(k + 1)) hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(*output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) if not return_dict: return tuple(v for v in [ hidden_states, presents, all_hidden_states, all_self_attentions ] if v is not None) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, )
def rerank(infile: AnyStr, outfile: AnyStr, extractive_reader_outfile: AnyStr, model_sd: dict, config: dict, device: torch.device, gt_file: Union[None, AnyStr] = None): """ :param infile: reader input (ranker/reranker outputs) :param outfile: where to save re-scored outputs :param extractive_reader_outfile: outputs from extractive reader to re-rank :param model_sd: :param config: :param device: :param gt_file: file with ground truth answers used for online validation, Optional :return: method always switches model into eval state """ if gt_file is not None: hits = 0 rr_hits = 0 gen_hits = 0 total = 0 gt_file_path = os.path.join(gt_file["directory"], gt_file["name"]) if gt_file_path.endswith(".zip"): gt_file_path = gt_file_path[:-len(".zip")] with jsonlines.open(gt_file_path, mode="r") as reader: correct_answers = dict((OpenQA_WikiPassages.get_qa_from_example(e) for e in reader)) logging.info(f"Re-ranking {len(correct_answers)} data samples") reader_tokenizer = FIDTrainer.init_tokenizer(config) generative_reader = T5FusionInDecoder.from_pretrained(config, do_not_download_weights=True) generative_reader.resize_token_embeddings(len(reader_tokenizer)) if "state_dict" in model_sd: model_sd = model_sd["state_dict"] generative_reader.load_state_dict(model_sd) generative_reader = generative_reader.float().to(device) # make sure 32bit precision model: T5FusionInDecoder = generative_reader.eval() db = PassageDB(db_path=config['pass_database']) fields = FusionInDecoderDataset.prepare_fields( pad_t=reader_tokenizer.pad_token_id) include_passage_masks = config["fusion_strategy"] == "passages" test = FusionInDecoderDataset(infile, fields=fields, tokenizer=reader_tokenizer, database=db, transformer=config["reader_transformer_type"], cache_dir=config["data_cache_dir"], max_len=config.get("reader_max_input_length", None), context_length=config["context_length"], include_passage_masks=include_passage_masks, preprocessing_truncation=config["preprocessing_truncation"], use_cache=False, is_training=False) test_iter = Iterator(test, sort=False, shuffle=False, batch_size=1, repeat=False, device=device) it = tqdm(enumerate(test_iter), total=len(test_iter.data()) // test_iter.batch_size + 1) # load extractive reader's top-K predictions with jsonlines.open(extractive_reader_outfile) as reader_outputs: ext_reader_predictions = {e['raw_question']: e for e in reader_outputs} for i, b in it: if gt_file is not None: ####################### Compute extractive_reader's hit############################### total += 1 original_max_i = argmax(ext_reader_predictions[b.question[0]]['reader_scores']) original_prediction = ext_reader_predictions[b.question[0]]["answers"][original_max_i] hits += int(eval_utils.metric_max_over_ground_truths( metric_fn=eval_utils.exact_match_score, prediction=original_prediction, ground_truths=correct_answers[b.question[0]])) ###################################################################################### # encode passages concatenated_encoder_output, concatenated_encoder_attention = model(input_ids=b.src[0], attention_mask=b.src_mask[0], encode_only=True) # tokenize & numericalize answers from extractive reader tokenized_answers = FusionInDecoderDataset.assemble_target_sequences( answers=ext_reader_predictions[b.question[0]]["answers"], tokenizer=reader_tokenizer) answer_masks = [[1] * len(a) for a in tokenized_answers] # rather do this in for cycle, to not further increase memory complexity scores = [] for ans, mask in zip(tokenized_answers, answer_masks): tensorized_answer = torch.LongTensor(ans).to(device).unsqueeze(0) tensorized_answer_mask = torch.LongTensor(mask).to(device).unsqueeze(0) b.doc_mask = b.doc_mask[0] if include_passage_masks else None concatenated_encoder_output_copy = BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=copy.deepcopy(concatenated_encoder_output['last_hidden_state'])) concatenated_encoder_attention_copy = copy.deepcopy(concatenated_encoder_attention) lm_logits = model(input_ids=None, attention_mask=concatenated_encoder_attention_copy, passage_mask=b.doc_mask, encoder_outputs=concatenated_encoder_output_copy, decoder_input_ids=tensorized_answer[:, :-1].contiguous(), decoder_attention_mask=tensorized_answer_mask[:, :-1].contiguous())[0] labels = tensorized_answer[:, 1:].reshape(-1) logprobs = - F.cross_entropy(lm_logits.view(-1, get_model(model).config.vocab_size), labels, reduction='none') logprobs[labels == reader_tokenizer.pad_token_id] = 0. scores.append(logprobs.sum().item()) # save the scores from the generative reader ext_reader_predictions[b.question[0]]["reader_scores"] = scores if gt_file is not None: ####################### Compute abstractive_reader's hit############################### tensorised_answers = get_model(model).generate(input_ids=concatenated_encoder_attention, # num_beams=5, # num_return_sequences=5, attention_mask=concatenated_encoder_attention, encoder_outputs=concatenated_encoder_output, decoder_start_token_id=b.target[0][0]) generated_prediction = reader_tokenizer.decode(tensorised_answers[0], skip_special_tokens=True) gen_hits += int(eval_utils.metric_max_over_ground_truths( metric_fn=eval_utils.exact_match_score, prediction=generated_prediction, ground_truths=correct_answers[b.question[0]])) ######################################################################################## ####################### Compute re-ranked ############hit############################### reranked_max_i = argmax(scores) reranked_prediction = ext_reader_predictions[b.question[0]]["answers"][reranked_max_i] rr_hits += int(eval_utils.metric_max_over_ground_truths( metric_fn=eval_utils.exact_match_score, prediction=reranked_prediction, ground_truths=correct_answers[b.question[0]])) ######################################################################################## it.set_description( f"Original EM: {hits / total * 100:.2f}; Reranked EM: {rr_hits / total * 100:.2f}; Generative EM: {gen_hits / total * 100:.2f}") # Write-out generatively re-scored predictions with jsonlines.open(outfile, "w") as ofwriter: ofwriter.write_all(ext_reader_predictions.values()) if gt_file is not None: logging.info(f"Extractive EM: {hits / total * 100.}") logging.info(f"Re-ranked EM: {rr_hits / total * 100.}") logging.info(f"Generative EM: {gen_hits / total * 100.}") print(f"Extractive EM: {hits / total * 100.}") print(f"Re-ranked EM: {rr_hits / total * 100.}") print(f"Generative EM: {gen_hits / total * 100.}")
def validate(self, model: T5FusionInDecoder, val_iter: BucketIterator, optimizer_dict=None, log_results=False): """ Does not compute validation loss for now """ model = model.eval() it = tqdm(enumerate(val_iter), total=len(val_iter.data()) // val_iter.batch_size + 1) total = 0 hits = 0 losslist = [] if log_results: import csv model_type = self.config['reader_transformer_type'].replace("/", "_") outf = open(f"results/gen_reader_{model_type}.csv", "w", encoding="utf-8") csvw = csv.writer(outf, delimiter=',') csvw.writerow(["Correct", "Question", "Predicted Answer", "GT Answer", "Input"]) for i, batch in it: batch.src = batch.src[0] batch.src_mask = batch.src_mask[0] batch.doc_mask = batch.doc_mask[0] if hasattr(batch, "doc_mask") else None total += len(batch) concatenated_encoder_output, concatenated_encoder_attention = model(input_ids=batch.src, attention_mask=batch.src_mask, encode_only=True) concatenated_encoder_output_copy = BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=copy.deepcopy(concatenated_encoder_output['last_hidden_state'])) concatenated_encoder_attention_copy = copy.deepcopy(concatenated_encoder_attention) outputs: Seq2SeqLMOutput = model(input_ids=None, attention_mask=concatenated_encoder_attention_copy, encoder_outputs=concatenated_encoder_output_copy, passage_mask=batch.doc_mask, decoder_input_ids=batch.target[:, :-1].contiguous(), decoder_attention_mask=batch.target_mask[:, :-1].contiguous()) lm_logits = outputs.logits labels = batch.target[:, 1:].reshape(-1) losses = F.cross_entropy(lm_logits.view(-1, get_model(model).config.vocab_size), labels, reduction='none') losslist += losses.tolist() # hacky, provide just some tensor as input ids, such that it matches batch dimension 1, # do not provide input ids, as they should not be needed (and have pre-concatenation batch dim) tokenized_answers = get_model(model).generate(input_ids=concatenated_encoder_attention, # num_beams=5, # num_return_sequences=5, attention_mask=concatenated_encoder_attention, encoder_outputs=concatenated_encoder_output, decoder_start_token_id=batch.target[0][0]) predicted_answers = [self.tokenizer.decode(ans, skip_special_tokens=True) for ans in tokenized_answers] for i in range(len(batch)): hit = eval_utils.metric_max_over_ground_truths( metric_fn=eval_utils.exact_match_score, prediction=predicted_answers[i], ground_truths=batch.answers[i]) hits += int(hit) if log_results: csvw.writerow([ hit, batch.question[i], predicted_answers[i], batch.answers[i], self.tokenizer.decode(batch.src[i]) ]) it.set_description(f"Val Loss: {sum(losslist) / len(losslist):.3f} EM: {hits / total:.3f}") EM = hits / total logging.info(f"S: {get_model(model).training_steps} Validation Loss: {sum(losslist) / len(losslist)}") logging.info(f"Validation EM: {EM}") if log_results: outf.close() if EM > self.best_em and not self.config['test_only']: logging.info(f"{EM} ---> New BEST!") self.best_em = EM serializable_model_name = self.config['reader_transformer_type'].replace("/", "_") saveable_model = get_model(model) saveable_model.optimizer_state_dict = optimizer_dict # Note that model training is fully resumable # it contains .optimizer_state_dict and .training_steps (=number of updates) saved_name = os.path.join(self.config['save_dir'], f"generative_reader_" f"EM{EM:.4f}_" f"S{get_model(model).training_steps}_" f"M{serializable_model_name}_" f"{get_timestamp()}_{socket.gethostname()}") self.best_ckpt_name = saved_name torch.save(saveable_model, saved_name) model = model.train() return EM
def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, num_layers=None, num_layers_total=None, rng_seed=None, drop_unused_layers=False, approximate_unused_layers=False, start_sampling_from=0, exclude_layers=tuple(), layer_normalizers=None, keep_last_layer=False, ): exclude_layers = tuple() if not isinstance(exclude_layers, (list, tuple)) else exclude_layers all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = ( ) if output_attentions and self.config.add_cross_attention else None n_forward_layers = 0 n_grad_forward_layers = 0 next_decoder_cache = () if use_cache else None total_layers = len( self.layer) if num_layers_total is None else num_layers_total layers = [self.layer[i] for i in range(total_layers)] selected_layers = list(range(total_layers)) probas = None if num_layers is not None and num_layers < total_layers: odd = False if approximate_unused_layers: assert total_layers % 2 == 0 odd = num_layers % 2 != 0 start_sampling_from = start_sampling_from / 2 num_layers = int(num_layers / 2) + (1 if odd else 0) total_layers = int(total_layers / 2) assert exclude_layers is None or len(exclude_layers) == 0 g_cpu = None if rng_seed is not None: g_cpu = torch.Generator() g_cpu = g_cpu.manual_seed(rng_seed) probas = torch.tensor([ ((total_layers - i) / total_layers) if (i >= start_sampling_from and i not in exclude_layers) else 0.0 for i in range(total_layers) ])**max(self.sampling_alpha, 0.01) # print((total_layers, start_sampling_from), "Probas = ", probas) selected_layers = sorted( torch.multinomial(probas, num_layers, replacement=False, generator=g_cpu).long().tolist()) if keep_last_layer: selected_layers[-1] = total_layers - 1 if approximate_unused_layers: num_layers = num_layers * 2 - (1 if odd else 0) start_sampling_from = int(start_sampling_from * 2) total_layers = total_layers * 2 selected_layers = [ i for s in selected_layers for i in [2 * s, 2 * s + 1] ] if odd: if random.random() < 0.5: selected_layers = selected_layers[:-1] else: selected_layers = selected_layers[:-2] + [ selected_layers[-1] ] # selected_layers = list(range(len(layers))) # print((len(layers), len(selected_layers), start_sampling_from), selected_layers, exclude_layers, self.sampling_alpha, probas) drop_unused_layers = drop_unused_layers or approximate_unused_layers approximate_unused_layers = False prev_grad_layer = max(start_sampling_from - 1, 0) layer_scales = dict( zip(self.scale_factors_idx, torch.abs(self.scale_factors_emb))) layer_scales_loss = 0.0 for i, layer_module in enumerate(layers): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[ i] if past_key_values is not None else None grad_layer = ( i in selected_layers or drop_unused_layers or not approximate_unused_layers ) and self.training and i >= start_sampling_from and torch.is_grad_enabled( ) scale_factor = 1 if i > prev_grad_layer + 1 and ( drop_unused_layers or approximate_unused_layers ) and layer_normalizers is not None and self.enable_layer_normalizers: scale_factor = 1 + layer_scales[(prev_grad_layer, i)] if self.enable_layer_normalizers_statistics: self.distance_statistics[i - prev_grad_layer].mul_(0.999).add_( 0.001 * scale_factor.detach()) if i < start_sampling_from: pass elif drop_unused_layers and i not in selected_layers: continue elif drop_unused_layers and layer_normalizers is not None and self.enable_layer_normalizers: hidden_states = hidden_states * scale_factor if isinstance(scale_factor, torch.Tensor): layer_scales_loss = layer_scales_loss + layer_scales[ (prev_grad_layer, i)] # print((i, prev_grad_layer, len(layers)), (grad_layer, drop_unused_layers, approximate_unused_layers,), scale_factor) prev_grad_layer = i if self.enable_layer_normalizers_statistics: layer_normalizer_fn( hidden_states.detach(), self.layer_normalizers_statistics[i].detach(), self.training, self.train_layer_normalizers, False, self.enable_layer_normalizers_statistics) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions) return custom_forward with torch.set_grad_enabled(grad_layer): if getattr(self.config, "gradient_checkpointing", False) and self.training and grad_layer: if use_cache: logger.warn( "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=False`...") use_cache = False layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states if grad_layer else hidden_states.detach(), attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, layer_normalizers[i + 1] if layer_normalizers is not None else None, ) else: layer_outputs = create_custom_forward(layer_module)( hidden_states if grad_layer else hidden_states.detach(), attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, layer_normalizers[i + 1] if layer_normalizers is not None else None, ) if grad_layer: hidden_states = layer_outputs[0] n_grad_forward_layers += 1 else: hidden_states = hidden_states + (layer_outputs[0].detach() - hidden_states.detach()) n_forward_layers += 1 if use_cache: next_decoder_cache += (layer_outputs[-1], ) if output_attentions: all_self_attentions = all_self_attentions + ( layer_outputs[1], ) if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + ( layer_outputs[2], ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) if not return_dict: return tuple(v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if v is not None) rv = BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) rv["layer_scales_loss"] = layer_scales_loss rv["selected_layers"] = selected_layers rv["n_grad_forward_layers"] = n_grad_forward_layers rv["n_forward_layers"] = n_forward_layers return rv
def forward( self, input_ids=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, gat_hidden_states=None, gat_attention_mask=None, head_mask=None, encoder_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): # print('\n+++ decoder forward') # print('encoder_hidden_states ', encoder_hidden_states.shape if encoder_hidden_states is not None else None) # print('encoder_attention_mask ', encoder_attention_mask.shape if encoder_attention_mask is not None else None) # print('gat_hidden_states ', gat_hidden_states.shape if gat_hidden_states is not None else None) # print('gat_attention_mask ', gat_attention_mask.shape if gat_attention_mask is not None else None) assert gat_hidden_states is not None, 'gat_hidden_states is None' assert not output_attentions, f'output_attentions is {output_attentions}' output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" ) elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError( "You have to specify either decoder_input_ids or decoder_inputs_embeds" ) # past_key_values_length past_key_values_length = past_key_values[0][0].shape[ 2] if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale attention_mask = self._prepare_decoder_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) # expand gat attention mask if gat_hidden_states is not None and gat_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] gat_attention_mask = _expand_mask(gat_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) # embed positions positions = self.embed_positions(input_shape, past_key_values_length) hidden_states = inputs_embeds + positions hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if ( output_attentions and (encoder_hidden_states is not None or gat_hidden_states is not None)) else None next_decoder_cache = () if use_cache else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: assert head_mask.size()[0] == ( len(self.layers) ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states, ) dropout_probability = random.uniform(0, 1) if self.training and (dropout_probability < self.layerdrop): continue past_key_value = past_key_values[ idx] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False) and self.training: if use_cache: logger.warn( "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=False`...") use_cache = False def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, use_cache) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, attention_mask, encoder_hidden_states, gat_hidden_states, encoder_attention_mask, gat_attention_mask, head_mask[idx] if head_mask is not None else None, encoder_head_mask[idx] if encoder_head_mask is not None else None, None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, gat_hidden_states=gat_hidden_states, gat_attention_mask=gat_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += ( layer_outputs[3 if output_attentions else 1], ) if output_attentions: all_self_attns += (layer_outputs[1], ) if encoder_hidden_states is not None: all_cross_attentions += (layer_outputs[2], ) hidden_states = self.layer_norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states, ) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [ hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions ] if v is not None) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, )
def alt_forward(input_ids=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, inputs_embeds=None, head_mask=None, encoder_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, grad_chk_pnt_rate=None): # Model parallel if self.model_parallel: torch.cuda.set_device(self.first_device) self.embed_tokens = self.embed_tokens.to(self.first_device) use_cache = use_cache if use_cache is not None else self.config.use_cache if self.training and use_cache: assert (grad_chk_pnt_rate is None), "Can't use grad checkpoint and cache." output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError( f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time" ) elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError( f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds" ) if inputs_embeds is None: assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" inputs_embeds = self.embed_tokens(input_ids) ### CHANGE BELOW if self.window_mode: # set input shape to window size to format attention masks correctly batch_size, seq_length = input_shape new_batch_size = batch_size * self.windows_per_sample input_shape = (new_batch_size, min(seq_length, self.window_size)) if encoder_hidden_states is not None: # match window batches # Get duplicated encodings together with indices [1,1, 2,2, 3,3, etc] encoding_index = torch.arange(batch_size).repeat( self.windows_per_sample, 1).T.reshape(-1) encoder_hidden_states = encoder_hidden_states[encoding_index] ### CHANGE ABOVE batch_size, seq_length = input_shape # required mask seq length can be calculated via length of past mask_seq_length = past_key_values[0][0].shape[ 2] + seq_length if past_key_values is not None else seq_length if use_cache is True: assert self.is_decoder, ":obj:`use_cache` can only be set to `True` if {} is used as a decoder".format( self) if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length).to( inputs_embeds.device) if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: encoder_seq_length = encoder_hidden_states.shape[1] encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long) # initialize past_key_values with `None` if past does not exist if past_key_values is None: past_key_values = [None] * len(self.block) # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask( attention_mask, input_shape, inputs_embeds.device) if self.is_decoder and encoder_attention_mask is not None: encoder_extended_attention_mask = self.invert_attention_mask( encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) encoder_head_mask = self.get_head_mask(encoder_head_mask, self.config.num_layers) present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None position_bias = None encoder_decoder_position_bias = None hidden_states = self.dropout(inputs_embeds) for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): layer_head_mask = head_mask[i] encoder_layer_head_mask = encoder_head_mask[i] # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) if position_bias is not None: position_bias = position_bias.to(hidden_states.device) if encoder_hidden_states is not None: encoder_hidden_states = encoder_hidden_states.to( hidden_states.device) if encoder_extended_attention_mask is not None: encoder_extended_attention_mask = encoder_extended_attention_mask.to( hidden_states.device) if encoder_decoder_position_bias is not None: encoder_decoder_position_bias = encoder_decoder_position_bias.to( hidden_states.device) if layer_head_mask is not None: layer_head_mask = layer_head_mask.to(hidden_states.device) if encoder_layer_head_mask is not None: encoder_layer_head_mask = encoder_layer_head_mask.to( hidden_states.device) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) ### CHANGE BELOW if grad_chk_pnt_rate and self.training and i % grad_chk_pnt_rate != 0: if use_cache: logger.warn( "`use_cache=True` is incompatible with `grad_chk_pnt_rate=True`. Setting " "`use_cache=False`...") use_cache = False # recacluate gradients later assert (past_key_value is None) def create_custom_forward(module): x = i def custom_forward(*inputs): return module(*inputs, head_mask[x], encoder_head_mask[x], None, False, False) return custom_forward layer_outputs = checkpoint( create_custom_forward(layer_module), hidden_states, extended_attention_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, encoder_decoder_position_bias, ) else: # std way of calculating gradients layer_outputs = layer_module( hidden_states, attention_mask=extended_attention_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, encoder_layer_head_mask=encoder_layer_head_mask, past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, ) ### CHANGE ABOVE # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) hidden_states, present_key_value_state = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention weights), # (self-attention position bias), (cross-attention weights), (cross-attention position bias) position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[ 4 if output_attentions else 3] # append next layer key value states if use_cache: present_key_value_states = present_key_value_states + ( present_key_value_state, ) if output_attentions: all_attentions = all_attentions + (layer_outputs[3], ) if self.is_decoder: all_cross_attentions = all_cross_attentions + ( layer_outputs[5], ) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: for k, v in self.device_map.items(): if i == v[-1] and "cuda:" + str(k) != self.last_device: hidden_states = hidden_states.to("cuda:" + str(k + 1)) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) if not return_dict: return tuple(v for v in [ hidden_states, present_key_value_states, all_hidden_states, all_attentions, all_cross_attentions, ] if v is not None) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_value_states, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, )
def forward( self, input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) batch_size = input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size = inputs_embeds.shape[0] else: raise ValueError( "You have to specify either input_ids or inputs_embeds") if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) if position_ids is not None: position_ids = position_ids.view(-1, input_shape[-1]) if past_key_values is None: past_length = 0 past_key_values = [None] * len(self.h) else: past_length = past_key_values[0][0].size(-2) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) # Attention mask. if attention_mask is not None: assert batch_size > 0, "batch_size has to be defined and > 0" attention_mask = attention_mask.view(batch_size, -1) # We create a 3D attention mask from a 2D tenBaseModelOutputWithPastAndCrossAttentionsds, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. attention_mask = attention_mask[:, None, None, :] # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to( dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.add_cross_attention and encoder_hidden_states is not None: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size( ) encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_attention_mask = self.invert_attention_mask( encoder_attention_mask) else: encoder_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # head_mask has shape n_layer x batch x n_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1), ) presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = ( ) if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states.view( *output_shape), ) # if getattr(self.config, "gradient_checkpointing", False): if self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): # checkpointing only works with tuple returns, not with lists return tuple(output for output in module( *inputs, use_cache, output_attentions)) return custom_forward outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, layer_past, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, ) else: outputs = block( hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, ) if i == self.split_point - 1: if self.gradient_checkpointing: outputs = list(outputs) for j in range(len(outputs)): outputs[j] = outputs[j].to(1) outputs = tuple(outputs) else: for j in range(len(outputs)): outputs[j] = outputs[j].to(1) if use_cache: presents = list(presents) for j in range(len(presents)): presents[j] = presents[j].to(1) presents = tuple(presents) if output_attentions: all_self_attentions = list(all_self_attentions) all_cross_attentions = list(all_cross_attentions) for j in range(len(all_self_attentions)): all_self_attentions[j] = all_self_attentions[j].to(1) for j in range(len(all_cross_attentions)): all_cross_attentions[j] = all_cross_attentions[j].to(1) hidden_states, present = outputs[:2] if use_cache is True: presents = presents + (present, ) if output_attentions: all_self_attentions = all_self_attentions + (outputs[2], ) all_self_attentions = all_cross_attentions + (outputs[3], ) hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(*output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) if not return_dict: return tuple(v for v in [ hidden_states, presents, all_hidden_states, all_self_attentions ] if v is not None) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, )
def forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, head_mask=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ################################################# # Modification to T5ForConditionalGeneration (MF) ################################################# encode_only=False, passage_mask=None, ################################################# ############### END OF MODIFICATION (MF)######### ################################################# ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` Returns: Examples:: >>> from transformers import T5Tokenizer, T5ForConditionalGeneration >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') >>> model = T5FusionInDecoder.from_pretrained('t5-small') >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='pt').input_ids >>> outputs = model(input_ids=input_ids, labels=labels) >>> loss = outputs.loss >>> logits = outputs.logits >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1 >>> outputs = model.generate(input_ids) """ use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) if encoder_outputs is None: # Convert encoder inputs in embeddings if needed encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) ################################################# # Modification to T5ForConditionalGeneration (MF) ################################################# concatenated_encoder_outputs, concatenated_attention_mask = self.concatenate_encoder_outputs( encoder_outputs=encoder_outputs, encoder_attention_mask=attention_mask, passage_mask=passage_mask) if encode_only: return concatenated_encoder_outputs, concatenated_attention_mask ################################################# ############### END OF MODIFICATION (MF)######### ################################################# elif return_dict and not isinstance( encoder_outputs, BaseModelOutputWithPastAndCrossAttentions): # Assume concatenated encoder outputs are passed! concatenated_encoder_outputs = BaseModelOutputWithPastAndCrossAttentions( # Renamed (MF) last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) concatenated_attention_mask = attention_mask # Minor modification (MF) else: # Minor modification (MF) # Assume concatenated encoder outputs are passed! concatenated_encoder_outputs = encoder_outputs concatenated_attention_mask = attention_mask hidden_states = concatenated_encoder_outputs[0] # Renamed (MF) if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) # If decoding with past key value states, only the last tokens # should be given as an input if past_key_values is not None: assert labels is None, "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=concatenated_attention_mask, # Renamed (MF) head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = decoder_outputs[0] if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) lm_logits = self.lm_head(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: output = (lm_logits, ) + decoder_outputs[ 1:] + concatenated_encoder_outputs # Renamed (MF) return ((loss, ) + output) if loss is not None else output return Seq2SeqLMOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=concatenated_encoder_outputs. last_hidden_state, # Renamed (MF) encoder_hidden_states=concatenated_encoder_outputs. hidden_states, # Renamed (MF) encoder_attentions=concatenated_encoder_outputs. attentions, # Renamed (MF) )
def forward( self, input_ids=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, encoder_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" Args: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using :class:`~transformers.BartTokenizer`. See :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for details. `What are input IDs? <../glossary.html#input-ids>`__ attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the heas is **masked**. encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention on hidden heads. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the heas is **masked**. past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. output_hidden_states (:obj:`bool`, `optional`): Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for more detail. return_dict (:obj:`bool`, `optional`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" ) elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError( "You have to specify either decoder_input_ids or decoder_inputs_embeds" ) # past_key_values_length past_key_values_length = past_key_values[0][0].shape[ 2] if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale attention_mask = self._prepare_decoder_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) # embed positions positions = self.embed_positions(input_shape, past_key_values_length) hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if ( output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: assert head_mask.size()[0] == ( len(self.layers) ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states, ) dropout_probability = random.uniform(0, 1) if self.training and (dropout_probability < self.layerdrop): continue past_key_value = past_key_values[ idx] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False) and self.training: if use_cache: logger.warn( "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=False`...") use_cache = False def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, use_cache) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, encoder_head_mask[idx] if encoder_head_mask is not None else None, None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += ( layer_outputs[3 if output_attentions else 1], ) if output_attentions: all_self_attns += (layer_outputs[1], ) if encoder_hidden_states is not None: all_cross_attentions += (layer_outputs[2], ) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states, ) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [ hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions ] if v is not None) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, )
def forward( self, input_ids=None, attention_mask=None, vis_inputs=None, vis_attention_mask=None, inputs_embeds=None, head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): if inputs_embeds is None: assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" inputs_embeds = self.embed_tokens(input_ids) B, L = inputs_embeds.size()[:-1] vis_feats = vis_inputs[0] boxes = vis_inputs[1] img_order_ids = None obj_order_ids = None if len(vis_inputs) >= 3: img_order_ids = vis_inputs[2] if len(vis_inputs) == 4: obj_order_ids = vis_inputs[3] vis_embeds = self.visual_embedding(vis_feats, boxes, img_order_ids, obj_order_ids) V_L = vis_embeds.size(1) inputs_embeds = torch.cat([inputs_embeds, vis_embeds], dim=1) if attention_mask is None: attention_mask = input_ids.ne(self.config.pad_token_id).to( dtype=inputs_embeds.dtype, device=inputs_embeds.device) if vis_attention_mask is None: vis_attention_mask = attention_mask.new_ones(B, V_L) attention_mask = torch.cat([attention_mask, vis_attention_mask], dim=1) # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask( attention_mask, (B, L + V_L), inputs_embeds.device) # initialize past_key_values with `None` if past does not exist if past_key_values is None: past_key_values = [None] * len(self.block) # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None # position_bias = None # encoder_decoder_position_bias = None hidden_states = self.dropout(inputs_embeds) if self.config.num_layers > 0: assert self.block[0].layer[ 0].SelfAttention.has_relative_attention_bias seq_length = L + V_L q_len = seq_length k_len = seq_length # [1, n_heads, Q_len, K_len] text_position_bias = self.block[0].layer[ 0].SelfAttention.compute_bias(L, L) num_heads = text_position_bias.size(1) position_bias = text_position_bias.new_zeros( 1, num_heads, seq_length, seq_length) position_bias[:, :, :L, :L] = text_position_bias # print('position_bias size', position_bias.size()) # print('attention_mask size', attention_mask.size()) # print('extended_attention_mask size', extended_attention_mask.size()) # relative position bias only between Text <-> Text # no relative position bias Text -> Vision # no relative position bias Vision -> Text # no relative position bias Vision <-> Vision # position_bias[:, :, L:, :] = 0 # position_bias[:, :, :, L:] = 0 position_bias = position_bias + extended_attention_mask for i, (layer_module, past_key_value) in enumerate( zip(self.block, past_key_values)): # if output_hidden_states: # all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = layer_module( hidden_states, attention_mask=extended_attention_mask, position_bias=position_bias, encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None, head_mask=head_mask[i], past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) hidden_states, present_key_value_state = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention weights), # (self-attention position bias), (cross-attention weights), (cross-attention position bias) position_bias = layer_outputs[2] # append next layer key value states if use_cache: present_key_value_states = present_key_value_states + \ (present_key_value_state,) # if output_attentions: # all_attentions = all_attentions + (layer_outputs[3],) # if self.is_decoder: # all_cross_attentions = all_cross_attentions + \ # (layer_outputs[5],) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) if not return_dict: return tuple(v for v in [ hidden_states, present_key_value_states, all_hidden_states, all_attentions, all_cross_attentions, ] if v is not None) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_value_states, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, )