Exemple #1
0
    def check_encoder_decoder_model(
        self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
    ):
        encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
        enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
        self.assertTrue(enc_dec_model.config.decoder.is_decoder)
        self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
        self.assertTrue(enc_dec_model.config.is_encoder_decoder)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            pixel_values=pixel_values,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            output_hidden_states=True,
        )
        self.assertEqual(
            outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
        )
        encoder_outputs = BaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1])
        outputs_encoder_decoder = enc_dec_model(
            encoder_outputs=encoder_outputs,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
        )

        self.assertEqual(
            outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
        )
Exemple #2
0
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=False,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states, )

            layer_head_mask = head_mask[i] if head_mask is not None else None

            if getattr(self.config, "gradient_checkpointing", False):

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                )

            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    output_attentions,
                )
            hidden_states = layer_outputs[0]
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1], )

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states, )

        if not return_dict:
            return tuple(
                v for v in [hidden_states, all_hidden_states, all_attentions]
                if v is not None)
        return BaseModelOutput(last_hidden_state=hidden_states,
                               hidden_states=all_hidden_states,
                               attentions=all_attentions)
    def forward(self,
                x,
                attn_mask=None,
                head_mask=None,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=None,
                encoder_history_states=None):  # docstyle-ignore
        """
        Parameters:
            x: torch.tensor(bs, seq_length, dim) Input sequence embedded.
            attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.

        Returns:
            hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)
            layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
                Tuple of length n_layers with the hidden states from each layer.
                Optional: only if output_hidden_states=True
            all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
                Tuple of length n_layers with the attention weights from each layer
                Optional: only if output_attentions=True
        """
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        hidden_state = x
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_state, )

            history_state = None if encoder_history_states is None else encoder_history_states[
                i]
            layer_outputs = layer_module(x=hidden_state,
                                         attn_mask=attn_mask,
                                         head_mask=head_mask[i],
                                         output_attentions=output_attentions,
                                         history_state=history_state)
            hidden_state = layer_outputs[-1]

            if output_attentions:
                assert len(layer_outputs) == 2
                attentions = layer_outputs[0]
                all_attentions = all_attentions + (attentions, )
            else:
                assert len(layer_outputs) == 1

        # Add last layer
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_state, )

        if not return_dict:
            return tuple(
                v for v in [hidden_state, all_hidden_states, all_attentions]
                if v is not None)
        return BaseModelOutput(last_hidden_state=hidden_state,
                               hidden_states=all_hidden_states,
                               attentions=all_attentions)
