예제 #1
0
def build_encoder(opt, embeddings):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == "transformer":
        return TransformerEncoder(opt.enc_layers, opt.enc_rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.dropout, embeddings)
    elif opt.encoder_type == "cnn":
        return CNNEncoder(opt.enc_layers, opt.enc_rnn_size,
                          opt.cnn_kernel_width,
                          opt.dropout, embeddings)
    elif opt.encoder_type == "mean":
        return MeanEncoder(opt.enc_layers, embeddings)
    else:
        # "rnn" or "brnn"
        logger.info('RNNEncoder: type %s, bidir %d, layers %d, '
                    'hidden size %d, dropout %.2f' %
                    (opt.rnn_type, opt.brnn, opt.enc_layers,
                     opt.enc_rnn_size, opt.dropout))
        return RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                          opt.enc_rnn_size, opt.dropout, embeddings,
                          opt.bridge)
예제 #2
0
def build_encoder(opt, embeddings):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == "transformer":
        return TransformerEncoder(opt.enc_layers,
                                  opt.enc_rnn_size,
                                  opt.heads,
                                  opt.transformer_ff,
                                  opt.dropout,
                                  embeddings,
                                  ablation=opt.ablation)
    elif opt.encoder_type == "cnn":
        return CNNEncoder(opt.enc_layers, opt.enc_rnn_size,
                          opt.cnn_kernel_width, opt.dropout, embeddings)
    elif opt.encoder_type == "mean":
        return MeanEncoder(opt.enc_layers, embeddings)
    else:
        # "rnn" or "brnn"
        return RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                          opt.enc_rnn_size, opt.dropout, embeddings,
                          opt.bridge)
