Beispiel #1
0
def make_base_model(model_opt, fields, gpu, checkpoint=None):
    """
    Args:
        model_opt: the option loaded from checkpoint.
        fields: `Field` objects for the model.
        gpu(bool): whether to use gpu.
        checkpoint: the model gnerated by train phase, or a resumed snapshot
                    model from a stopped training.
    Returns:
        the NMTModel.
    """
    assert model_opt.model_type in ["text", "img", "audio"], \
        ("Unsupported model type %s" % (model_opt.model_type))

    # Make encoder.
    if model_opt.model_type == "text":
        src_dict = fields["src"].vocab
        feature_dicts = onmt.io.collect_feature_vocabs(fields, 'src')
        src_embeddings = make_embeddings(model_opt, src_dict, feature_dicts)
        encoder = make_encoder(model_opt, src_embeddings)
    elif model_opt.model_type == "img":
        encoder = ImageEncoder(model_opt.enc_layers, model_opt.brnn,
                               model_opt.rnn_size, model_opt.dropout)
    elif model_opt.model_type == "audio":
        encoder = AudioEncoder(model_opt.enc_layers, model_opt.brnn,
                               model_opt.rnn_size, model_opt.dropout,
                               model_opt.sample_rate, model_opt.window_size)

    # Make decoder.
    tgt_dict = fields["tgt"].vocab
    feature_dicts = onmt.io.collect_feature_vocabs(fields, 'tgt')
    tgt_embeddings = make_embeddings(model_opt,
                                     tgt_dict,
                                     feature_dicts,
                                     for_encoder=False)

    # Share the embedding matrix - preprocess with share_vocab required.
    if model_opt.share_embeddings:
        # src/tgt vocab should be the same if `-share_vocab` is specified.
        if src_dict != tgt_dict:
            raise AssertionError('The `-share_vocab` should be set during '
                                 'preprocess if you use share_embeddings!')

        tgt_embeddings.word_lut.weight = src_embeddings.word_lut.weight

    decoder = make_decoder(model_opt, tgt_embeddings)

    # Make NMTModel(= encoder + decoder).
    model = NMTModel(encoder, decoder)
    model.model_type = model_opt.model_type
    # import pdb; pdb.set_trace()

    # Make Generator.
    if not model_opt.copy_attn:
        generator = nn.Sequential(
            nn.Linear(model_opt.rnn_size, len(fields["tgt"].vocab)),
            nn.LogSoftmax(dim=-1))
        if model_opt.share_decoder_embeddings:
            generator[0].weight = decoder.embeddings.word_lut.weight
    else:
        generator = CopyGenerator(model_opt.rnn_size, fields["tgt"].vocab)

    # Load the model states from checkpoint or initialize them.
    if checkpoint is not None:
        print('Loading model parameters.')

        # print("checkpoint")
        # for name, param in sorted(checkpoint["model"].items()):
        #     print(f"{name}", param.size())
        # print()
        # print("model itself")
        # for name, param in sorted(model.state_dict().items()):
        #     print(f"{name}", param.size())
        # print()
        state = model.state_dict()
        if state.keys() == checkpoint["model"].keys():
            model.load_state_dict(checkpoint["model"])
        else:
            state = model.state_dict()
            state.update(checkpoint["model"])
            model.load_state_dict(state)
            for name, param in model.named_parameters():
                if "embedding" in name:
                    if model_opt.param_init != 0.0:
                        param.data.uniform_(-model_opt.param_init,
                                            model_opt.param_init)
                    if model_opt.param_init_glorot:
                        if param.dim() > 1:
                            xavier_uniform(p)

        # print("checkpoint")
        # for name, param in sorted(checkpoint["generator"].items()):
        #     print(f"{name}", param.size())
        # print()
        # print("model itself generator")
        # for name, param in sorted(generator.state_dict().items()):
        #     print(f"{name}", param.size())

        state = generator.state_dict()
        if state.keys() == checkpoint["generator"].keys():
            generator.load_state_dict(checkpoint["generator"])
        else:
            state = generator.state_dict()
            state.update(checkpoint["generator"])
            generator.load_state_dict(state)
            for name, param in generator.named_parameters():
                if "linear.bias" in name or "linear.weight" in name:
                    if model_opt.param_init != 0.0:
                        param.data.uniform_(-model_opt.param_init,
                                            model_opt.param_init)
                    if model_opt.param_init_glorot:
                        if param.dim() > 1:
                            xavier_uniform(p)
    else:
        if model_opt.param_init != 0.0:
            print('Intializing model parameters.')
            for p in model.parameters():
                p.data.uniform_(-model_opt.param_init, model_opt.param_init)
            for p in generator.parameters():
                p.data.uniform_(-model_opt.param_init, model_opt.param_init)
        if model_opt.param_init_glorot:
            for p in model.parameters():
                if p.dim() > 1:
                    xavier_uniform(p)
            for p in generator.parameters():
                if p.dim() > 1:
                    xavier_uniform(p)

        if hasattr(model.encoder, 'embeddings'):
            model.encoder.embeddings.load_pretrained_vectors(
                model_opt.pre_word_vecs_enc, model_opt.fix_word_vecs_enc)
        if hasattr(model.decoder, 'embeddings'):
            model.decoder.embeddings.load_pretrained_vectors(
                model_opt.pre_word_vecs_dec, model_opt.fix_word_vecs_dec)

    # Add generator to model (this registers it as parameter of model).
    model.generator = generator

    # Make the whole model leverage GPU if indicated to do so.
    if gpu:
        model.cuda()
    else:
        model.cpu()

    return model
def make_base_model(model_opt,
                    fields,
                    gpu,
                    checkpoint=None,
                    init_encoder=False,
                    rev_checkpoint=None,
                    top_layer=100):
    """
    Args:
        model_opt: the option loaded from checkpoint.
        fields: `Field` objects for the model.
        gpu(bool): whether to use gpu.
        checkpoint: the model gnerated by train phase, or a resumed snapshot
                    model from a stopped training.
    Returns:
        the NMTModel.
    """
    assert model_opt.model_type in ["text", "img", "audio"], \
        ("Unsupported model type %s" % (model_opt.model_type))

    # Make encoder.
    if model_opt.model_type == "text":
        src_dict = fields["src"].vocab
        feature_dicts = onmt.io.collect_feature_vocabs(fields, 'src')
        src_embeddings = make_embeddings(model_opt, src_dict, feature_dicts)
        encoder = make_encoder(model_opt, src_embeddings)
    elif model_opt.model_type == "img":
        encoder = ImageEncoder(model_opt.enc_layers, model_opt.brnn,
                               model_opt.rnn_size, model_opt.dropout)
    elif model_opt.model_type == "audio":
        encoder = AudioEncoder(model_opt.enc_layers, model_opt.brnn,
                               model_opt.rnn_size, model_opt.dropout,
                               model_opt.sample_rate, model_opt.window_size)

    # Make decoder.
    tgt_dict = fields["tgt"].vocab
    feature_dicts = onmt.io.collect_feature_vocabs(fields, 'tgt')
    tgt_embeddings = make_embeddings(model_opt,
                                     tgt_dict,
                                     feature_dicts,
                                     for_encoder=False)

    # Share the embedding matrix - preprocess with share_vocab required.
    if model_opt.share_embeddings:
        # src/tgt vocab should be the same if `-share_vocab` is specified.
        if src_dict != tgt_dict:
            raise AssertionError('The `-share_vocab` should be set during '
                                 'preprocess if you use share_embeddings!')

        tgt_embeddings.word_lut.weight = src_embeddings.word_lut.weight

    decoder = make_decoder(model_opt, tgt_embeddings)
    if model_opt.share_rnn:
        if model_opt.input_feed == 1:
            raise AssertionError('Cannot share encoder and decoder weights'
                                 'when using input feed in decoder')
        if model_opt.src_word_vec_size != model_opt.src_word_vec_size:
            raise AssertionError('Cannot share encoder and decoder weights'
                                 'if embeddings are different sizes')
        encoder.rnn = decoder.rnn

    # Make NMTModel(= encoder + decoder).
    model = NMTModel(encoder, decoder)
    model.model_type = model_opt.model_type

    # Make Generator.
    if not model_opt.copy_attn:
        generator = nn.Sequential(
            nn.Linear(model_opt.rnn_size, len(fields["tgt"].vocab)),
            nn.LogSoftmax())
        if model_opt.share_decoder_embeddings:
            generator[0].weight = decoder.embeddings.word_lut.weight
    else:
        generator = CopyGenerator(model_opt.rnn_size, fields["tgt"].vocab)

    # Load the model states from checkpoint or initialize them.
    if checkpoint is not None and not init_encoder:
        print('Loading model parameters from checkpoint.')
        model.load_state_dict(checkpoint['model'])
        generator.load_state_dict(checkpoint['generator'])
    else:
        if model_opt.param_init != 0.0:
            print('Intializing model parameters.')
            for p in model.parameters():
                p.data.uniform_(-model_opt.param_init, model_opt.param_init)
            for p in generator.parameters():
                p.data.uniform_(-model_opt.param_init, model_opt.param_init)
        if hasattr(model.encoder, 'embeddings'):
            model.encoder.embeddings.load_pretrained_vectors(
                model_opt.pre_word_vecs_enc, model_opt.fix_word_vecs_enc)
        if hasattr(model.decoder, 'embeddings'):
            model.decoder.embeddings.load_pretrained_vectors(
                model_opt.pre_word_vecs_dec, model_opt.fix_word_vecs_dec)
        if init_encoder:
            model_dict = checkpoint['model']
            encoder_dict = {}

            model_dict_keys = []
            for key in model_dict.keys():
                if key[:7] == 'encoder':
                    if key[-7:] == 'reverse':
                        if int(key[-9]) > top_layer:
                            continue
                    else:
                        if key[8:18] != 'embeddings' and int(
                                key[-1]) > top_layer:
                            continue
                    model_dict_keys.append(key)
            print(model_dict_keys)

            # Load encoder parameters
            new_model_dict = model.state_dict()
            for key, value in model_dict.items():
                if key in model_dict_keys:
                    new_model_dict[key] = value
            """
            if rev_checkpoint is not None:
                rev_model_dict = rev_checkpoint['model']
                if key[:7] == 'encoder' and key[8:18] != 'embeddings':
                    new_model_dict[key+'_reverse'] = value
            """
            model.load_state_dict(new_model_dict)

            # Freeze encoder parameters
            for name, param in model.named_parameters():
                if name in model_dict_keys:
                    param.requires_grad = False

    # Add generator to model (this registers it as parameter of model).
    model.generator = generator

    # Make the whole model leverage GPU if indicated to do so.
    if gpu:
        model.cuda()
    else:
        model.cpu()

    return model