Exemple #4
0
    def check_encoder_decoder_model(
        self,
        config,
        input_ids,
        attention_mask,
        encoder_hidden_states,
        decoder_config,
        decoder_input_ids,
        decoder_attention_mask,
        **kwargs,
    ):
        encoder_model, decoder_model = self.get_encoder_decoder_model(
            config, decoder_config)
        enc_dec_model = EncoderDecoderModel(encoder=encoder_model,
                                            decoder=decoder_model)
        self.assertTrue(enc_dec_model.config.decoder.is_decoder)
        self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
        self.assertTrue(enc_dec_model.config.is_encoder_decoder)
        enc_dec_model.to(torch_device)
        outputs_encoder_decoder = enc_dec_model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            return_dict=True,
        )
        self.assertEqual(outputs_encoder_decoder["logits"].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(
            outputs_encoder_decoder["encoder_last_hidden_state"].shape,
            (input_ids.shape + (config.hidden_size, )))

        encoder_outputs = BaseModelOutput(
            last_hidden_state=encoder_hidden_states)
        outputs_encoder_decoder = enc_dec_model(
            encoder_outputs=encoder_outputs,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            return_dict=True,
        )

        self.assertEqual(outputs_encoder_decoder["logits"].shape,
                         (decoder_input_ids.shape +
                          (decoder_config.vocab_size, )))
        self.assertEqual(
            outputs_encoder_decoder["encoder_last_hidden_state"].shape,
            (input_ids.shape + (config.hidden_size, )))
Exemple #5
0
    def check_model_with_encoder_outputs(self, config, input_ids,
                                         attention_mask, decoder_input_ids,
                                         decoder_attention_mask, **kwargs):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        for model_class in self.all_model_classes:
            model = model_class(
                config, retriever=self.get_retriever(config)).to(torch_device)
            model.eval()

            self.assertTrue(model.config.is_encoder_decoder)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            encoder_outputs = BaseModelOutput(
                outputs.generator_enc_last_hidden_state)

            # run only generator
            outputs = model(
                encoder_outputs=encoder_outputs,
                doc_scores=outputs.doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (self.n_docs * decoder_input_ids.shape[0],
                 decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (self.n_docs * decoder_input_ids.shape[0],
                 self.max_combined_length, config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape,
                             (input_ids.shape[0], self.n_docs))
Exemple #6
0
 def generate(self, input_ids, attention_mask, max_length):
     self.encoder.n_passages = input_ids.size(1)
     kwars = dict()
     kwars['attention_mask'] = attention_mask.view(attention_mask.size(0),
                                                   -1)
     updated_kwars = super()._prepare_encoder_decoder_kwargs_for_generation(
         input_ids.view(input_ids.size(0), -1), kwars)
     base_encoder_outputs = BaseModelOutput()
     base_encoder_outputs['last_hidden_state'] = updated_kwars[
         'encoder_outputs'][0]
     return super().generate(
         input_ids=input_ids.view(input_ids.size(0), -1),
         attention_mask=attention_mask.view(attention_mask.size(0), -1),
         max_length=max_length,
         encoder_outputs=base_encoder_outputs,
         #past =  ((updated_kwars['encoder_outputs']), None)
     )
Exemple #7
0
    def forward(
        self,
        input_ids,
        attention_mask,
        inputs_embeds=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        encoder_hidden_state = torch.from_numpy(
            self.encoder.run(
                None, {
                    "input_ids": input_ids.cpu().numpy(),
                    "attention_mask": attention_mask.cpu().numpy()
                })[0])

        return BaseModelOutput(encoder_hidden_state)
Exemple #8
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        sentence_indicator=None,
        sentence_labels=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1]
                if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2]
                if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]

        # extract salient sentences
        if self.config.sequential_extraction:
            gumbel_output, all_sentence_logits = self.selection_loop(
                hidden_states, sentence_indicator, sentence_labels)
        else:
            gumbel_output, sentence_logits = self.single_extraction(
                hidden_states, sentence_indicator, sentence_labels)

        new_attention_mask = utils.convert_attention_mask(
            sentence_indicator, gumbel_output)
        masked_hidden_states = new_attention_mask.unsqueeze(-1) * hidden_states

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        if self.training:
            reconstruction_decoder_input_ids = self._shift_right(input_ids)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(
                    self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(
                    self.decoder.first_device)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=masked_hidden_states,
            encoder_attention_mask=new_attention_mask,
            head_mask=decoder_head_mask,
            encoder_head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if self.training:
            summary = self.greedy_decode(input_ids, masked_hidden_states,
                                         new_attention_mask)
            encoded_summary = self.get_encoder()(
                summary, attention_mask=(summary != 0).long())

            reconstruction_decoder_output = self.decoder(
                input_ids=reconstruction_decoder_input_ids,
                attention_mask=decoder_attention_mask,
                inputs_embeds=decoder_inputs_embeds,
                past_key_values=past_key_values,
                encoder_hidden_states=hidden_states,
                encoder_attention_mask=attention_mask,
                head_mask=decoder_head_mask,
                encoder_head_mask=head_mask,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

        sequence_output = reconstruction_decoder_output[
            0] if self.training else decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            self.sentence_classifier = self.sentence_classifier.to(
                self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            labels = input_ids * attention_mask + (-100) * (1 - attention_mask)

            if self.training:
                loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)),
                                labels.view(-1))

                sim_loss_fct = nn.CosineSimilarity()

                pooled_hidden_states = hidden_states.mean(
                    1) if self.config.mean_pool_similarity else torch.max(
                        hidden_states, 1)[0]
                pooled_encoded_summary = encoded_summary[0].mean(
                    1) if self.config.mean_pool_similarity else torch.max(
                        encoded_summary[0], 1)[0]

                #                pooled_encoded_summary = masked_hidden_states.mean(1)
                loss -= (sim_loss_fct(pooled_hidden_states,
                                      pooled_encoded_summary)).mean()
            else:
                loss = torch.tensor(0.).cuda()

#            sentence_loss_fct = nn.BCEWithLogitsLoss()
#            loss = 0

#            if self.config.sequential_extraction:
#                sentence_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
#                for i, logits in enumerate(all_sentence_logits):
#                    loss += sentence_loss_fct(logits, sentence_labels[:, i])
#            else:
#                sentence_label_one_hot = utils.convert_one_hot(sentence_labels, sentence_logits.size(1)).float().detach()
#                loss += 2 * -torch.mean(torch.sum(
#                    sentence_label_one_hot * torch.log_softmax(sentence_logits.squeeze(-1), dim=-1),
#                    dim=-1))
#               loss += 2*sentence_loss_fct(sentence_logits.squeeze(-1)[sentence_mask], sentence_label_one_hot[sentence_mask])
#               loss += 2*loss_fct(sentence_logits.view(-1, sentence_logits.size(-1)), sentence_label_one_hot.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs
            return ((loss, ) + output) if loss is not None else output

        return ExtractorAbstractorOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
            extracted_attentions=new_attention_mask,
            gumbel_output=None if self.training else gumbel_output)
    def forward(
        self,
        input_ids,
        attention_mask=None,
        decoder_input_ids=None,
        encoder_outputs: Optional[Tuple] = None,
        decoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        if "decoder_past_key_values" in kwargs:
            warnings.warn(
                "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("decoder_past_key_values")

        if decoder_input_ids is None:
            use_cache = False

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # make masks if user doesn't supply
        if not use_cache:
            (decoder_input_ids, decoder_padding_mask, causal_mask,) = _prepare_meena_decoder_inputs(
                self.config,
                input_ids,
                decoder_input_ids=decoder_input_ids,
                decoder_padding_mask=decoder_attention_mask,
                causal_mask_dtype=self.shared.weight.dtype,
            )
        else:
            decoder_padding_mask, causal_mask = None, None

        assert decoder_input_ids is not None

        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_dict=False
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            decoder_input_ids,
            encoder_outputs[0],
            attention_mask,
            decoder_padding_mask,
            decoder_causal_mask=causal_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
    def forward(
        self,
        input_ids,
        attention_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=False,
    ):
        """
        Args:
            input_ids (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            attention_mask (torch.LongTensor): indicating which indices are padding tokens.
        Returns:
            BaseModelOutput or Tuple comprised of:
                - **x** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_states** (tuple(torch.FloatTensor)): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *output_hidden_states:* is True.
                - **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer.
                During training might not be of length n_layers because of layer dropout.
        """
        # check attention mask and invert
        if attention_mask is not None:
            attention_mask = invert_mask(attention_mask)

        bsz, seq_len = input_ids.shape[:2]

        inputs_embeds = self.embed_tokens(input_ids)
        inputs_embeds = inputs_embeds * (self.embed_dim ** 0.5)

        positions = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
        embed_pos = self.embed_positions(positions)

        x = inputs_embeds + embed_pos
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        encoder_states = [] if output_hidden_states else None
        all_attentions = () if output_attentions else None
        for encoder_layer in self.layers:
            if output_hidden_states:
                encoder_states.append(x)

            x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions)

            if output_attentions:
                all_attentions = all_attentions + (attn,)

        if output_hidden_states:
            encoder_states.append(x)
            # T x B x C -> B x T x C
            encoder_states = tuple(hidden_state.transpose(0, 1) for hidden_state in encoder_states)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if not return_dict:
            return tuple(v for v in [x, encoder_states, all_attentions] if v is not None)
        return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
Exemple #11
0
    def forward(self,
                input_ids=None,
                labels=None,
                attention_mask=None,
                encoder_outputs=None,
                decoder_input_ids=None,
                latent=None,
                use_cache=None,
                return_dict=True,
                **unused_kwargs):
        assert return_dict, "Need return_dict=True, using tuple's is not implimented"
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        if input_ids is not None:
            if decoder_input_ids is not None and input_ids.equal(
                    decoder_input_ids) is False:
                raise ValueError(
                    "`input_ids` and `decoder_input_ids` do not match. Funnel-VAE can only reproduce its input sequence."
                )
            if self.config.prepend_eos_token:
                raise NotImplementedError()
            if attention_mask is None:
                attention_mask = input_ids.ne(
                    self.transformer.config.pad_token_id).long()
            if encoder_outputs is None:
                encoder_outputs = self._get_encoder_outputs(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    return_dict=True,
                )
        if encoder_outputs is not None and not isinstance(
                encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1]
                if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2]
                if len(encoder_outputs) > 2 else None,
            )

        vae_outputs = self.vae(input_encoding=encoder_outputs.last_hidden_state
                               if encoder_outputs else None,
                               latent=latent,
                               global_step=self.global_step)

        # TODO allow more options here
        if self.config.padding_input:
            upsampled_encoding = upsample(
                vae_outputs.reconstructed_encoding,
                stride=2**(len(self.config.transformer.block_sizes) - 1),
                target_len=self.config.transformer_decoder.n_positions,
                separate_cls=self.config.transformer.separate_cls,
                truncate_seq=self.config.transformer.truncate_seq,
            )
            if self.config.use_skip_connections:
                # TODO use skip connections like in the O.G. Funnel model
                raise NotImplementedError()
        else:
            upsampled_encoding = vae_outputs.reconstructed_encoding

        # Now using gpt2 decoder

        if labels is not None and decoder_input_ids is None:
            # get decoder inputs from shifting labels to the right
            decoder_input_ids = self._shift_right(input_ids)
            # use old attention mask shifted right
            attention_mask = torch.cat(
                (torch.ones(attention_mask.size(0),
                            1,
                            device=attention_mask.device), attention_mask),
                1)[:, :attention_mask.size(1) - 1]

        # TODO is this letting the model cheat by just looking at its labels?
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=upsampled_encoding,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True)

        reg_loss_w = self._regulariser_loss_weight_schedule()
        loss = decoder_outputs.loss + vae_outputs.reg_loss * reg_loss_w

        if self.training and self.config.use_extra_logs:
            self._update_logs(decoder_ce=decoder_outputs.loss.item(),
                              reg_loss=vae_outputs.reg_loss.item(),
                              reg_loss_w=reg_loss_w)

        return BaseTransformerVAE_Output(
            loss=loss,
            logits=decoder_outputs.logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state
            if encoder_outputs else None,
            encoder_hidden_states=encoder_outputs.hidden_states
            if encoder_outputs else None,
            encoder_attentions=encoder_outputs.attentions
            if encoder_outputs else None,
            latent=vae_outputs.latent,
            reg_loss=vae_outputs.reg_loss,
            decoder_ce=decoder_outputs.loss,
        )
