Esempio n. 1
0
def build_language_model(opt, dicts):

    onmt.Constants.layer_norm = opt.layer_norm
    onmt.Constants.weight_norm = opt.weight_norm
    onmt.Constants.activation_layer = opt.activation_layer
    onmt.Constants.version = 1.0
    onmt.Constants.attention_out = opt.attention_out
    onmt.Constants.residual_type = opt.residual_type

    from onmt.modules.TransformerLM.Models import TransformerLM, TransformerLMDecoder

    positional_encoder = PositionalEncoding(opt.model_size, len_max=MAX_LEN)

    decoder = TransformerLMDecoder(opt, dicts['tgt'], positional_encoder)

    generators = [
        onmt.modules.BaseModel.Generator(opt.model_size, dicts['tgt'].size())
    ]

    model = TransformerLM(None, decoder, nn.ModuleList(generators))

    if opt.tie_weights:
        print("Joining the weights of decoder input and output embeddings")
        model.tie_weights()

    for g in model.generator:
        init.xavier_uniform_(g.linear.weight)

    init.normal_(model.decoder.word_lut.weight,
                 mean=0,
                 std=opt.model_size**-0.5)

    return model
Esempio n. 2
0
def build_tm_model(opt, dicts):

    # BUILD POSITIONAL ENCODING
    if opt.time == 'positional_encoding':
        positional_encoder = PositionalEncoding(opt.model_size, len_max=MAX_LEN)
    else:
        raise NotImplementedError

    # BUILD GENERATOR
    generators = [onmt.modules.BaseModel.Generator(opt.model_size, dicts['tgt'].size())]

    # BUILD EMBEDDING
    if 'src' in dicts:
        embedding_src = nn.Embedding(dicts['src'].size(),
                                     opt.model_size,
                                     padding_idx=onmt.Constants.PAD)
    else:
        embedding_src = None

    if opt.join_embedding and embedding_src is not None:
        embedding_tgt = embedding_src
        print("* Joining the weights of encoder and decoder word embeddings")
    else:
        embedding_tgt = nn.Embedding(dicts['tgt'].size(),
                                     opt.model_size,
                                     padding_idx=onmt.Constants.PAD)

    if 'atb' in dicts and dicts['atb'] is not None:
        from onmt.modules.Utilities import AttributeEmbeddings

        attribute_embeddings = AttributeEmbeddings(dicts['atb'], opt.model_size)

    else:
        attribute_embeddings = None

    if opt.ctc_loss != 0:
        generators.append(onmt.modules.BaseModel.Generator(opt.model_size, dicts['tgt'].size() + 1))

    if opt.model == 'transformer':
        # raise NotImplementedError

        onmt.Constants.init_value = opt.param_init

        if opt.encoder_type == "text":
            encoder = TransformerEncoder(opt, embedding_src, positional_encoder, opt.encoder_type)
        elif opt.encoder_type == "audio":
            encoder = TransformerEncoder(opt, None, positional_encoder, opt.encoder_type)
        elif opt.encoder_type == "mix":
            text_encoder = TransformerEncoder(opt, embedding_src, positional_encoder, "text")
            audio_encoder = TransformerEncoder(opt, None, positional_encoder, "audio")
            encoder = MixedEncoder(text_encoder, audio_encoder)
        else:
            print ("Unknown encoder type:", opt.encoder_type)
            exit(-1)

        decoder = TransformerDecoder(opt, embedding_tgt, positional_encoder, attribute_embeddings=attribute_embeddings)

        model = Transformer(encoder, decoder, nn.ModuleList(generators))

    elif opt.model == 'stochastic_transformer':
        
        from onmt.modules.StochasticTransformer.Models import StochasticTransformerEncoder, StochasticTransformerDecoder

        onmt.Constants.weight_norm = opt.weight_norm
        onmt.Constants.init_value = opt.param_init
        
        if opt.encoder_type == "text":
            encoder = StochasticTransformerEncoder(opt, embedding_src, positional_encoder, opt.encoder_type)
        elif opt.encoder_type == "audio":
            encoder = StochasticTransformerEncoder(opt, 0, positional_encoder, opt.encoder_type)
        elif opt.encoder_type == "mix":
            text_encoder = StochasticTransformerEncoder(opt, embedding_src, positional_encoder, "text")
            audio_encoder = StochasticTransformerEncoder(opt, None, positional_encoder, "audio")
            encoder = MixedEncoder(text_encoder, audio_encoder)
        else:
            print ("Unknown encoder type:", opt.encoder_type)
            exit(-1)

        decoder = StochasticTransformerDecoder(opt, embedding_tgt, positional_encoder, attribute_embeddings=attribute_embeddings)

        model = Transformer(encoder, decoder, nn.ModuleList(generators))

    elif opt.model == 'dlcltransformer' :

        from onmt.modules.DynamicTransformer.Models import DlclTransformerDecoder, DlclTransformerEncoder

        if opt.encoder_type == "text":
            encoder = DlclTransformerEncoder(opt, embedding_src, positional_encoder, opt.encoder_type)
        elif opt.encoder_type == "audio":
            encoder = DlclTransformerEncoder(opt, None, positional_encoder, opt.encoder_type)

        decoder = DlclTransformerDecoder(opt, embedding_tgt, positional_encoder)

        model = Transformer(encoder, decoder, nn.ModuleList(generators))

    else:
        raise NotImplementedError

    if opt.tie_weights:  
        print("Joining the weights of decoder input and output embeddings")
        model.tie_weights()

    for g in model.generator:
        init.xavier_uniform_(g.linear.weight)

    if opt.encoder_type == "audio":

        if opt.init_embedding == 'xavier':
            init.xavier_uniform_(model.decoder.word_lut.weight)
        elif opt.init_embedding == 'normal':
            init.normal_(model.decoder.word_lut.weight, mean=0, std=opt.model_size ** -0.5)
    else:
        if opt.init_embedding == 'xavier':
            init.xavier_uniform_(model.encoder.word_lut.weight)
            init.xavier_uniform_(model.decoder.word_lut.weight)
        elif opt.init_embedding == 'normal':
            init.normal_(model.encoder.word_lut.weight, mean=0, std=opt.model_size ** -0.5)
            init.normal_(model.decoder.word_lut.weight, mean=0, std=opt.model_size ** -0.5)

    return model
