Beispiel #1
0
    def forward(self, input_ids, encoder_hidden_states, *args, **kwargs):
        # Unoptimized unconditional transfer to numpy for interfacing with polygraphy
        input_ids = input_ids.cpu().numpy().astype("int64")
        encoder_hidden_states = encoder_hidden_states.cpu().numpy().astype("float32")

        logits = self.trt_context.infer(
            {"input_ids": input_ids, "encoder_hidden_states": encoder_hidden_states}
        )["hidden_states"]

        return Seq2SeqLMOutput(logits=torch.from_numpy(logits))
Beispiel #2
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(input_ids=input_ids,
                                           attention_mask=attention_mask)

        encoder_hidden_states = encoder_outputs[0]

        if past_key_values is not None:
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        if past_key_values is None:

            # runs only for the first time:
            init_onnx_outputs = self.decoder_init(decoder_input_ids,
                                                  attention_mask,
                                                  encoder_hidden_states)

            logits, past_key_values = init_onnx_outputs

        else:

            onnx_outputs = self.decoder(
                decoder_input_ids,
                attention_mask,
                encoder_hidden_states,
                past_key_values,
            )

            logits, past_key_values = onnx_outputs

        return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)
Beispiel #3
0
    def forward(self, input_ids, encoder_hidden_states, *args, **kwargs):
        self.inputs["input_ids"][:, :input_ids.shape[1]] = input_ids
        self.inputs["encoder_hidden_states"][:, :TRTHFRunner.ENCODER_LENGTH, :] = encoder_hidden_states[:, :TRTHFRunner.ENCODER_LENGTH, :]

        # TODO: This can be better generalized
        self.trt_context.set_binding_shape(0, input_ids.shape)
        self.trt_context.set_binding_shape(1, (1, TRTHFRunner.ENCODER_LENGTH, self.max_sequence_length))

        # Copy to device
        self.trt_context.execute_v2(bindings=self.bindings)

        # Transfer predictions back from GPU to do greedy search
        return Seq2SeqLMOutput(logits=self.outputs["hidden_states"][:, :input_ids.shape[1]].cpu())
Beispiel #4
0
        def forward(self, input_ids, encoder_hidden_states, **kwargs):
            decoder_outputs = self.decoder(
                input_ids=input_ids,
                encoder_hidden_states=encoder_hidden_states,
                **kwargs)

            # self.config.d_model ** -0.5 for rescaling output on vocab.
            # as seen in https://huggingface.co/transformers/_modules/transformers/models/t5/modeling_t5.html#T5ForConditionalGeneration
            sequence_output = decoder_outputs[0] * self.config.d_model**-0.5
            logits = self.lm_head(sequence_output)
            if not kwargs.get("return_dict", False):
                return (logits, ) + decoder_outputs[1:]

            return Seq2SeqLMOutput(logits=logits)
Beispiel #5
0
        def forward(*args, **kwargs):
            outputs = origin_forward(*args, **kwargs)

            if outputs.past_key_values is not None:
                # Multiply by 1.0 to workaround a bug in OpenVINO 2022.1 with
                # dynamic shapes inputs connected to model outputs:
                past_key_values = []
                for i in range(12):
                    past_key_values.append((
                        outputs.past_key_values[i][0],
                        outputs.past_key_values[i][1],
                        outputs.past_key_values[i][2] * 1.0,
                        outputs.past_key_values[i][3] * 1.0,
                    ))
                outputs.past_key_values = tuple(past_key_values)

            return Seq2SeqLMOutput(
                logits=outputs.logits,
                past_key_values=outputs.past_key_values,
            )
Beispiel #6
0
    def forward(self,
                fusion_map,
                input_ids,
                attention_mask,
                decoder_input_ids,
                decoder_attention_mask,
                return_hidden_states=False,
                **kwargs):
        encoder_outputs = self.encoder_forward(fusion_map, input_ids,
                                               attention_mask)
        encoder_fused_states = encoder_outputs.last_hidden_state
        fused_attention_mask = encoder_outputs.attentions
        encoder_layer_states = encoder_outputs.hidden_states

        dec_outputs = self.decoder(input_ids=decoder_input_ids,
                                   attention_mask=decoder_attention_mask,
                                   encoder_hidden_states=encoder_fused_states,
                                   encoder_attention_mask=fused_attention_mask,
                                   output_hidden_states=return_hidden_states)
        sequence_output = dec_outputs[0]
        lm_logits = self.lm_head(sequence_output)

        return Seq2SeqLMOutput(logits=lm_logits,
                               encoder_hidden_states=encoder_layer_states)