Exemple #12
0
    def forward(self,
                input_ids=None,
                labels=None,
                attention_mask=None,
                encoder_outputs=None,
                decoder_input_ids=None,
                latent=None,
                use_cache=None,
                return_dict=True,
                **unused_kwargs):
        assert return_dict, "Need return_dict=True, using tuple's is not implimented"
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        if input_ids is not None:
            if decoder_input_ids is not None and input_ids.equal(
                    decoder_input_ids) is False:
                raise ValueError(
                    "`input_ids` and `decoder_input_ids` do not match. Funnel-VAE can only reproduce its input sequence."
                )
            if self.config.prepend_eos_token:
                raise NotImplementedError()
            if attention_mask is None:
                attention_mask = input_ids.ne(
                    self.transformer.config.pad_token_id).long()
            if encoder_outputs is None:
                encoder_outputs = self._get_encoder_outputs(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    return_dict=True,
                )
        if encoder_outputs is not None and not isinstance(
                encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1]
                if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2]
                if len(encoder_outputs) > 2 else None,
            )

        vae_outputs = self.vae(input_encoding=encoder_outputs.last_hidden_state
                               if encoder_outputs else None,
                               latent=latent,
                               global_step=self.global_step)

        # TODO allow more options here, specifically allow an extra encoder block after upsampling
        if self.config.padding_input:
            upsampled_encoding = upsample(
                vae_outputs.reconstructed_encoding,
                stride=2**(len(self.config.transformer.block_sizes) - 1),
                target_len=self.config.transformer_decoder.n_positions,
                separate_cls=self.config.transformer.separate_cls,
                truncate_seq=self.config.transformer.truncate_seq,
            )
        else:
            upsampled_encoding = vae_outputs.reconstructed_encoding

        skip_conn_w = 0
        if encoder_outputs and self.config.use_skip_connection:
            skip_conn_w = self._skip_conn_schedule()
            upsampled_encoding += skip_conn_w * encoder_outputs.hidden_states[
                self.config.transformer.block_sizes[0]][:, :upsampled_encoding.
                                                        size(1)]

        # Now using T5 decoder

        if labels is not None and decoder_input_ids is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(
                labels) if labels is not None else None

        decoder_outputs = self.transformer.decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=upsampled_encoding,
            use_cache=use_cache,
            return_dict=True)

        sequence_output = decoder_outputs.last_hidden_state
        # Rescale output before projecting on vocab
        # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
        sequence_output = sequence_output * (self.config.transformer.d_model**
                                             -0.5)
        lm_logits = self.transformer.lm_head(sequence_output)

        decoder_ce = torch.tensor(0.0, device=lm_logits.device)
        seq_accuracy = torch.tensor(0.0, device=lm_logits.device)
        token_accuracy = torch.tensor(0.0, device=lm_logits.device)
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            decoder_ce = loss_fct(lm_logits.view(-1, lm_logits.size(-1)),
                                  labels.view(-1))
            chosen_tokens = torch.argmax(lm_logits, 2)
            pad_tokens = (labels == -100).int()
            correct_tokens = (chosen_tokens == labels).int() + pad_tokens
            seq_accuracy = (torch.min(correct_tokens, dim=1).values.sum() /
                            labels.size(0)).detach()
            num_pad_tokens = pad_tokens.sum()
            token_accuracy = ((correct_tokens.sum() - num_pad_tokens) /
                              (labels.numel() - num_pad_tokens)).detach()

        reg_loss_w = self._regulariser_loss_weight_schedule()
        loss = decoder_ce + vae_outputs.reg_loss * reg_loss_w

        if self.training and self.config.use_extra_logs:
            self._update_logs(decoder_ce=decoder_ce.item(),
                              seq_accuracy=seq_accuracy,
                              token_accuracy=token_accuracy,
                              reg_loss=vae_outputs.reg_loss.item(),
                              reg_loss_w=reg_loss_w,
                              skip_conn_w=skip_conn_w,
                              latent_dropout=vae_outputs.latent_dropout)

        return BaseTransformerVAE_Output(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state
            if encoder_outputs else None,
            encoder_hidden_states=encoder_outputs.hidden_states
            if encoder_outputs else None,
            encoder_attentions=encoder_outputs.attentions
            if encoder_outputs else None,
            latent=vae_outputs.latent,
            reg_loss=vae_outputs.reg_loss,
            decoder_ce=decoder_ce,
            seq_accuracy=seq_accuracy,
            token_accuracy=token_accuracy)