Esempio n. 3
0
def build_model(opt, dicts):

    model = None

    opt = update_backward_compatibility(opt)

    onmt.Constants.layer_norm = opt.layer_norm
    onmt.Constants.weight_norm = opt.weight_norm
    onmt.Constants.activation_layer = opt.activation_layer
    onmt.Constants.version = 1.0
    onmt.Constants.attention_out = opt.attention_out
    onmt.Constants.residual_type = opt.residual_type

    MAX_LEN = onmt.Constants.max_position_length  # This should be the longest sentence from the dataset

    if opt.model == 'recurrent' or opt.model == 'rnn':

        from onmt.modules.rnn.Models import RecurrentEncoder, RecurrentDecoder, RecurrentModel

        encoder = RecurrentEncoder(opt, dicts['src'])

        decoder = RecurrentDecoder(opt, dicts['tgt'])

        generator = onmt.modules.BaseModel.Generator(opt.rnn_size,
                                                     dicts['tgt'].size())

        model = RecurrentModel(encoder, decoder, generator)

    elif opt.model == 'transformer':
        # raise NotImplementedError

        onmt.Constants.init_value = opt.param_init

        if opt.time == 'positional_encoding':
            positional_encoder = PositionalEncoding(opt.model_size,
                                                    len_max=MAX_LEN)
        else:
            positional_encoder = None

        encoder = TransformerEncoder(opt, dicts['src'], positional_encoder)
        decoder = TransformerDecoder(opt, dicts['tgt'], positional_encoder)

        generator = onmt.modules.BaseModel.Generator(opt.model_size,
                                                     dicts['tgt'].size())

        model = Transformer(encoder, decoder, generator)

        #~ print(encoder)

    elif opt.model == 'stochastic_transformer':

        from onmt.modules.StochasticTransformer.Models import StochasticTransformerEncoder, StochasticTransformerDecoder

        onmt.Constants.weight_norm = opt.weight_norm
        onmt.Constants.init_value = opt.param_init

        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=MAX_LEN)
        #~ positional_encoder = None

        encoder = StochasticTransformerEncoder(opt, dicts['src'],
                                               positional_encoder)

        decoder = StochasticTransformerDecoder(opt, dicts['tgt'],
                                               positional_encoder)

        generator = onmt.modules.BaseModel.Generator(opt.model_size,
                                                     dicts['tgt'].size())

        model = Transformer(encoder, decoder, generator)

    elif opt.model == 'fctransformer':

        from onmt.modules.FCTransformer.Models import FCTransformerEncoder, FCTransformerDecoder

        onmt.Constants.weight_norm = opt.weight_norm
        onmt.Constants.init_value = opt.param_init

        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=MAX_LEN)

        encoder = FCTransformerEncoder(opt, dicts['src'], positional_encoder)
        decoder = FCTransformerDecoder(opt, dicts['tgt'], positional_encoder)

        generator = onmt.modules.BaseModel.Generator(opt.model_size,
                                                     dicts['tgt'].size())

        model = Transformer(encoder, decoder, generator)
    elif opt.model == 'ptransformer':

        from onmt.modules.ParallelTransformer.Models import ParallelTransformerEncoder, ParallelTransformerDecoder

        onmt.Constants.weight_norm = opt.weight_norm
        onmt.Constants.init_value = opt.param_init

        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=MAX_LEN)

        encoder = ParallelTransformerEncoder(opt, dicts['src'],
                                             positional_encoder)
        decoder = ParallelTransformerDecoder(opt, dicts['tgt'],
                                             positional_encoder)

        generator = onmt.modules.BaseModel.Generator(opt.model_size,
                                                     dicts['tgt'].size())

        model = Transformer(encoder, decoder, generator)

    elif opt.model in ['universal_transformer', 'utransformer']:

        from onmt.modules.UniversalTransformer.Models import UniversalTransformerDecoder, UniversalTransformerEncoder
        from onmt.modules.UniversalTransformer.Layers import TimeEncoding

        onmt.Constants.weight_norm = opt.weight_norm
        onmt.Constants.init_value = opt.param_init

        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=MAX_LEN)
        time_encoder = TimeEncoding(opt.model_size, len_max=32)

        encoder = UniversalTransformerEncoder(opt, dicts['src'],
                                              positional_encoder, time_encoder)
        decoder = UniversalTransformerDecoder(opt, dicts['tgt'],
                                              positional_encoder, time_encoder)

        generator = onmt.modules.BaseModel.Generator(opt.model_size,
                                                     dicts['tgt'].size())

        model = Transformer(encoder, decoder, generator)

    elif opt.model in ['iid_stochastic_transformer']:

        from onmt.modules.IIDStochasticTransformer.Models import IIDStochasticTransformerEncoder, IIDStochasticTransformerDecoder

        onmt.Constants.weight_norm = opt.weight_norm
        onmt.Constants.init_value = opt.param_init

        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=MAX_LEN)
        #~ positional_encoder = None

        encoder = IIDStochasticTransformerEncoder(opt, dicts['src'],
                                                  positional_encoder)

        decoder = IIDStochasticTransformerDecoder(opt, dicts['tgt'],
                                                  positional_encoder)

        generator = onmt.modules.BaseModel.Generator(opt.model_size,
                                                     dicts['tgt'].size())

        model = Transformer(encoder, decoder, generator)

    elif opt.model in ['reinforce_transformer']:

        from onmt.modules.StochasticTransformer.Models import StochasticTransformerEncoder, StochasticTransformerDecoder
        from onmt.modules.ReinforceTransformer.Models import ReinforcedStochasticDecoder, ReinforceTransformer

        onmt.Constants.weight_norm = opt.weight_norm
        onmt.Constants.init_value = opt.param_init

        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=MAX_LEN)

        encoder = StochasticTransformerEncoder(opt, dicts['src'],
                                               positional_encoder)

        decoder = ReinforceStochasticTransformerDecoder(
            opt, dicts['tgt'], positional_encoder)

        generator = onmt.modules.BaseModel.Generator(opt.model_size,
                                                     dicts['tgt'].size())

        model = ReinforceTransformer(encoder, decoder, generator)

    else:
        raise NotImplementedError

    if opt.tie_weights:
        print("Joining the weights of decoder input and output embeddings")
        model.tie_weights()

    if opt.join_embedding:
        print("Joining the weights of encoder and decoder word embeddings")
        model.share_enc_dec_embedding()

    init = torch.nn.init

    init.xavier_uniform_(model.generator.linear.weight)

    if opt.init_embedding == 'xavier':
        init.xavier_uniform_(model.encoder.word_lut.weight)
        init.xavier_uniform_(model.decoder.word_lut.weight)
    elif opt.init_embedding == 'normal':
        init.normal_(model.encoder.word_lut.weight,
                     mean=0,
                     std=opt.model_size**-0.5)
        init.normal_(model.decoder.word_lut.weight,
                     mean=0,
                     std=opt.model_size**-0.5)

    return model
