def forward(self, input_ids, encoder_hidden_states, *args, **kwargs): # Unoptimized unconditional transfer to numpy for interfacing with polygraphy input_ids = input_ids.cpu().numpy().astype("int64") encoder_hidden_states = encoder_hidden_states.cpu().numpy().astype("float32") logits = self.trt_context.infer( {"input_ids": input_ids, "encoder_hidden_states": encoder_hidden_states} )["hidden_states"] return Seq2SeqLMOutput(logits=torch.from_numpy(logits))
def forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): if encoder_outputs is None: # Convert encoder inputs in embeddings if needed encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) encoder_hidden_states = encoder_outputs[0] if past_key_values is not None: if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] if past_key_values is None: # runs only for the first time: init_onnx_outputs = self.decoder_init(decoder_input_ids, attention_mask, encoder_hidden_states) logits, past_key_values = init_onnx_outputs else: onnx_outputs = self.decoder( decoder_input_ids, attention_mask, encoder_hidden_states, past_key_values, ) logits, past_key_values = onnx_outputs return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)
def forward(self, input_ids, encoder_hidden_states, *args, **kwargs): self.inputs["input_ids"][:, :input_ids.shape[1]] = input_ids self.inputs["encoder_hidden_states"][:, :TRTHFRunner.ENCODER_LENGTH, :] = encoder_hidden_states[:, :TRTHFRunner.ENCODER_LENGTH, :] # TODO: This can be better generalized self.trt_context.set_binding_shape(0, input_ids.shape) self.trt_context.set_binding_shape(1, (1, TRTHFRunner.ENCODER_LENGTH, self.max_sequence_length)) # Copy to device self.trt_context.execute_v2(bindings=self.bindings) # Transfer predictions back from GPU to do greedy search return Seq2SeqLMOutput(logits=self.outputs["hidden_states"][:, :input_ids.shape[1]].cpu())
def forward(self, input_ids, encoder_hidden_states, **kwargs): decoder_outputs = self.decoder( input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, **kwargs) # self.config.d_model ** -0.5 for rescaling output on vocab. # as seen in https://huggingface.co/transformers/_modules/transformers/models/t5/modeling_t5.html#T5ForConditionalGeneration sequence_output = decoder_outputs[0] * self.config.d_model**-0.5 logits = self.lm_head(sequence_output) if not kwargs.get("return_dict", False): return (logits, ) + decoder_outputs[1:] return Seq2SeqLMOutput(logits=logits)
def forward(*args, **kwargs): outputs = origin_forward(*args, **kwargs) if outputs.past_key_values is not None: # Multiply by 1.0 to workaround a bug in OpenVINO 2022.1 with # dynamic shapes inputs connected to model outputs: past_key_values = [] for i in range(12): past_key_values.append(( outputs.past_key_values[i][0], outputs.past_key_values[i][1], outputs.past_key_values[i][2] * 1.0, outputs.past_key_values[i][3] * 1.0, )) outputs.past_key_values = tuple(past_key_values) return Seq2SeqLMOutput( logits=outputs.logits, past_key_values=outputs.past_key_values, )
def forward(self, fusion_map, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, return_hidden_states=False, **kwargs): encoder_outputs = self.encoder_forward(fusion_map, input_ids, attention_mask) encoder_fused_states = encoder_outputs.last_hidden_state fused_attention_mask = encoder_outputs.attentions encoder_layer_states = encoder_outputs.hidden_states dec_outputs = self.decoder(input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_fused_states, encoder_attention_mask=fused_attention_mask, output_hidden_states=return_hidden_states) sequence_output = dec_outputs[0] lm_logits = self.lm_head(sequence_output) return Seq2SeqLMOutput(logits=lm_logits, encoder_hidden_states=encoder_layer_states)
def forward(self, batch, final_context, context_rnn_state, encoder_loss, current_token_id=None, decoder_wrapper=None, expansion_factor=1, generation_dict=None): context, context_limited = batch.context.value, batch.context.limited answer, answer_limited = batch.answer.value, batch.answer.limited decoder_vocab = self.numericalizer.decoder_vocab self.map_to_full = decoder_vocab.decode context_padding = context.data == self.pad_idx if self.training: if self.args.rnn_layers > 0: self.rnn_decoder.applyMasks(context_padding) else: self.context_attn.applyMasks(context_padding) answer_padding = (answer.data == self.pad_idx)[:, :-1] answer_embedded = self.decoder_embeddings(answer[:, :-1], padding=answer_padding) if self.args.rnn_layers > 0: rnn_decoder_outputs = self.rnn_decoder( answer_embedded, final_context, hidden=context_rnn_state) decoder_output, vocab_pointer_switch_input, context_attention, rnn_state = rnn_decoder_outputs else: context_decoder_output, context_attention = self.context_attn( answer_embedded, final_context) vocab_pointer_switch_input = torch.cat( (context_decoder_output, answer_embedded), dim=-1) decoder_output = self.dropout(context_decoder_output) vocab_pointer_switch = self.vocab_pointer_switch( vocab_pointer_switch_input) probs = self.probs(decoder_output, vocab_pointer_switch, context_attention, context_limited, decoder_vocab) probs, targets = mask(answer_limited[:, 1:].contiguous(), probs.contiguous(), pad_idx=decoder_vocab.pad_idx) loss = F.nll_loss(probs.log(), targets, ignore_index=decoder_vocab.pad_idx) if encoder_loss is not None: loss += self.args.encoder_loss_weight * encoder_loss return Seq2SeqLMOutput(loss=loss) else: if decoder_wrapper is None: decoder_wrapper = self.decoder_wrapper( final_context, context_padding, context_limited, decoder_vocab, rnn_state=context_rnn_state, expansion_factor=expansion_factor, generation_dict=generation_dict) else: current_token_id = current_token_id.clone().cpu().apply_( self.map_to_full).to(current_token_id.device) # (next_token_logits, past) where `past` includes all the states needed to continue generation logits = torch.log( decoder_wrapper.next_token_probs(current_token_id)) return Seq2SeqLMOutput(logits=logits, past_key_values=decoder_wrapper)
def forward( self, input_image=None, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict kwargs_encoder = { argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_") } kwargs_decoder = { argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } if encoder_outputs is None: # Get image embeddings. image_encoder_outputs = self.image_encoder( pixel_values=input_image) # Get command embeddings. command_encoder_outputs = self.command_encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs_encoder, ) # Concatenate outputs of both encoders along "sequence" dimension. encoder_outputs = torch.cat( (image_encoder_outputs[0], command_encoder_outputs[0]), 1) # Back to regular track ;) encoder_hidden_states = encoder_outputs # optionally project encoder_hidden_states # TODO: probably this is wrong! both encoders should have the same hidden size! #if ( # self.decoder.config.hidden_size != self.image_encoder.config.hidden_size or \ # self.decoder.config.hidden_size != self.command_encoder.config.hidden_size) # and self.decoder.config.cross_attention_hidden_size is None #): # encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id) # Create masks for dual-encoder inputs. encoder_hidden_shape = (encoder_outputs.shape[0], encoder_outputs.shape[1]) dual_attention_masks = torch.ones(encoder_hidden_shape, dtype=torch.long, device=self.device) # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=dual_attention_masks, inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, past_key_values=past_key_values, return_dict=return_dict, **kwargs_decoder, ) # Compute loss independent from decoder (as some shift the logits inside them) loss = None if labels is not None: #warnings.warn(DEPRECATION_WARNING, FutureWarning) logits = decoder_outputs.logits if return_dict else decoder_outputs[ 0] loss_fct = CrossEntropyLoss() loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) if not return_dict: if loss is not None: return (loss, ) + decoder_outputs + encoder_outputs else: return decoder_outputs + encoder_outputs return Seq2SeqLMOutput( loss=loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, # tuple decoder_hidden_states=decoder_outputs.hidden_states, # None! decoder_attentions=decoder_outputs.attentions, # None! cross_attentions=decoder_outputs.cross_attentions, # None! # TODO: Not sure about the following ones! what should be there? #encoder_last_hidden_state=command_encoder_outputs.last_hidden_state, #encoder_hidden_states=encoder_hidden_states.hidden_states, #encoder_attentions=command_encoder_outputs.attentions, )
def forward( self, input_ids, attention_mask=None, encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, past_key_values=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **unused, ): if "lm_labels" in unused: warnings.warn( "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", FutureWarning, ) labels = unused.pop("lm_labels") if "decoder_cached_states" in unused: warnings.warn( "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", FutureWarning, ) past_key_values = unused.pop("decoder_cached_states") if "decoder_past_key_values" in unused: warnings.warn( "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", FutureWarning, ) past_key_values = unused.pop("decoder_past_key_values") return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) lm_logits = F.linear(outputs[0], self.model.shared.weight) masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() # TODO(SS): do we need to ignore pad tokens in labels? masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (lm_logits,) + outputs[1:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return Seq2SeqLMOutput( loss=masked_lm_loss, logits=lm_logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, )
def forward( self, input_ids=None, attention_mask=None, vis_inputs=None, vis_attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, reduce_loss=False, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id) outputs = self.model( input_ids, attention_mask=attention_mask, vis_inputs=vis_inputs, vis_attention_mask=vis_attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias masked_lm_loss = None if labels is not None: # loss_fct = CrossEntropyLoss() if reduce_loss: loss_fct = CrossEntropyLoss(ignore_index=-100) else: loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none') masked_lm_loss = loss_fct( lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (lm_logits, ) + outputs[1:] return ((masked_lm_loss, ) + output) if masked_lm_loss is not None else output return Seq2SeqLMOutput( loss=masked_lm_loss, logits=lm_logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, )
def forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, #TJH In 4.4.2 labels contains what in unifiedqa is called decoder_input_ids use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): #TJH: Added for compatibility with 4.4.2 return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: #TJH added for compatibility with other 4.4.2 seq2seq models if decoder_input_ids is None: #TJH: how it is done in modelling_bart.py. Using the unifiedQA method instead # decoder_input_ids = shift_tokens_right( # labels, self.config.pad_token_id, self.config.decoder_start_token_id # ) decoder_start_token_id = self.config.decoder_start_token_id decoder_input_ids = labels.new_zeros(labels.shape) decoder_input_ids[..., 1:] = labels[..., :-1].clone() decoder_input_ids[..., 0] = decoder_start_token_id # TJH: below from modeling_bart.py outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, #TJH: no underscore encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias) loss = None if labels is not None: #TJH labels is not None instead of is_training loss_fct = nn.CrossEntropyLoss(reduce=False) losses = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) loss = torch.sum(losses * decoder_attention_mask.float().view(-1)) if not return_dict: #TJH: from modeling_bart.py output = (lm_logits, ) + outputs[1:] return ((loss, ) + output) if loss is not None else output return Seq2SeqLMOutput( #TJH: from modeling_bart.py. loss=loss, logits=lm_logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, )
def forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, head_mask=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ################################################# # Modification to T5ForConditionalGeneration (MF) ################################################# encode_only=False, passage_mask=None, ################################################# ############### END OF MODIFICATION (MF)######### ################################################# ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` Returns: Examples:: >>> from transformers import T5Tokenizer, T5ForConditionalGeneration >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') >>> model = T5FusionInDecoder.from_pretrained('t5-small') >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='pt').input_ids >>> outputs = model(input_ids=input_ids, labels=labels) >>> loss = outputs.loss >>> logits = outputs.logits >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1 >>> outputs = model.generate(input_ids) """ use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) if encoder_outputs is None: # Convert encoder inputs in embeddings if needed encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) ################################################# # Modification to T5ForConditionalGeneration (MF) ################################################# concatenated_encoder_outputs, concatenated_attention_mask = self.concatenate_encoder_outputs( encoder_outputs=encoder_outputs, encoder_attention_mask=attention_mask, passage_mask=passage_mask) if encode_only: return concatenated_encoder_outputs, concatenated_attention_mask ################################################# ############### END OF MODIFICATION (MF)######### ################################################# elif return_dict and not isinstance( encoder_outputs, BaseModelOutputWithPastAndCrossAttentions): # Assume concatenated encoder outputs are passed! concatenated_encoder_outputs = BaseModelOutputWithPastAndCrossAttentions( # Renamed (MF) last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) concatenated_attention_mask = attention_mask # Minor modification (MF) else: # Minor modification (MF) # Assume concatenated encoder outputs are passed! concatenated_encoder_outputs = encoder_outputs concatenated_attention_mask = attention_mask hidden_states = concatenated_encoder_outputs[0] # Renamed (MF) if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) # If decoding with past key value states, only the last tokens # should be given as an input if past_key_values is not None: assert labels is None, "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=concatenated_attention_mask, # Renamed (MF) head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = decoder_outputs[0] if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) lm_logits = self.lm_head(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: output = (lm_logits, ) + decoder_outputs[ 1:] + concatenated_encoder_outputs # Renamed (MF) return ((loss, ) + output) if loss is not None else output return Seq2SeqLMOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=concatenated_encoder_outputs. last_hidden_state, # Renamed (MF) encoder_hidden_states=concatenated_encoder_outputs. hidden_states, # Renamed (MF) encoder_attentions=concatenated_encoder_outputs. attentions, # Renamed (MF) )
def forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): r""" Returns: Examples: ```python >>> from transformers import EncoderDecoderModel, BertTokenizer >>> import torch >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained( ... "bert-base-uncased", "bert-base-uncased" >>> ) # initialize Bert2Bert from pre-trained checkpoints >>> # training >>> model.config.decoder_start_token_id = tokenizer.cls_token_id >>> model.config.pad_token_id = tokenizer.pad_token_id >>> model.config.vocab_size = model.config.decoder.vocab_size >>> input_ids = tokenizer("This is a really long text", return_tensors="pt").input_ids >>> labels = tokenizer("This is the corresponding summary", return_tensors="pt").input_ids >>> outputs = model(input_ids=input_ids, labels=input_ids) >>> loss, logits = outputs.loss, outputs.logits >>> # save and load from pretrained >>> model.save_pretrained("bert2bert") >>> model = EncoderDecoderModel.from_pretrained("bert2bert") >>> # generation >>> generated = model.generate(input_ids) ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict kwargs_encoder = { argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_") } kwargs_decoder = { argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs_encoder, ) encoder_hidden_states = encoder_outputs[0] # optionally project encoder_hidden_states if (self.encoder.config.hidden_size != self.decoder.config.hidden_size and self.decoder.config.cross_attention_hidden_size is None): encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id) # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, past_key_values=past_key_values, return_dict=return_dict, **kwargs_decoder, ) # Compute loss independent from decoder (as some shift the logits inside them) loss = None if labels is not None: #warnings.warn(DEPRECATION_WARNING, FutureWarning) logits = decoder_outputs.logits if return_dict else decoder_outputs[ 0] loss_fct = CrossEntropyLoss() loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) if not return_dict: if loss is not None: return (loss, ) + decoder_outputs + encoder_outputs else: return decoder_outputs + encoder_outputs return Seq2SeqLMOutput( loss=loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, )
def forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): custom_loss = None if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # first speaker 1 occurrence persona_boundaries = (input_ids == 30525).nonzero(as_tuple=False) boundaries = [] for item in persona_boundaries.tolist(): if item[0] not in [x[0] for x in boundaries]: boundaries.append(item) boundaries = [item[1] for item in boundaries] persona_input_ids, persona_attention_mask = [], [] for i, (ids, attention) in enumerate(zip(input_ids, attention_mask)): persona_input_ids.append( ids.tolist()[:boundaries[i] + 1] + [30524 for _ in range(512 - (boundaries[i] + 1))]) persona_attention_mask.append( attention.tolist()[:boundaries[i] + 1] + [0 for _ in range(512 - (boundaries[i] + 1))]) device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") persona_input_ids = torch.Tensor(persona_input_ids).long().to( device) persona_attention_mask = torch.Tensor( persona_attention_mask).long().to(device) custom_encoder_outputs = self.encoder( input_ids=persona_input_ids, attention_mask=persona_attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) custom_loss = self.cs( encoder_outputs['pooler_output'], custom_encoder_outputs['pooler_output'], ).mean() # print(custom_encoder_outputs.keys(), type(custom_encoder_outputs['pooler_output']), custom_encoder_outputs['pooler_output'].size()) # print(custom_loss) # sys.exit() # after 1st occurrence convert to 0 # attention, input is # custom loss encoder_hidden_states = encoder_outputs[0] # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, inputs_embeds=decoder_inputs_embeds, labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, # use_cache=use_cache, # past_key_values=past_key_values, return_dict=return_dict, ) if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqLMOutput( loss=decoder_outputs.loss, logits=decoder_outputs.logits, # closs = custom_loss, past_key_values=custom_loss, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, # cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, )
def forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, return_dict=None, labels=None, latent_z=None ): if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( input_ids, self.config.pad_token_id, self.config.decoder_start_token_id ) return_dict = self.config.use_return_dict encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=return_dict, ) ########################### batch_size = input_ids.shape[0] # seq_repr = torch.cat((encoder_outputs[0][:, 0, :], encoder_outputs[0].mean(dim=1), encoder_outputs[0][:, -1, :]), -1).view(batch_size, -1) seq_repr = torch.cat((encoder_outputs[0][:, 0, :], encoder_outputs[0][:, -1, :]), -1).view(batch_size, -1) # Reparameterize mu = self.to_mu(seq_repr) logvar = self.to_logvar(seq_repr) z = self.reparameterize(mu, logvar) # # add noise # if self.word_dropout_rate > 0: # # randomly replace decoder input with <unk> # prob = torch.rand(decoder_input_ids.size()) # if torch.cuda.is_available(): # prob = prob.cuda() # prob[ (decoder_input_ids.data - self.config.pad_token_id) == 0] = 1 # decoder_input_sequence = decoder_input_ids.clone() # decoder_input_sequence[prob < self.word_dropout_rate] = self.config.unk_token_id # decoder_input_ids = decoder_input_sequence latent_z = self.to_emb(z) # fitting embedding size ########################### # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, return_dict=return_dict, latent_z=latent_z ) lm_logits = self.lm_head(decoder_outputs[0]) + self.final_logits_bias masked_lm_loss = None if labels is None: labels = input_ids.clone() loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) # reconstruction masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) # kl div kl_loss = self.loss_kl(mu, logvar) loss_total= masked_lm_loss + self.config.lambda_kl * kl_loss if not return_dict: output = (lm_logits,) + decoder_outputs[1:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return Seq2SeqLMOutput( loss=loss_total, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, )