def build_model(cls, args, src_dict, dst_dict): """Build a new model instance.""" base_architecture(args) assert args.sequence_lstm, "CharRNNModel only supports sequence_lstm" assert args.cell_type == "lstm", "CharRNNModel only supports cell_type lstm" assert hasattr(args, "char_source_dict_size"), ( "args.char_source_dict_size required. " "should be set by load_binarized_dataset()" ) encoder = CharRNNEncoder( src_dict, num_chars=args.char_source_dict_size, char_embed_dim=args.char_embed_dim, token_embed_dim=args.encoder_embed_dim, freeze_embed=args.encoder_freeze_embed, char_rnn_units=args.char_rnn_units, char_rnn_layers=args.char_rnn_layers, num_layers=args.encoder_layers, hidden_dim=args.encoder_hidden_dim, dropout_in=args.encoder_dropout_in, dropout_out=args.encoder_dropout_out, residual_level=args.residual_level, bidirectional=bool(args.encoder_bidirectional), word_dropout_params=args.word_dropout_params, ) decoder = rnn.RNNDecoder( src_dict=src_dict, dst_dict=dst_dict, vocab_reduction_params=args.vocab_reduction_params, encoder_hidden_dim=args.encoder_hidden_dim, embed_dim=args.decoder_embed_dim, freeze_embed=args.decoder_freeze_embed, out_embed_dim=args.decoder_out_embed_dim, cell_type=args.cell_type, num_layers=args.decoder_layers, hidden_dim=args.decoder_hidden_dim, attention_type=args.attention_type, dropout_in=args.decoder_dropout_in, dropout_out=args.decoder_dropout_out, residual_level=args.residual_level, averaging_encoder=args.averaging_encoder, ) return cls(encoder, decoder)
def build_model(cls, args, task): """Build a new model instance.""" src_dict, dst_dict = task.source_dictionary, task.target_dictionary base_architecture(args) assert args.sequence_lstm, "CharRNNModel only supports sequence_lstm" assert args.cell_type == "lstm", "CharRNNModel only supports cell_type lstm" assert hasattr(args, "char_source_dict_size"), ( "args.char_source_dict_size required. " "should be set by load_binarized_dataset()" ) if hasattr(args, "char_cnn_params"): args.embed_bytes = getattr(args, "embed_bytes", False) # If we embed bytes then the number of indices is fixed and does not # depend on the dictionary if args.embed_bytes: num_chars = vocab_constants.NUM_BYTE_INDICES + TAGS.__len__() + 1 else: num_chars = args.char_source_dict_size # In case use_pretrained_weights is true, verify the model params # are correctly set if args.embed_bytes and getattr(args, "use_pretrained_weights", False): verify_pretrain_params(args) encoder = CharCNNEncoder( src_dict, num_chars=num_chars, unk_only_char_encoding=args.unk_only_char_encoding, embed_dim=args.char_embed_dim, token_embed_dim=args.encoder_embed_dim, freeze_embed=args.encoder_freeze_embed, normalize_embed=args.encoder_normalize_embed, char_cnn_params=args.char_cnn_params, char_cnn_nonlinear_fn=args.char_cnn_nonlinear_fn, char_cnn_pool_type=args.char_cnn_pool_type, char_cnn_num_highway_layers=args.char_cnn_num_highway_layers, char_cnn_output_dim=getattr(args, "char_cnn_output_dim", -1), num_layers=args.encoder_layers, hidden_dim=args.encoder_hidden_dim, dropout_in=args.encoder_dropout_in, dropout_out=args.encoder_dropout_out, residual_level=args.residual_level, bidirectional=bool(args.encoder_bidirectional), word_dropout_params=args.word_dropout_params, use_pretrained_weights=getattr(args, "use_pretrained_weights", False), finetune_pretrained_weights=getattr( args, "finetune_pretrained_weights", False ), weights_file=getattr(args, "pretrained_weights_file", ""), ) else: assert ( args.unk_only_char_encoding is False ), "unk_only_char_encoding should be False when using CharRNNEncoder" encoder = CharRNNEncoder( src_dict, num_chars=args.char_source_dict_size, char_embed_dim=args.char_embed_dim, token_embed_dim=args.encoder_embed_dim, normalize_embed=args.encoder_normalize_embed, char_rnn_units=args.char_rnn_units, char_rnn_layers=args.char_rnn_layers, num_layers=args.encoder_layers, hidden_dim=args.encoder_hidden_dim, dropout_in=args.encoder_dropout_in, dropout_out=args.encoder_dropout_out, residual_level=args.residual_level, bidirectional=bool(args.encoder_bidirectional), ) decoder = rnn.RNNDecoder( src_dict=src_dict, dst_dict=dst_dict, vocab_reduction_params=args.vocab_reduction_params, encoder_hidden_dim=args.encoder_hidden_dim, embed_dim=args.decoder_embed_dim, freeze_embed=args.decoder_freeze_embed, out_embed_dim=args.decoder_out_embed_dim, cell_type=args.cell_type, num_layers=args.decoder_layers, hidden_dim=args.decoder_hidden_dim, attention_type=args.attention_type, dropout_in=args.decoder_dropout_in, dropout_out=args.decoder_dropout_out, residual_level=args.residual_level, averaging_encoder=args.averaging_encoder, ) return cls(task, encoder, decoder)