def build_tm_model(opt, dicts):

    # BUILD POSITIONAL ENCODING
    if opt.time == 'positional_encoding':
        # by me
        # len_max 是否要修改
        positional_encoder = PositionalEncoding(opt.model_size, len_max=MAX_LEN)
    else:
        raise NotImplementedError

    # BUILD GENERATOR
    generators = [onmt.modules.BaseModel.Generator(opt.model_size, dicts['tgt'].size())]

    # BUILD EMBEDDING
    if 'src' in dicts:
        # embedding_src = nn.Embedding(dicts['src'].size(),
        #                              opt.model_size,
        #                              padding_idx=onmt.Constants.PAD)

        # by me 我们用bert的词向量作为embedding, 如果bert的词向量维度和transformer的词向量维度不一致,我们做线性转换
        if onmt.Constants.BERT_HIDDEN != opt.model_size:
            bert_linear = nn.Linear(onmt.Constants.BERT_HIDDEN, opt.model_size)
        else:
            bert_linear = None

    else:
        embedding_src = None

    if opt.join_embedding and embedding_src is not None:
        embedding_tgt = embedding_src
        print("* Joining the weights of encoder and decoder word embeddings")
    else:
        embedding_tgt = nn.Embedding(dicts['tgt'].size(),
                                     opt.model_size,
                                     padding_idx=onmt.Constants.PAD)

    if opt.ctc_loss != 0:
        generators.append(onmt.modules.BaseModel.Generator(opt.model_size, dicts['tgt'].size() + 1))

    if opt.model == 'transformer':
        onmt.Constants.init_value = opt.param_init
        if opt.encoder_type == "text":
            encoder = TransformerEncoder(opt, bert_linear, positional_encoder, opt.encoder_type)
            # 不加载state,只构建一个对象
            if opt.not_load_bert_state:
                print("we dont load the state of Bert from pytorch model or from pretrained model")
                bert_config = BertConfig.from_json_file(opt.bert_config_dir + "/" + opt.bert_config_name)
                bert = BertModel(bert_config)
            # 这里 bert_model_dir 可以是pytorch提供的预训练模型,也可以是经过自己fine_tune的bert 
            else:
                if opt.bert_state_dict:
                    print("after builing bert we load the state from finetuned Bert")
                    finetuned_state_dict = torch.load(opt.bert_state_dict)
                    bert = BertModel.from_pretrained(cache_dir=opt.bert_config_dir, state_dict=finetuned_state_dict)
                else:
                    print("after builing bert we load the state from Pytorch")
                    bert = BertModel.from_pretrained(cache_dir=opt.bert_config_dir)
            replace_layer_norm(bert, "Transformer")

        else:
            print("Unknown encoder type:", opt.encoder_type)
            exit(-1)

        decoder = TransformerDecoder(opt, embedding_tgt, positional_encoder, attribute_embeddings=None)

        model = Transformer(bert, encoder, decoder, nn.ModuleList(generators))


    elif opt.model == 'relative_transformer':
        from onmt.modules.RelativeTransformer.Models import RelativeTransformer
        positional_encoder = SinusoidalPositionalEmbedding(opt.model_size)
        # if opt.encoder_type == "text":
        # encoder = TransformerEncoder(opt, embedding_src, positional_encoder, opt.encoder_ty   pe)
        # encoder = RelativeTransformerEncoder(opt, embedding_src, relative_positional_encoder, opt.encoder_type)
        if opt.encoder_type == "audio":
            raise NotImplementedError
            # encoder = TransformerEncoder(opt, None, positional_encoder, opt.encoder_type)
        generator = nn.ModuleList(generators)
        model = RelativeTransformer(opt, [embedding_src, embedding_tgt], positional_encoder, generator=generator)

    else:
        raise NotImplementedError

    if opt.tie_weights:  
        print("Joining the weights of decoder input and output embeddings")
        model.tie_weights()

    for g in model.generator:
        init.xavier_uniform_(g.linear.weight)

    if opt.encoder_type == "audio":

        if opt.init_embedding == 'xavier':
            init.xavier_uniform_(model.decoder.word_lut.weight)
        elif opt.init_embedding == 'normal':
            init.normal_(model.decoder.word_lut.weight, mean=0, std=opt.model_size ** -0.5)
    else:
        if opt.init_embedding == 'xavier':
            if model.encoder.word_lut:
                init.xavier_uniform_(model.encoder.word_lut.weight)
            init.xavier_uniform_(model.decoder.word_lut.weight)
        elif opt.init_embedding == 'normal':
            if model.encoder.word_lut:
                init.normal_(model.encoder.word_lut.weight, mean=0, std=opt.model_size ** -0.5)
            init.normal_(model.decoder.word_lut.weight, mean=0, std=opt.model_size ** -0.5)

    return model
