def build_decoder(cls, args, task, embed_tokens): _args = copy.deepcopy(args) if args.adaptor_proj or args.encoder_proj: # not V0 arch _args.encoder_embed_dim = _args.decoder_embed_dim _args.dropout = args.decoder_dropout _args.attention_dropout = args.decoder_attention_dropout _args.activation_dropout = args.decoder_activation_dropout decoder = TransformerDecoder(_args, task.target_dictionary, embed_tokens) decoder = cls.maybe_load_pretrained( decoder, getattr(args, "load_pretrained_decoder_from", None)) for k, p in decoder.named_parameters(): p.requires_grad = need_finetuning(args.finetune_decoder_params, k) return decoder
def build_decoder(cls, args, task): _args = copy.deepcopy(args) _args.dropout = args.mbart_dropout _args.attention_dropout = args.mbart_attention_dropout _args.activation_dropout = args.mbart_activation_dropout _args.max_target_positions = 1024 dec_emb = nn.Embedding( len(task.tgt_dict), _args.encoder_embed_dim, task.tgt_dict.pad() ) decoder = TransformerDecoder(_args, task.tgt_dict, dec_emb) if getattr(args, "load_pretrained_mbart_from", None): decoder = checkpoint_utils.load_pretrained_component_from_model( component=decoder, checkpoint=args.load_pretrained_mbart_from ) if getattr(args, "no_final_norm_decoder", False): decoder.layer_norm = None for k, p in decoder.named_parameters(): # Freeze pretrained models by default if safe_hasattr( args, "finetune_mbart_decoder_params" ) and need_finetuning( args.finetune_mbart_decoder_params, k ): p.requires_grad = True else: p.requires_grad = False compute_cross_attentive_loss = ( True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False ) cross_attentive_loss_without_norm = getattr( args, "attentive_cost_without_normalize", False ) cross_attentive_loss_reverse = ( False # getattr(args, "attentive_cost_reverse", False) ) decoder = TransformerMultiInputDecoder( dictionary=task.target_dictionary, spch_decoder=decoder, text_decoder=decoder, compute_cross_attentive_loss=compute_cross_attentive_loss, cross_attentive_loss_with_norm=True if not cross_attentive_loss_without_norm else False, cross_attentive_loss_reverse=cross_attentive_loss_reverse, ) return decoder
def build_decoder(cls, args, task, embed_tokens): _args = copy.deepcopy(args) _args.dropout = args.decoder_dropout _args.attention_dropout = args.decoder_attention_dropout _args.activation_dropout = args.decoder_activation_dropout _args.max_target_positions = 1024 decoder = TransformerDecoder(_args, task.target_dictionary, embed_tokens) if getattr(args, "load_pretrained_decoder_from", None): decoder = checkpoint_utils.load_pretrained_component_from_model( component=decoder, checkpoint=args.load_pretrained_decoder_from) for k, p in decoder.named_parameters(): # Freeze pretrained models by default if safe_hasattr(args, 'finetune_decoder_params' ) and XMTransformerModel.finetune_params( args.finetune_decoder_params, k): p.requires_grad = True else: p.requires_grad = False return decoder