예제 #3
0
def build_encoder(opt, embeddings, embeddings_latt = False, feat_vec_size = 512):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
        embeddings_latt: embeddings of senses if lattice is used    # latt!!!
        feat_vec_size: for adaptable feat_vec_size             # latt!!!
    """
#latt
    if opt.encoder_type == "transformer":
        if embeddings_latt != False:  #latt
            return TransformerEncoder(opt.enc_layers, opt.rnn_size,
                                      opt.heads, opt.transformer_ff,
                                      opt.dropout, embeddings, embeddings_latt, feat_vec_size) #latt
        else:
            return TransformerEncoder(opt.enc_layers, opt.rnn_size,
                                      opt.heads, opt.transformer_ff,
                                      opt.dropout, embeddings, embeddings_latt, feat_vec_size) #latt
#latt
    elif opt.encoder_type == "cnn":
        return CNNEncoder(opt.enc_layers, opt.rnn_size,
                          opt.cnn_kernel_width,
                          opt.dropout, embeddings)
    elif opt.encoder_type == "mean":
        return MeanEncoder(opt.enc_layers, embeddings)
    else:
        # "rnn" or "brnn"
        return RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                          opt.rnn_size, opt.dropout, embeddings,
                          opt.bridge)
예제 #4
0
def build_encoder(opt, embeddings):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == "transformer":
        return TransformerEncoder(opt.enc_layers, opt.rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.dropout, embeddings), None
    elif opt.encoder_type == "cnn":
        return CNNEncoder(opt.enc_layers, opt.rnn_size,
                          opt.cnn_kernel_width,
                          opt.dropout, embeddings), None
    elif opt.encoder_type == "mean":
        return MeanEncoder(opt.enc_layers, embeddings), None
    else:
        # "rnn" or "brnn"
        word_encoder = RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                          opt.rnn_size, opt.dropout, embeddings, None, opt.bridge)
        if opt.rnn_type == "LSTM":
            emb_size = opt.enc_layers * opt.rnn_size * 2
        else:
            emb_size = opt.enc_layers * opt.rnn_size
        sen_encoder = RNNEncoder(opt.rnn_type, opt.brnn, opt.sen_enc_layers,
                          opt.sen_rnn_size, opt.dropout, None, emb_size, opt.bridge)
        return word_encoder, sen_encoder
예제 #5
0
def build_encoder(opt, embeddings, main_encoder=None):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == "transformer":
        encoder = TransformerEncoder(opt.enc_layers,
                                     opt.enc_rnn_size,
                                     opt.heads,
                                     opt.transformer_ff,
                                     opt.dropout,
                                     embeddings,
                                     main_encoder=main_encoder,
                                     mtl_opt=opt)
    elif opt.encoder_type == "cnn":
        encoder = CNNEncoder(opt.enc_layers, opt.enc_rnn_size,
                             opt.cnn_kernel_width, opt.dropout, embeddings)
    elif opt.encoder_type == "mean":
        encoder = MeanEncoder(opt.enc_layers, embeddings)
    else:
        encoder = RNNEncoder(opt.rnn_type,
                             opt.brnn,
                             opt.enc_layers,
                             opt.enc_rnn_size,
                             opt.dropout,
                             embeddings,
                             opt.bridge,
                             main_encoder=main_encoder,
                             mtl_opt=opt)
    return encoder
예제 #6
0
def build_encoder(opt, embeddings):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == "transformer":
        encoder = TransformerEncoder(opt.enc_layers, opt.enc_rnn_size,
                                     opt.heads, opt.transformer_ff,
                                     opt.dropout, embeddings)
    elif opt.encoder_type == "cnn":
        encoder = CNNEncoder(opt.enc_layers, opt.enc_rnn_size,
                             opt.cnn_kernel_width, opt.dropout, embeddings)
    elif opt.encoder_type == "mean":
        encoder = MeanEncoder(opt.enc_layers, embeddings)
    elif opt.encoder_type == "hr_brnn":
        bi_enc = True
        encoder = HREncoder(opt.rnn_type, bi_enc, opt.enc_layers,
                            opt.enc_rnn_size, opt.dropout, embeddings,
                            opt.bridge)
    elif opt.encoder_type == "seq_hr_brnn":
        bi_enc = True
        encoder = SeqHREncoder(opt.rnn_type, bi_enc, opt.enc_layers,
                               opt.enc_rnn_size, opt.dropout, embeddings,
                               opt.bridge)
    elif opt.encoder_type == "tg_brnn":
        bi_enc = True
        encoder = TGEncoder(opt.rnn_type, bi_enc, opt.enc_layers,
                            opt.enc_rnn_size, opt.dropout, embeddings)
    else:
        bi_enc = 'brnn' in opt.encoder_type
        encoder = RNNEncoder(opt.rnn_type,
                             bi_enc,
                             opt.enc_layers,
                             opt.enc_rnn_size,
                             opt.dropout,
                             embeddings,
                             opt.bridge,
                             use_catSeq_dp=opt.use_catSeq_dp)
    return encoder
예제 #7
0
def build_encoder(opt, embeddings):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == "transformer":
        return TransformerEncoder(opt.enc_layers, opt.enc_rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.dropout, embeddings)
    elif opt.encoder_type == "cnn":
        return CNNEncoder(opt.enc_layers, opt.enc_rnn_size,
                          opt.cnn_kernel_width,
                          opt.dropout, embeddings)
    elif opt.encoder_type == "mean":
        return MeanEncoder(opt.enc_layers, embeddings)
    elif opt.encoder_type == "rnntreelstm" or opt.encoder_type == "treelstm":
        opt.brnn = True if opt.encoder_type == "rnntreelstm" else False
        return TreeLSTMEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                        opt.rnn_size, opt.dropout, embeddings,
                        opt.bridge, False)
    elif opt.encoder_type == "rnnbitreelstm" or opt.encoder_type == "bitreelstm":
        opt.brnn = True if opt.encoder_type == "rnnbitreelstm" else False
        return TreeLSTMEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                        opt.rnn_size, opt.dropout, embeddings,
                        opt.bridge, True)    
    elif opt.encoder_type == "rnngcn" or opt.encoder_type == "gcn":
        opt.brnn = True if opt.encoder_type == "rnngcn" else False
        return GCNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                          opt.rnn_size, opt.dropout, embeddings,
                          opt.bridge, opt.gcn_dropout, 
                          opt.gcn_edge_dropout, opt.n_gcn_layers, 
                          opt.activation, opt.highway)    
    else:
        # "rnn" or "brnn"
        return RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                          opt.enc_rnn_size, opt.dropout, embeddings,
                          opt.bridge)
예제 #8
0
def build_encoder(opt, embeddings, fields=None):
    """
    Various encoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this encoder.
    """
    if opt.encoder_type == 'simple_context_0':
        # bottom n-1 layers are shared
        return SimpleContextTransformerEncoder(
                                  opt.enc_layers - 1, opt.rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.dropout, embeddings,
                                  selected_ctx=0)
    elif opt.encoder_type == 'simple_context_1':
        # bottom n-1 layers are shared
        return SimpleContextTransformerEncoder(
                                  opt.enc_layers - 1, opt.rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.dropout, embeddings,
                                  selected_ctx=1)

    elif opt.encoder_type == "transformer":
        return TransformerEncoder(opt.enc_layers, opt.rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.dropout, embeddings)
    elif opt.encoder_type == "cnn":
        return CNNEncoder(opt.enc_layers, opt.rnn_size,
                          opt.cnn_kernel_width,
                          opt.dropout, embeddings)
    elif opt.encoder_type == "mean":
        return MeanEncoder(opt.enc_layers, embeddings)
    else:
        # "rnn" or "brnn"
        return RNNEncoder(opt.rnn_type, opt.brnn, opt.enc_layers,
                          opt.rnn_size, opt.dropout, embeddings,
                          opt.bridge)
예제 #9
0
def build_base_model(model_opt, fields, gpu, checkpoint=None):
    """
    Args:
        model_opt: the option loaded from checkpoint.
        fields: `Field` objects for the model.
        gpu(bool): whether to use gpu.
        checkpoint: the model gnerated by train phase, or a resumed snapshot
                    model from a stopped training.
    Returns:
        the NMTModel.
    """
    assert model_opt.model_type in ["text", "img", "audio", "vector"], \
        ("Unsupported model type %s" % (model_opt.model_type))

    use_src_directly_for_dec = False
    # Build encoder.
    if model_opt.model_type == "text":
        src_dict = fields["src"].vocab
        feature_dicts = inputters.collect_feature_vocabs(fields, 'src')
        src_embeddings = build_embeddings(model_opt, src_dict, feature_dicts)
        encoder = build_encoder(model_opt, src_embeddings)
    elif model_opt.model_type == "img":
        if ("image_channel_size" not in model_opt.__dict__):
            image_channel_size = 3
        else:
            image_channel_size = model_opt.image_channel_size

        encoder = ImageEncoder(model_opt.enc_layers, model_opt.brnn,
                               model_opt.enc_rnn_size, model_opt.dropout,
                               image_channel_size)
    elif model_opt.model_type == "audio":
        encoder = AudioEncoder(model_opt.rnn_type, model_opt.enc_layers,
                               model_opt.dec_layers, model_opt.brnn,
                               model_opt.enc_rnn_size, model_opt.dec_rnn_size,
                               model_opt.audio_enc_pooling, model_opt.dropout,
                               model_opt.sample_rate, model_opt.window_size)
    elif model_opt.model_type == "vector":
        use_src_directly_for_dec = True
        if not hasattr(fields["src"], 'vocab'):
            fields["src"].vocab = fields["tgt"].vocab
        src_dict = fields["src"].vocab
        #self.word_lut.weight.requires_grad = False
        feature_dicts = inputters.collect_feature_vocabs(fields, 'src')
        tgt_embeddings = build_embeddings(model_opt, src_dict, feature_dicts)
        if model_opt.encoder_type == "rnn" or model_opt.encoder_type == "brnn":
            encoder = RNNEncoder(model_opt.rnn_type, model_opt.brnn,
                                 model_opt.enc_layers, model_opt.enc_rnn_size,
                                 model_opt.dropout, None, model_opt.bridge)
            tgt_embeddings = None
        elif model_opt.decoder_type == "cnn":
            use_src_directly_for_dec = False
            encoder = CNNEncoder(model_opt.enc_layers, model_opt.enc_rnn_size,
                                 model_opt.cnn_kernel_width, model_opt.dropout,
                                 None)
            tgt_embeddings = None
        else:
            encoder = None

    # Build decoder.
    tgt_dict = fields["tgt"].vocab
    feature_dicts = inputters.collect_feature_vocabs(fields, 'tgt')
    if model_opt.model_type != "vector":
        tgt_embeddings = build_embeddings(model_opt,
                                          tgt_dict,
                                          feature_dicts,
                                          for_encoder=False)
    # else:
    #     tgt_embeddings = None

    # Share the embedding matrix - preprocess with share_vocab required.
    if model_opt.share_embeddings:
        # src/tgt vocab should be the same if `-share_vocab` is specified.
        if src_dict != tgt_dict:
            raise AssertionError('The `-share_vocab` should be set during '
                                 'preprocess if you use share_embeddings!')

        tgt_embeddings.word_lut.weight = src_embeddings.word_lut.weight

    decoder = build_decoder(model_opt, tgt_embeddings)

    # Build NMTModel(= encoder + decoder).
    device = torch.device("cuda" if gpu else "cpu")
    if model_opt.decoder_type.startswith("vecdif"):
        model = onmt.models.VecModel(
            encoder,
            decoder,
            use_src_directly_for_dec=use_src_directly_for_dec)
    else:
        model = onmt.models.NMTModel(
            encoder,
            decoder,
            use_src_directly_for_dec=use_src_directly_for_dec)

    # Build Generator.
    if not model_opt.copy_attn:
        if model_opt.generator_function == "sparsemax":
            gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1)
        elif model_opt.generator_function == "sigmoid":
            gen_func = nn.Sigmoid()
        else:
            gen_func = nn.LogSoftmax(dim=-1)
        if model_opt.model_type == "vector":
            if model_opt.generator_function == "none":
                # if model_opt.final_vec_size != model_opt.dec_rnn_size:
                #     generator = nn.Sequential(
                #         nn.Linear(model_opt.dec_rnn_size, model_opt.final_vec_size))
                # else:
                generator = None
            else:
                generator = nn.Sequential(
                    nn.Linear(model_opt.dec_rnn_size,
                              model_opt.final_vec_size), gen_func)
        else:
            generator = nn.Sequential(
                nn.Linear(model_opt.dec_rnn_size, len(fields["tgt"].vocab)),
                gen_func)
        if model_opt.share_decoder_embeddings:
            generator[0].weight = decoder.embeddings.word_lut.weight
    else:
        generator = CopyGenerator(model_opt.dec_rnn_size, fields["tgt"].vocab)

    # Load the model states from checkpoint or initialize them.
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model'], strict=False)
        if generator is not None:
            generator.load_state_dict(checkpoint['generator'], strict=False)
    else:
        if model_opt.param_init != 0.0:
            for p in model.parameters():
                p.data.uniform_(-model_opt.param_init, model_opt.param_init)
            if generator is not None:
                for p in generator.parameters():
                    p.data.uniform_(-model_opt.param_init,
                                    model_opt.param_init)
        if model_opt.param_init_glorot:
            for p in model.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
            if generator is not None:
                for p in generator.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        if hasattr(model.encoder,
                   'embeddings') and model_opt.model_type != "vector":
            model.encoder.embeddings.load_pretrained_vectors(
                model_opt.pre_word_vecs_enc, model_opt.fix_word_vecs_enc)
        if hasattr(model.decoder,
                   'embeddings') and model_opt.model_type != "vector":
            model.decoder.embeddings.load_pretrained_vectors(
                model_opt.pre_word_vecs_dec, model_opt.fix_word_vecs_dec)

    # Add generator to model (this registers it as parameter of model).
    model.generator = generator
    model.to(device)

    return model