Esempio n. 5
0
def build_model(opt, dicts):

    model = None

    if not hasattr(opt, 'model'):
        opt.model = 'recurrent'

    if not hasattr(opt, 'layer_norm'):
        opt.layer_norm = 'slow'

    if not hasattr(opt, 'attention_out'):
        opt.attention_out = 'default'

    if not hasattr(opt, 'residual_type'):
        opt.residual_type = 'regular'

    onmt.Constants.layer_norm = opt.layer_norm
    onmt.Constants.weight_norm = opt.weight_norm
    onmt.Constants.activation_layer = opt.activation_layer
    onmt.Constants.version = 1.0
    onmt.Constants.attention_out = opt.attention_out
    onmt.Constants.residual_type = opt.residual_type

    if opt.model == 'recurrent' or opt.model == 'rnn':

        from onmt.modules.rnn.Models import RecurrentEncoder, RecurrentDecoder, RecurrentModel

        encoder = RecurrentEncoder(opt, dicts['src'])

        decoder = RecurrentDecoder(opt, dicts['tgt'])

        generator = onmt.modules.BaseModel.Generator(opt.rnn_size,
                                                     dicts['tgt'].size())

        model = RecurrentModel(encoder, decoder, generator)

    elif opt.model == 'transformer':
        # raise NotImplementedError

        max_size = 262  # This should be the longest sentence from the dataset
        onmt.Constants.init_value = opt.param_init

        if opt.time == 'positional_encoding':
            positional_encoder = PositionalEncoding(opt.model_size,
                                                    len_max=max_size)
        else:
            positional_encoder = None
        #~ elif opt.time == 'gru':
        #~ positional_encoder = nn.GRU(opt.model_size, opt.model_size, 1, batch_first=True)
        #~ elif opt.time == 'lstm':
        #~ positional_encoder = nn.LSTM(opt.model_size, opt.model_size, 1, batch_first=True)

        encoder = TransformerEncoder(opt, dicts['src'], positional_encoder)
        decoder = TransformerDecoder(opt, dicts['tgt'], positional_encoder)

        generator = onmt.modules.BaseModel.Generator(opt.model_size,
                                                     dicts['tgt'].size())

        model = Transformer(encoder, decoder, generator)

        #~ print(encoder)

    elif opt.model == 'stochastic_transformer':

        from onmt.modules.StochasticTransformer.Models import StochasticTransformerEncoder, StochasticTransformerDecoder

        max_size = 256  # This should be the longest sentence from the dataset
        onmt.Constants.weight_norm = opt.weight_norm
        onmt.Constants.init_value = opt.param_init

        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=max_size)
        #~ positional_encoder = None

        encoder = StochasticTransformerEncoder(opt, dicts['src'],
                                               positional_encoder)

        decoder = StochasticTransformerDecoder(opt, dicts['tgt'],
                                               positional_encoder)

        generator = onmt.modules.BaseModel.Generator(opt.model_size,
                                                     dicts['tgt'].size())

        model = Transformer(encoder, decoder, generator)

    elif opt.model == 'fctransformer':

        from onmt.modules.FCTransformer.Models import FCTransformerEncoder, FCTransformerDecoder

        max_size = 256  # This should be the longest sentence from the dataset
        onmt.Constants.weight_norm = opt.weight_norm
        onmt.Constants.init_value = opt.param_init

        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=max_size)

        encoder = FCTransformerEncoder(opt, dicts['src'], positional_encoder)
        decoder = FCTransformerDecoder(opt, dicts['tgt'], positional_encoder)

        generator = onmt.modules.BaseModel.Generator(opt.model_size,
                                                     dicts['tgt'].size())

        model = Transformer(encoder, decoder, generator)
    elif opt.model == 'ptransformer':

        from onmt.modules.ParallelTransformer.Models import ParallelTransformerEncoder, ParallelTransformerDecoder

        max_size = 256  # This should be the longest sentence from the dataset
        onmt.Constants.weight_norm = opt.weight_norm
        onmt.Constants.init_value = opt.param_init

        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=max_size)

        encoder = ParallelTransformerEncoder(opt, dicts['src'],
                                             positional_encoder)
        decoder = ParallelTransformerDecoder(opt, dicts['tgt'],
                                             positional_encoder)

        generator = onmt.modules.BaseModel.Generator(opt.model_size,
                                                     dicts['tgt'].size())

        model = Transformer(encoder, decoder, generator)

        #~ print(encoder)

    else:
        raise NotImplementedError

        #~
    #~ init = torch.nn.init
    #~
    #~ init.xavier_uniform(model.encoder.word_lut.weight)
    #~ init.xavier_uniform(model.decoder.word_lut.weight)

    # Weight tying between decoder input and output embedding:
    if opt.tie_weights:
        print("Joining the weights of decoder input and output embeddings")
        model.tie_weights()

    if opt.join_embedding:
        print("Joining the weights of encoder and decoder word embeddings")
        model.share_enc_dec_embedding()

    return model
