예제 #1
0
def pre_embed(config, alphabet):
    """
    :param config:
    :param alphabet:
    :return:
    """
    print("............................")
    pretrain_embed = None
    embed_types = ""
    if config.pretrained_embed and config.zeros:
        embed_types = "zeros"
    elif config.pretrained_embed and config.avg:
        embed_types = "avg"
    elif config.pretrained_embed and config.uniform:
        embed_types = "uniform"
    elif config.pretrained_embed and config.nnembed:
        embed_types = "nn"
    if config.pretrained_embed is True:
        p = Embed(path=config.pretrained_embed_file,
                  words_dict=alphabet.word_alphabet.id2words,
                  embed_type=embed_types,
                  pad=paddingkey)
        pretrain_embed = p.get_embed()

        embed_dict = {"pretrain_embed": pretrain_embed}
        # pcl.save(obj=embed_dict, path=os.path.join(config.pkl_directory, config.pkl_embed))
        torch.save(obj=embed_dict,
                   f=os.path.join(config.pkl_directory, config.pkl_embed))

    return pretrain_embed
def pre_embed(config, alphabet):
    """
    :param config: config
    :param alphabet:  alphabet dict
    :return:  pre-train embed
    """
    print("***************************************")
    pretrain_embed = None
    embed_types = ""
    if config.pretrained_embed and config.zeros:
        embed_types = "zero"
    elif config.pretrained_embed and config.avg:
        embed_types = "avg"
    elif config.pretrained_embed and config.uniform:
        embed_types = "uniform"
    elif config.pretrained_embed and config.nnembed:
        embed_types = "nn"
    if config.pretrained_embed is True:
        p = Embed(path=config.pretrained_embed_file, words_dict=alphabet.ext_word_alphabet.id2words, embed_type=embed_types,
                  pad=PAD)
        pretrain_embed = p.get_embed()

        embed_dict = {"pretrain_embed": pretrain_embed}
        torch.save(obj=embed_dict, f=os.path.join(config.pkl_directory, config.pkl_embed))

    return pretrain_embed
def pre_embed(config, alphabet, alphabet_static):
    """
    :param alphabet_static:
    :param config: config
    :param alphabet:  alphabet dict
    :return:  pre-train embed
    """
    print("***************************************")
    char_pretrain_embed, bichar_pretrain_embed = None, None
    embed_types = ""
    if (config.char_pretrained_embed is True or config.bichar_pretrained_embed is True) and config.zeros:
        embed_types = "zero"
    elif (config.char_pretrained_embed is True or config.bichar_pretrained_embed is True) and config.avg:
        embed_types = "avg"
    elif (config.char_pretrained_embed is True or config.bichar_pretrained_embed is True) and config.uniform:
        embed_types = "uniform"
    elif (config.char_pretrained_embed is True or config.bichar_pretrained_embed is True) and config.nnembed:
        embed_types = "nn"
    if config.char_pretrained_embed is True:
        p = Embed(path=config.char_pretrained_embed_file, words_dict=alphabet_static.char_alphabet.id2words,
                  embed_type=embed_types,
                  pad=paddingkey)
        char_pretrain_embed = p.get_embed()

    if config.bichar_pretrained_embed is True:
        p = Embed(path=config.bichar_pretrained_embed_file, words_dict=alphabet_static.bichar_alphabet.id2words,
                  embed_type=embed_types,
                  pad=paddingkey)
        bichar_pretrain_embed = p.get_embed()

    if config.char_pretrained_embed is True or config.bichar_pretrained_embed is True:
        embed_dict = {"char_pretrain_embed": char_pretrain_embed, "bichar_pretrain_embed": bichar_pretrain_embed}
        if config.save_pkl is True:
            pcl.save(obj=embed_dict, path=os.path.join(config.pkl_directory, config.pkl_embed))

    return char_pretrain_embed, bichar_pretrain_embed