def check_encoder_decoder_model( self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs ): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) self.assertTrue(enc_dec_model.config.decoder.is_decoder) self.assertTrue(enc_dec_model.config.decoder.add_cross_attention) self.assertTrue(enc_dec_model.config.is_encoder_decoder) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, output_hidden_states=True, ) self.assertEqual( outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) ) encoder_outputs = BaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1]) outputs_encoder_decoder = enc_dec_model( encoder_outputs=encoder_outputs, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) self.assertEqual( outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) )
def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False, ): all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) layer_head_mask = head_mask[i] if head_mask is not None else None if getattr(self.config, "gradient_checkpointing", False): def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1], ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) if not return_dict: return tuple( v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions)
def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=None, encoder_history_states=None): # docstyle-ignore """ Parameters: x: torch.tensor(bs, seq_length, dim) Input sequence embedded. attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence. Returns: hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top) layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)] Tuple of length n_layers with the hidden states from each layer. Optional: only if output_hidden_states=True all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] Tuple of length n_layers with the attention weights from each layer Optional: only if output_attentions=True """ all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_state = x for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state, ) history_state = None if encoder_history_states is None else encoder_history_states[ i] layer_outputs = layer_module(x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions, history_state=history_state) hidden_state = layer_outputs[-1] if output_attentions: assert len(layer_outputs) == 2 attentions = layer_outputs[0] all_attentions = all_attentions + (attentions, ) else: assert len(layer_outputs) == 1 # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state, ) if not return_dict: return tuple( v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None) return BaseModelOutput(last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions)
def check_encoder_decoder_model( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, **kwargs, ): encoder_model, decoder_model = self.get_encoder_decoder_model( config, decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) self.assertTrue(enc_dec_model.config.decoder.is_decoder) self.assertTrue(enc_dec_model.config.decoder.add_cross_attention) self.assertTrue(enc_dec_model.config.is_encoder_decoder) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, return_dict=True, ) self.assertEqual(outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual( outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size, ))) encoder_outputs = BaseModelOutput( last_hidden_state=encoder_hidden_states) outputs_encoder_decoder = enc_dec_model( encoder_outputs=encoder_outputs, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, return_dict=True, ) self.assertEqual(outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual( outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size, )))
def check_model_with_encoder_outputs(self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs): self.assertIsNotNone(config.question_encoder) self.assertIsNotNone(config.generator) for model_class in self.all_model_classes: model = model_class( config, retriever=self.get_retriever(config)).to(torch_device) model.eval() self.assertTrue(model.config.is_encoder_decoder) outputs = model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) encoder_outputs = BaseModelOutput( outputs.generator_enc_last_hidden_state) # run only generator outputs = model( encoder_outputs=encoder_outputs, doc_scores=outputs.doc_scores, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, ) # logits self.assertEqual( outputs.logits.shape, (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size), ) # generator encoder last hidden states self.assertEqual( outputs.generator_enc_last_hidden_state.shape, (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size), ) # doc scores self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
def generate(self, input_ids, attention_mask, max_length): self.encoder.n_passages = input_ids.size(1) kwars = dict() kwars['attention_mask'] = attention_mask.view(attention_mask.size(0), -1) updated_kwars = super()._prepare_encoder_decoder_kwargs_for_generation( input_ids.view(input_ids.size(0), -1), kwars) base_encoder_outputs = BaseModelOutput() base_encoder_outputs['last_hidden_state'] = updated_kwars[ 'encoder_outputs'][0] return super().generate( input_ids=input_ids.view(input_ids.size(0), -1), attention_mask=attention_mask.view(attention_mask.size(0), -1), max_length=max_length, encoder_outputs=base_encoder_outputs, #past = ((updated_kwars['encoder_outputs']), None) )
def forward( self, input_ids, attention_mask, inputs_embeds=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): encoder_hidden_state = torch.from_numpy( self.encoder.run( None, { "input_ids": input_ids.cpu().numpy(), "attention_mask": attention_mask.cpu().numpy() })[0]) return BaseModelOutput(encoder_hidden_state)
def forward( self, input_ids=None, attention_mask=None, sentence_indicator=None, sentence_labels=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) if encoder_outputs is None: # Convert encoder inputs in embeddings if needed encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) hidden_states = encoder_outputs[0] # extract salient sentences if self.config.sequential_extraction: gumbel_output, all_sentence_logits = self.selection_loop( hidden_states, sentence_indicator, sentence_labels) else: gumbel_output, sentence_logits = self.single_extraction( hidden_states, sentence_indicator, sentence_labels) new_attention_mask = utils.convert_attention_mask( sentence_indicator, gumbel_output) masked_hidden_states = new_attention_mask.unsqueeze(-1) * hidden_states if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) if self.training: reconstruction_decoder_input_ids = self._shift_right(input_ids) # If decoding with past key value states, only the last tokens # should be given as an input if past_key_values is not None: assert labels is None, "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) hidden_states = hidden_states.to(self.decoder.first_device) if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids.to( self.decoder.first_device) if attention_mask is not None: attention_mask = attention_mask.to(self.decoder.first_device) if decoder_attention_mask is not None: decoder_attention_mask = decoder_attention_mask.to( self.decoder.first_device) # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=masked_hidden_states, encoder_attention_mask=new_attention_mask, head_mask=decoder_head_mask, encoder_head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.training: summary = self.greedy_decode(input_ids, masked_hidden_states, new_attention_mask) encoded_summary = self.get_encoder()( summary, attention_mask=(summary != 0).long()) reconstruction_decoder_output = self.decoder( input_ids=reconstruction_decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, encoder_head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = reconstruction_decoder_output[ 0] if self.training else decoder_outputs[0] # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.encoder.first_device) self.lm_head = self.lm_head.to(self.encoder.first_device) self.sentence_classifier = self.sentence_classifier.to( self.encoder.first_device) sequence_output = sequence_output.to(self.lm_head.weight.device) if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) lm_logits = self.lm_head(sequence_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-100) labels = input_ids * attention_mask + (-100) * (1 - attention_mask) if self.training: loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) sim_loss_fct = nn.CosineSimilarity() pooled_hidden_states = hidden_states.mean( 1) if self.config.mean_pool_similarity else torch.max( hidden_states, 1)[0] pooled_encoded_summary = encoded_summary[0].mean( 1) if self.config.mean_pool_similarity else torch.max( encoded_summary[0], 1)[0] # pooled_encoded_summary = masked_hidden_states.mean(1) loss -= (sim_loss_fct(pooled_hidden_states, pooled_encoded_summary)).mean() else: loss = torch.tensor(0.).cuda() # sentence_loss_fct = nn.BCEWithLogitsLoss() # loss = 0 # if self.config.sequential_extraction: # sentence_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) # for i, logits in enumerate(all_sentence_logits): # loss += sentence_loss_fct(logits, sentence_labels[:, i]) # else: # sentence_label_one_hot = utils.convert_one_hot(sentence_labels, sentence_logits.size(1)).float().detach() # loss += 2 * -torch.mean(torch.sum( # sentence_label_one_hot * torch.log_softmax(sentence_logits.squeeze(-1), dim=-1), # dim=-1)) # loss += 2*sentence_loss_fct(sentence_logits.squeeze(-1)[sentence_mask], sentence_label_one_hot[sentence_mask]) # loss += 2*loss_fct(sentence_logits.view(-1, sentence_logits.size(-1)), sentence_label_one_hot.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs return ((loss, ) + output) if loss is not None else output return ExtractorAbstractorOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, extracted_attentions=new_attention_mask, gumbel_output=None if self.training else gumbel_output)
def forward( self, input_ids, attention_mask=None, decoder_input_ids=None, encoder_outputs: Optional[Tuple] = None, decoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): if "decoder_past_key_values" in kwargs: warnings.warn( "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", FutureWarning, ) past_key_values = kwargs.pop("decoder_past_key_values") if decoder_input_ids is None: use_cache = False output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # make masks if user doesn't supply if not use_cache: (decoder_input_ids, decoder_padding_mask, causal_mask,) = _prepare_meena_decoder_inputs( self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_padding_mask=decoder_attention_mask, causal_mask_dtype=self.shared.weight.dtype, ) else: decoder_padding_mask, causal_mask = None, None assert decoder_input_ids is not None if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_dict=False elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) decoder_outputs = self.decoder( decoder_input_ids, encoder_outputs[0], attention_mask, decoder_padding_mask, decoder_causal_mask=causal_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, )
def forward( self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False, ): """ Args: input_ids (LongTensor): tokens in the source language of shape `(batch, src_len)` attention_mask (torch.LongTensor): indicating which indices are padding tokens. Returns: BaseModelOutput or Tuple comprised of: - **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_states** (tuple(torch.FloatTensor)): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *output_hidden_states:* is True. - **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer. During training might not be of length n_layers because of layer dropout. """ # check attention mask and invert if attention_mask is not None: attention_mask = invert_mask(attention_mask) bsz, seq_len = input_ids.shape[:2] inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = inputs_embeds * (self.embed_dim ** 0.5) positions = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) embed_pos = self.embed_positions(positions) x = inputs_embeds + embed_pos x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) encoder_states = [] if output_hidden_states else None all_attentions = () if output_attentions else None for encoder_layer in self.layers: if output_hidden_states: encoder_states.append(x) x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions) if output_attentions: all_attentions = all_attentions + (attn,) if output_hidden_states: encoder_states.append(x) # T x B x C -> B x T x C encoder_states = tuple(hidden_state.transpose(0, 1) for hidden_state in encoder_states) # T x B x C -> B x T x C x = x.transpose(0, 1) if not return_dict: return tuple(v for v in [x, encoder_states, all_attentions] if v is not None) return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
def forward(self, input_ids=None, labels=None, attention_mask=None, encoder_outputs=None, decoder_input_ids=None, latent=None, use_cache=None, return_dict=True, **unused_kwargs): assert return_dict, "Need return_dict=True, using tuple's is not implimented" use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is not None: if decoder_input_ids is not None and input_ids.equal( decoder_input_ids) is False: raise ValueError( "`input_ids` and `decoder_input_ids` do not match. Funnel-VAE can only reproduce its input sequence." ) if self.config.prepend_eos_token: raise NotImplementedError() if attention_mask is None: attention_mask = input_ids.ne( self.transformer.config.pad_token_id).long() if encoder_outputs is None: encoder_outputs = self._get_encoder_outputs( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, ) if encoder_outputs is not None and not isinstance( encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) vae_outputs = self.vae(input_encoding=encoder_outputs.last_hidden_state if encoder_outputs else None, latent=latent, global_step=self.global_step) # TODO allow more options here if self.config.padding_input: upsampled_encoding = upsample( vae_outputs.reconstructed_encoding, stride=2**(len(self.config.transformer.block_sizes) - 1), target_len=self.config.transformer_decoder.n_positions, separate_cls=self.config.transformer.separate_cls, truncate_seq=self.config.transformer.truncate_seq, ) if self.config.use_skip_connections: # TODO use skip connections like in the O.G. Funnel model raise NotImplementedError() else: upsampled_encoding = vae_outputs.reconstructed_encoding # Now using gpt2 decoder if labels is not None and decoder_input_ids is None: # get decoder inputs from shifting labels to the right decoder_input_ids = self._shift_right(input_ids) # use old attention mask shifted right attention_mask = torch.cat( (torch.ones(attention_mask.size(0), 1, device=attention_mask.device), attention_mask), 1)[:, :attention_mask.size(1) - 1] # TODO is this letting the model cheat by just looking at its labels? decoder_outputs = self.decoder( input_ids=decoder_input_ids, encoder_hidden_states=upsampled_encoding, attention_mask=attention_mask, labels=labels, return_dict=True) reg_loss_w = self._regulariser_loss_weight_schedule() loss = decoder_outputs.loss + vae_outputs.reg_loss * reg_loss_w if self.training and self.config.use_extra_logs: self._update_logs(decoder_ce=decoder_outputs.loss.item(), reg_loss=vae_outputs.reg_loss.item(), reg_loss_w=reg_loss_w) return BaseTransformerVAE_Output( loss=loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state if encoder_outputs else None, encoder_hidden_states=encoder_outputs.hidden_states if encoder_outputs else None, encoder_attentions=encoder_outputs.attentions if encoder_outputs else None, latent=vae_outputs.latent, reg_loss=vae_outputs.reg_loss, decoder_ce=decoder_outputs.loss, )
def forward(self, input_ids=None, labels=None, attention_mask=None, encoder_outputs=None, decoder_input_ids=None, latent=None, use_cache=None, return_dict=True, **unused_kwargs): assert return_dict, "Need return_dict=True, using tuple's is not implimented" use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is not None: if decoder_input_ids is not None and input_ids.equal( decoder_input_ids) is False: raise ValueError( "`input_ids` and `decoder_input_ids` do not match. Funnel-VAE can only reproduce its input sequence." ) if self.config.prepend_eos_token: raise NotImplementedError() if attention_mask is None: attention_mask = input_ids.ne( self.transformer.config.pad_token_id).long() if encoder_outputs is None: encoder_outputs = self._get_encoder_outputs( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, ) if encoder_outputs is not None and not isinstance( encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) vae_outputs = self.vae(input_encoding=encoder_outputs.last_hidden_state if encoder_outputs else None, latent=latent, global_step=self.global_step) # TODO allow more options here, specifically allow an extra encoder block after upsampling if self.config.padding_input: upsampled_encoding = upsample( vae_outputs.reconstructed_encoding, stride=2**(len(self.config.transformer.block_sizes) - 1), target_len=self.config.transformer_decoder.n_positions, separate_cls=self.config.transformer.separate_cls, truncate_seq=self.config.transformer.truncate_seq, ) else: upsampled_encoding = vae_outputs.reconstructed_encoding skip_conn_w = 0 if encoder_outputs and self.config.use_skip_connection: skip_conn_w = self._skip_conn_schedule() upsampled_encoding += skip_conn_w * encoder_outputs.hidden_states[ self.config.transformer.block_sizes[0]][:, :upsampled_encoding. size(1)] # Now using T5 decoder if labels is not None and decoder_input_ids is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right( labels) if labels is not None else None decoder_outputs = self.transformer.decoder( input_ids=decoder_input_ids, encoder_hidden_states=upsampled_encoding, use_cache=use_cache, return_dict=True) sequence_output = decoder_outputs.last_hidden_state # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.config.transformer.d_model** -0.5) lm_logits = self.transformer.lm_head(sequence_output) decoder_ce = torch.tensor(0.0, device=lm_logits.device) seq_accuracy = torch.tensor(0.0, device=lm_logits.device) token_accuracy = torch.tensor(0.0, device=lm_logits.device) if labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-100) decoder_ce = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) chosen_tokens = torch.argmax(lm_logits, 2) pad_tokens = (labels == -100).int() correct_tokens = (chosen_tokens == labels).int() + pad_tokens seq_accuracy = (torch.min(correct_tokens, dim=1).values.sum() / labels.size(0)).detach() num_pad_tokens = pad_tokens.sum() token_accuracy = ((correct_tokens.sum() - num_pad_tokens) / (labels.numel() - num_pad_tokens)).detach() reg_loss_w = self._regulariser_loss_weight_schedule() loss = decoder_ce + vae_outputs.reg_loss * reg_loss_w if self.training and self.config.use_extra_logs: self._update_logs(decoder_ce=decoder_ce.item(), seq_accuracy=seq_accuracy, token_accuracy=token_accuracy, reg_loss=vae_outputs.reg_loss.item(), reg_loss_w=reg_loss_w, skip_conn_w=skip_conn_w, latent_dropout=vae_outputs.latent_dropout) return BaseTransformerVAE_Output( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state if encoder_outputs else None, encoder_hidden_states=encoder_outputs.hidden_states if encoder_outputs else None, encoder_attentions=encoder_outputs.attentions if encoder_outputs else None, latent=vae_outputs.latent, reg_loss=vae_outputs.reg_loss, decoder_ce=decoder_ce, seq_accuracy=seq_accuracy, token_accuracy=token_accuracy)
def forward( self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" Args: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using :class:`~transformers.BartTokenizer`. See :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for details. `What are input IDs? <../glossary.html#input-ids>`__ attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the heas is **masked**. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. output_hidden_states (:obj:`bool`, `optional`): Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for more detail. return_dict (:obj:`bool`, `optional`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError( "You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input_shape) hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: assert head_mask.size()[0] == ( len(self.layers) ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states, ) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(encoder_layer), hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1], ) if output_hidden_states: encoder_states = encoder_states + (hidden_states, ) if not return_dict: return tuple( v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions)
def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False, med=None, ): all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) # if getattr(self.config, "gradient_checkpointing", False): # # def create_custom_forward(module): # def custom_forward(*inputs): # return module(*inputs, output_attentions) # # return custom_forward # # layer_outputs = torch.utils.checkpoint.checkpoint( # create_custom_forward(layer_module), # hidden_states, # attention_mask, # head_mask[i], # encoder_hidden_states, # encoder_attention_mask, # ) # else: # -- attention_output, ff_output2, ff_output1, ret_attn_info, layer_outputs = layer_module( hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, output_attentions, ) # -- # zmod cur_output = ff_output2 # ff-output cur_layer_info = {"hid": cur_output} cur_layer_info.update(ret_attn_info) add_expr, early_exit = med.layer_end(cur_layer_info) # check if add_expr is not None: # reuse the original one to avoid extra parameters! # follow mad-x for the adapter architecture! cur_output = layer_module.output.LayerNorm(attention_output + ff_output1 + add_expr) # -- hidden_states = cur_output # hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1], ) # -- # zmod if early_exit: break # adaptive exit! # -- if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) if not return_dict: return tuple( v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions)
def forward( self, input_ids, attention_mask=None, decoder_input_ids=None, encoder_outputs: Optional[Tuple] = None, decoder_attention_mask=None, decoder_past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_tuple=None, **kwargs, ): if decoder_input_ids is None: use_cache = False output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple # make masks if user doesn't supply if not use_cache: decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs( self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_padding_mask=decoder_attention_mask, causal_mask_dtype=self.shared.weight.dtype, ) else: decoder_padding_mask, causal_mask = None, None causal_mask[0, 1] = 0 assert decoder_input_ids is not None if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_tuple=return_tuple, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_tuple=False elif not return_tuple and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) decoder_outputs = self.decoder( decoder_input_ids, encoder_outputs[0], attention_mask, decoder_padding_mask, decoder_causal_mask=causal_mask, decoder_past_key_values=decoder_past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_tuple=return_tuple, ) if return_tuple: return decoder_outputs + encoder_outputs return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, decoder_past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, )
def forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, head_mask=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, task=None, task_embedding=None, **kwargs, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): Used to hide legacy arguments that have been deprecated. Returns: Examples:: >>> from transformers import T5Tokenizer, T5ForConditionalGeneration >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') >>> model = T5ForConditionalGeneration.from_pretrained('t5-small', return_dict=True) >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='pt').input_ids >>> outputs = model(input_ids=input_ids, labels=labels) >>> loss = outputs.loss >>> logits = outputs.logits >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids# Batch size 1 >>> outputs = model.generate(input_ids) """ if "lm_labels" in kwargs: warnings.warn( "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", FutureWarning, ) labels = kwargs.pop("lm_labels") if "decoder_past_key_value_states" in kwargs: warnings.warn( "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", FutureWarning, ) past_key_values = kwargs.pop("decoder_past_key_value_states") if "decoder_past_key_values" in kwargs: warnings.warn( "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", FutureWarning, ) past_key_values = kwargs.pop("decoder_past_key_values") assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) if encoder_outputs is None: # Convert encoder inputs in embeddings if needed encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, task=task, task_embedding=self.task_embedding_controller(task) if self.train_adapters \ and isinstance(self.adapter_config, MetaAdapterConfig) else None ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) hidden_states = encoder_outputs[0] if self.fixed_length_emb: # Appends the attention mask for the projection of fixed length embeddings # to the attention mask of hidden states. if self.concat_projection_token: projection_length = 1 else: projection_length = self.config.projection_length attention_mask_projection = torch.ones(hidden_states.shape[0], projection_length, device=attention_mask.device, dtype=torch.long) if self.only_projection_bottleneck: attention_mask = attention_mask_projection else: attention_mask = torch.cat((attention_mask_projection, attention_mask), dim=1) if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) # If decoding with past key value states, only the last tokens # should be given as an input if past_key_values is not None: assert labels is None, "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, task=task, task_embedding=self.task_embedding_controller(task) \ if (self.train_adapters and isinstance(self.adapter_config, MetaAdapterConfig)) else None ) sequence_output = decoder_outputs[0] # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim ** -0.5) lm_logits = self.lm_head(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs return ((loss,) + output) if loss is not None else output return RuseSeq2SeqLMOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, pooled_enc_hidden_state=encoder_outputs.pooled_enc_hidden_state, )
def encoder_forward(self, fusion_map, input_ids, attention_mask, return_hidden_states=False): embed_dim = self.transformer.config.hidden_size batch_size = len(fusion_map) encoder_outputs = self.transformer.encoder( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=return_hidden_states, return_dict=True) encoder_hidden_states = encoder_outputs.last_hidden_state longest_fused_seq = max( [attention_mask[start:end].sum() for start, end in fusion_map]) encoder_fused_states = torch.zeros( (batch_size, longest_fused_seq, embed_dim), device=self.device) fused_attention_mask = torch.zeros((batch_size, longest_fused_seq), device=self.device) layer_fused_encoder_states = [] if return_hidden_states: encoder_layers_hidden_states = encoder_outputs.hidden_states layers = len(encoder_layers_hidden_states) encoder_layers_fused_states = torch.zeros( (batch_size, longest_fused_seq, layers, embed_dim), device=self.device) for (start, end), i in zip(fusion_map, range(batch_size)): encoder_layers_hidden_states = torch.einsum( 'ijkl->jkil', torch.stack(encoder_layers_hidden_states)) if isinstance( encoder_layers_hidden_states, tuple) else encoder_layers_hidden_states selected_states = encoder_layers_hidden_states[start:end] encoder_attention_mask = attention_mask[start:end].reshape( -1).to(torch.bool) flat_encoder_layer_states = selected_states.reshape( -1, layers, embed_dim)[encoder_attention_mask] encoder_layers_fused_states[ i, :flat_encoder_layer_states. shape[0]] = flat_encoder_layer_states fused_encoder_states = [] for (start, end), i in zip(fusion_map, range(batch_size)): selected_states = encoder_hidden_states[start:end] encoder_attention_mask = attention_mask[start:end].reshape(-1).to( torch.bool) flat_encoder_states = selected_states.reshape( -1, embed_dim)[encoder_attention_mask] encoder_fused_states[ i, :flat_encoder_states.shape[0]] = flat_encoder_states fused_attention_mask[i, :flat_encoder_states.shape[0]] = 1 encoder_outputs = BaseModelOutput( last_hidden_state=encoder_fused_states, hidden_states=encoder_layers_fused_states if return_hidden_states else None, attentions=fused_attention_mask) return encoder_outputs
def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False, layer_config=None, length_config=None, always_keep_cls_token=True, ): bsz, tsz, dim = hidden_states.size() if length_config is not None: restored_hidden_states = hidden_states remain_indices = torch.arange(tsz, device=hidden_states.device).unsqueeze(0).repeat(bsz, 1) all_hidden_states = () if output_hidden_states else None if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) all_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): if layer_config is not None and i not in layer_config: continue layer_head_mask = head_mask[i] if head_mask is not None else None layer_output_length = length_config[i] if length_config is not None else None if getattr(self.config, "gradient_checkpointing", False): def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions, layer_output_length, always_keep_cls_token) return custom_forward layer_outputs, keep_indices = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, ) else: layer_outputs, keep_indices = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions, output_length=layer_output_length, always_keep_cls_token=always_keep_cls_token, ) hidden_states = layer_outputs[0] if layer_output_length: remain_indices = remain_indices.gather(1, keep_indices) restored_hidden_states = restored_hidden_states.scatter(1, remain_indices.unsqueeze(-1).expand(-1, -1, dim), hidden_states) if attention_mask is not None: attention_mask = expand_gather(attention_mask, 3, keep_indices.unsqueeze(1).unsqueeze(2)) if attention_mask.size(2) > 1: attention_mask = expand_gather(attention_mask, 2, keep_indices.unsqueeze(1).unsqueeze(3)) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) last_hidden_state = restored_hidden_states if length_config is not None else hidden_states if not return_dict: return tuple(v for v in [last_hidden_state, all_hidden_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=last_hidden_state, hidden_states=all_hidden_states, attentions=all_attentions )
def __call__(self, *args, **kwargs): kwargs["return_dict"] = False res = super().__call__(*args, **kwargs) return BaseModelOutput(last_hidden_state=torch.tensor(res[0]))
def forward( self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=None, layer_config=None, length_config=None, always_keep_cls_token=True, ): """ Parameters ---------- hidden_states: torch.tensor(bs, seq_length, dim) Input sequence embedded. attention_mask: torch.tensor(bs, seq_length) Attention mask on the sequence. Outputs ------- hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hiddens states in the last (top) layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)] Tuple of length n_layers with the hidden states from each layer. Optional: only if output_hidden_states=True all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] Tuple of length n_layers with the attention weights from each layer Optional: only if output_attentions=True """ bsz, tsz, dim = hidden_states.size() if length_config is not None: restored_hidden_states = hidden_states remain_indices = torch.arange( tsz, device=hidden_states.device).unsqueeze(0).repeat(bsz, 1) all_hidden_states = () if output_hidden_states else None if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) all_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): if layer_config is not None and i not in layer_config: continue layer_head_mask = head_mask[i] if head_mask is not None else None layer_output_length = length_config[ i] if length_config is not None else None layer_outputs, keep_indices = layer_module( hidden_states, attention_mask, layer_head_mask, output_attentions, output_length=layer_output_length, always_keep_cls_token=always_keep_cls_token, ) hidden_states = layer_outputs[-1] if layer_output_length: remain_indices = remain_indices.gather(1, keep_indices) restored_hidden_states = restored_hidden_states.scatter( 1, remain_indices.unsqueeze(-1).expand(-1, -1, dim), hidden_states) if attention_mask is not None: attention_mask = attention_mask.gather(1, keep_indices) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) if output_attentions: assert len(layer_outputs) == 2 attentions = layer_outputs[0] all_attentions = all_attentions + (attentions, ) else: assert len(layer_outputs) == 1 last_hidden_state = restored_hidden_states if length_config is not None else hidden_states if not return_dict: return tuple( v for v in [last_hidden_state, all_hidden_states, all_attentions] if v is not None) return BaseModelOutput(last_hidden_state=last_hidden_state, hidden_states=all_hidden_states, attentions=all_attentions)
def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=None): # docstyle-ignore """ Parameters: x: torch.tensor(bs, seq_length, dim) Input sequence embedded. attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence. Returns: hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top) layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)] Tuple of length n_layers with the hidden states from each layer. Optional: only if output_hidden_states=True all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] Tuple of length n_layers with the attention weights from each layer Optional: only if output_attentions=True """ all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None if self.training: inference_layers = [] for i in range(self.scc_n_layer): if self.bernoulli.sample() == 1: # REPLACE inference_layers.append(self.scc_layer[i]) else: # KEEP the original for offset in range(self.compress_ratio): inference_layers.append( self.layer[i * self.compress_ratio + offset]) else: # inference with compressed model inference_layers = self.scc_layer hidden_state = x for i, layer_module in enumerate(inference_layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state, ) layer_outputs = layer_module(x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions) hidden_state = layer_outputs[-1] if output_attentions: assert len(layer_outputs) == 2 attentions = layer_outputs[0] all_attentions = all_attentions + (attentions, ) else: assert len(layer_outputs) == 1 # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state, ) if not return_dict: return tuple( v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None) return BaseModelOutput(last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions)
def forward( self, input_ids=None, attention_mask=None, vis_inputs=None, vis_attention_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError( "You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input_shape) inputs_embeds = inputs_embeds + embed_pos B, L = inputs_embeds.size()[:-1] vis_feats = vis_inputs[0] boxes = vis_inputs[1] img_order_ids = None obj_order_ids = None if len(vis_inputs) >= 3: img_order_ids = vis_inputs[2] if len(vis_inputs) == 4: obj_order_ids = vis_inputs[3] vis_embeds = self.visual_embedding(vis_feats, boxes, img_order_ids, obj_order_ids) V_L = vis_embeds.size(1) if self.config.share_vis_lang_layer_norm: inputs_embeds = torch.cat([inputs_embeds, vis_embeds], dim=1) inputs_embeds = self.layernorm_embedding(inputs_embeds) else: inputs_embeds = self.layernorm_embedding(inputs_embeds) inputs_embeds = torch.cat([inputs_embeds, vis_embeds], dim=1) hidden_states = F.dropout(inputs_embeds, p=self.dropout, training=self.training) if attention_mask is None: attention_mask = input_ids.ne(self.config.pad_token_id).to( dtype=inputs_embeds.dtype, device=inputs_embeds.device) if vis_attention_mask is None: vis_attention_mask = torch.ones(B, V_L, dtype=inputs_embeds.dtype, device=inputs_embeds.device) # print('attention_mask, ', attention_mask.size()) # print('vis_attention_mask, ', vis_attention_mask.size()) attention_mask = torch.cat([attention_mask, vis_attention_mask], dim=1) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) # print('ext_attention_mask, ', attention_mask.size()) # print('attention_mask') # print(attention_mask.size()) # print(attention_mask) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states, ) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: if getattr(self.config, "gradient_checkpointing", False): def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(encoder_layer), hidden_states, attention_mask, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, output_attentions=output_attentions) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1], ) if output_hidden_states: encoder_states = encoder_states + (hidden_states, ) if not return_dict: return tuple( v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions)
def forward( self, input_ids=None, # real_input_ids=None, attention_mask=None, sentence_indicator=None, sentence_labels=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) if encoder_outputs is None: # Convert encoder inputs in embeddings if needed encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) hidden_states = encoder_outputs[0] hidden_states_non_pad = attention_mask.unsqueeze(-1) * hidden_states # extract salient sentences if self.config.sequential_extraction: gumbel_output, all_sentence_logits = self.selection_loop( hidden_states, sentence_indicator, sentence_labels) else: gumbel_output, sentence_logits = self.single_extraction( hidden_states, sentence_indicator, sentence_labels) new_attention_mask = utils.convert_attention_mask( sentence_indicator, gumbel_output) masked_hidden_states = new_attention_mask.unsqueeze(-1) * hidden_states if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) # If decoding with past key value states, only the last tokens # should be given as an input if past_key_values is not None: assert labels is None, "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) hidden_states = hidden_states.to(self.decoder.first_device) if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids.to( self.decoder.first_device) if attention_mask is not None: attention_mask = attention_mask.to(self.decoder.first_device) if decoder_attention_mask is not None: decoder_attention_mask = decoder_attention_mask.to( self.decoder.first_device) if self.training: attention_mask = self.attention_dropout(attention_mask) hidden_states = hidden_states * attention_mask.unsqueeze(-1) extracted_sentence_encoding = self.encoder( input_ids=input_ids * new_attention_mask.long(), attention_mask=new_attention_mask) # if not self.training: # if real_input_ids is None: # real_input_ids = input_ids # extracted_sentence_encoding = self.encoder(input_ids=real_input_ids*new_attention_mask.long(), attention_mask=new_attention_mask) # hidden_states = extracted_sentence_encoding[0] # attention_mask = new_attention_mask # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=hidden_states if self.training else masked_hidden_states, encoder_attention_mask=attention_mask if self.training else new_attention_mask, head_mask=decoder_head_mask, encoder_head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = decoder_outputs[0] # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.encoder.first_device) self.lm_head = self.lm_head.to(self.encoder.first_device) self.sentence_classifier = self.sentence_classifier.to( self.encoder.first_device) sequence_output = sequence_output.to(self.lm_head.weight.device) if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) lm_logits = self.lm_head(sequence_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) sim_loss_fct = nn.CosineSimilarity() pooled_hidden_states = hidden_states_non_pad.mean(1) #detach()? pooled_encoded_summary = masked_hidden_states.mean( 1) if not self.training else ( extracted_sentence_encoding[0] * (new_attention_mask.unsqueeze(-1))).mean(1) loss -= 2 * (sim_loss_fct(pooled_hidden_states, pooled_encoded_summary)).mean() # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs return ((loss, ) + output) if loss is not None else output return ExtractorAbstractorOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, extracted_attentions=new_attention_mask, gumbel_output=None if self.training else gumbel_output)
def forward( self, input_ids=None, attention_mask=None, vis_inputs=None, vis_attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): # different to other models, Bart automatically creates decoder_input_ids from # input_ids if no decoder_input_ids are provided if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( input_ids, self.config.pad_token_id, self.config.decoder_start_token_id) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, vis_inputs=vis_inputs, vis_attention_mask=vis_attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=False elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) if attention_mask is None: attention_mask = input_ids.ne(self.config.pad_token_id).to( dtype=torch.float, device=input_ids.device) if vis_attention_mask is None: B, L = attention_mask.size() V_L = encoder_outputs[0].size(1) - L vis_attention_mask = attention_mask.new_ones(B, V_L) encoder_attention_mask = torch.cat( [attention_mask, vis_attention_mask], dim=1) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], # encoder_attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, )
def forward(self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_tuple=False, visual=None): # check attention mask and invert if attention_mask is not None: attention_mask = invert_mask(attention_mask) inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale if self.visual is not None: visual = self.visual inputs_embeds = torch.cat([visual, inputs_embeds], dim=1) visual_zeros = torch.zeros( [visual.size()[0], visual.size()[1]], dtype=input_ids.dtype).to(torch.device("cuda")) embed_pos = self.embed_positions( torch.cat([visual_zeros, input_ids], dim=1)) x = inputs_embeds + embed_pos x = self.layernorm_embedding(x) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) encoder_states = [] if output_hidden_states else None all_attentions = () if output_attentions else None for encoder_layer in self.layers: if output_hidden_states: encoder_states.append(x) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) if self.training and (dropout_probability < self.layerdrop): # skip the layer attn = None else: x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions) if output_attentions: all_attentions = all_attentions + (attn, ) if self.layer_norm: x = self.layer_norm(x) if output_hidden_states: encoder_states.append(x) # T x B x C -> B x T x C encoder_states = tuple( hidden_state.transpose(0, 1) for hidden_state in encoder_states) # T x B x C -> B x T x C x = x.transpose(0, 1) if return_tuple: return tuple(v for v in [x, encoder_states, all_attentions] if v is not None) return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
def forward(self, input_ids=None, labels=None, attention_mask=None, encoder_outputs=None, decoder_input_ids=None, latent=None, use_cache=None, return_dict=True, **unused_kwargs): assert return_dict, "Need return_dict=True, using tuple's is not implimented" use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is not None: if self.config.prepend_eos_token: input_ids = self._shift_input_right(input_ids) if attention_mask is None: attention_mask = input_ids.ne( self.transformer.config.pad_token_id).long() if encoder_outputs is None: encoder_outputs = self.transformer.encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=True) if encoder_outputs is not None and not isinstance( encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) vae_outputs = self.vae(input_encoding=encoder_outputs.last_hidden_state if encoder_outputs else None, latent=latent, global_step=self.global_step) if labels is not None and decoder_input_ids is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self.transformer._shift_right( labels) if labels is not None else None decoder_outputs = self.transformer.decoder( input_ids=decoder_input_ids, encoder_hidden_states=vae_outputs.reconstructed_encoding, use_cache=use_cache, return_dict=True, ) sequence_output = decoder_outputs.last_hidden_state # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.config.transformer.d_model** -0.5) lm_logits = self.transformer.lm_head(sequence_output) decoder_ce = torch.tensor(0.0, device=lm_logits.device) if labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-100) decoder_ce = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 reg_loss_w = self._regulariser_loss_weight_schedule() loss = decoder_ce + vae_outputs.reg_loss * reg_loss_w if self.training and self.config.use_extra_logs: self._update_logs(decoder_ce=decoder_ce.item(), reg_loss=vae_outputs.reg_loss.item(), reg_loss_w=reg_loss_w) return BaseTransformerVAE_Output( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state if encoder_outputs else None, encoder_hidden_states=encoder_outputs.hidden_states if encoder_outputs else None, encoder_attentions=encoder_outputs.attentions if encoder_outputs else None, latent=vae_outputs.latent, reg_loss=vae_outputs.reg_loss, decoder_ce=decoder_ce, accuracy=None, )
def forward(self, input_ids=None, labels=None, attention_mask=None, encoder_outputs=None, decoder_input_ids=None, latent=None, return_dict=True, **unused_kwargs): assert return_dict, "Need return_dict=True, using tuple's is not implimented" if input_ids is not None: if decoder_input_ids is not None and input_ids.equal( decoder_input_ids) is False: raise ValueError( "`input_ids` and `decoder_input_ids` do not match. Funnel-VAE can only reproduce its input sequence." ) if self.config.prepend_eos_token: raise NotImplementedError() if attention_mask is None: attention_mask = input_ids.ne( self.transformer.config.pad_token_id).long() if encoder_outputs is None: encoder_outputs = self._get_encoder_outputs( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, ) if encoder_outputs is not None and not isinstance( encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) vae_outputs = self.vae(input_encoding=encoder_outputs.last_hidden_state if encoder_outputs else None, latent=latent, global_step=self.global_step) initial_encoding_size = ( vae_outputs.reconstructed_encoding.size(0), self.config.transformer.n_positions, self.config.transformer.d_model, ) decoder_outputs = self.transformer.funnel.decoder( final_hidden=vae_outputs.reconstructed_encoding, # Don't allow for residual connections, instead just send an empty tensor. first_block_hidden=torch.zeros( initial_encoding_size, device=vae_outputs.reconstructed_encoding.device), return_dict=True, ) last_hidden_state = decoder_outputs.last_hidden_state prediction_logits = self.transformer.lm_head(last_hidden_state) decoder_ce = torch.tensor(0.0, device=prediction_logits.device) if labels is not None: loss_fct = nn.CrossEntropyLoss() # -100 index = padding token decoder_ce = loss_fct( prediction_logits.view(-1, self.config.transformer.vocab_size), labels.view(-1)) reg_loss_w = self._regulariser_loss_weight_schedule() loss = decoder_ce + vae_outputs.reg_loss * reg_loss_w if self.training and self.config.use_extra_logs: self._update_logs(decoder_ce=decoder_ce.item(), reg_loss=vae_outputs.reg_loss.item(), reg_loss_w=reg_loss_w) return BaseTransformerVAE_Output( loss=loss, logits=prediction_logits, past_key_values=None, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=None, encoder_last_hidden_state=encoder_outputs.last_hidden_state if encoder_outputs else None, encoder_hidden_states=encoder_outputs.hidden_states if encoder_outputs else None, encoder_attentions=encoder_outputs.attentions if encoder_outputs else None, latent=vae_outputs.latent, reg_loss=vae_outputs.reg_loss, decoder_ce=decoder_ce, )
def forward( self, input_ids=None, attention_mask=None, encoder_outputs=None, vis_inputs=None, vis_attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, past_key_values=None, use_cache=None, labels=None, inputs_embeds=None, decoder_inputs_embeds=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, reduce_loss=False, return_hidden_state=False, **kwargs, ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, vis_inputs=vis_inputs, vis_attention_mask=vis_attention_mask, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) hidden_states = encoder_outputs[0] if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) # If decoding with past key value states, only the last tokens # should be given as an input if past_key_values is not None: assert labels is None, "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] if attention_mask is None: attention_mask = input_ids.ne(self.config.pad_token_id).to( dtype=hidden_states.dtype, device=hidden_states.device) if vis_attention_mask is None: B, L = attention_mask.size() V_L = encoder_outputs[0].size(1) - L vis_attention_mask = attention_mask.new_ones(B, V_L) encoder_attention_mask = torch.cat( [attention_mask, vis_attention_mask], dim=1) # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # print('decoder_outputs') # print(decoder_outputs) sequence_output = decoder_outputs[0] assert self.config.tie_word_embeddings is True if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) if return_hidden_state: return sequence_output lm_logits = self.lm_head(sequence_output) loss = None if labels is not None: # loss_fct = CrossEntropyLoss(ignore_index=-100) # loss = loss_fct( # lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if reduce_loss: loss_fct = CrossEntropyLoss(ignore_index=-100) else: loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none') loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # print('loss') # print(loss) # if not return_dict: # output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs # return ((loss,) + output) if loss is not None else output return VLSeq2SeqLMOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_last_hidden_state=decoder_outputs.last_hidden_state, decoder_hidden_states=decoder_outputs.hidden_states, # decoder_attentions=decoder_outputs.attentions, # encoder_last_hidden_state=encoder_outputs.last_hidden_state, # encoder_hidden_states=encoder_outputs.hidden_states, # encoder_attentions=encoder_outputs.attentions, # vis_encoder_last_hidden_state=vis_encoder_outputs.last_hidden_state, # vis_encoder_hidden_states=vis_encoder_outputs.hidden_states, # vis_encoder_attentions=vis_encoder_outputs.attentions, # cross_encoder_outputs=cross_encoder_outputs )