コード例 #1
0
    def build_model(cls, args, task):
        """ Build both the primal and dual models.
        For simplicity, both models share the same arch, i.e. the same model
        params would be used to initialize both models.
        Support for different models/archs would be added in further iterations.
        """
        base_architecture(args)

        if args.sequence_lstm:
            encoder_class = LSTMSequenceEncoder
        else:
            encoder_class = RNNEncoder
        decoder_class = RNNDecoder

        encoder_embed_tokens, decoder_embed_tokens = RNNModel.build_embed_tokens(
            args, task.primal_src_dict, task.primal_tgt_dict)
        primal_encoder = encoder_class(
            task.primal_src_dict,
            embed_dim=args.encoder_embed_dim,
            embed_tokens=encoder_embed_tokens,
            cell_type=args.cell_type,
            num_layers=args.encoder_layers,
            hidden_dim=args.encoder_hidden_dim,
            dropout_in=args.encoder_dropout_in,
            dropout_out=args.encoder_dropout_out,
            residual_level=args.residual_level,
            bidirectional=bool(args.encoder_bidirectional),
        )
        primal_decoder = decoder_class(
            src_dict=task.primal_src_dict,
            dst_dict=task.primal_tgt_dict,
            embed_tokens=decoder_embed_tokens,
            vocab_reduction_params=args.vocab_reduction_params,
            encoder_hidden_dim=args.encoder_hidden_dim,
            embed_dim=args.decoder_embed_dim,
            out_embed_dim=args.decoder_out_embed_dim,
            cell_type=args.cell_type,
            num_layers=args.decoder_layers,
            hidden_dim=args.decoder_hidden_dim,
            attention_type=args.attention_type,
            dropout_in=args.decoder_dropout_in,
            dropout_out=args.decoder_dropout_out,
            residual_level=args.residual_level,
            averaging_encoder=args.averaging_encoder,
        )
        primal_task = PytorchTranslateTask(args, task.primal_src_dict,
                                           task.primal_tgt_dict)
        primal_model = rnn.RNNModel(primal_task, primal_encoder,
                                    primal_decoder)
        if args.pretrained_forward_checkpoint:
            pretrained_forward_state = checkpoint_utils.load_checkpoint_to_cpu(
                args.pretrained_forward_checkpoint)
            primal_model.load_state_dict(pretrained_forward_state["model"],
                                         strict=True)
            print(
                f"Loaded pretrained primal model from {args.pretrained_forward_checkpoint}"
            )

        encoder_embed_tokens, decoder_embed_tokens = RNNModel.build_embed_tokens(
            args, task.dual_src_dict, task.dual_tgt_dict)
        dual_encoder = encoder_class(
            task.dual_src_dict,
            embed_dim=args.encoder_embed_dim,
            embed_tokens=encoder_embed_tokens,
            cell_type=args.cell_type,
            num_layers=args.encoder_layers,
            hidden_dim=args.encoder_hidden_dim,
            dropout_in=args.encoder_dropout_in,
            dropout_out=args.encoder_dropout_out,
            residual_level=args.residual_level,
            bidirectional=bool(args.encoder_bidirectional),
        )
        dual_decoder = decoder_class(
            src_dict=task.dual_src_dict,
            dst_dict=task.dual_tgt_dict,
            embed_tokens=decoder_embed_tokens,
            vocab_reduction_params=args.vocab_reduction_params,
            encoder_hidden_dim=args.encoder_hidden_dim,
            embed_dim=args.decoder_embed_dim,
            out_embed_dim=args.decoder_out_embed_dim,
            cell_type=args.cell_type,
            num_layers=args.decoder_layers,
            hidden_dim=args.decoder_hidden_dim,
            attention_type=args.attention_type,
            dropout_in=args.decoder_dropout_in,
            dropout_out=args.decoder_dropout_out,
            residual_level=args.residual_level,
            averaging_encoder=args.averaging_encoder,
        )
        dual_task = PytorchTranslateTask(args, task.dual_src_dict,
                                         task.dual_tgt_dict)
        dual_model = rnn.RNNModel(dual_task, dual_encoder, dual_decoder)
        if args.pretrained_backward_checkpoint:
            pretrained_backward_state = checkpoint_utils.load_checkpoint_to_cpu(
                args.pretrained_backward_checkpoint)
            dual_model.load_state_dict(pretrained_backward_state["model"],
                                       strict=True)
            print(
                f"Loaded pretrained dual model from {args.pretrained_backward_checkpoint}"
            )

        # TODO (T36875783): instantiate a langauge model
        lm_model = None
        return RNNDualLearningModel(args, task, primal_model, dual_model,
                                    lm_model)