Esempio n. 6
0
def build_tm_model(opt, dicts):

    # BUILD POSITIONAL ENCODING
    if opt.time == 'positional_encoding':
        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=MAX_LEN)
    else:
        raise NotImplementedError

    # if dicts['atb'].size() > 0:
    #     feat_embedding = nn.Embedding(dicts['atb'].size(), opt.model_size)
    # else:
    feat_embedding = None

    # BUILD GENERATOR
    generators = [
        onmt.modules.BaseModel.Generator(opt.model_size, dicts['tgt'].size())
    ]

    if opt.ctc_loss != 0:
        generators.append(
            onmt.modules.BaseModel.Generator(opt.model_size,
                                             dicts['tgt'].size() + 1))

    if opt.model == 'transformer':
        # raise NotImplementedError

        onmt.Constants.init_value = opt.param_init

        if opt.encoder_type == "text":
            encoder = TransformerEncoder(opt, dicts['src'], positional_encoder,
                                         opt.encoder_type)
        elif opt.encoder_type == "audio":
            encoder = TransformerEncoder(opt, opt.input_size,
                                         positional_encoder, opt.encoder_type)
        elif opt.encoder_type == "mix":
            text_encoder = TransformerEncoder(opt, dicts['src'],
                                              positional_encoder, "text")
            audio_encoder = TransformerEncoder(opt, opt.input_size,
                                               positional_encoder, "audio")
            encoder = MixedEncoder(text_encoder, audio_encoder)
        else:
            print("Unkown encoder type:", opt.encoder_type)
            exit(-1)

        decoder = TransformerDecoder(opt,
                                     dicts['tgt'],
                                     positional_encoder,
                                     feature_embedding=feat_embedding)

        model = Transformer(encoder, decoder, nn.ModuleList(generators))

    elif opt.model == 'stochastic_transformer':
        """
        The stochastic implementation of the Transformer as in 
        "Very Deep Self-Attention Networks for End-to-End Speech Recognition"
        """
        from onmt.modules.StochasticTransformer.Models import StochasticTransformerEncoder, StochasticTransformerDecoder

        onmt.Constants.weight_norm = opt.weight_norm
        onmt.Constants.init_value = opt.param_init

        if opt.encoder_type == "text":
            encoder = StochasticTransformerEncoder(opt, dicts['src'],
                                                   positional_encoder,
                                                   opt.encoder_type)
        elif opt.encoder_type == "audio":
            encoder = StochasticTransformerEncoder(opt, opt.input_size,
                                                   positional_encoder,
                                                   opt.encoder_type)
        elif opt.encoder_type == "mix":
            text_encoder = StochasticTransformerEncoder(
                opt, dicts['src'], positional_encoder, "text")
            audio_encoder = StochasticTransformerEncoder(
                opt, opt.input_size, positional_encoder, "audio")
            encoder = MixedEncoder(text_encoder, audio_encoder)
        else:
            print("Unknown encoder type:", opt.encoder_type)
            exit(-1)

        decoder = StochasticTransformerDecoder(opt, dicts['tgt'],
                                               positional_encoder)

        model = Transformer(encoder, decoder, nn.ModuleList(generators))

    elif opt.model in ['universal_transformer', 'utransformer']:

        from onmt.modules.UniversalTransformer.Models import UniversalTransformerDecoder, UniversalTransformerEncoder
        from onmt.modules.UniversalTransformer.Layers import TimeEncoding

        onmt.Constants.weight_norm = opt.weight_norm
        onmt.Constants.init_value = opt.param_init

        time_encoder = TimeEncoding(opt.model_size, len_max=32)

        encoder = UniversalTransformerEncoder(opt, dicts['src'],
                                              positional_encoder, time_encoder)
        decoder = UniversalTransformerDecoder(opt, dicts['tgt'],
                                              positional_encoder, time_encoder)

        model = Transformer(encoder, decoder, nn.ModuleList(generators))

    else:
        raise NotImplementedError

    if opt.tie_weights:
        print("Joining the weights of decoder input and output embeddings")
        model.tie_weights()

    if opt.join_embedding:
        print("Joining the weights of encoder and decoder word embeddings")
        model.share_enc_dec_embedding()

    for g in model.generator:
        init.xavier_uniform_(g.linear.weight)

    if opt.encoder_type == "audio":
        init.xavier_uniform_(model.encoder.audio_trans.weight.data)
        if opt.init_embedding == 'xavier':
            init.xavier_uniform_(model.decoder.word_lut.weight)
        elif opt.init_embedding == 'normal':
            init.normal_(model.decoder.word_lut.weight,
                         mean=0,
                         std=opt.model_size**-0.5)
    elif opt.encoder_type == "text":
        if opt.init_embedding == 'xavier':
            init.xavier_uniform_(model.encoder.word_lut.weight)
            init.xavier_uniform_(model.decoder.word_lut.weight)
        elif opt.init_embedding == 'normal':
            init.normal_(model.encoder.word_lut.weight,
                         mean=0,
                         std=opt.model_size**-0.5)
            init.normal_(model.decoder.word_lut.weight,
                         mean=0,
                         std=opt.model_size**-0.5)
    elif opt.encoder_type == "mix":
        init.xavier_uniform_(
            model.encoder.audio_encoder.audio_trans.weight.data)
        if opt.init_embedding == 'xavier':
            init.xavier_uniform_(model.encoder.text_encodedr.word_lut.weight)
            init.xavier_uniform_(model.decoder.word_lut.weight)
        elif opt.init_embedding == 'normal':
            init.normal_(model.encoder.text_encoder.word_lut.weight,
                         mean=0,
                         std=opt.model_size**-0.5)
            init.normal_(model.decoder.word_lut.weight,
                         mean=0,
                         std=opt.model_size**-0.5)

    else:
        print("Unkown encoder type:", opt.encoder_type)
        exit(-1)

    return model
