def test_copy_gen_outp_has_no_prob_of_pad(self):
     for params, init_case in itertools.product(self.PARAMS,
                                                self.INIT_CASES):
         cgen = CopyGenerator(**init_case)
         dummy_in = self.dummy_inputs(params, init_case)
         res = cgen(*dummy_in)
         self.assertTrue(res[:, init_case["pad_idx"]].allclose(
             torch.tensor(0.0)))
 def test_copy_gen_forward_shape(self):
     for params, init_case in itertools.product(self.PARAMS,
                                                self.INIT_CASES):
         cgen = CopyGenerator(**init_case)
         dummy_in = self.dummy_inputs(params, init_case)
         res = cgen(*dummy_in)
         expected_shape = self.expected_shape(params, init_case)
         self.assertEqual(res.shape, expected_shape, init_case.__str__())
 def test_copy_gen_trainable_params_update(self):
     for params, init_case in itertools.product(
             self.PARAMS, self.INIT_CASES):
         cgen = CopyGenerator(**init_case)
         trainable_params = {n: p for n, p in cgen.named_parameters()
                             if p.requires_grad}
         assert len(trainable_params) > 0  # sanity check
         old_weights = deepcopy(trainable_params)
         dummy_in = self.dummy_inputs(params, init_case)
         res = cgen(*dummy_in)
         pretend_loss = res.sum()
         pretend_loss.backward()
         dummy_optim = torch.optim.SGD(trainable_params.values(), 1)
         dummy_optim.step()
         for param_name in old_weights.keys():
             self.assertTrue(
                 trainable_params[param_name]
                 .ne(old_weights[param_name]).any(),
                 param_name + " " + init_case.__str__())
 def test_copy_gen_trainable_params_update(self):
     for params, init_case in itertools.product(
             self.PARAMS, self.INIT_CASES):
         cgen = CopyGenerator(**init_case)
         trainable_params = {n: p for n, p in cgen.named_parameters()
                             if p.requires_grad}
         assert len(trainable_params) > 0  # sanity check
         old_weights = deepcopy(trainable_params)
         dummy_in = self.dummy_inputs(params, init_case)
         res = cgen(*dummy_in)
         pretend_loss = res.sum()
         pretend_loss.backward()
         dummy_optim = torch.optim.SGD(trainable_params.values(), 1)
         dummy_optim.step()
         for param_name in old_weights.keys():
             self.assertTrue(
                 trainable_params[param_name]
                 .ne(old_weights[param_name]).any(),
                 param_name + " " + init_case.__str__())
Beispiel #5
0
def build_tm_model(opt, dicts):
    onmt.constants.neg_log_sigma1 = opt.neg_log_sigma1
    onmt.constants.neg_log_sigma2 = opt.neg_log_sigma2
    onmt.constants.prior_pi = opt.prior_pi

    # BUILD POSITIONAL ENCODING
    if opt.time == 'positional_encoding':
        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=MAX_LEN)
    else:
        raise NotImplementedError

    # BUILD GENERATOR
    if opt.copy_generator:
        generators = [
            CopyGenerator(opt.model_size,
                          dicts['tgt'].size(),
                          fix_norm=opt.fix_norm_output_embedding)
        ]
    else:
        generators = [
            onmt.modules.base_seq2seq.Generator(
                opt.model_size,
                dicts['tgt'].size(),
                fix_norm=opt.fix_norm_output_embedding)
        ]

    # BUILD EMBEDDINGS
    if 'src' in dicts:
        embedding_src = nn.Embedding(dicts['src'].size(),
                                     opt.model_size,
                                     padding_idx=onmt.constants.PAD)
    else:
        embedding_src = None

    if opt.join_embedding and embedding_src is not None:
        embedding_tgt = embedding_src
        print("* Joining the weights of encoder and decoder word embeddings")
    else:
        embedding_tgt = nn.Embedding(dicts['tgt'].size(),
                                     opt.model_size,
                                     padding_idx=onmt.constants.PAD)

    if opt.use_language_embedding:
        print("* Create language embeddings with %d languages" %
              len(dicts['langs']))
        language_embeddings = nn.Embedding(len(dicts['langs']), opt.model_size)
    else:
        language_embeddings = None

    if opt.encoder_type == "text":
        encoder = RelativeTransformerEncoder(
            opt,
            embedding_src,
            None,
            opt.encoder_type,
            language_embeddings=language_embeddings)
    if opt.encoder_type == "audio":
        # raise NotImplementedError
        encoder = RelativeTransformerEncoder(
            opt,
            None,
            None,
            encoder_type=opt.encoder_type,
            language_embeddings=language_embeddings)

    generator = nn.ModuleList(generators)
    decoder = RelativeTransformerDecoder(
        opt, embedding_tgt, None, language_embeddings=language_embeddings)

    if opt.reconstruct:
        rev_decoder = RelativeTransformerDecoder(
            opt, embedding_src, None, language_embeddings=language_embeddings)
        rev_generator = [
            onmt.modules.base_seq2seq.Generator(
                opt.model_size,
                dicts['src'].size(),
                fix_norm=opt.fix_norm_output_embedding)
        ]
        rev_generator = nn.ModuleList(rev_generator)
    else:
        rev_decoder = None
        rev_generator = None

    model = BayesianTransformer(encoder,
                                decoder,
                                generator,
                                rev_decoder,
                                rev_generator,
                                mirror=opt.mirror_loss)

    if opt.tie_weights:
        print("* Joining the weights of decoder input and output embeddings")
        model.tie_weights()

    return model