コード例 #2
0
    def build_model(cls, args, task):
        """ Build both the primal and dual models.
        For simplicity, both models share the same arch, i.e. the same model
        params would be used to initialize both models.
        Support for different models/archs would be added in further iterations.
        """
        base_architecture(args)

        if args.sequence_lstm:
            encoder_class = LSTMSequenceEncoder
        else:
            encoder_class = RNNEncoder
        decoder_class = RNNDecoder

        primal_encoder = encoder_class(
            task.primal_src_dict,
            embed_dim=args.encoder_embed_dim,
            freeze_embed=args.encoder_freeze_embed,
            cell_type=args.cell_type,
            num_layers=args.encoder_layers,
            hidden_dim=args.encoder_hidden_dim,
            dropout_in=args.encoder_dropout_in,
            dropout_out=args.encoder_dropout_out,
            residual_level=args.residual_level,
            bidirectional=bool(args.encoder_bidirectional),
        )
        primal_decoder = decoder_class(
            src_dict=task.primal_src_dict,
            dst_dict=task.primal_tgt_dict,
            vocab_reduction_params=args.vocab_reduction_params,
            encoder_hidden_dim=args.encoder_hidden_dim,
            embed_dim=args.decoder_embed_dim,
            freeze_embed=args.decoder_freeze_embed,
            out_embed_dim=args.decoder_out_embed_dim,
            cell_type=args.cell_type,
            num_layers=args.decoder_layers,
            hidden_dim=args.decoder_hidden_dim,
            attention_type=args.attention_type,
            dropout_in=args.decoder_dropout_in,
            dropout_out=args.decoder_dropout_out,
            residual_level=args.residual_level,
            averaging_encoder=args.averaging_encoder,
        )
        primal_task = PytorchTranslateTask(args, task.primal_src_dict,
                                           task.primal_tgt_dict)
        primal_model = rnn.RNNModel(primal_task, primal_encoder,
                                    primal_decoder)

        dual_encoder = encoder_class(
            task.dual_src_dict,
            embed_dim=args.encoder_embed_dim,
            freeze_embed=args.encoder_freeze_embed,
            cell_type=args.cell_type,
            num_layers=args.encoder_layers,
            hidden_dim=args.encoder_hidden_dim,
            dropout_in=args.encoder_dropout_in,
            dropout_out=args.encoder_dropout_out,
            residual_level=args.residual_level,
            bidirectional=bool(args.encoder_bidirectional),
        )
        dual_decoder = decoder_class(
            src_dict=task.dual_src_dict,
            dst_dict=task.dual_tgt_dict,
            vocab_reduction_params=args.vocab_reduction_params,
            encoder_hidden_dim=args.encoder_hidden_dim,
            embed_dim=args.decoder_embed_dim,
            freeze_embed=args.decoder_freeze_embed,
            out_embed_dim=args.decoder_out_embed_dim,
            cell_type=args.cell_type,
            num_layers=args.decoder_layers,
            hidden_dim=args.decoder_hidden_dim,
            attention_type=args.attention_type,
            dropout_in=args.decoder_dropout_in,
            dropout_out=args.decoder_dropout_out,
            residual_level=args.residual_level,
            averaging_encoder=args.averaging_encoder,
        )
        dual_task = PytorchTranslateTask(args, task.dual_src_dict,
                                         task.dual_tgt_dict)
        dual_model = rnn.RNNModel(dual_task, dual_encoder, dual_decoder)
        # TODO (T36875783): instantiate a langauge model
        lm_model = None
        return RNNDualLearningModel(args, task, primal_model, dual_model,
                                    lm_model)