def build_encoder(cls, args, src_dict, embed_tokens, proj_to_decoder=False): return pytorch_translate_transformer.TransformerEncoder( args, src_dict, embed_tokens, proj_to_decoder=False)
def build_model(cls, args, task): """Build a new model instance.""" # make sure that all args are properly defaulted # (in case there are any new ones) base_architecture(args) src_dict, tgt_dict = task.source_dictionary, task.target_dictionary encoder_embed_tokens = pytorch_translate_transformer.build_embedding( dictionary=src_dict, embed_dim=args.encoder_embed_dim, path=args.encoder_pretrained_embed, freeze=args.encoder_freeze_embed, ) decoder_embed_tokens = pytorch_translate_transformer.build_embedding( dictionary=tgt_dict, embed_dim=args.decoder_embed_dim, path=args.decoder_pretrained_embed, freeze=args.decoder_freeze_embed, ) encoder = pytorch_translate_transformer.TransformerEncoder( args, src_dict, encoder_embed_tokens, proj_to_decoder=False) decoder = HybridRNNDecoder(args, src_dict, tgt_dict, decoder_embed_tokens) return HybridTransformerRNNModel(task, encoder, decoder)
def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if not hasattr(args, "max_source_positions"): args.max_source_positions = 1024 if not hasattr(args, "max_target_positions"): args.max_target_positions = 1024 src_dict, tgt_dict = task.source_dictionary, task.target_dictionary def build_embedding(dictionary, embed_dim, path=None): num_embeddings = len(dictionary) padding_idx = dictionary.pad() emb = Embedding(num_embeddings, embed_dim, padding_idx) # if provided, load from preloaded dictionaries if path: embed_dict = utils.parse_embedding(path) utils.load_embedding(embed_dict, dictionary, emb) return emb if args.share_all_embeddings: if src_dict != tgt_dict: raise RuntimeError( "--share-all-embeddings requires a joined dictionary") if args.encoder_embed_dim != args.decoder_embed_dim: raise RuntimeError( """--share-all-embeddings requires --encoder-embed-dim \ to match --decoder-embed-dim""") if args.decoder_embed_path and (args.decoder_embed_path != args.encoder_embed_path): raise RuntimeError( "--share-all-embeddings not compatible with --decoder-embed-path" ) encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim, args.encoder_embed_path) decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True else: encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim, args.encoder_embed_path) decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim, args.decoder_embed_path) encoder = pytorch_translate_transformer.TransformerEncoder( args, src_dict, encoder_embed_tokens) decoder = TransformerAANDecoder(args, src_dict, tgt_dict, decoder_embed_tokens) return TransformerAANModel(task, encoder, decoder)
def build_model(cls, args, task): """Build a new model instance.""" # make sure that all args are properly defaulted # (in case there are any new ones) base_architecture(args) src_dict, tgt_dict = task.source_dictionary, task.target_dictionary encoder_embed_tokens = pytorch_translate_transformer.build_embedding( dictionary=src_dict, embed_dim=args.encoder_embed_dim, path=args.encoder_pretrained_embed, freeze=args.encoder_freeze_embed, ) teacher_decoder_embed_tokens = pytorch_translate_transformer.build_embedding( dictionary=tgt_dict, embed_dim=args.decoder_embed_dim, path=args.decoder_pretrained_embed, freeze=args.decoder_freeze_embed, ) student_decoder_embed_tokens = pytorch_translate_transformer.build_embedding( dictionary=tgt_dict, embed_dim=args.student_decoder_embed_dim) encoder = pytorch_translate_transformer.TransformerEncoder( args, src_dict, encoder_embed_tokens, proj_to_decoder=True) teacher_decoder = pytorch_translate_transformer.TransformerModel.build_decoder( args, src_dict, tgt_dict, embed_tokens=teacher_decoder_embed_tokens) student_decoder = StudentHybridRNNDecoder( args, src_dict, tgt_dict, student_decoder_embed_tokens) return DualDecoderKDModel( task=task, encoder=encoder, teacher_decoder=teacher_decoder, student_decoder=student_decoder, )