コード例 #1
0
 def build_encoder(cls,
                   args,
                   src_dict,
                   embed_tokens,
                   proj_to_decoder=False):
     return pytorch_translate_transformer.TransformerEncoder(
         args, src_dict, embed_tokens, proj_to_decoder=False)
コード例 #2
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        # make sure that all args are properly defaulted
        # (in case there are any new ones)
        base_architecture(args)

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        encoder_embed_tokens = pytorch_translate_transformer.build_embedding(
            dictionary=src_dict,
            embed_dim=args.encoder_embed_dim,
            path=args.encoder_pretrained_embed,
            freeze=args.encoder_freeze_embed,
        )
        decoder_embed_tokens = pytorch_translate_transformer.build_embedding(
            dictionary=tgt_dict,
            embed_dim=args.decoder_embed_dim,
            path=args.decoder_pretrained_embed,
            freeze=args.decoder_freeze_embed,
        )

        encoder = pytorch_translate_transformer.TransformerEncoder(
            args, src_dict, encoder_embed_tokens, proj_to_decoder=False)
        decoder = HybridRNNDecoder(args, src_dict, tgt_dict,
                                   decoder_embed_tokens)
        return HybridTransformerRNNModel(task, encoder, decoder)
コード例 #3
0
ファイル: transformer_aan.py プロジェクト: zbn123/translate
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        if not hasattr(args, "max_source_positions"):
            args.max_source_positions = 1024
        if not hasattr(args, "max_target_positions"):
            args.max_target_positions = 1024

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        def build_embedding(dictionary, embed_dim, path=None):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()
            emb = Embedding(num_embeddings, embed_dim, padding_idx)
            # if provided, load from preloaded dictionaries
            if path:
                embed_dict = utils.parse_embedding(path)
                utils.load_embedding(embed_dict, dictionary, emb)
            return emb

        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise RuntimeError(
                    "--share-all-embeddings requires a joined dictionary")
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise RuntimeError(
                    """--share-all-embeddings requires --encoder-embed-dim \
                    to match --decoder-embed-dim""")
            if args.decoder_embed_path and (args.decoder_embed_path !=
                                            args.encoder_embed_path):
                raise RuntimeError(
                    "--share-all-embeddings not compatible with --decoder-embed-path"
                )
            encoder_embed_tokens = build_embedding(src_dict,
                                                   args.encoder_embed_dim,
                                                   args.encoder_embed_path)
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = build_embedding(src_dict,
                                                   args.encoder_embed_dim,
                                                   args.encoder_embed_path)
            decoder_embed_tokens = build_embedding(tgt_dict,
                                                   args.decoder_embed_dim,
                                                   args.decoder_embed_path)

        encoder = pytorch_translate_transformer.TransformerEncoder(
            args, src_dict, encoder_embed_tokens)
        decoder = TransformerAANDecoder(args, src_dict, tgt_dict,
                                        decoder_embed_tokens)
        return TransformerAANModel(task, encoder, decoder)
コード例 #4
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        # make sure that all args are properly defaulted
        # (in case there are any new ones)
        base_architecture(args)

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        encoder_embed_tokens = pytorch_translate_transformer.build_embedding(
            dictionary=src_dict,
            embed_dim=args.encoder_embed_dim,
            path=args.encoder_pretrained_embed,
            freeze=args.encoder_freeze_embed,
        )

        teacher_decoder_embed_tokens = pytorch_translate_transformer.build_embedding(
            dictionary=tgt_dict,
            embed_dim=args.decoder_embed_dim,
            path=args.decoder_pretrained_embed,
            freeze=args.decoder_freeze_embed,
        )

        student_decoder_embed_tokens = pytorch_translate_transformer.build_embedding(
            dictionary=tgt_dict, embed_dim=args.student_decoder_embed_dim)

        encoder = pytorch_translate_transformer.TransformerEncoder(
            args, src_dict, encoder_embed_tokens, proj_to_decoder=True)

        teacher_decoder = pytorch_translate_transformer.TransformerModel.build_decoder(
            args,
            src_dict,
            tgt_dict,
            embed_tokens=teacher_decoder_embed_tokens)

        student_decoder = StudentHybridRNNDecoder(
            args, src_dict, tgt_dict, student_decoder_embed_tokens)

        return DualDecoderKDModel(
            task=task,
            encoder=encoder,
            teacher_decoder=teacher_decoder,
            student_decoder=student_decoder,
        )