예제 #1
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        src_dict, dst_dict = task.source_dictionary, task.target_dictionary
        base_architecture(args)

        assert hasattr(args, "char_source_dict_size"), (
            "args.char_source_dict_size required. "
            "should be set by load_binarized_dataset()")

        assert hasattr(
            args, "char_cnn_params"
        ), "Only char CNN is supported for the char encoder hybrid model"

        args.embed_bytes = getattr(args, "embed_bytes", False)

        # In case use_pretrained_weights is true, verify the model params
        # are correctly set
        if args.embed_bytes and getattr(args, "use_pretrained_weights", False):
            char_source_model.verify_pretrain_params(args)

        encoder = CharSourceHybridModel.build_encoder(args=args,
                                                      src_dict=src_dict)
        decoder = CharSourceHybridModel.build_decoder(args=args,
                                                      src_dict=src_dict,
                                                      dst_dict=dst_dict)

        return cls(task, encoder, decoder)
    def build_model(cls, args, task):
        """Build a new model instance."""
        src_dict, dst_dict = task.source_dictionary, task.target_dictionary
        base_architecture(args)

        assert hasattr(args, "char_source_dict_size"), (
            "args.char_source_dict_size required. "
            "should be set by load_binarized_dataset()"
        )

        if args.share_all_embeddings:
            if src_dict != dst_dict:
                raise RuntimeError(
                    "--share-all-embeddings requires a joined dictionary"
                )
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise RuntimeError(
                    "--share-all-embeddings requires --encoder-embed-dim "
                    "to match --decoder-embed-dim"
                )
            if args.decoder_pretrained_embed and (
                args.decoder_pretrained_embed != args.encoder_pretrained_embed
            ):
                raise RuntimeError(
                    "--share-all-embeddings not compatible with "
                    "--decoder-pretrained-embed"
                )
            encoder_embed_tokens = transformer.build_embedding(
                src_dict,
                args.encoder_embed_dim,
                args.encoder_pretrained_embed,
                args.encoder_freeze_embed,
            )
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = transformer.build_embedding(
                src_dict,
                args.encoder_embed_dim,
                args.encoder_pretrained_embed,
                args.encoder_freeze_embed,
            )
            decoder_embed_tokens = transformer.build_embedding(
                dst_dict,
                args.decoder_embed_dim,
                args.decoder_pretrained_embed,
                args.decoder_freeze_embed,
            )

        args.embed_bytes = getattr(args, "embed_bytes", False)

        # If we embed bytes then the number of indices is fixed and does not
        # depend on the dictionary
        if args.embed_bytes:
            num_chars = vocab_constants.NUM_BYTE_INDICES + TAGS.__len__() + 1
        else:
            num_chars = args.char_source_dict_size

        # In case use_pretrained_weights is true, verify the model params
        # are correctly set
        if args.embed_bytes and getattr(args, "use_pretrained_weights", False):
            verify_pretrain_params(args)

        encoder = CharCNNEncoder(
            args,
            src_dict,
            encoder_embed_tokens,
            num_chars=num_chars,
            embed_dim=args.char_embed_dim,
            char_cnn_params=args.char_cnn_params,
            char_cnn_nonlinear_fn=args.char_cnn_nonlinear_fn,
            char_cnn_pool_type=args.char_cnn_pool_type,
            char_cnn_num_highway_layers=args.char_cnn_num_highway_layers,
            char_cnn_output_dim=getattr(args, "char_cnn_output_dim", -1),
            use_pretrained_weights=getattr(args, "use_pretrained_weights", False),
            finetune_pretrained_weights=getattr(
                args, "finetune_pretrained_weights", False
            ),
            weights_file=getattr(args, "pretrained_weights_file", ""),
        )
        decoder = transformer.TransformerDecoder(
            args=args,
            src_dict=src_dict,
            dst_dict=dst_dict,
            embed_tokens=decoder_embed_tokens,
        )
        return cls(task, encoder, decoder)
예제 #3
0
    def build_model(cls, args, task):
        """Build a new model instance."""
        src_dict, dst_dict = task.source_dictionary, task.target_dictionary
        base_architecture(args)

        assert hasattr(args, "char_source_dict_size"), (
            "args.char_source_dict_size required. "
            "should be set by load_binarized_dataset()")

        assert hasattr(
            args, "char_cnn_params"
        ), "Only char CNN is supported for the char encoder hybrid model"

        args.embed_bytes = getattr(args, "embed_bytes", False)

        # If we embed bytes then the number of indices is fixed and does not
        # depend on the dictionary
        if args.embed_bytes:
            num_chars = vocab_constants.NUM_BYTE_INDICES + TAGS.__len__() + 1
        else:
            num_chars = args.char_source_dict_size

        # In case use_pretrained_weights is true, verify the model params
        # are correctly set
        if args.embed_bytes and getattr(args, "use_pretrained_weights", False):
            char_source_model.verify_pretrain_params(args)

        encoder_embed_tokens = pytorch_translate_transformer.build_embedding(
            src_dict,
            args.encoder_embed_dim,
            args.encoder_pretrained_embed,
            args.encoder_freeze_embed,
        )
        encoder = CharCNNEncoder(
            args,
            src_dict,
            encoder_embed_tokens,
            num_chars=num_chars,
            embed_dim=args.char_embed_dim,
            char_cnn_params=args.char_cnn_params,
            char_cnn_nonlinear_fn=args.char_cnn_nonlinear_fn,
            char_cnn_pool_type=args.char_cnn_pool_type,
            char_cnn_num_highway_layers=args.char_cnn_num_highway_layers,
            char_cnn_output_dim=getattr(args, "char_cnn_output_dim", -1),
            use_pretrained_weights=getattr(args, "use_pretrained_weights",
                                           False),
            finetune_pretrained_weights=getattr(args,
                                                "finetune_pretrained_weights",
                                                False),
            weights_file=getattr(args, "pretrained_weights_file", ""),
            left_pad=False,
        )

        decoder_embed_tokens = pytorch_translate_transformer.build_embedding(
            dst_dict,
            args.decoder_embed_dim,
            args.decoder_pretrained_embed,
            args.decoder_freeze_embed,
        )
        decoder = hybrid_transformer_rnn.HybridRNNDecoder(
            args, src_dict, dst_dict, decoder_embed_tokens)

        return cls(task, encoder, decoder)