Esempio n. 7
0
def build_tm_model(opt, dicts):

    # BUILD POSITIONAL ENCODING
    if opt.time == 'positional_encoding':
        positional_encoder = PositionalEncoding(opt.model_size,
                                                len_max=MAX_LEN)
    else:
        raise NotImplementedError

    # BUILD GENERATOR
    generators = [
        onmt.modules.BaseModel.Generator(opt.model_size, dicts['tgt'].size())
    ]

    if opt.ctc_loss != 0:
        generators.append(
            onmt.modules.BaseModel.Generator(opt.model_size,
                                             dicts['tgt'].size() + 1))

    if opt.model == 'transformer':
        # raise NotImplementedError

        onmt.Constants.init_value = opt.param_init

        if opt.encoder_type == "text":
            encoder = TransformerEncoder(opt, dicts['src'], positional_encoder)
        else:
            encoder = TransformerEncoder(opt, opt.input_size,
                                         positional_encoder)

        decoder = TransformerDecoder(opt, dicts['tgt'], positional_encoder)

        model = Transformer(encoder, decoder, nn.ModuleList(generators))

    elif opt.model == 'stochastic_transformer':

        from onmt.modules.StochasticTransformer.Models import StochasticTransformerEncoder, StochasticTransformerDecoder

        onmt.Constants.weight_norm = opt.weight_norm
        onmt.Constants.init_value = opt.param_init

        if opt.encoder_type == "text":
            encoder = StochasticTransformerEncoder(opt, dicts['src'],
                                                   positional_encoder)
        else:
            encoder = StochasticTransformerEncoder(opt, opt.input_size,
                                                   positional_encoder)

        decoder = StochasticTransformerDecoder(opt, dicts['tgt'],
                                               positional_encoder)

        model = Transformer(encoder, decoder, nn.ModuleList(generators))

    elif opt.model in ['universal_transformer', 'utransformer']:

        from onmt.modules.UniversalTransformer.Models import UniversalTransformerDecoder, UniversalTransformerEncoder
        from onmt.modules.UniversalTransformer.Layers import TimeEncoding

        onmt.Constants.weight_norm = opt.weight_norm
        onmt.Constants.init_value = opt.param_init

        time_encoder = TimeEncoding(opt.model_size, len_max=32)

        encoder = UniversalTransformerEncoder(opt, dicts['src'],
                                              positional_encoder, time_encoder)
        decoder = UniversalTransformerDecoder(opt, dicts['tgt'],
                                              positional_encoder, time_encoder)

        model = Transformer(encoder, decoder, nn.ModuleList(generators))

    else:
        raise NotImplementedError

    if opt.tie_weights:
        print("Joining the weights of decoder input and output embeddings")
        model.tie_weights()

    if opt.join_embedding:
        print("Joining the weights of encoder and decoder word embeddings")
        model.share_enc_dec_embedding()

    for g in model.generator:
        init.xavier_uniform_(g.linear.weight)

    if opt.encoder_type == "audio":
        init.xavier_uniform_(model.encoder.audio_trans.weight.data)
        if opt.init_embedding == 'xavier':
            init.xavier_uniform_(model.decoder.word_lut.weight)
        elif opt.init_embedding == 'normal':
            init.normal_(model.decoder.word_lut.weight,
                         mean=0,
                         std=opt.model_size**-0.5)
    else:
        if opt.init_embedding == 'xavier':
            init.xavier_uniform_(model.encoder.word_lut.weight)
            init.xavier_uniform_(model.decoder.word_lut.weight)
        elif opt.init_embedding == 'normal':
            init.normal_(model.encoder.word_lut.weight,
                         mean=0,
                         std=opt.model_size**-0.5)
            init.normal_(model.decoder.word_lut.weight,
                         mean=0,
                         std=opt.model_size**-0.5)

    return model