def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present base_architecture(args) if not hasattr(args, "max_positions"): args.max_positions = args.tokens_per_sample encoder_embed_tokens = TransformerModel.build_embedding( args, task.source_dictionary, args.encoder_embed_dim, args.encoder_embed_path, ) encoder = LinformerEncoderFromTransformerEncoder( args, task.source_dictionary, encoder_embed_tokens)
def build_model(cls, args, task): """Build a new model instance.""" # from fairseq.tasks.multilingual_translation import MultilingualTranslationTask # assert isinstance(task, MultilingualTranslationTask) # make sure all arguments are present in older models base_architecture(args) if args.share_encoders: args.share_encoder_embeddings = True ### nat model # build shared embeddings (if applicable) src_dict, tgt_dict = task.source_dictionary, task.target_dictionary if args.share_all_embeddings: if src_dict != tgt_dict: raise ValueError("--share-all-embeddings requires a joined dictionary") if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" ) if args.decoder_embed_path and ( args.decoder_embed_path != args.encoder_embed_path ): raise ValueError( "--share-all-embeddings not compatible with --decoder-embed-path" ) encoder_embed_tokens = TransformerModel.build_embedding( args, src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True else: encoder_embed_tokens = TransformerModel.build_embedding( args, src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = TransformerModel.build_embedding( args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path ) student_cls = ARCH_MODEL_REGISTRY[args.student_arch] encoder = student_cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = student_cls.build_decoder(args, tgt_dict, decoder_embed_tokens) student = student_cls(args,encoder,decoder) teacher_cls = ARCH_MODEL_REGISTRY[args.teacher_arch] if not issubclass(teacher_cls, NATransformerModel): teacher_cls = PatchedTransformerModel teacher_encoder = teacher_cls.build_encoder( args, src_dict, encoder_embed_tokens if args.share_encoder_embeddings else TransformerModel.build_embedding( args, src_dict, args.encoder_embed_dim, args.encoder_embed_path ) ) teacher_decoder = teacher_cls.build_decoder( args, tgt_dict, decoder_embed_tokens if args.share_decoder_embeddings else TransformerModel.build_embedding( args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path ) ) teacher = teacher_cls(args,teacher_encoder,teacher_decoder) return cls(args, student, teacher)