Exemple #13
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Args:
            input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using :class:`~transformers.BartTokenizer`. See
                :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
                for details.

                `What are input IDs? <../glossary.html#input-ids>`__
            attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
                Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                `What are attention masks? <../glossary.html#attention-mask>`__
            head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
                Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:

                - 1 indicates the head is **not masked**,
                - 0 indicates the heas is **masked**.

            inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
                Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
                representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
                into associated vectors than the model's internal embedding lookup matrix.
            output_attentions (:obj:`bool`, `optional`):
                Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
                returned tensors for more detail.
            output_hidden_states (:obj:`bool`, `optional`):
                Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
                for more detail.
            return_dict (:obj:`bool`, `optional`):
                Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError(
                "You have to specify either input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

        embed_pos = self.embed_positions(input_shape)

        hidden_states = inputs_embeds + embed_pos
        hidden_states = self.layernorm_embedding(hidden_states)
        hidden_states = F.dropout(hidden_states,
                                  p=self.dropout,
                                  training=self.training)

        # expand attention_mask
        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        # check if head_mask has a correct number of layers specified if desired
        if head_mask is not None:
            assert head_mask.size()[0] == (
                len(self.layers)
            ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
        for idx, encoder_layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states, )
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability <
                                  self.layerdrop):  # skip the layer
                layer_outputs = (None, None)
            else:
                if getattr(self.config, "gradient_checkpointing",
                           False) and self.training:

                    def create_custom_forward(module):
                        def custom_forward(*inputs):
                            return module(*inputs, output_attentions)

                        return custom_forward

                    layer_outputs = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(encoder_layer),
                        hidden_states,
                        attention_mask,
                        (head_mask[idx] if head_mask is not None else None),
                    )
                else:
                    layer_outputs = encoder_layer(
                        hidden_states,
                        attention_mask,
                        layer_head_mask=(head_mask[idx]
                                         if head_mask is not None else None),
                        output_attentions=output_attentions,
                    )

                hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1], )

        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states, )

        if not return_dict:
            return tuple(
                v for v in [hidden_states, encoder_states, all_attentions]
                if v is not None)
        return BaseModelOutput(last_hidden_state=hidden_states,
                               hidden_states=encoder_states,
                               attentions=all_attentions)
Exemple #14
0
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=False,
        med=None,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states, )

            # if getattr(self.config, "gradient_checkpointing", False):
            #
            #     def create_custom_forward(module):
            #         def custom_forward(*inputs):
            #             return module(*inputs, output_attentions)
            #
            #         return custom_forward
            #
            #     layer_outputs = torch.utils.checkpoint.checkpoint(
            #         create_custom_forward(layer_module),
            #         hidden_states,
            #         attention_mask,
            #         head_mask[i],
            #         encoder_hidden_states,
            #         encoder_attention_mask,
            #     )
            # else:
            # --
            attention_output, ff_output2, ff_output1, ret_attn_info, layer_outputs = layer_module(
                hidden_states,
                attention_mask,
                head_mask[i],
                encoder_hidden_states,
                encoder_attention_mask,
                output_attentions,
            )
            # --
            # zmod
            cur_output = ff_output2  # ff-output
            cur_layer_info = {"hid": cur_output}
            cur_layer_info.update(ret_attn_info)
            add_expr, early_exit = med.layer_end(cur_layer_info)  # check
            if add_expr is not None:  # reuse the original one to avoid extra parameters!
                # follow mad-x for the adapter architecture!
                cur_output = layer_module.output.LayerNorm(attention_output +
                                                           ff_output1 +
                                                           add_expr)
            # --

            hidden_states = cur_output
            # hidden_states = layer_outputs[0]
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1], )

            # --
            # zmod
            if early_exit:
                break  # adaptive exit!
            # --

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states, )

        if not return_dict:
            return tuple(
                v for v in [hidden_states, all_hidden_states, all_attentions]
                if v is not None)
        return BaseModelOutput(last_hidden_state=hidden_states,
                               hidden_states=all_hidden_states,
                               attentions=all_attentions)
