class BartDecoder(FairseqIncrementalDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, cfg: Wav2BartPoolConfig, dictionary=None, embed_tokens=None, no_encoder_attn=False, ): super().__init__(dictionary) self.cfg = cfg # bart = torch.hub.load('pytorch/fairseq', 'bart.base') from fairseq.models.bart import BARTModel if os.path.isfile(os.path.join(cfg.bart_path, 'model.pt')): print('loading bart from cfg path') bart = BARTModel.from_pretrained(cfg.bart_path, checkpoint_file='model.pt') else: print('loading bart from relative path') bart = BARTModel.from_pretrained('models/bart.base', checkpoint_file='model.pt') bart_decoder = bart.model.decoder self.decoder = TransformerDecoder(bart_decoder.args, bart_decoder.dictionary, bart_decoder.embed_tokens) self.decoder.load_state_dict(bart_decoder.state_dict()) def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (Tensor, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ # with torch.no_grad() if self.cfg.fix_decoder else contextlib.ExitStack(): x, extra = self.decoder(prev_output_tokens, encoder_out, incremental_state) return x, extra def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): self.decoder.extract_features(prev_output_tokens, encoder_out, incremental_state) def max_positions(self): """Maximum output length supported by the decoder.""" return self.decoder.max_positions() def buffered_future_mask(self, tensor): return self.decoder.buffered_future_mask def upgrade_state_dict_named(self, state_dict, name): return state_dict
class BartDecoder(FairseqIncrementalDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, cfg: Wav2Vec2BartConfig, dictionary=None, embed_tokens=None, no_encoder_attn=False, transform_embed=None, bart=None, ): super().__init__(dictionary) self.cfg = cfg # bart = torch.hub.load('pytorch/fairseq', 'bart.base') bart_decoder = bart.model.decoder self.decoder = TransformerDecoder(bart_decoder.args, bart_decoder.dictionary, bart_decoder.embed_tokens) self.decoder.load_state_dict(bart_decoder.state_dict()) self.decoder.embed_tokens = EmbeddingTransformed(self.decoder.embed_tokens, transform_embed) def forward( self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (Tensor, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ # with torch.no_grad() if self.cfg.fix_decoder else contextlib.ExitStack(): x, extra = self.decoder(prev_output_tokens, encoder_out, incremental_state) for k in ['wav2vec_logits', 'wav2vec_padding_mask', 'ctc_weight', 'ce_weight']: extra[k] = encoder_out[k] print('bart decoder extra.keys()', extra.keys()) return x, extra def extract_features( self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused ): self.decoder.extract_features(prev_output_tokens, encoder_out, incremental_state) def max_positions(self): """Maximum output length supported by the decoder.""" return self.decoder.max_positions() def buffered_future_mask(self, tensor): return self.decoder.buffered_future_mask def upgrade_state_dict_named(self, state_dict, name): return state_dict