def from_roberta(roberta_enc: roberta.RobertaModel, args, dictionary): encoder = roberta_enc.encoder.sentence_encoder vocab_size, embed_dim = encoder.embed_tokens.weight.shape if args.share_all_embeddings: lm_head = roberta_enc.encoder.lm_head assert encoder.embed_tokens.weight is lm_head.weight, ( "Can't use --share-all-embeddings with a model " "that was pretraiend with --untie-weights-roberta_enc") else: lm_head = roberta.RobertaLMHead(embed_dim, vocab_size, roberta_enc.args.activation_fn) dec_embs = nn.Embedding(vocab_size, embed_dim, dictionary.pad()) if args.share_all_embeddings or args.share_decoder_input_output_embed: # Note: I wasn't able to use Embedding _weight parameter to achive this sharing. dec_embs.weight = lm_head.weight decoder = TransformerDecoder( RobertaEncDecModel.read_args_from_roberta(roberta_enc.args), dictionary, dec_embs, no_encoder_attn=False, output_projection=lm_head, ) if getattr(args, "pretrained_decoder", False): decoder_dict = encoder.state_dict() # TODO: hide setting "encoder_attn" layers behind a flag. for k, w in list(decoder_dict.items()): if ".self_attn" in k: k_enc_attn = k.replace(".self_attn", ".encoder_attn") decoder_dict[k_enc_attn] = w.detach().clone() for k, w in lm_head.state_dict().items(): decoder_dict["output_projection." + k] = w missing_keys, unexpected_keys = decoder.load_state_dict( decoder_dict, strict=False) # missing_keys = [m for m in missing_keys if ".encoder_attn" not in m] assert not missing_keys and not unexpected_keys, ( "Failed to load state dict. " f"Missing keys: {missing_keys}. " f"Unexpected keys: {unexpected_keys}.") if args.share_all_embeddings: assert decoder.output_projection.weight is decoder.embed_tokens.weight assert encoder.embed_tokens.weight is decoder.embed_tokens.weight elif args.share_decoder_input_output_embed: assert decoder.output_projection.weight is decoder.embed_tokens.weight assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight else: assert decoder.output_projection.weight is not decoder.embed_tokens.weight assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight return RobertaEncDecModel(encoder, decoder)
class BartDecoder(FairseqIncrementalDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, cfg: Wav2BartPoolConfig, dictionary=None, embed_tokens=None, no_encoder_attn=False, ): super().__init__(dictionary) self.cfg = cfg # bart = torch.hub.load('pytorch/fairseq', 'bart.base') from fairseq.models.bart import BARTModel if os.path.isfile(os.path.join(cfg.bart_path, 'model.pt')): print('loading bart from cfg path') bart = BARTModel.from_pretrained(cfg.bart_path, checkpoint_file='model.pt') else: print('loading bart from relative path') bart = BARTModel.from_pretrained('models/bart.base', checkpoint_file='model.pt') bart_decoder = bart.model.decoder self.decoder = TransformerDecoder(bart_decoder.args, bart_decoder.dictionary, bart_decoder.embed_tokens) self.decoder.load_state_dict(bart_decoder.state_dict()) def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (Tensor, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ # with torch.no_grad() if self.cfg.fix_decoder else contextlib.ExitStack(): x, extra = self.decoder(prev_output_tokens, encoder_out, incremental_state) return x, extra def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): self.decoder.extract_features(prev_output_tokens, encoder_out, incremental_state) def max_positions(self): """Maximum output length supported by the decoder.""" return self.decoder.max_positions() def buffered_future_mask(self, tensor): return self.decoder.buffered_future_mask def upgrade_state_dict_named(self, state_dict, name): return state_dict
def build_model(cls, args, task): encoder = TrOCREncoder(args=args, dictionary=task.source_dictionary) args.encoder_embed_dim = encoder.deit.embed_dim if getattr(args, "max_target_positions", None) is None: args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS if getattr(args, "decoder_pretrained", None).startswith('roberta2'): logger.info( 'Using the learned pos embedding version loading roberta.') decoder_embed_tokens = cls.build_embedding(args, task.target_dictionary, args.decoder_embed_dim, args.decoder_embed_path) pretrained_model = getattr(args, "decoder_pretrained", None) specified = pretrained_model.find('-') != -1 if specified: pretrained_model = pretrained_model.replace('-', '.') logger.info( 'Load pre-trained decoder parameters from {}'.format( pretrained_model)) roberta = torch.hub.load('pytorch/fairseq:main', pretrained_model) elif args.decoder_layers == 6: logger.info( 'Load pre-trained decoder parameters from roberta.base') roberta = torch.hub.load('pytorch/fairseq:main', 'roberta.base') elif args.decoder_layers == 12: logger.info( 'Load pre-trained decoder parameters from roberta.large') roberta = torch.hub.load('pytorch/fairseq:main', 'roberta.large') else: raise AttributeError('Cannot determind the pre-trained model') roberta.model.args.encoder_layers = args.decoder_layers roberta.model.args.fp16 = args.fp16 roberta_args = TrOCRModel.read_args_from_roberta( roberta.model.args) roberta_args.encoder_embed_dim = args.encoder_embed_dim decoder = TransformerDecoder( roberta_args, task.target_dictionary, decoder_embed_tokens, no_encoder_attn=False, ) roberta_layers = roberta.model.encoder.sentence_encoder.layers decoder_layers = decoder.layers offset = len(roberta_layers) - len(decoder_layers) assert offset >= 0 decoder_dict = roberta.state_dict() new_decoder_dict = {} for key, val in decoder_dict.items(): if key.startswith('model.encoder.sentence_encoder.layers.'): layer_num = int( key[len('model.encoder.sentence_encoder.layers.' ):].split('.')[0]) if layer_num - offset < 0: continue else: new_key = 'model.encoder.sentence_encoder.layers.{}.'.format( str(layer_num - offset) ) + '.'.join( key[len('model.encoder.sentence_encoder.layers.' ):].split('.')[1:]) new_decoder_dict[new_key] = val else: new_decoder_dict[key] = val decoder_dict = new_decoder_dict for k, w in list(decoder_dict.items()): if '.lm_head' in k: k_proj = "output_projection." + k[ len('model.encoder.lm_head.'):] decoder_dict[k_proj] = w.detach().clone() del decoder_dict[k] del decoder_dict['_float_tensor'] del decoder_dict['output_projection.weight'] del decoder_dict['output_projection.bias'] del decoder_dict['output_projection.dense.weight'] del decoder_dict['output_projection.dense.bias'] del decoder_dict['output_projection.layer_norm.weight'] del decoder_dict['output_projection.layer_norm.bias'] new_decoder_dict = {} for key, val in decoder_dict.items(): if "sentence_encoder" in key: key = key[len('model.encoder.sentence_encoder.'):] elif "encoder" in key: key = key[len('model.encoder.'):] new_decoder_dict[key] = val missing_keys, unexpected_keys = decoder.load_state_dict( new_decoder_dict, strict=False) elif getattr(args, "decoder_pretrained", None) == 'unilm': logger.info('Decoder is pretrained using the unilm.') prefix_of_parameter = 'bert' decoder_embed_tokens = cls.build_embedding(args, task.target_dictionary, args.decoder_embed_dim, args.decoder_embed_path) decoder = UniLMDecoder( args, task.target_dictionary, decoder_embed_tokens, no_encoder_attn=False, ) if hasattr( args, 'decoder_pretrained_url' ) and args.decoder_pretrained_url != None and args.decoder_pretrained_url != '': unilm_url = args.decoder_pretrained_url logger.info('The unilm model url: {}.'.format( unilm_url[:unilm_url.find('?')])) unilm_state_dict = torch.hub.load_state_dict_from_url( unilm_url) unilm_layers = OrderedDict([ (k, unilm_state_dict[k]) for k in unilm_state_dict.keys() if k.startswith(prefix_of_parameter + '.encoder.layer.') ]) unilm_layers_num = [] for k in unilm_layers.keys(): t = k.replace(prefix_of_parameter + '.encoder.layer.', '') t = t[:t.find('.')] unilm_layers_num.append(int(t)) unilm_layers_num = max(unilm_layers_num) + 1 offset = unilm_layers_num - len(decoder.layers) assert offset == 0 decoder_dict = decoder.state_dict() # embedding new_pos_weight = torch.zeros_like( decoder_dict['embed_positions.weight']) # position padding will right offset padding idx + 1 new_pos_weight[task.target_dictionary.pad() + 1:, :] = unilm_state_dict[ prefix_of_parameter + '.embeddings.position_embeddings.weight'] new_decoder_dict = { 'embed_tokens.weight': unilm_state_dict[prefix_of_parameter + '.embeddings.word_embeddings.weight'], 'embed_positions.weight': new_pos_weight, 'layernorm_embedding.weight': unilm_state_dict[prefix_of_parameter + '.embeddings.LayerNorm.weight'], 'layernorm_embedding.bias': unilm_state_dict[prefix_of_parameter + '.embeddings.LayerNorm.bias'] } # layers key_map = { 'self_attn.k_proj': 'attention.self.key', 'self_attn.v_proj': 'attention.self.value', 'self_attn.q_proj': 'attention.self.query', 'self_attn.out_proj': 'attention.output.dense', 'self_attn_layer_norm': 'attention.output.LayerNorm', 'fc1': 'intermediate.dense', 'fc2': 'output.dense', 'final_layer_norm': 'output.LayerNorm' } for layer_id in range(unilm_layers_num): unilm_prefix = prefix_of_parameter + '.encoder.layer.{}.'.format( layer_id) decoder_prefix = 'layers.{}.'.format(layer_id) for key in key_map: for suffix in ['.weight', '.bias']: decoder_key = decoder_prefix + key + suffix unilm_key = unilm_prefix + key_map[key] + suffix if decoder_key in decoder_dict and unilm_key in unilm_state_dict: new_decoder_dict[ decoder_key] = unilm_state_dict[unilm_key] if hasattr(args, "reset_dictionary") and args.reset_dictionary: logger.info( 'Reset token embedding weights during decoder initialization.' ) del new_decoder_dict['embed_tokens.weight'] elif hasattr(args, "adapt_dictionary") and args.adapt_dictionary: unilm_embed_tokens_weight = new_decoder_dict[ 'embed_tokens.weight'] logger.info( 'Adapt token embedding weights during decoder initialization from {} to {}' .format(unilm_embed_tokens_weight.shape[0], decoder_embed_tokens.weight.shape[0])) new_decoder_dict['embed_tokens.weight'] = torch.zeros_like( decoder_dict['embed_tokens.weight']) new_decoder_dict['embed_tokens.weight'][:min( unilm_embed_tokens_weight. shape[0], decoder_dict['embed_tokens.weight'].shape[0] ), :] = unilm_embed_tokens_weight[:min( unilm_embed_tokens_weight.shape[0], decoder_dict['embed_tokens.weight'].shape[0]), :] missing_keys, unexpected_keys = decoder.load_state_dict( new_decoder_dict, strict=False) else: logger.warning( 'You must specify the unilm model url or the decoder is randomly initialized.' ) # freeze k_proj bias for layer in decoder.layers: layer.self_attn.k_proj.bias.requires_grad = False elif getattr(args, "decoder_pretrained", None).upper() == 'None' or getattr( args, "decoder_pretrained", None) == None: logger.info('Decoder is randomly initialized.') decoder_embed_tokens = cls.build_embedding(args, task.target_dictionary, args.decoder_embed_dim, args.decoder_embed_path) decoder = TransformerDecoder(args=args, dictionary=task.target_dictionary, embed_tokens=decoder_embed_tokens, no_encoder_attn=False) elif getattr(args, "decoder_pretrained", None).startswith('roberta'): logger.info('Using the old version loading roberta.') decoder_embed_tokens = cls.build_embedding(args, task.target_dictionary, args.decoder_embed_dim, args.decoder_embed_path) decoder = TransformerDecoder(args=args, dictionary=task.target_dictionary, embed_tokens=decoder_embed_tokens, no_encoder_attn=False) pretrained_model = getattr(args, "decoder_pretrained", None) specified = pretrained_model.find('-') != -1 if specified: pretrained_model = pretrained_model.replace('-', '.') logger.info( 'Load pre-trained decoder parameters from {}'.format( pretrained_model)) roberta = torch.hub.load('pytorch/fairseq:main', pretrained_model) elif args.decoder_layers == 6: logger.info( 'Load pre-trained decoder parameters from roberta.base') roberta = torch.hub.load('pytorch/fairseq:main', 'roberta.base') elif args.decoder_layers == 12: logger.info( 'Load pre-trained decoder parameters from roberta.large') roberta = torch.hub.load('pytorch/fairseq:main', 'roberta.large') else: raise AttributeError('Cannot determind the pre-trained model') decoder.embed_tokens.load_state_dict( roberta.model.encoder.sentence_encoder.embed_tokens.state_dict( )) roberta_layers = roberta.model.encoder.sentence_encoder.layers decoder_layers = decoder.layers offset = len(roberta_layers) - len(decoder_layers) assert offset >= 0 for i in range(len(decoder_layers)): roberta_i = i + offset decoder_layers[i].self_attn.load_state_dict( roberta_layers[roberta_i].self_attn.state_dict()) decoder_layers[i].self_attn_layer_norm.load_state_dict( roberta_layers[roberta_i].self_attn_layer_norm.state_dict( )) else: raise Exception('Undefined decoder pretraining method.') model = cls(encoder, decoder) return model
def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_lm_architecture(args) if args.decoder_layers_to_keep: args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) if getattr(args, "max_target_positions", None) is None: args.max_target_positions = getattr(args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS) if args.character_embeddings: embed_tokens = CharacterTokenEmbedder( task.source_dictionary, eval(args.character_filters), args.character_embedding_dim, args.decoder_embed_dim, args.char_embedder_highway_layers, ) elif args.adaptive_input: embed_tokens = AdaptiveInput( len(task.source_dictionary), task.source_dictionary.pad(), args.decoder_input_dim, args.adaptive_input_factor, args.decoder_embed_dim, options.eval_str_list(args.adaptive_input_cutoff, type=int), args.quant_noise_pq, args.quant_noise_pq_block_size, ) else: embed_tokens = cls.build_embedding(args, task.source_dictionary, args.decoder_input_dim) if args.tie_adaptive_weights: assert args.adaptive_input assert args.adaptive_input_factor == args.adaptive_softmax_factor assert (args.adaptive_softmax_cutoff == args.adaptive_input_cutoff ), "{} != {}".format(args.adaptive_softmax_cutoff, args.adaptive_input_cutoff) assert args.decoder_input_dim == args.decoder_output_dim decoder = TransformerDecoder(args, task.target_dictionary, embed_tokens, no_encoder_attn=True) if getattr(args, "lm_path", None): print('load Transformer_LM from {}'.format(args.lm_path)) state = checkpoint_utils.load_checkpoint_to_cpu(args.lm_path) lm_args = state["args"] lm_args.data = args.data assert getattr(lm_args, "lm_path", None) is None task = tasks.setup_task(lm_args) decoder = task.build_model(lm_args) print('restore Transformer_LM from {}'.format(args.lm_path)) decoder.load_state_dict(state["model"], strict=True) decoder.dim_output = len(task.dictionary) return cls(decoder)
class BartDecoder(FairseqIncrementalDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, cfg: Wav2Vec2BartConfig, dictionary=None, embed_tokens=None, no_encoder_attn=False, transform_embed=None, bart=None, ): super().__init__(dictionary) self.cfg = cfg # bart = torch.hub.load('pytorch/fairseq', 'bart.base') bart_decoder = bart.model.decoder self.decoder = TransformerDecoder(bart_decoder.args, bart_decoder.dictionary, bart_decoder.embed_tokens) self.decoder.load_state_dict(bart_decoder.state_dict()) self.decoder.embed_tokens = EmbeddingTransformed(self.decoder.embed_tokens, transform_embed) def forward( self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing encoder_out (Tensor, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ # with torch.no_grad() if self.cfg.fix_decoder else contextlib.ExitStack(): x, extra = self.decoder(prev_output_tokens, encoder_out, incremental_state) for k in ['wav2vec_logits', 'wav2vec_padding_mask', 'ctc_weight', 'ce_weight']: extra[k] = encoder_out[k] print('bart decoder extra.keys()', extra.keys()) return x, extra def extract_features( self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused ): self.decoder.extract_features(prev_output_tokens, encoder_out, incremental_state) def max_positions(self): """Maximum output length supported by the decoder.""" return self.decoder.max_positions() def buffered_future_mask(self, tensor): return self.decoder.buffered_future_mask def upgrade_state_dict_named(self, state_dict, name): return state_dict