Exemple #15
0
    def forward(
        self,
        input_ids,
        attention_mask=None,
        decoder_input_ids=None,
        encoder_outputs: Optional[Tuple] = None,
        decoder_attention_mask=None,
        decoder_past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_tuple=None,
        **kwargs,
    ):

        if decoder_input_ids is None:
            use_cache = False

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple

        # make masks if user doesn't supply
        if not use_cache:
            decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
                self.config,
                input_ids,
                decoder_input_ids=decoder_input_ids,
                decoder_padding_mask=decoder_attention_mask,
                causal_mask_dtype=self.shared.weight.dtype,
            )
        else:
            decoder_padding_mask, causal_mask = None, None

        causal_mask[0, 1] = 0

        assert decoder_input_ids is not None

        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_tuple=return_tuple,
            )
        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_tuple=False
        elif not return_tuple and not isinstance(encoder_outputs,
                                                 BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1]
                if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2]
                if len(encoder_outputs) > 2 else None,
            )

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            decoder_input_ids,
            encoder_outputs[0],
            attention_mask,
            decoder_padding_mask,
            decoder_causal_mask=causal_mask,
            decoder_past_key_values=decoder_past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_tuple=return_tuple,
        )

        if return_tuple:
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            decoder_past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
Exemple #16
0
    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            decoder_input_ids=None,
            decoder_attention_mask=None,
            encoder_outputs=None,
            past_key_values=None,
            head_mask=None,
            inputs_embeds=None,
            decoder_inputs_embeds=None,
            labels=None,
            use_cache=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            task=None,
            task_embedding=None,
            **kwargs,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`,
        `optional`):
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to ``-100`` are ignored
            (masked), the loss is only computed for
            labels in ``[0, ..., config.vocab_size]``
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.

        Returns:

        Examples::

            >>> from transformers import T5Tokenizer, T5ForConditionalGeneration

            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
            >>> model = T5ForConditionalGeneration.from_pretrained('t5-small',
            return_dict=True)

            >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1>
            park', return_tensors='pt').input_ids
            >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the
            <extra_id_2> </s>', return_tensors='pt').input_ids
            >>> outputs = model(input_ids=input_ids, labels=labels)
            >>> loss = outputs.loss
            >>> logits = outputs.logits

            >>> input_ids = tokenizer("summarize: studies have shown that owning
            a dog is good for you ", return_tensors="pt").input_ids# Batch size 1
            >>> outputs = model.generate(input_ids)
        """
        if "lm_labels" in kwargs:
            warnings.warn(
                "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
                FutureWarning,
            )
            labels = kwargs.pop("lm_labels")
        if "decoder_past_key_value_states" in kwargs:
            warnings.warn(
                "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("decoder_past_key_value_states")
        if "decoder_past_key_values" in kwargs:
            warnings.warn(
                "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("decoder_past_key_values")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                task=task,
                task_embedding=self.task_embedding_controller(task) if self.train_adapters \
                                                                       and isinstance(self.adapter_config,
                                                                                      MetaAdapterConfig) else None
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )
        hidden_states = encoder_outputs[0]
        if self.fixed_length_emb:
            # Appends the attention mask for the projection of fixed length embeddings
            # to the attention mask of hidden states.
            if self.concat_projection_token:
                projection_length = 1
            else:
                projection_length = self.config.projection_length
            attention_mask_projection = torch.ones(hidden_states.shape[0],
                                                   projection_length, device=attention_mask.device, dtype=torch.long)
            if self.only_projection_bottleneck:
                attention_mask = attention_mask_projection
            else:
                attention_mask = torch.cat((attention_mask_projection,
                                            attention_mask), dim=1)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            task=task,
            task_embedding=self.task_embedding_controller(task) \
                if (self.train_adapters and isinstance(self.adapter_config, MetaAdapterConfig)) else None
        )

        sequence_output = decoder_outputs[0]
        # Rescale output before projecting on vocab
        # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
        sequence_output = sequence_output * (self.model_dim ** -0.5)
        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return RuseSeq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
            pooled_enc_hidden_state=encoder_outputs.pooled_enc_hidden_state,
        )
Exemple #17
0
    def encoder_forward(self,
                        fusion_map,
                        input_ids,
                        attention_mask,
                        return_hidden_states=False):
        embed_dim = self.transformer.config.hidden_size
        batch_size = len(fusion_map)
        encoder_outputs = self.transformer.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=return_hidden_states,
            return_dict=True)
        encoder_hidden_states = encoder_outputs.last_hidden_state

        longest_fused_seq = max(
            [attention_mask[start:end].sum() for start, end in fusion_map])
        encoder_fused_states = torch.zeros(
            (batch_size, longest_fused_seq, embed_dim), device=self.device)
        fused_attention_mask = torch.zeros((batch_size, longest_fused_seq),
                                           device=self.device)

        layer_fused_encoder_states = []
        if return_hidden_states:
            encoder_layers_hidden_states = encoder_outputs.hidden_states
            layers = len(encoder_layers_hidden_states)
            encoder_layers_fused_states = torch.zeros(
                (batch_size, longest_fused_seq, layers, embed_dim),
                device=self.device)
            for (start, end), i in zip(fusion_map, range(batch_size)):
                encoder_layers_hidden_states = torch.einsum(
                    'ijkl->jkil',
                    torch.stack(encoder_layers_hidden_states)) if isinstance(
                        encoder_layers_hidden_states,
                        tuple) else encoder_layers_hidden_states
                selected_states = encoder_layers_hidden_states[start:end]

                encoder_attention_mask = attention_mask[start:end].reshape(
                    -1).to(torch.bool)

                flat_encoder_layer_states = selected_states.reshape(
                    -1, layers, embed_dim)[encoder_attention_mask]
                encoder_layers_fused_states[
                    i, :flat_encoder_layer_states.
                    shape[0]] = flat_encoder_layer_states

        fused_encoder_states = []
        for (start, end), i in zip(fusion_map, range(batch_size)):
            selected_states = encoder_hidden_states[start:end]
            encoder_attention_mask = attention_mask[start:end].reshape(-1).to(
                torch.bool)
            flat_encoder_states = selected_states.reshape(
                -1, embed_dim)[encoder_attention_mask]

            encoder_fused_states[
                i, :flat_encoder_states.shape[0]] = flat_encoder_states
            fused_attention_mask[i, :flat_encoder_states.shape[0]] = 1

        encoder_outputs = BaseModelOutput(
            last_hidden_state=encoder_fused_states,
            hidden_states=encoder_layers_fused_states
            if return_hidden_states else None,
            attentions=fused_attention_mask)
        return encoder_outputs
Exemple #18
0
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=False,
        layer_config=None,
        length_config=None,
        always_keep_cls_token=True,
    ):
        bsz, tsz, dim = hidden_states.size()

        if length_config is not None:
            restored_hidden_states = hidden_states
            remain_indices = torch.arange(tsz, device=hidden_states.device).unsqueeze(0).repeat(bsz, 1)

        all_hidden_states = () if output_hidden_states else None
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states, )
        all_attentions = () if output_attentions else None
        for i, layer_module in enumerate(self.layer):
            if layer_config is not None and i not in layer_config:
                continue

            layer_head_mask = head_mask[i] if head_mask is not None else None
            layer_output_length = length_config[i] if length_config is not None else None

            if getattr(self.config, "gradient_checkpointing", False):

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, output_attentions, layer_output_length, always_keep_cls_token)

                    return custom_forward

                layer_outputs, keep_indices = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                layer_outputs, keep_indices = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    output_attentions,
                    output_length=layer_output_length,
                    always_keep_cls_token=always_keep_cls_token,
                )
            hidden_states = layer_outputs[0]

            if layer_output_length:
                remain_indices = remain_indices.gather(1, keep_indices)
                restored_hidden_states = restored_hidden_states.scatter(1, remain_indices.unsqueeze(-1).expand(-1, -1, dim), hidden_states)

                if attention_mask is not None:
                    attention_mask = expand_gather(attention_mask, 3, keep_indices.unsqueeze(1).unsqueeze(2))
                    if attention_mask.size(2) > 1:
                        attention_mask = expand_gather(attention_mask, 2, keep_indices.unsqueeze(1).unsqueeze(3))

            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        last_hidden_state = restored_hidden_states if length_config is not None else hidden_states
        if not return_dict:
            return tuple(v for v in [last_hidden_state, all_hidden_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=last_hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
        )
Exemple #19
0
 def __call__(self, *args, **kwargs):
     kwargs["return_dict"] = False
     res = super().__call__(*args, **kwargs)
     return BaseModelOutput(last_hidden_state=torch.tensor(res[0]))
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=None,
        layer_config=None,
        length_config=None,
        always_keep_cls_token=True,
    ):
        """
        Parameters
        ----------
        hidden_states: torch.tensor(bs, seq_length, dim)
            Input sequence embedded.
        attention_mask: torch.tensor(bs, seq_length)
            Attention mask on the sequence.

        Outputs
        -------
        hidden_state: torch.tensor(bs, seq_length, dim)
            Sequence of hiddens states in the last (top) layer
        all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
            Tuple of length n_layers with the hidden states from each layer.
            Optional: only if output_hidden_states=True
        all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
            Tuple of length n_layers with the attention weights from each layer
            Optional: only if output_attentions=True
        """
        bsz, tsz, dim = hidden_states.size()

        if length_config is not None:
            restored_hidden_states = hidden_states
            remain_indices = torch.arange(
                tsz, device=hidden_states.device).unsqueeze(0).repeat(bsz, 1)

        all_hidden_states = () if output_hidden_states else None
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states, )
        all_attentions = () if output_attentions else None
        for i, layer_module in enumerate(self.layer):
            if layer_config is not None and i not in layer_config:
                continue

            layer_head_mask = head_mask[i] if head_mask is not None else None
            layer_output_length = length_config[
                i] if length_config is not None else None

            layer_outputs, keep_indices = layer_module(
                hidden_states,
                attention_mask,
                layer_head_mask,
                output_attentions,
                output_length=layer_output_length,
                always_keep_cls_token=always_keep_cls_token,
            )
            hidden_states = layer_outputs[-1]

            if layer_output_length:
                remain_indices = remain_indices.gather(1, keep_indices)
                restored_hidden_states = restored_hidden_states.scatter(
                    1,
                    remain_indices.unsqueeze(-1).expand(-1, -1, dim),
                    hidden_states)

                if attention_mask is not None:
                    attention_mask = attention_mask.gather(1, keep_indices)

            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states, )

            if output_attentions:
                assert len(layer_outputs) == 2
                attentions = layer_outputs[0]
                all_attentions = all_attentions + (attentions, )
            else:
                assert len(layer_outputs) == 1

        last_hidden_state = restored_hidden_states if length_config is not None else hidden_states
        if not return_dict:
            return tuple(
                v for v in
                [last_hidden_state, all_hidden_states, all_attentions]
                if v is not None)
        return BaseModelOutput(last_hidden_state=last_hidden_state,
                               hidden_states=all_hidden_states,
                               attentions=all_attentions)
Exemple #21
0
    def forward(self,
                x,
                attn_mask=None,
                head_mask=None,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=None):  # docstyle-ignore
        """
        Parameters:
            x: torch.tensor(bs, seq_length, dim) Input sequence embedded.
            attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.

        Returns:
            hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)
            layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
                Tuple of length n_layers with the hidden states from each layer.
                Optional: only if output_hidden_states=True
            all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
                Tuple of length n_layers with the attention weights from each layer
                Optional: only if output_attentions=True
        """
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        if self.training:
            inference_layers = []
            for i in range(self.scc_n_layer):
                if self.bernoulli.sample() == 1:  # REPLACE
                    inference_layers.append(self.scc_layer[i])
                else:  # KEEP the original
                    for offset in range(self.compress_ratio):
                        inference_layers.append(
                            self.layer[i * self.compress_ratio + offset])

        else:  # inference with compressed model
            inference_layers = self.scc_layer

        hidden_state = x
        for i, layer_module in enumerate(inference_layers):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_state, )

            layer_outputs = layer_module(x=hidden_state,
                                         attn_mask=attn_mask,
                                         head_mask=head_mask[i],
                                         output_attentions=output_attentions)
            hidden_state = layer_outputs[-1]

            if output_attentions:
                assert len(layer_outputs) == 2
                attentions = layer_outputs[0]
                all_attentions = all_attentions + (attentions, )
            else:
                assert len(layer_outputs) == 1

        # Add last layer
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_state, )

        if not return_dict:
            return tuple(
                v for v in [hidden_state, all_hidden_states, all_attentions]
                if v is not None)
        return BaseModelOutput(last_hidden_state=hidden_state,
                               hidden_states=all_hidden_states,
                               attentions=all_attentions)
Exemple #22
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        vis_inputs=None,
        vis_attention_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError(
                "You have to specify either input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

        embed_pos = self.embed_positions(input_shape)

        inputs_embeds = inputs_embeds + embed_pos

        B, L = inputs_embeds.size()[:-1]

        vis_feats = vis_inputs[0]
        boxes = vis_inputs[1]
        img_order_ids = None
        obj_order_ids = None
        if len(vis_inputs) >= 3:
            img_order_ids = vis_inputs[2]
        if len(vis_inputs) == 4:
            obj_order_ids = vis_inputs[3]

        vis_embeds = self.visual_embedding(vis_feats, boxes, img_order_ids,
                                           obj_order_ids)
        V_L = vis_embeds.size(1)

        if self.config.share_vis_lang_layer_norm:
            inputs_embeds = torch.cat([inputs_embeds, vis_embeds], dim=1)

            inputs_embeds = self.layernorm_embedding(inputs_embeds)
        else:
            inputs_embeds = self.layernorm_embedding(inputs_embeds)
            inputs_embeds = torch.cat([inputs_embeds, vis_embeds], dim=1)

        hidden_states = F.dropout(inputs_embeds,
                                  p=self.dropout,
                                  training=self.training)

        if attention_mask is None:
            attention_mask = input_ids.ne(self.config.pad_token_id).to(
                dtype=inputs_embeds.dtype, device=inputs_embeds.device)

        if vis_attention_mask is None:
            vis_attention_mask = torch.ones(B,
                                            V_L,
                                            dtype=inputs_embeds.dtype,
                                            device=inputs_embeds.device)

        # print('attention_mask, ', attention_mask.size())
        # print('vis_attention_mask, ', vis_attention_mask.size())

        attention_mask = torch.cat([attention_mask, vis_attention_mask], dim=1)

        # expand attention_mask
        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)

        # print('ext_attention_mask, ', attention_mask.size())
        # print('attention_mask')
        # print(attention_mask.size())
        # print(attention_mask)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        for encoder_layer in self.layers:
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states, )
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability <
                                  self.layerdrop):  # skip the layer
                layer_outputs = (None, None)
            else:
                if getattr(self.config, "gradient_checkpointing", False):

                    def create_custom_forward(module):
                        def custom_forward(*inputs):
                            return module(*inputs, output_attentions)

                        return custom_forward

                    layer_outputs = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(encoder_layer),
                        hidden_states,
                        attention_mask,
                    )
                else:
                    layer_outputs = encoder_layer(
                        hidden_states,
                        attention_mask,
                        output_attentions=output_attentions)

                hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1], )

        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states, )

        if not return_dict:
            return tuple(
                v for v in [hidden_states, encoder_states, all_attentions]
                if v is not None)
        return BaseModelOutput(last_hidden_state=hidden_states,
                               hidden_states=encoder_states,
                               attentions=all_attentions)
Exemple #23
0
    def forward(
        self,
        input_ids=None,
        #           real_input_ids=None,
        attention_mask=None,
        sentence_indicator=None,
        sentence_labels=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1]
                if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2]
                if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]
        hidden_states_non_pad = attention_mask.unsqueeze(-1) * hidden_states

        # extract salient sentences
        if self.config.sequential_extraction:
            gumbel_output, all_sentence_logits = self.selection_loop(
                hidden_states, sentence_indicator, sentence_labels)
        else:
            gumbel_output, sentence_logits = self.single_extraction(
                hidden_states, sentence_indicator, sentence_labels)

        new_attention_mask = utils.convert_attention_mask(
            sentence_indicator, gumbel_output)
        masked_hidden_states = new_attention_mask.unsqueeze(-1) * hidden_states

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(
                    self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(
                    self.decoder.first_device)

        if self.training:
            attention_mask = self.attention_dropout(attention_mask)
            hidden_states = hidden_states * attention_mask.unsqueeze(-1)
            extracted_sentence_encoding = self.encoder(
                input_ids=input_ids * new_attention_mask.long(),
                attention_mask=new_attention_mask)

#        if not self.training:
#            if real_input_ids is None:
#                real_input_ids = input_ids
#            extracted_sentence_encoding = self.encoder(input_ids=real_input_ids*new_attention_mask.long(), attention_mask=new_attention_mask)
#            hidden_states = extracted_sentence_encoding[0]
#            attention_mask = new_attention_mask

# Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states
            if self.training else masked_hidden_states,
            encoder_attention_mask=attention_mask
            if self.training else new_attention_mask,
            head_mask=decoder_head_mask,
            encoder_head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            self.sentence_classifier = self.sentence_classifier.to(
                self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)),
                            labels.view(-1))
            sim_loss_fct = nn.CosineSimilarity()
            pooled_hidden_states = hidden_states_non_pad.mean(1)  #detach()?
            pooled_encoded_summary = masked_hidden_states.mean(
                1) if not self.training else (
                    extracted_sentence_encoding[0] *
                    (new_attention_mask.unsqueeze(-1))).mean(1)

            loss -= 2 * (sim_loss_fct(pooled_hidden_states,
                                      pooled_encoded_summary)).mean()

            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs
            return ((loss, ) + output) if loss is not None else output

        return ExtractorAbstractorOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
            extracted_attentions=new_attention_mask,
            gumbel_output=None if self.training else gumbel_output)
Exemple #24
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        vis_inputs=None,
        vis_attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):

        # different to other models, Bart automatically creates decoder_input_ids from
        # input_ids if no decoder_input_ids are provided
        if decoder_input_ids is None and decoder_inputs_embeds is None:
            decoder_input_ids = shift_tokens_right(
                input_ids, self.config.pad_token_id,
                self.config.decoder_start_token_id)

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                vis_inputs=vis_inputs,
                vis_attention_mask=vis_attention_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=False
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1]
                if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2]
                if len(encoder_outputs) > 2 else None,
            )

        if attention_mask is None:
            attention_mask = input_ids.ne(self.config.pad_token_id).to(
                dtype=torch.float, device=input_ids.device)
        if vis_attention_mask is None:
            B, L = attention_mask.size()
            V_L = encoder_outputs[0].size(1) - L
            vis_attention_mask = attention_mask.new_ones(B, V_L)
        encoder_attention_mask = torch.cat(
            [attention_mask, vis_attention_mask], dim=1)

        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_outputs[0],
            # encoder_attention_mask=attention_mask,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
Exemple #25
0
    def forward(self,
                input_ids,
                attention_mask=None,
                output_attentions=False,
                output_hidden_states=False,
                return_tuple=False,
                visual=None):
        # check attention mask and invert
        if attention_mask is not None:
            attention_mask = invert_mask(attention_mask)

        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

        if self.visual is not None:
            visual = self.visual

        inputs_embeds = torch.cat([visual, inputs_embeds], dim=1)

        visual_zeros = torch.zeros(
            [visual.size()[0], visual.size()[1]],
            dtype=input_ids.dtype).to(torch.device("cuda"))
        embed_pos = self.embed_positions(
            torch.cat([visual_zeros, input_ids], dim=1))

        x = inputs_embeds + embed_pos
        x = self.layernorm_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        encoder_states = [] if output_hidden_states else None
        all_attentions = () if output_attentions else None
        for encoder_layer in self.layers:
            if output_hidden_states:
                encoder_states.append(x)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability <
                                  self.layerdrop):  # skip the layer
                attn = None
            else:
                x, attn = encoder_layer(x,
                                        attention_mask,
                                        output_attentions=output_attentions)

            if output_attentions:
                all_attentions = all_attentions + (attn, )

        if self.layer_norm:
            x = self.layer_norm(x)
        if output_hidden_states:
            encoder_states.append(x)
            # T x B x C -> B x T x C
            encoder_states = tuple(
                hidden_state.transpose(0, 1)
                for hidden_state in encoder_states)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if return_tuple:
            return tuple(v for v in [x, encoder_states, all_attentions]
                         if v is not None)
        return BaseModelOutput(last_hidden_state=x,
                               hidden_states=encoder_states,
                               attentions=all_attentions)
Exemple #26
0
    def forward(self,
                input_ids=None,
                labels=None,
                attention_mask=None,
                encoder_outputs=None,
                decoder_input_ids=None,
                latent=None,
                use_cache=None,
                return_dict=True,
                **unused_kwargs):
        assert return_dict, "Need return_dict=True, using tuple's is not implimented"
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        if input_ids is not None:
            if self.config.prepend_eos_token:
                input_ids = self._shift_input_right(input_ids)
            if attention_mask is None:
                attention_mask = input_ids.ne(
                    self.transformer.config.pad_token_id).long()
            if encoder_outputs is None:
                encoder_outputs = self.transformer.encoder(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    return_dict=True)
        if encoder_outputs is not None and not isinstance(
                encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1]
                if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2]
                if len(encoder_outputs) > 2 else None,
            )

        vae_outputs = self.vae(input_encoding=encoder_outputs.last_hidden_state
                               if encoder_outputs else None,
                               latent=latent,
                               global_step=self.global_step)

        if labels is not None and decoder_input_ids is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self.transformer._shift_right(
                labels) if labels is not None else None

        decoder_outputs = self.transformer.decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=vae_outputs.reconstructed_encoding,
            use_cache=use_cache,
            return_dict=True,
        )

        sequence_output = decoder_outputs.last_hidden_state
        # Rescale output before projecting on vocab
        # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
        sequence_output = sequence_output * (self.config.transformer.d_model**
                                             -0.5)
        lm_logits = self.transformer.lm_head(sequence_output)

        decoder_ce = torch.tensor(0.0, device=lm_logits.device)
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            decoder_ce = loss_fct(lm_logits.view(-1, lm_logits.size(-1)),
                                  labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        reg_loss_w = self._regulariser_loss_weight_schedule()
        loss = decoder_ce + vae_outputs.reg_loss * reg_loss_w

        if self.training and self.config.use_extra_logs:
            self._update_logs(decoder_ce=decoder_ce.item(),
                              reg_loss=vae_outputs.reg_loss.item(),
                              reg_loss_w=reg_loss_w)

        return BaseTransformerVAE_Output(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state
            if encoder_outputs else None,
            encoder_hidden_states=encoder_outputs.hidden_states
            if encoder_outputs else None,
            encoder_attentions=encoder_outputs.attentions
            if encoder_outputs else None,
            latent=vae_outputs.latent,
            reg_loss=vae_outputs.reg_loss,
            decoder_ce=decoder_ce,
            accuracy=None,
        )
Exemple #27
0
    def forward(self,
                input_ids=None,
                labels=None,
                attention_mask=None,
                encoder_outputs=None,
                decoder_input_ids=None,
                latent=None,
                return_dict=True,
                **unused_kwargs):
        assert return_dict, "Need return_dict=True, using tuple's is not implimented"

        if input_ids is not None:
            if decoder_input_ids is not None and input_ids.equal(
                    decoder_input_ids) is False:
                raise ValueError(
                    "`input_ids` and `decoder_input_ids` do not match. Funnel-VAE can only reproduce its input sequence."
                )
            if self.config.prepend_eos_token:
                raise NotImplementedError()
            if attention_mask is None:
                attention_mask = input_ids.ne(
                    self.transformer.config.pad_token_id).long()
            if encoder_outputs is None:
                encoder_outputs = self._get_encoder_outputs(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    return_dict=True,
                )
        if encoder_outputs is not None and not isinstance(
                encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1]
                if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2]
                if len(encoder_outputs) > 2 else None,
            )

        vae_outputs = self.vae(input_encoding=encoder_outputs.last_hidden_state
                               if encoder_outputs else None,
                               latent=latent,
                               global_step=self.global_step)

        initial_encoding_size = (
            vae_outputs.reconstructed_encoding.size(0),
            self.config.transformer.n_positions,
            self.config.transformer.d_model,
        )

        decoder_outputs = self.transformer.funnel.decoder(
            final_hidden=vae_outputs.reconstructed_encoding,
            # Don't allow for residual connections, instead just send an empty tensor.
            first_block_hidden=torch.zeros(
                initial_encoding_size,
                device=vae_outputs.reconstructed_encoding.device),
            return_dict=True,
        )

        last_hidden_state = decoder_outputs.last_hidden_state
        prediction_logits = self.transformer.lm_head(last_hidden_state)

        decoder_ce = torch.tensor(0.0, device=prediction_logits.device)
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()  # -100 index = padding token
            decoder_ce = loss_fct(
                prediction_logits.view(-1, self.config.transformer.vocab_size),
                labels.view(-1))

        reg_loss_w = self._regulariser_loss_weight_schedule()
        loss = decoder_ce + vae_outputs.reg_loss * reg_loss_w

        if self.training and self.config.use_extra_logs:
            self._update_logs(decoder_ce=decoder_ce.item(),
                              reg_loss=vae_outputs.reg_loss.item(),
                              reg_loss_w=reg_loss_w)

        return BaseTransformerVAE_Output(
            loss=loss,
            logits=prediction_logits,
            past_key_values=None,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=None,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state
            if encoder_outputs else None,
            encoder_hidden_states=encoder_outputs.hidden_states
            if encoder_outputs else None,
            encoder_attentions=encoder_outputs.attentions
            if encoder_outputs else None,
            latent=vae_outputs.latent,
            reg_loss=vae_outputs.reg_loss,
            decoder_ce=decoder_ce,
        )
Exemple #28
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        encoder_outputs=None,
        vis_inputs=None,
        vis_attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        labels=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        reduce_loss=False,
        return_hidden_state=False,
        **kwargs,
    ):

        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if encoder_outputs is None:

            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                vis_inputs=vis_inputs,
                vis_attention_mask=vis_attention_mask,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1]
                if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2]
                if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        if attention_mask is None:
            attention_mask = input_ids.ne(self.config.pad_token_id).to(
                dtype=hidden_states.dtype, device=hidden_states.device)
        if vis_attention_mask is None:
            B, L = attention_mask.size()
            V_L = encoder_outputs[0].size(1) - L
            vis_attention_mask = attention_mask.new_ones(B, V_L)
        encoder_attention_mask = torch.cat(
            [attention_mask, vis_attention_mask], dim=1)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # print('decoder_outputs')
        # print(decoder_outputs)

        sequence_output = decoder_outputs[0]

        assert self.config.tie_word_embeddings is True

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        if return_hidden_state:
            return sequence_output

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            # loss_fct = CrossEntropyLoss(ignore_index=-100)
            # loss = loss_fct(
            #     lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

            if reduce_loss:
                loss_fct = CrossEntropyLoss(ignore_index=-100)
            else:
                loss_fct = CrossEntropyLoss(ignore_index=-100,
                                            reduction='none')
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)),
                            labels.view(-1))

            # print('loss')
            # print(loss)

        # if not return_dict:
        #     output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
        #     return ((loss,) + output) if loss is not None else output

        return VLSeq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_last_hidden_state=decoder_outputs.last_hidden_state,
            decoder_hidden_states=decoder_outputs.hidden_states,
            # decoder_attentions=decoder_outputs.attentions,
            # encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            # encoder_hidden_states=encoder_outputs.hidden_states,
            # encoder_attentions=encoder_outputs.attentions,
            # vis_encoder_last_hidden_state=vis_encoder_outputs.last_hidden_state,
            # vis_encoder_hidden_states=vis_encoder_outputs.hidden_states,
            # vis_encoder_attentions=vis_encoder_outputs.attentions,
            # cross_encoder_outputs=cross_encoder_outputs
        )