Beispiel #7
0
    def forward(self,
                batch,
                final_context,
                context_rnn_state,
                encoder_loss,
                current_token_id=None,
                decoder_wrapper=None,
                expansion_factor=1,
                generation_dict=None):

        context, context_limited = batch.context.value, batch.context.limited
        answer, answer_limited = batch.answer.value, batch.answer.limited
        decoder_vocab = self.numericalizer.decoder_vocab
        self.map_to_full = decoder_vocab.decode
        context_padding = context.data == self.pad_idx
        if self.training:
            if self.args.rnn_layers > 0:
                self.rnn_decoder.applyMasks(context_padding)
            else:
                self.context_attn.applyMasks(context_padding)

            answer_padding = (answer.data == self.pad_idx)[:, :-1]
            answer_embedded = self.decoder_embeddings(answer[:, :-1],
                                                      padding=answer_padding)

            if self.args.rnn_layers > 0:
                rnn_decoder_outputs = self.rnn_decoder(
                    answer_embedded, final_context, hidden=context_rnn_state)
                decoder_output, vocab_pointer_switch_input, context_attention, rnn_state = rnn_decoder_outputs
            else:
                context_decoder_output, context_attention = self.context_attn(
                    answer_embedded, final_context)
                vocab_pointer_switch_input = torch.cat(
                    (context_decoder_output, answer_embedded), dim=-1)
                decoder_output = self.dropout(context_decoder_output)

            vocab_pointer_switch = self.vocab_pointer_switch(
                vocab_pointer_switch_input)

            probs = self.probs(decoder_output, vocab_pointer_switch,
                               context_attention, context_limited,
                               decoder_vocab)

            probs, targets = mask(answer_limited[:, 1:].contiguous(),
                                  probs.contiguous(),
                                  pad_idx=decoder_vocab.pad_idx)
            loss = F.nll_loss(probs.log(),
                              targets,
                              ignore_index=decoder_vocab.pad_idx)
            if encoder_loss is not None:
                loss += self.args.encoder_loss_weight * encoder_loss

            return Seq2SeqLMOutput(loss=loss)
        else:
            if decoder_wrapper is None:
                decoder_wrapper = self.decoder_wrapper(
                    final_context,
                    context_padding,
                    context_limited,
                    decoder_vocab,
                    rnn_state=context_rnn_state,
                    expansion_factor=expansion_factor,
                    generation_dict=generation_dict)
            else:
                current_token_id = current_token_id.clone().cpu().apply_(
                    self.map_to_full).to(current_token_id.device)
            # (next_token_logits, past) where `past` includes all the states needed to continue generation
            logits = torch.log(
                decoder_wrapper.next_token_probs(current_token_id))
            return Seq2SeqLMOutput(logits=logits,
                                   past_key_values=decoder_wrapper)
    def forward(
        self,
        input_image=None,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        kwargs_encoder = {
            argument: value
            for argument, value in kwargs.items()
            if not argument.startswith("decoder_")
        }

        kwargs_decoder = {
            argument[len("decoder_"):]: value
            for argument, value in kwargs.items()
            if argument.startswith("decoder_")
        }

        if encoder_outputs is None:
            # Get image embeddings.
            image_encoder_outputs = self.image_encoder(
                pixel_values=input_image)

            # Get command embeddings.
            command_encoder_outputs = self.command_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs_encoder,
            )

            # Concatenate outputs of both encoders along "sequence" dimension.
            encoder_outputs = torch.cat(
                (image_encoder_outputs[0], command_encoder_outputs[0]), 1)

        # Back to regular track ;)
        encoder_hidden_states = encoder_outputs

        # optionally project encoder_hidden_states
        # TODO: probably this is wrong! both encoders should have the same hidden size!
        #if (
        #    self.decoder.config.hidden_size != self.image_encoder.config.hidden_size  or \
        #       self.decoder.config.hidden_size != self.command_encoder.config.hidden_size)
        #    and self.decoder.config.cross_attention_hidden_size is None
        #):
        #    encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)

        if (labels is not None) and (decoder_input_ids is None
                                     and decoder_inputs_embeds is None):
            decoder_input_ids = shift_tokens_right(
                labels, self.config.pad_token_id,
                self.config.decoder_start_token_id)

        # Create masks for dual-encoder inputs.
        encoder_hidden_shape = (encoder_outputs.shape[0],
                                encoder_outputs.shape[1])
        dual_attention_masks = torch.ones(encoder_hidden_shape,
                                          dtype=torch.long,
                                          device=self.device)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=dual_attention_masks,
            inputs_embeds=decoder_inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            use_cache=use_cache,
            past_key_values=past_key_values,
            return_dict=return_dict,
            **kwargs_decoder,
        )

        # Compute loss independent from decoder (as some shift the logits inside them)
        loss = None
        if labels is not None:
            #warnings.warn(DEPRECATION_WARNING, FutureWarning)
            logits = decoder_outputs.logits if return_dict else decoder_outputs[
                0]
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size),
                            labels.view(-1))

        if not return_dict:
            if loss is not None:
                return (loss, ) + decoder_outputs + encoder_outputs
            else:
                return decoder_outputs + encoder_outputs

        return Seq2SeqLMOutput(
            loss=loss,
            logits=decoder_outputs.logits,
            past_key_values=decoder_outputs.past_key_values,  # tuple
            decoder_hidden_states=decoder_outputs.hidden_states,  # None!
            decoder_attentions=decoder_outputs.attentions,  # None!
            cross_attentions=decoder_outputs.cross_attentions,  # None!
            # TODO: Not sure about the following ones! what should be there?
            #encoder_last_hidden_state=command_encoder_outputs.last_hidden_state,
            #encoder_hidden_states=encoder_hidden_states.hidden_states,
            #encoder_attentions=command_encoder_outputs.attentions,
        )
    def forward(
        self,
        input_ids,
        attention_mask=None,
        encoder_outputs=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        past_key_values=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **unused,
    ):
        if "lm_labels" in unused:
            warnings.warn(
                "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
                FutureWarning,
            )
            labels = unused.pop("lm_labels")
        if "decoder_cached_states" in unused:
            warnings.warn(
                "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = unused.pop("decoder_cached_states")
        if "decoder_past_key_values" in unused:
            warnings.warn(
                "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = unused.pop("decoder_past_key_values")
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            use_cache = False
            if decoder_input_ids is None:
                decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        lm_logits = F.linear(outputs[0], self.model.shared.weight)

        masked_lm_loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # TODO(SS): do we need to ignore pad tokens in labels?
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
Beispiel #10
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        vis_inputs=None,
        vis_attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        reduce_loss=False,
        **kwargs,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if decoder_input_ids is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id,
                    self.config.decoder_start_token_id)

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            vis_inputs=vis_inputs,
            vis_attention_mask=vis_attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias

        masked_lm_loss = None
        if labels is not None:
            # loss_fct = CrossEntropyLoss()
            if reduce_loss:
                loss_fct = CrossEntropyLoss(ignore_index=-100)
            else:
                loss_fct = CrossEntropyLoss(ignore_index=-100,
                                            reduction='none')
            masked_lm_loss = loss_fct(
                lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (lm_logits, ) + outputs[1:]
            return ((masked_lm_loss, ) +
                    output) if masked_lm_loss is not None else output

        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
Beispiel #11
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,  #TJH In 4.4.2 labels contains what in unifiedqa is called decoder_input_ids
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        #TJH: Added for compatibility with 4.4.2
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:  #TJH added for compatibility with other 4.4.2 seq2seq models
            if decoder_input_ids is None:
                #TJH: how it is done in modelling_bart.py. Using the unifiedQA method instead
                #                decoder_input_ids = shift_tokens_right(
                #                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                #                )
                decoder_start_token_id = self.config.decoder_start_token_id
                decoder_input_ids = labels.new_zeros(labels.shape)
                decoder_input_ids[..., 1:] = labels[..., :-1].clone()
                decoder_input_ids[..., 0] = decoder_start_token_id

        # TJH: below from modeling_bart.py
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,  #TJH: no underscore
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        lm_logits = F.linear(outputs[0],
                             self.model.shared.weight,
                             bias=self.final_logits_bias)

        loss = None
        if labels is not None:  #TJH labels is not None instead of is_training
            loss_fct = nn.CrossEntropyLoss(reduce=False)
            losses = loss_fct(lm_logits.view(-1, self.config.vocab_size),
                              labels.view(-1))
            loss = torch.sum(losses * decoder_attention_mask.float().view(-1))

        if not return_dict:  #TJH: from modeling_bart.py
            output = (lm_logits, ) + outputs[1:]
            return ((loss, ) + output) if loss is not None else output

        return Seq2SeqLMOutput(  #TJH: from modeling_bart.py. 
            loss=loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        head_mask=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,

        #################################################
        # Modification to T5ForConditionalGeneration (MF)
        #################################################
        encode_only=False,
        passage_mask=None,
        #################################################
        ############### END OF MODIFICATION (MF)#########
        #################################################
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for
            labels in ``[0, ..., config.vocab_size]``

        Returns:

        Examples::

            >>> from transformers import T5Tokenizer, T5ForConditionalGeneration

            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
            >>> model = T5FusionInDecoder.from_pretrained('t5-small')

            >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
            >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='pt').input_ids
            >>> outputs = model(input_ids=input_ids, labels=labels)
            >>> loss = outputs.loss
            >>> logits = outputs.logits

            >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids  # Batch size 1
            >>> outputs = model.generate(input_ids)
        """
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            #################################################
            # Modification to T5ForConditionalGeneration (MF)
            #################################################
            concatenated_encoder_outputs, concatenated_attention_mask = self.concatenate_encoder_outputs(
                encoder_outputs=encoder_outputs,
                encoder_attention_mask=attention_mask,
                passage_mask=passage_mask)
            if encode_only:
                return concatenated_encoder_outputs, concatenated_attention_mask
            #################################################
            ############### END OF MODIFICATION (MF)#########
            #################################################

        elif return_dict and not isinstance(
                encoder_outputs, BaseModelOutputWithPastAndCrossAttentions):
            # Assume concatenated encoder outputs are passed!
            concatenated_encoder_outputs = BaseModelOutputWithPastAndCrossAttentions(  # Renamed (MF)
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1]
                if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2]
                if len(encoder_outputs) > 2 else None,
            )
            concatenated_attention_mask = attention_mask  # Minor modification (MF)
        else:  # Minor modification (MF)
            # Assume concatenated encoder outputs are passed!
            concatenated_encoder_outputs = encoder_outputs
            concatenated_attention_mask = attention_mask

        hidden_states = concatenated_encoder_outputs[0]  # Renamed (MF)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=concatenated_attention_mask,  # Renamed (MF)
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)),
                            labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits, ) + decoder_outputs[
                1:] + concatenated_encoder_outputs  # Renamed (MF)
            return ((loss, ) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=concatenated_encoder_outputs.
            last_hidden_state,  # Renamed (MF)
            encoder_hidden_states=concatenated_encoder_outputs.
            hidden_states,  # Renamed (MF)
            encoder_attentions=concatenated_encoder_outputs.
            attentions,  # Renamed (MF)
        )
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        r"""
        Returns:

        Examples:

        ```python
        >>> from transformers import EncoderDecoderModel, BertTokenizer
        >>> import torch

        >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained(
        ...     "bert-base-uncased", "bert-base-uncased"
        >>> )  # initialize Bert2Bert from pre-trained checkpoints

        >>> # training
        >>> model.config.decoder_start_token_id = tokenizer.cls_token_id
        >>> model.config.pad_token_id = tokenizer.pad_token_id
        >>> model.config.vocab_size = model.config.decoder.vocab_size

        >>> input_ids = tokenizer("This is a really long text", return_tensors="pt").input_ids
        >>> labels = tokenizer("This is the corresponding summary", return_tensors="pt").input_ids
        >>> outputs = model(input_ids=input_ids, labels=input_ids)
        >>> loss, logits = outputs.loss, outputs.logits

        >>> # save and load from pretrained
        >>> model.save_pretrained("bert2bert")
        >>> model = EncoderDecoderModel.from_pretrained("bert2bert")

        >>> # generation
        >>> generated = model.generate(input_ids)
        ```"""

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        kwargs_encoder = {
            argument: value
            for argument, value in kwargs.items()
            if not argument.startswith("decoder_")
        }

        kwargs_decoder = {
            argument[len("decoder_"):]: value
            for argument, value in kwargs.items()
            if argument.startswith("decoder_")
        }

        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs_encoder,
            )

        encoder_hidden_states = encoder_outputs[0]

        # optionally project encoder_hidden_states
        if (self.encoder.config.hidden_size != self.decoder.config.hidden_size
                and self.decoder.config.cross_attention_hidden_size is None):
            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)

        if (labels is not None) and (decoder_input_ids is None
                                     and decoder_inputs_embeds is None):
            decoder_input_ids = shift_tokens_right(
                labels, self.config.pad_token_id,
                self.config.decoder_start_token_id)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            use_cache=use_cache,
            past_key_values=past_key_values,
            return_dict=return_dict,
            **kwargs_decoder,
        )

        # Compute loss independent from decoder (as some shift the logits inside them)
        loss = None
        if labels is not None:
            #warnings.warn(DEPRECATION_WARNING, FutureWarning)
            logits = decoder_outputs.logits if return_dict else decoder_outputs[
                0]
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size),
                            labels.view(-1))

        if not return_dict:
            if loss is not None:
                return (loss, ) + decoder_outputs + encoder_outputs
            else:
                return decoder_outputs + encoder_outputs

        return Seq2SeqLMOutput(
            loss=loss,
            logits=decoder_outputs.logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
Beispiel #14
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):

        custom_loss = None
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            # first speaker 1 occurrence
            persona_boundaries = (input_ids == 30525).nonzero(as_tuple=False)
            boundaries = []
            for item in persona_boundaries.tolist():
                if item[0] not in [x[0] for x in boundaries]:
                    boundaries.append(item)
            boundaries = [item[1] for item in boundaries]
            persona_input_ids, persona_attention_mask = [], []
            for i, (ids, attention) in enumerate(zip(input_ids,
                                                     attention_mask)):
                persona_input_ids.append(
                    ids.tolist()[:boundaries[i] + 1] +
                    [30524 for _ in range(512 - (boundaries[i] + 1))])
                persona_attention_mask.append(
                    attention.tolist()[:boundaries[i] + 1] +
                    [0 for _ in range(512 - (boundaries[i] + 1))])
            device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
            persona_input_ids = torch.Tensor(persona_input_ids).long().to(
                device)
            persona_attention_mask = torch.Tensor(
                persona_attention_mask).long().to(device)
            custom_encoder_outputs = self.encoder(
                input_ids=persona_input_ids,
                attention_mask=persona_attention_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            custom_loss = self.cs(
                encoder_outputs['pooler_output'],
                custom_encoder_outputs['pooler_output'],
            ).mean()
            # print(custom_encoder_outputs.keys(), type(custom_encoder_outputs['pooler_output']), custom_encoder_outputs['pooler_output'].size())
            # print(custom_loss)
            # sys.exit()

            # after 1st occurrence convert to 0
            # attention, input is
            # custom loss

        encoder_hidden_states = encoder_outputs[0]

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            # use_cache=use_cache,
            # past_key_values=past_key_values,
            return_dict=return_dict,
        )

        if not return_dict:
            return decoder_outputs + encoder_outputs

        return Seq2SeqLMOutput(
            loss=decoder_outputs.loss,
            logits=decoder_outputs.logits,
            # closs = custom_loss,
            past_key_values=custom_loss,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            # cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
Beispiel #15
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        return_dict=None,
        labels=None,
        latent_z=None
    ):

        if decoder_input_ids is None:
            decoder_input_ids = shift_tokens_right(
                input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
            )

        return_dict = self.config.use_return_dict

        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=return_dict,
        )
            
        ###########################
        batch_size = input_ids.shape[0]
        # seq_repr = torch.cat((encoder_outputs[0][:, 0, :], encoder_outputs[0].mean(dim=1), encoder_outputs[0][:, -1, :]), -1).view(batch_size, -1)
        seq_repr = torch.cat((encoder_outputs[0][:, 0, :], encoder_outputs[0][:, -1, :]), -1).view(batch_size, -1)
        # Reparameterize
        mu = self.to_mu(seq_repr)
        logvar = self.to_logvar(seq_repr)
        z = self.reparameterize(mu, logvar)

        # # add noise
        # if self.word_dropout_rate > 0:
        #     # randomly replace decoder input with <unk>
        #     prob = torch.rand(decoder_input_ids.size())
        #     if torch.cuda.is_available():
        #         prob = prob.cuda()
        #     prob[ (decoder_input_ids.data - self.config.pad_token_id) == 0] = 1
        #     decoder_input_sequence = decoder_input_ids.clone()
        #     decoder_input_sequence[prob < self.word_dropout_rate] = self.config.unk_token_id
        #     decoder_input_ids = decoder_input_sequence

        latent_z = self.to_emb(z) # fitting embedding size
        ###########################

        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_outputs[0],
            encoder_attention_mask=attention_mask,
            return_dict=return_dict,
            latent_z=latent_z
        )

        lm_logits = self.lm_head(decoder_outputs[0]) + self.final_logits_bias

        masked_lm_loss = None
        if labels is None:
            labels = input_ids.clone()
            loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
            # reconstruction
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

            # kl div
            kl_loss = self.loss_kl(mu, logvar)
            loss_total= masked_lm_loss + self.config.lambda_kl * kl_loss

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss_total,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )