def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present
        base_architecture(args)

        if not hasattr(args, "max_positions"):
            args.max_positions = args.tokens_per_sample

        encoder_embed_tokens = TransformerModel.build_embedding(
            args,
            task.source_dictionary,
            args.encoder_embed_dim,
            args.encoder_embed_path,
        )
        encoder = LinformerEncoderFromTransformerEncoder(
            args, task.source_dictionary, encoder_embed_tokens)
Beispiel #2
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        # from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
        # assert isinstance(task, MultilingualTranslationTask)

        # make sure all arguments are present in older models
        base_architecture(args)

        if args.share_encoders:
            args.share_encoder_embeddings = True

        ### nat model
        # build shared embeddings (if applicable)
        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise ValueError("--share-all-embeddings requires a joined dictionary")
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
                )
            if args.decoder_embed_path and (
                args.decoder_embed_path != args.encoder_embed_path
            ):
                raise ValueError(
                    "--share-all-embeddings not compatible with --decoder-embed-path"
                )
            encoder_embed_tokens = TransformerModel.build_embedding(
                args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = TransformerModel.build_embedding(
                args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = TransformerModel.build_embedding(
                args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
            )


        student_cls = ARCH_MODEL_REGISTRY[args.student_arch]
        encoder = student_cls.build_encoder(args, src_dict, encoder_embed_tokens)
        decoder = student_cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
        student = student_cls(args,encoder,decoder)

        teacher_cls = ARCH_MODEL_REGISTRY[args.teacher_arch]
        if not issubclass(teacher_cls, NATransformerModel):
            teacher_cls = PatchedTransformerModel

        teacher_encoder = teacher_cls.build_encoder(
            args, src_dict,
            encoder_embed_tokens if args.share_encoder_embeddings else TransformerModel.build_embedding(
                args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
                )
            )
        teacher_decoder = teacher_cls.build_decoder(
            args, tgt_dict,
            decoder_embed_tokens if args.share_decoder_embeddings else TransformerModel.build_embedding(
                args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
                )
            )
        teacher = teacher_cls(args,teacher_encoder,teacher_decoder)

        return cls(args, student, teacher)