Beispiel #6
0
def build_tm_model(opt, dicts):
    # BUILD POSITIONAL ENCODING
    if opt.time == 'positional_encoding':
        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=MAX_LEN)
    else:
        raise NotImplementedError

    if opt.reconstruct:
        # reconstruction is only compatible
        assert opt.model == 'relative_transformer'
        assert opt.encoder_type == 'text'

    # BUILD GENERATOR
    if opt.copy_generator:
        if opt.nce_noise > 0:
            print("[INFO] Copy generator overrides NCE.")
            opt.nce = False
            opt.nce_noise = 0
        generators = [
            CopyGenerator(opt.model_size,
                          dicts['tgt'].size(),
                          fix_norm=opt.fix_norm_output_embedding)
        ]
    elif opt.nce_noise > 0:
        from onmt.modules.nce.nce_linear import NCELinear
        from onmt.modules.nce.nce_utils import build_unigram_noise
        noise_distribution = build_unigram_noise(
            torch.FloatTensor(list(dicts['tgt'].frequencies.values())))

        generator = NCELinear(opt.model_size,
                              dicts['tgt'].size(),
                              fix_norm=opt.fix_norm_output_embedding,
                              noise_distribution=noise_distribution,
                              noise_ratio=opt.nce_noise)
        generators = [generator]
    else:
        generators = [
            onmt.modules.base_seq2seq.Generator(
                opt.model_size,
                dicts['tgt'].size(),
                fix_norm=opt.fix_norm_output_embedding)
        ]

    # BUILD EMBEDDINGS
    if 'src' in dicts:
        embedding_src = nn.Embedding(dicts['src'].size(),
                                     opt.model_size,
                                     padding_idx=onmt.constants.PAD)
    else:
        embedding_src = None

    if opt.join_embedding and embedding_src is not None:
        embedding_tgt = embedding_src
        print("* Joining the weights of encoder and decoder word embeddings")
    else:
        embedding_tgt = nn.Embedding(dicts['tgt'].size(),
                                     opt.model_size,
                                     padding_idx=onmt.constants.PAD)

    if opt.use_language_embedding:
        print("* Create language embeddings with %d languages" %
              len(dicts['langs']))
        language_embeddings = nn.Embedding(len(dicts['langs']), opt.model_size)
    else:
        language_embeddings = None

    if opt.ctc_loss != 0:
        generators.append(
            onmt.modules.base_seq2seq.Generator(opt.model_size,
                                                dicts['tgt'].size() + 1))

    if opt.model in ['conformer', 'speech_transformer', 'hybrid_transformer']:
        onmt.constants.init_value = opt.param_init
        from onmt.models.speech_recognizer.relative_transformer import \
            SpeechTransformerEncoder, SpeechTransformerDecoder

        if opt.model == 'conformer':
            from onmt.models.speech_recognizer.conformer import ConformerEncoder, Conformer
            from onmt.models.speech_recognizer.lstm import SpeechLSTMDecoder
            opt.cnn_downsampling = True  # force this bool to have masking at decoder to be corrected
            encoder = ConformerEncoder(opt, None, None, 'audio')

            decoder = SpeechLSTMDecoder(
                opt, embedding_tgt, language_embeddings=language_embeddings)

            model = Conformer(encoder,
                              decoder,
                              nn.ModuleList(generators),
                              ctc=opt.ctc_loss > 0.0)
        elif opt.model == 'hybrid_transformer':
            from onmt.models.speech_recognizer.lstm import SpeechLSTMDecoder, SpeechLSTMEncoder, SpeechLSTMSeq2Seq
            encoder = SpeechTransformerEncoder(opt, None, positional_encoder,
                                               opt.encoder_type)

            decoder = SpeechLSTMDecoder(
                opt, embedding_tgt, language_embeddings=language_embeddings)

            model = SpeechLSTMSeq2Seq(encoder,
                                      decoder,
                                      nn.ModuleList(generators),
                                      ctc=opt.ctc_loss > 0.0)
        else:
            encoder = SpeechTransformerEncoder(opt, None, positional_encoder,
                                               opt.encoder_type)

            decoder = SpeechTransformerDecoder(
                opt,
                embedding_tgt,
                positional_encoder,
                language_embeddings=language_embeddings)
            model = RelativeTransformer(encoder,
                                        decoder,
                                        nn.ModuleList(generators),
                                        None,
                                        None,
                                        mirror=opt.mirror_loss,
                                        ctc=opt.ctc_loss > 0.0)

        # If we use the multilingual model and weights are partitioned:
        if opt.multilingual_partitioned_weights:

            # this is basically the language embeddings
            factor_embeddings = nn.Embedding(len(dicts['langs']),
                                             opt.mpw_factor_size)

            encoder.factor_embeddings = factor_embeddings
            decoder.factor_embeddings = factor_embeddings

    elif opt.model in ["LSTM", 'lstm']:
        # print("LSTM")
        onmt.constants.init_value = opt.param_init
        from onmt.models.speech_recognizer.lstm import SpeechLSTMDecoder, SpeechLSTMEncoder, SpeechLSTMSeq2Seq

        encoder = SpeechLSTMEncoder(opt, None, opt.encoder_type)

        decoder = SpeechLSTMDecoder(opt,
                                    embedding_tgt,
                                    language_embeddings=language_embeddings)

        model = SpeechLSTMSeq2Seq(encoder,
                                  decoder,
                                  nn.ModuleList(generators),
                                  ctc=opt.ctc_loss > 0.0)

    elif opt.model in ['multilingual_translator', 'translator']:
        onmt.constants.init_value = opt.param_init
        from onmt.models.multilingual_translator.relative_transformer import \
            RelativeTransformerEncoder, RelativeTransformerDecoder

        encoder = RelativeTransformerEncoder(
            opt,
            embedding_src,
            None,
            opt.encoder_type,
            language_embeddings=language_embeddings)
        decoder = RelativeTransformerDecoder(
            opt, embedding_tgt, None, language_embeddings=language_embeddings)

        model = RelativeTransformer(encoder,
                                    decoder,
                                    nn.ModuleList(generators),
                                    None,
                                    None,
                                    mirror=opt.mirror_loss)

    elif opt.model in ['transformer', 'stochastic_transformer']:
        onmt.constants.init_value = opt.param_init

        if opt.encoder_type == "text":
            encoder = TransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        elif opt.encoder_type == "audio":
            encoder = TransformerEncoder(opt, None, positional_encoder,
                                         opt.encoder_type)
        elif opt.encoder_type == "mix":
            text_encoder = TransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                "text",
                language_embeddings=language_embeddings)
            audio_encoder = TransformerEncoder(opt, None, positional_encoder,
                                               "audio")
            encoder = MixedEncoder(text_encoder, audio_encoder)
        else:
            print("Unknown encoder type:", opt.encoder_type)
            exit(-1)

        decoder = TransformerDecoder(opt,
                                     embedding_tgt,
                                     positional_encoder,
                                     language_embeddings=language_embeddings)

        model = Transformer(encoder,
                            decoder,
                            nn.ModuleList(generators),
                            mirror=opt.mirror_loss)

    elif opt.model == 'relative_transformer':
        from onmt.models.relative_transformer import \
            RelativeTransformerEncoder, RelativeTransformerDecoder

        if opt.encoder_type == "text":
            encoder = RelativeTransformerEncoder(
                opt,
                embedding_src,
                None,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        if opt.encoder_type == "audio":
            # raise NotImplementedError
            encoder = RelativeTransformerEncoder(
                opt,
                None,
                None,
                encoder_type=opt.encoder_type,
                language_embeddings=language_embeddings)

        generator = nn.ModuleList(generators)
        decoder = RelativeTransformerDecoder(
            opt, embedding_tgt, None, language_embeddings=language_embeddings)

        if opt.reconstruct:
            rev_decoder = RelativeTransformerDecoder(
                opt,
                embedding_src,
                None,
                language_embeddings=language_embeddings)
            rev_generator = [
                onmt.modules.base_seq2seq.Generator(
                    opt.model_size,
                    dicts['src'].size(),
                    fix_norm=opt.fix_norm_output_embedding)
            ]
            rev_generator = nn.ModuleList(rev_generator)
        else:
            rev_decoder = None
            rev_generator = None

        model = RelativeTransformer(encoder,
                                    decoder,
                                    generator,
                                    rev_decoder,
                                    rev_generator,
                                    mirror=opt.mirror_loss)

    elif opt.model == 'distance_transformer':

        # from onmt.models.relative_transformer import RelativeTransformerDecoder, RelativeTransformer
        from onmt.models.distance_transformer import DistanceTransformerEncoder

        if opt.encoder_type == "text":
            encoder = DistanceTransformerEncoder(
                opt,
                embedding_src,
                None,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        if opt.encoder_type == "audio":
            # raise NotImplementedError
            encoder = DistanceTransformerEncoder(
                opt,
                None,
                None,
                encoder_type=opt.encoder_type,
                language_embeddings=language_embeddings)

        generator = nn.ModuleList(generators)
        decoder = DistanceTransformerDecoder(
            opt, embedding_tgt, None, language_embeddings=language_embeddings)
        model = Transformer(encoder,
                            decoder,
                            generator,
                            mirror=opt.mirror_loss)

    elif opt.model == 'universal_transformer':
        from onmt.models.universal_transformer import UniversalTransformerDecoder, UniversalTransformerEncoder

        generator = nn.ModuleList(generators)

        if opt.encoder_type == "text":
            encoder = UniversalTransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        elif opt.encoder_type == "audio":
            encoder = UniversalTransformerEncoder(opt, None,
                                                  positional_encoder,
                                                  opt.encoder_type)

        decoder = UniversalTransformerDecoder(
            opt,
            embedding_tgt,
            positional_encoder,
            language_embeddings=language_embeddings)

        model = Transformer(encoder,
                            decoder,
                            generator,
                            mirror=opt.mirror_loss)

    elif opt.model == 'relative_universal_transformer':
        from onmt.models.relative_universal_transformer import \
            RelativeUniversalTransformerEncoder, RelativeUniversalTransformerDecoder
        generator = nn.ModuleList(generators)

        if opt.encoder_type == "text":
            encoder = RelativeUniversalTransformerEncoder(
                opt,
                embedding_src,
                None,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        elif opt.encoder_type == "audio":
            encoder = RelativeUniversalTransformerDecoder(
                opt, None, None, opt.encoder_type)

        decoder = RelativeUniversalTransformerDecoder(
            opt, embedding_tgt, None, language_embeddings=language_embeddings)

        model = RelativeTransformer(encoder,
                                    decoder,
                                    generator,
                                    mirror=opt.mirror_loss)

    elif opt.model == 'relative_unified_transformer':
        from onmt.models.relative_unified_transformer import RelativeUnifiedTransformer

        if opt.encoder_type == "audio":
            raise NotImplementedError

        generator = nn.ModuleList(generators)
        model = RelativeUnifiedTransformer(
            opt,
            embedding_src,
            embedding_tgt,
            generator,
            positional_encoder,
            language_embeddings=language_embeddings)

    elif opt.model == 'memory_transformer':
        from onmt.models.memory_transformer import MemoryTransformer

        if opt.encoder_type == "audio":
            raise NotImplementedError

        generator = nn.ModuleList(generators)
        model = MemoryTransformer(opt,
                                  embedding_src,
                                  embedding_tgt,
                                  generator,
                                  positional_encoder,
                                  language_embeddings=language_embeddings,
                                  dictionary=dicts['tgt'])

    else:
        raise NotImplementedError

    if opt.tie_weights:
        print("* Joining the weights of decoder input and output embeddings")
        model.tie_weights()

    return model
def build_tm_model(opt, dicts):
    # BUILD POSITIONAL ENCODING
    if opt.time == 'positional_encoding':
        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=MAX_LEN)
    else:
        raise NotImplementedError

    # BUILD GENERATOR
    if opt.copy_generator:
        generators = [
            CopyGenerator(opt.model_size,
                          dicts['tgt'].size(),
                          fix_norm=opt.fix_norm_output_embedding)
        ]
    else:
        generators = [
            onmt.modules.base_seq2seq.Generator(
                opt.model_size,
                dicts['tgt'].size(),
                fix_norm=opt.fix_norm_output_embedding)
        ]

    # BUILD EMBEDDINGS
    if 'src' in dicts:
        embedding_src = nn.Embedding(dicts['src'].size(),
                                     opt.model_size,
                                     padding_idx=onmt.constants.PAD)
    else:
        embedding_src = None

    if opt.join_embedding and embedding_src is not None:
        embedding_tgt = embedding_src
        print("* Joining the weights of encoder and decoder word embeddings")
    else:
        embedding_tgt = nn.Embedding(dicts['tgt'].size(),
                                     opt.model_size,
                                     padding_idx=onmt.constants.PAD)

    if opt.use_language_embedding:
        print("* Create language embeddings with %d languages" %
              len(dicts['langs']))
        language_embeddings = nn.Embedding(len(dicts['langs']), opt.model_size)
    else:
        language_embeddings = None

    if opt.ctc_loss != 0:
        generators.append(
            onmt.modules.base_seq2seq.Generator(opt.model_size,
                                                dicts['tgt'].size() + 1))

    if opt.model in ['transformer', 'stochastic_transformer']:
        onmt.constants.init_value = opt.param_init

        if opt.encoder_type == "text":
            encoder = TransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        elif opt.encoder_type == "audio":
            encoder = TransformerEncoder(opt, None, positional_encoder,
                                         opt.encoder_type)
        elif opt.encoder_type == "mix":
            text_encoder = TransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                "text",
                language_embeddings=language_embeddings)
            audio_encoder = TransformerEncoder(opt, None, positional_encoder,
                                               "audio")
            encoder = MixedEncoder(text_encoder, audio_encoder)
        else:
            print("Unknown encoder type:", opt.encoder_type)
            exit(-1)

        decoder = TransformerDecoder(opt,
                                     embedding_tgt,
                                     positional_encoder,
                                     language_embeddings=language_embeddings)

        model = Transformer(encoder,
                            decoder,
                            nn.ModuleList(generators),
                            mirror=opt.mirror_loss)

    elif opt.model == 'relative_transformer':

        from onmt.models.relative_transformer import RelativeTransformerEncoder, RelativeTransformerDecoder, \
            RelativeTransformer

        if opt.encoder_type == "text":
            encoder = RelativeTransformerEncoder(
                opt,
                embedding_src,
                None,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        if opt.encoder_type == "audio":
            # raise NotImplementedError
            encoder = RelativeTransformerEncoder(
                opt,
                None,
                None,
                encoder_type=opt.encoder_type,
                language_embeddings=language_embeddings)

        generator = nn.ModuleList(generators)
        decoder = RelativeTransformerDecoder(
            opt, embedding_tgt, None, language_embeddings=language_embeddings)
        model = RelativeTransformer(encoder,
                                    decoder,
                                    generator,
                                    mirror=opt.mirror_loss)

    elif opt.model == 'distance_transformer':

        from onmt.models.relative_transformer import RelativeTransformerDecoder, RelativeTransformer
        from onmt.models.distance_transformer import DistanceTransformerEncoder

        if opt.encoder_type == "text":
            encoder = DistanceTransformerEncoder(
                opt,
                embedding_src,
                None,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        if opt.encoder_type == "audio":
            # raise NotImplementedError
            encoder = DistanceTransformerEncoder(
                opt,
                None,
                None,
                encoder_type=opt.encoder_type,
                language_embeddings=language_embeddings)

        generator = nn.ModuleList(generators)
        decoder = RelativeTransformerDecoder(
            opt, embedding_tgt, None, language_embeddings=language_embeddings)
        model = RelativeTransformer(encoder,
                                    decoder,
                                    generator,
                                    mirror=opt.mirror_loss)

    elif opt.model == 'unified_transformer':
        from onmt.models.unified_transformer import UnifiedTransformer

        if opt.encoder_type == "audio":
            raise NotImplementedError

        generator = nn.ModuleList(generators)
        model = UnifiedTransformer(opt,
                                   embedding_src,
                                   embedding_tgt,
                                   generator,
                                   positional_encoder,
                                   language_embeddings=language_embeddings)

    elif opt.model == 'relative_unified_transformer':
        from onmt.models.relative_unified_transformer import RelativeUnifiedTransformer

        if opt.encoder_type == "audio":
            raise NotImplementedError

        generator = nn.ModuleList(generators)
        model = RelativeUnifiedTransformer(
            opt,
            embedding_src,
            embedding_tgt,
            generator,
            positional_encoder,
            language_embeddings=language_embeddings)

    elif opt.model == 'memory_transformer':
        from onmt.models.memory_transformer import MemoryTransformer

        if opt.encoder_type == "audio":
            raise NotImplementedError

        generator = nn.ModuleList(generators)
        model = MemoryTransformer(opt,
                                  embedding_src,
                                  embedding_tgt,
                                  generator,
                                  positional_encoder,
                                  language_embeddings=language_embeddings,
                                  dictionary=dicts['tgt'])

    else:
        raise NotImplementedError

    if opt.tie_weights:
        print("* Joining the weights of decoder input and output embeddings")
        model.tie_weights()

    return model
Beispiel #8
0
def build_tm_model(opt, dicts):
    # BUILD POSITIONAL ENCODING
    if opt.time == 'positional_encoding':
        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=MAX_LEN)
    else:
        raise NotImplementedError

    # temporary fix for speech autoencoder

    if opt.reconstruct:
        # reconstruction is only compatible
        assert opt.model == 'relative_transformer'
        assert opt.encoder_type == 'text'

    # BUILD GENERATOR
    if opt.copy_generator:
        if opt.nce_noise > 0:
            print("[INFO] Copy generator overrides NCE.")
            opt.nce = False
            opt.nce_noise = 0
        generators = [
            CopyGenerator(opt.model_size,
                          dicts['tgt'].size(),
                          fix_norm=opt.fix_norm_output_embedding)
        ]
    elif opt.nce_noise > 0:
        from onmt.modules.nce.nce_linear import NCELinear
        from onmt.modules.nce.nce_utils import build_unigram_noise
        noise_distribution = build_unigram_noise(
            torch.FloatTensor(list(dicts['tgt'].frequencies.values())))

        generator = NCELinear(opt.model_size,
                              dicts['tgt'].size(),
                              fix_norm=opt.fix_norm_output_embedding,
                              noise_distribution=noise_distribution,
                              noise_ratio=opt.nce_noise)
        generators = [generator]
    else:
        if "tgt" in dicts:
            generators = [
                onmt.modules.base_seq2seq.Generator(
                    opt.model_size,
                    dicts['tgt'].size(),
                    fix_norm=opt.fix_norm_output_embedding)
            ]

    # BUILD EMBEDDINGS
    if 'src' in dicts:
        embedding_src = nn.Embedding(dicts['src'].size(),
                                     opt.model_size,
                                     padding_idx=onmt.constants.PAD)
    else:
        embedding_src = None

    if opt.join_embedding and embedding_src is not None:
        embedding_tgt = embedding_src
        print("* Joining the weights of encoder and decoder word embeddings")
    elif 'tgt' in dicts:

        embedding_tgt = nn.Embedding(dicts['tgt'].size(),
                                     opt.model_size,
                                     padding_idx=onmt.constants.PAD)  # done
    else:
        embedding_tgt = None

    if opt.use_language_embedding:  # done
        print("* Create language embeddings with %d languages" %
              len(dicts['langs']))
        language_embeddings = nn.Embedding(len(dicts['langs']), opt.model_size)
    else:
        language_embeddings = None

    if opt.ctc_loss != 0:
        generators.append(
            onmt.modules.base_seq2seq.Generator(opt.model_size,
                                                dicts['tgt'].size() + 1))

    if opt.model in ['conformer', 'speech_transformer']:
        onmt.constants.init_value = opt.param_init
        from onmt.models.speech_recognizer.relative_transformer import \
            SpeechTransformerEncoder, SpeechTransformerDecoder

        if opt.model == 'conformer':
            from onmt.models.speech_recognizer.conformer import ConformerEncoder
            opt.cnn_downsampling = True  # force this bool to have masking at decoder to be corrected
            encoder = ConformerEncoder(opt, None, None, 'audio')
        else:
            encoder = SpeechTransformerEncoder(opt, None, positional_encoder,
                                               opt.encoder_type)

        decoder = SpeechTransformerDecoder(
            opt,
            embedding_tgt,
            positional_encoder,
            language_embeddings=language_embeddings)
        model = RelativeTransformer(encoder,
                                    decoder,
                                    nn.ModuleList(generators),
                                    None,
                                    None,
                                    mirror=opt.mirror_loss)

    elif opt.model == "LSTM_v2":
        onmt.constants.init_value = opt.param_init
        from onmt.models.speech_recognizer.lstm_v2 import SpeechLSTMDecoder, SpeechLSTMEncoder, SpeechLSTMSeq2Seq

        encoder = SpeechLSTMEncoder(opt, None, opt.encoder_type)

        decoder = SpeechLSTMDecoder(opt,
                                    embedding_tgt,
                                    language_embeddings=language_embeddings)

        model = SpeechLSTMSeq2Seq(encoder, decoder, nn.ModuleList(generators))
        # print(model)
        # sys.exit()

    elif opt.model == "LSTM":

        onmt.constants.init_value = opt.param_init
        from onmt.models.LSTM_based import SpeechLSTMDecoder, SpeechLSTMEncoder, SpeechLSTMSeq2Seq

        encoder = SpeechLSTMEncoder(opt, None, opt.encoder_type)

        decoder = SpeechLSTMDecoder(opt,
                                    embedding_tgt,
                                    language_embeddings=language_embeddings)

        model = SpeechLSTMSeq2Seq(encoder, decoder, nn.ModuleList(generators))
    elif opt.model in ['multilingual_translator', 'translator']:
        onmt.constants.init_value = opt.param_init
        from onmt.models.multilingual_translator.relative_transformer import \
            RelativeTransformerEncoder, RelativeTransformerDecoder

        encoder = RelativeTransformerEncoder(
            opt,
            embedding_src,
            None,
            opt.encoder_type,
            language_embeddings=language_embeddings)
        decoder = RelativeTransformerDecoder(
            opt, embedding_tgt, None, language_embeddings=language_embeddings)

        model = RelativeTransformer(encoder,
                                    decoder,
                                    nn.ModuleList(generators),
                                    None,
                                    None,
                                    mirror=opt.mirror_loss)
    elif opt.model == "speech_ae" or opt.model == "speech2speech":
        onmt.constants.init_value = opt.param_init
        from onmt.models.speech_metamorphosis.speechae import SpeechLSTMEncoder
        from onmt.models.speech_metamorphosis.speechae import TacotronDecoder, SpeechAE

        encoder = SpeechLSTMEncoder(opt, None, opt.encoder_type)

        decoder = TacotronDecoder(opt)

        model = SpeechAE(encoder, decoder, opt)

        print("Create speech autoencoder successfully")

    elif opt.model == "speech_FN":

        onmt.constants.init_value = opt.param_init
        from onmt.models.speech_metamorphosis.speechae import SpeechLSTMEncoder, LatentDiscrinator
        from onmt.models.speech_metamorphosis.speechae import TacotronDecoder, SpeechAE

        encoder = SpeechLSTMEncoder(opt, None, opt.encoder_type)

        decoder = TacotronDecoder(opt, accent_emdedding=language_embeddings)

        model_ae = SpeechAE(encoder, decoder, opt)

        lat_dis = LatentDiscrinator(opt,
                                    hidden_size=128,
                                    output_size=len(dicts['langs']))

        model = (model_ae, lat_dis)

        print("Create speech fader network successfully")

    elif opt.model in ['transformer', 'stochastic_transformer']:
        onmt.constants.init_value = opt.param_init

        if opt.encoder_type == "text":
            encoder = TransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        elif opt.encoder_type == "audio":
            encoder = TransformerEncoder(opt, None, positional_encoder,
                                         opt.encoder_type)
        elif opt.encoder_type == "mix":
            text_encoder = TransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                "text",
                language_embeddings=language_embeddings)
            audio_encoder = TransformerEncoder(opt, None, positional_encoder,
                                               "audio")
            encoder = MixedEncoder(text_encoder, audio_encoder)
        else:
            print("Unknown encoder type:", opt.encoder_type)
            exit(-1)

        decoder = TransformerDecoder(opt,
                                     embedding_tgt,
                                     positional_encoder,
                                     language_embeddings=language_embeddings)

        model = Transformer(encoder,
                            decoder,
                            nn.ModuleList(generators),
                            mirror=opt.mirror_loss)

    elif opt.model == 'relative_transformer':
        from onmt.models.relative_transformer import \
            RelativeTransformerEncoder, RelativeTransformerDecoder

        if opt.encoder_type == "text":
            encoder = RelativeTransformerEncoder(
                opt,
                embedding_src,
                None,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        if opt.encoder_type == "audio":
            # raise NotImplementedError
            encoder = RelativeTransformerEncoder(
                opt,
                None,
                None,
                encoder_type=opt.encoder_type,
                language_embeddings=language_embeddings)

        generator = nn.ModuleList(generators)
        decoder = RelativeTransformerDecoder(
            opt, embedding_tgt, None, language_embeddings=language_embeddings)

        if opt.reconstruct:
            rev_decoder = RelativeTransformerDecoder(
                opt,
                embedding_src,
                None,
                language_embeddings=language_embeddings)
            rev_generator = [
                onmt.modules.base_seq2seq.Generator(
                    opt.model_size,
                    dicts['src'].size(),
                    fix_norm=opt.fix_norm_output_embedding)
            ]
            rev_generator = nn.ModuleList(rev_generator)
        else:
            rev_decoder = None
            rev_generator = None

        model = RelativeTransformer(encoder,
                                    decoder,
                                    generator,
                                    rev_decoder,
                                    rev_generator,
                                    mirror=opt.mirror_loss)

    elif opt.model == 'universal_transformer':
        from onmt.models.universal_transformer import UniversalTransformerDecoder, UniversalTransformerEncoder

        generator = nn.ModuleList(generators)

        if opt.encoder_type == "text":
            encoder = UniversalTransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        elif opt.encoder_type == "audio":
            encoder = UniversalTransformerEncoder(opt, None,
                                                  positional_encoder,
                                                  opt.encoder_type)

        decoder = UniversalTransformerDecoder(
            opt,
            embedding_tgt,
            positional_encoder,
            language_embeddings=language_embeddings)

        model = Transformer(encoder,
                            decoder,
                            generator,
                            mirror=opt.mirror_loss)

    elif opt.model == 'relative_universal_transformer':
        from onmt.models.relative_universal_transformer import \
            RelativeUniversalTransformerEncoder, RelativeUniversalTransformerDecoder
        generator = nn.ModuleList(generators)

        if opt.encoder_type == "text":
            encoder = RelativeUniversalTransformerEncoder(
                opt,
                embedding_src,
                None,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        elif opt.encoder_type == "audio":
            encoder = RelativeUniversalTransformerDecoder(
                opt, None, None, opt.encoder_type)

        decoder = RelativeUniversalTransformerDecoder(
            opt, embedding_tgt, None, language_embeddings=language_embeddings)

        model = RelativeTransformer(encoder,
                                    decoder,
                                    generator,
                                    mirror=opt.mirror_loss)

    elif opt.model == 'relative_unified_transformer':
        from onmt.models.relative_unified_transformer import RelativeUnifiedTransformer

        if opt.encoder_type == "audio":
            raise NotImplementedError

        generator = nn.ModuleList(generators)
        model = RelativeUnifiedTransformer(
            opt,
            embedding_src,
            embedding_tgt,
            generator,
            positional_encoder,
            language_embeddings=language_embeddings)

    elif opt.model == 'memory_transformer':
        from onmt.models.memory_transformer import MemoryTransformer

        if opt.encoder_type == "audio":
            raise NotImplementedError

        generator = nn.ModuleList(generators)
        model = MemoryTransformer(opt,
                                  embedding_src,
                                  embedding_tgt,
                                  generator,
                                  positional_encoder,
                                  language_embeddings=language_embeddings,
                                  dictionary=dicts['tgt'])

    else:
        raise NotImplementedError

    if opt.tie_weights:
        print("* Joining the weights of decoder input and output embeddings")
        model.tie_weights()
    if opt.tie_weights_lid:
        print("* Joining the weights of output lid  and language embeddings")
        model.tie_weights_lid()
    return model
Beispiel #9
0
def build_tm_model(opt, dicts):
    onmt.constants = add_tokenidx(opt, onmt.constants, dicts)

    # BUILD POSITIONAL ENCODING
    if opt.time == 'positional_encoding':
        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=MAX_LEN)
    else:
        raise NotImplementedError

    if opt.reconstruct:
        # reconstruction is only compatible
        assert opt.model == 'relative_transformer'
        assert opt.encoder_type == 'text'

    # BUILD GENERATOR
    if opt.copy_generator:
        if opt.nce_noise > 0:
            print("[INFO] Copy generator overrides NCE.")
            opt.nce = False
            opt.nce_noise = 0
        generators = [
            CopyGenerator(opt.model_size,
                          dicts['tgt'].size(),
                          fix_norm=opt.fix_norm_output_embedding)
        ]
    elif opt.nce_noise > 0:
        from onmt.modules.nce.nce_linear import NCELinear
        from onmt.modules.nce.nce_utils import build_unigram_noise
        noise_distribution = build_unigram_noise(
            torch.FloatTensor(list(dicts['tgt'].frequencies.values())))

        generator = NCELinear(opt.model_size,
                              dicts['tgt'].size(),
                              fix_norm=opt.fix_norm_output_embedding,
                              noise_distribution=noise_distribution,
                              noise_ratio=opt.nce_noise)
        generators = [generator]
    else:
        generators = [
            onmt.modules.base_seq2seq.Generator(
                opt.model_size,
                dicts['tgt'].size(),
                fix_norm=opt.fix_norm_output_embedding)
        ]

    # BUILD EMBEDDINGS
    if 'src' in dicts:
        if (not hasattr(opt, "enc_pretrained_model")) or (
                not opt.enc_pretrained_model):
            embedding_src = nn.Embedding(dicts['src'].size(),
                                         opt.model_size,
                                         padding_idx=onmt.constants.SRC_PAD)
    else:
        embedding_src = None

    if opt.join_embedding and embedding_src is not None:
        embedding_tgt = embedding_src
        print("* Joining the weights of encoder and decoder word embeddings")
    elif not opt.dec_pretrained_model:
        embedding_tgt = nn.Embedding(dicts['tgt'].size(),
                                     opt.model_size,
                                     padding_idx=onmt.constants.TGT_PAD)
    else:
        assert opt.model == "pretrain_transformer"
        embedding_tgt = None

    if opt.use_language_embedding:
        print("* Create language embeddings with %d languages" %
              len(dicts['langs']))
        language_embeddings = nn.Embedding(len(dicts['langs']), opt.model_size)
    else:
        language_embeddings = None

    if opt.ctc_loss != 0:
        generators.append(
            onmt.modules.base_seq2seq.Generator(opt.model_size,
                                                dicts['tgt'].size() + 1))

    if opt.model in ['conformer', 'speech_transformer', 'hybrid_transformer']:
        onmt.constants.init_value = opt.param_init
        from onmt.models.speech_recognizer.relative_transformer import \
            SpeechTransformerEncoder, SpeechTransformerDecoder

        if opt.model == 'conformer':
            from onmt.models.speech_recognizer.conformer import ConformerEncoder, Conformer
            from onmt.models.speech_recognizer.lstm import SpeechLSTMDecoder
            opt.cnn_downsampling = True  # force this bool to have masking at decoder to be corrected
            encoder = ConformerEncoder(opt, None, None, 'audio')

            # decoder = SpeechLSTMDecoder(opt, embedding_tgt, language_embeddings=language_embeddings)
            decoder = SpeechTransformerDecoder(
                opt,
                embedding_tgt,
                positional_encoder,
                language_embeddings=language_embeddings)

            # model = Conformer(encoder, decoder, nn.ModuleList(generators), ctc=opt.ctc_loss > 0.0)
            model = RelativeTransformer(encoder,
                                        decoder,
                                        nn.ModuleList(generators),
                                        None,
                                        None,
                                        mirror=opt.mirror_loss,
                                        ctc=opt.ctc_loss > 0.0)
        elif opt.model == 'hybrid_transformer':
            from onmt.models.speech_recognizer.lstm import SpeechLSTMDecoder, SpeechLSTMEncoder, SpeechLSTMSeq2Seq
            encoder = SpeechTransformerEncoder(opt, None, positional_encoder,
                                               opt.encoder_type)

            decoder = SpeechLSTMDecoder(
                opt, embedding_tgt, language_embeddings=language_embeddings)

            model = SpeechLSTMSeq2Seq(encoder,
                                      decoder,
                                      nn.ModuleList(generators),
                                      ctc=opt.ctc_loss > 0.0)
        else:
            encoder = SpeechTransformerEncoder(opt, None, positional_encoder,
                                               opt.encoder_type)

            decoder = SpeechTransformerDecoder(
                opt,
                embedding_tgt,
                positional_encoder,
                language_embeddings=language_embeddings)
            model = RelativeTransformer(encoder,
                                        decoder,
                                        nn.ModuleList(generators),
                                        None,
                                        None,
                                        mirror=opt.mirror_loss,
                                        ctc=opt.ctc_loss > 0.0)

        # If we use the multilingual model and weights are partitioned:
        if opt.multilingual_partitioned_weights:
            # this is basically the language embeddings
            factor_embeddings = nn.Embedding(len(dicts['langs']),
                                             opt.mpw_factor_size)

            encoder.factor_embeddings = factor_embeddings
            decoder.factor_embeddings = factor_embeddings

    elif opt.model in ["LSTM", 'lstm']:
        # print("LSTM")
        onmt.constants.init_value = opt.param_init
        from onmt.models.speech_recognizer.lstm import SpeechLSTMDecoder, SpeechLSTMEncoder, SpeechLSTMSeq2Seq

        encoder = SpeechLSTMEncoder(opt, None, opt.encoder_type)

        decoder = SpeechLSTMDecoder(opt,
                                    embedding_tgt,
                                    language_embeddings=language_embeddings)

        model = SpeechLSTMSeq2Seq(encoder,
                                  decoder,
                                  nn.ModuleList(generators),
                                  ctc=opt.ctc_loss > 0.0)

    elif opt.model in ['multilingual_translator', 'translator']:
        onmt.constants.init_value = opt.param_init
        from onmt.models.multilingual_translator.relative_transformer import \
            RelativeTransformerEncoder, RelativeTransformerDecoder

        encoder = RelativeTransformerEncoder(
            opt,
            embedding_src,
            None,
            opt.encoder_type,
            language_embeddings=language_embeddings)
        decoder = RelativeTransformerDecoder(
            opt, embedding_tgt, None, language_embeddings=language_embeddings)

        model = RelativeTransformer(encoder,
                                    decoder,
                                    nn.ModuleList(generators),
                                    None,
                                    None,
                                    mirror=opt.mirror_loss)

    elif opt.model in ['transformer', 'stochastic_transformer']:
        onmt.constants.init_value = opt.param_init

        if opt.encoder_type == "text":
            encoder = TransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        elif opt.encoder_type == "audio":
            encoder = TransformerEncoder(opt, None, positional_encoder,
                                         opt.encoder_type)
        elif opt.encoder_type == "mix":
            text_encoder = TransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                "text",
                language_embeddings=language_embeddings)
            audio_encoder = TransformerEncoder(opt, None, positional_encoder,
                                               "audio")
            encoder = MixedEncoder(text_encoder, audio_encoder)
        else:
            print("Unknown encoder type:", opt.encoder_type)
            exit(-1)

        decoder = TransformerDecoder(opt,
                                     embedding_tgt,
                                     positional_encoder,
                                     language_embeddings=language_embeddings)

        model = Transformer(encoder,
                            decoder,
                            nn.ModuleList(generators),
                            mirror=opt.mirror_loss)

    elif opt.model == 'relative_transformer':
        from onmt.models.relative_transformer import \
            RelativeTransformerEncoder, RelativeTransformerDecoder

        if opt.encoder_type == "text":
            encoder = RelativeTransformerEncoder(
                opt,
                embedding_src,
                None,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        if opt.encoder_type == "audio":
            # raise NotImplementedError
            encoder = RelativeTransformerEncoder(
                opt,
                None,
                None,
                encoder_type=opt.encoder_type,
                language_embeddings=language_embeddings)

        generator = nn.ModuleList(generators)
        decoder = RelativeTransformerDecoder(
            opt, embedding_tgt, None, language_embeddings=language_embeddings)

        if opt.reconstruct:
            rev_decoder = RelativeTransformerDecoder(
                opt,
                embedding_src,
                None,
                language_embeddings=language_embeddings)
            rev_generator = [
                onmt.modules.base_seq2seq.Generator(
                    opt.model_size,
                    dicts['src'].size(),
                    fix_norm=opt.fix_norm_output_embedding)
            ]
            rev_generator = nn.ModuleList(rev_generator)
        else:
            rev_decoder = None
            rev_generator = None

        model = RelativeTransformer(encoder,
                                    decoder,
                                    generator,
                                    rev_decoder,
                                    rev_generator,
                                    mirror=opt.mirror_loss)

    elif opt.model == 'universal_transformer':
        from onmt.legacy.old_models.universal_transformer import UniversalTransformerDecoder, UniversalTransformerEncoder

        generator = nn.ModuleList(generators)

        if opt.encoder_type == "text":
            encoder = UniversalTransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        elif opt.encoder_type == "audio":
            encoder = UniversalTransformerEncoder(opt, None,
                                                  positional_encoder,
                                                  opt.encoder_type)

        decoder = UniversalTransformerDecoder(
            opt,
            embedding_tgt,
            positional_encoder,
            language_embeddings=language_embeddings)

        model = Transformer(encoder,
                            decoder,
                            generator,
                            mirror=opt.mirror_loss)

    elif opt.model == 'pretrain_transformer':
        assert (opt.enc_pretrained_model or opt.dec_pretrained_model)
        from onmt.models.pretrain_transformer import PretrainTransformer
        print(f"pos_emb_type: {opt.pos_emb_type}")
        print(f"max_pos_length: {opt.max_pos_length }")
        print(
            f"Share position embeddings cross heads: {not opt.diff_head_pos}")
        print()
        if opt.enc_pretrained_model:
            print("* Build encoder with enc_pretrained_model: {}".format(
                opt.enc_pretrained_model))
        if opt.enc_pretrained_model == "bert":
            from pretrain_module.configuration_bert import BertConfig
            from pretrain_module.modeling_bert import BertModel

            enc_bert_config = BertConfig.from_json_file(opt.enc_config_file)
            encoder = BertModel(
                enc_bert_config,
                bert_word_dropout=opt.enc_pretrain_word_dropout,
                bert_emb_dropout=opt.enc_pretrain_emb_dropout,
                bert_atten_dropout=opt.enc_pretrain_attn_dropout,
                bert_hidden_dropout=opt.enc_pretrain_hidden_dropout,
                bert_hidden_size=opt.enc_pretrain_hidden_size,
                is_decoder=False,
                before_plm_output_ln=opt.before_enc_output_ln,
                gradient_checkpointing=opt.enc_gradient_checkpointing,
                max_pos_len=opt.max_pos_length,
                diff_head_pos=opt.diff_head_pos,
                pos_emb_type=opt.pos_emb_type,
            )

        elif opt.enc_pretrained_model == "roberta":
            from pretrain_module.configuration_roberta import RobertaConfig
            from pretrain_module.modeling_roberta import RobertaModel
            enc_roberta_config = RobertaConfig.from_json_file(
                opt.enc_config_file)

            encoder = RobertaModel(
                enc_roberta_config,
                bert_word_dropout=opt.enc_pretrain_word_dropout,
                bert_emb_dropout=opt.enc_pretrain_emb_dropout,
                bert_atten_dropout=opt.enc_pretrain_attn_dropout,
                bert_hidden_dropout=opt.enc_pretrain_hidden_dropout,
                bert_hidden_size=opt.enc_pretrain_hidden_size,
                is_decoder=False,
                before_plm_output_ln=opt.before_enc_output_ln,
                gradient_checkpointing=opt.enc_gradient_checkpointing,
                max_pos_len=opt.max_pos_length,
                diff_head_pos=opt.diff_head_pos,
                pos_emb_type=opt.pos_emb_type,
            )
        elif not opt.enc_pretrained_model:
            print(" Encoder is not from pretrained model")
            encoder = TransformerEncoder(
                opt,
                embedding_src,
                positional_encoder,
                opt.encoder_type,
                language_embeddings=language_embeddings)
        else:
            print("Warning: only bert and roberta are implemented for encoder")
            exit(-1)

        if opt.load_from or not opt.enc_state_dict:
            if opt.verbose:
                print("  No weights loading from {} for encoder".format(
                    opt.enc_pretrained_model))
        elif opt.enc_pretrained_model:
            print("  Loading weights for encoder from: \n", opt.enc_state_dict)

            enc_model_state_dict = torch.load(opt.enc_state_dict,
                                              map_location="cpu")

            encoder.from_pretrained(state_dict=enc_model_state_dict,
                                    model=encoder,
                                    output_loading_info=opt.verbose,
                                    model_prefix=opt.enc_pretrained_model)

        if opt.dec_pretrained_model:
            print("* Build decoder with dec_pretrained_model: {}".format(
                opt.dec_pretrained_model))

        if opt.dec_pretrained_model == "bert":
            if opt.enc_pretrained_model != "bert":
                from pretrain_module.configuration_bert import BertConfig
                from pretrain_module.modeling_bert import BertModel
            dec_bert_config = BertConfig.from_json_file(opt.dec_config_file)
            decoder = BertModel(
                dec_bert_config,
                bert_word_dropout=opt.dec_pretrain_word_dropout,
                bert_emb_dropout=opt.dec_pretrain_emb_dropout,
                bert_atten_dropout=opt.dec_pretrain_attn_dropout,
                bert_hidden_dropout=opt.dec_pretrain_hidden_dropout,
                bert_hidden_size=opt.dec_pretrain_hidden_size,
                is_decoder=True,
                gradient_checkpointing=opt.dec_gradient_checkpointing,
                max_pos_len=opt.max_pos_length,
                diff_head_pos=opt.diff_head_pos,
                pos_emb_type=opt.pos_emb_type,
            )

        elif opt.dec_pretrained_model == "roberta":
            if opt.enc_pretrained_model != "roberta":
                from pretrain_module.configuration_roberta import RobertaConfig
                from pretrain_module.modeling_roberta import RobertaModel

            dec_roberta_config = RobertaConfig.from_json_file(
                opt.dec_config_file)

            decoder = RobertaModel(
                dec_roberta_config,
                bert_word_dropout=opt.dec_pretrain_word_dropout,
                bert_emb_dropout=opt.dec_pretrain_emb_dropout,
                bert_atten_dropout=opt.dec_pretrain_attn_dropout,
                bert_hidden_dropout=opt.dec_pretrain_hidden_dropout,
                bert_hidden_size=opt.dec_pretrain_hidden_size,
                is_decoder=True,
                gradient_checkpointing=opt.dec_gradient_checkpointing,
                max_pos_len=opt.max_pos_length,
                diff_head_pos=opt.diff_head_pos,
                pos_emb_type=opt.pos_emb_type,
            )

        elif not opt.dec_pretrained_model:
            print(" Decoder is not from pretrained model")
            decoder = TransformerDecoder(
                opt,
                embedding_tgt,
                positional_encoder,
                language_embeddings=language_embeddings)
        else:
            print("Warning: only bert and roberta are implemented for decoder")
            exit(-1)

        if opt.load_from or not opt.dec_state_dict:
            if opt.verbose:
                print("  No weights loading from {} for decoder".format(
                    opt.dec_pretrained_model))
        elif opt.enc_pretrained_model:
            print("  Loading weights for decoder from: \n", opt.dec_state_dict)
            dec_model_state_dict = torch.load(opt.dec_state_dict,
                                              map_location="cpu")

            decoder.from_pretrained(state_dict=dec_model_state_dict,
                                    model=decoder,
                                    output_loading_info=opt.verbose,
                                    model_prefix=opt.dec_pretrained_model)

        encoder.enc_pretrained_model = opt.enc_pretrained_model
        decoder.dec_pretrained_model = opt.dec_pretrained_model

        encoder.input_type = opt.encoder_type

        model = PretrainTransformer(encoder, decoder,
                                    nn.ModuleList(generators))
    else:
        raise NotImplementedError

    if opt.tie_weights:
        print("* Joining the weights of decoder input and output embeddings")
        model.tie_weights()

    return model