def build_model(cls, args, task): """ Build both the primal and dual models. For simplicity, both models share the same arch, i.e. the same model params would be used to initialize both models. Support for different models/archs would be added in further iterations. """ base_architecture(args) if args.sequence_lstm: encoder_class = LSTMSequenceEncoder else: encoder_class = RNNEncoder decoder_class = RNNDecoder encoder_embed_tokens, decoder_embed_tokens = RNNModel.build_embed_tokens( args, task.primal_src_dict, task.primal_tgt_dict) primal_encoder = encoder_class( task.primal_src_dict, embed_dim=args.encoder_embed_dim, embed_tokens=encoder_embed_tokens, cell_type=args.cell_type, num_layers=args.encoder_layers, hidden_dim=args.encoder_hidden_dim, dropout_in=args.encoder_dropout_in, dropout_out=args.encoder_dropout_out, residual_level=args.residual_level, bidirectional=bool(args.encoder_bidirectional), ) primal_decoder = decoder_class( src_dict=task.primal_src_dict, dst_dict=task.primal_tgt_dict, embed_tokens=decoder_embed_tokens, vocab_reduction_params=args.vocab_reduction_params, encoder_hidden_dim=args.encoder_hidden_dim, embed_dim=args.decoder_embed_dim, out_embed_dim=args.decoder_out_embed_dim, cell_type=args.cell_type, num_layers=args.decoder_layers, hidden_dim=args.decoder_hidden_dim, attention_type=args.attention_type, dropout_in=args.decoder_dropout_in, dropout_out=args.decoder_dropout_out, residual_level=args.residual_level, averaging_encoder=args.averaging_encoder, ) primal_task = PytorchTranslateTask(args, task.primal_src_dict, task.primal_tgt_dict) primal_model = rnn.RNNModel(primal_task, primal_encoder, primal_decoder) if args.pretrained_forward_checkpoint: pretrained_forward_state = checkpoint_utils.load_checkpoint_to_cpu( args.pretrained_forward_checkpoint) primal_model.load_state_dict(pretrained_forward_state["model"], strict=True) print( f"Loaded pretrained primal model from {args.pretrained_forward_checkpoint}" ) encoder_embed_tokens, decoder_embed_tokens = RNNModel.build_embed_tokens( args, task.dual_src_dict, task.dual_tgt_dict) dual_encoder = encoder_class( task.dual_src_dict, embed_dim=args.encoder_embed_dim, embed_tokens=encoder_embed_tokens, cell_type=args.cell_type, num_layers=args.encoder_layers, hidden_dim=args.encoder_hidden_dim, dropout_in=args.encoder_dropout_in, dropout_out=args.encoder_dropout_out, residual_level=args.residual_level, bidirectional=bool(args.encoder_bidirectional), ) dual_decoder = decoder_class( src_dict=task.dual_src_dict, dst_dict=task.dual_tgt_dict, embed_tokens=decoder_embed_tokens, vocab_reduction_params=args.vocab_reduction_params, encoder_hidden_dim=args.encoder_hidden_dim, embed_dim=args.decoder_embed_dim, out_embed_dim=args.decoder_out_embed_dim, cell_type=args.cell_type, num_layers=args.decoder_layers, hidden_dim=args.decoder_hidden_dim, attention_type=args.attention_type, dropout_in=args.decoder_dropout_in, dropout_out=args.decoder_dropout_out, residual_level=args.residual_level, averaging_encoder=args.averaging_encoder, ) dual_task = PytorchTranslateTask(args, task.dual_src_dict, task.dual_tgt_dict) dual_model = rnn.RNNModel(dual_task, dual_encoder, dual_decoder) if args.pretrained_backward_checkpoint: pretrained_backward_state = checkpoint_utils.load_checkpoint_to_cpu( args.pretrained_backward_checkpoint) dual_model.load_state_dict(pretrained_backward_state["model"], strict=True) print( f"Loaded pretrained dual model from {args.pretrained_backward_checkpoint}" ) # TODO (T36875783): instantiate a langauge model lm_model = None return RNNDualLearningModel(args, task, primal_model, dual_model, lm_model)
def build_model(cls, args, task): """ Build both the primal and dual models. For simplicity, both models share the same arch, i.e. the same model params would be used to initialize both models. Support for different models/archs would be added in further iterations. """ base_architecture(args) if args.sequence_lstm: encoder_class = LSTMSequenceEncoder else: encoder_class = RNNEncoder decoder_class = RNNDecoder primal_encoder = encoder_class( task.primal_src_dict, embed_dim=args.encoder_embed_dim, freeze_embed=args.encoder_freeze_embed, cell_type=args.cell_type, num_layers=args.encoder_layers, hidden_dim=args.encoder_hidden_dim, dropout_in=args.encoder_dropout_in, dropout_out=args.encoder_dropout_out, residual_level=args.residual_level, bidirectional=bool(args.encoder_bidirectional), ) primal_decoder = decoder_class( src_dict=task.primal_src_dict, dst_dict=task.primal_tgt_dict, vocab_reduction_params=args.vocab_reduction_params, encoder_hidden_dim=args.encoder_hidden_dim, embed_dim=args.decoder_embed_dim, freeze_embed=args.decoder_freeze_embed, out_embed_dim=args.decoder_out_embed_dim, cell_type=args.cell_type, num_layers=args.decoder_layers, hidden_dim=args.decoder_hidden_dim, attention_type=args.attention_type, dropout_in=args.decoder_dropout_in, dropout_out=args.decoder_dropout_out, residual_level=args.residual_level, averaging_encoder=args.averaging_encoder, ) primal_task = PytorchTranslateTask(args, task.primal_src_dict, task.primal_tgt_dict) primal_model = rnn.RNNModel(primal_task, primal_encoder, primal_decoder) dual_encoder = encoder_class( task.dual_src_dict, embed_dim=args.encoder_embed_dim, freeze_embed=args.encoder_freeze_embed, cell_type=args.cell_type, num_layers=args.encoder_layers, hidden_dim=args.encoder_hidden_dim, dropout_in=args.encoder_dropout_in, dropout_out=args.encoder_dropout_out, residual_level=args.residual_level, bidirectional=bool(args.encoder_bidirectional), ) dual_decoder = decoder_class( src_dict=task.dual_src_dict, dst_dict=task.dual_tgt_dict, vocab_reduction_params=args.vocab_reduction_params, encoder_hidden_dim=args.encoder_hidden_dim, embed_dim=args.decoder_embed_dim, freeze_embed=args.decoder_freeze_embed, out_embed_dim=args.decoder_out_embed_dim, cell_type=args.cell_type, num_layers=args.decoder_layers, hidden_dim=args.decoder_hidden_dim, attention_type=args.attention_type, dropout_in=args.decoder_dropout_in, dropout_out=args.decoder_dropout_out, residual_level=args.residual_level, averaging_encoder=args.averaging_encoder, ) dual_task = PytorchTranslateTask(args, task.dual_src_dict, task.dual_tgt_dict) dual_model = rnn.RNNModel(dual_task, dual_encoder, dual_decoder) # TODO (T36875783): instantiate a langauge model lm_model = None return RNNDualLearningModel(args, task, primal_model, dual_model, lm_model)