示例#1
0
def train_generator(
        model_save_suffix=model_save_suffixes["[no_prefix]train_generator"]):
    assert flags.pretrain_model_dir, "pretrain_model_dir is required"
    save_model_path = osp.join(flags.save_model_dir, model_save_suffix)
    pretrained_model_pathes = {
        "EMBEDDING":
        osp.join(flags.pretrain_model_dir,
                 model_save_suffixes["[no_prefix]pre_train_cl_model"]),
        "FG_S":
        osp.join(flags.pretrain_model_dir,
                 model_save_suffixes["[no_prefix]train_lm_model"]),
        "SEQ_G_LSTM_1":
        osp.join(flags.pretrain_model_dir,
                 model_save_suffixes["train_lm_model"]),
        "SEQ_G_LSTM_2":
        osp.join(flags.pretrain_model_dir,
                 model_save_suffixes["train_ae_model"]),
        "RNN_TO_EMBEDDING":
        osp.join(flags.pretrain_model_dir,
                 model_save_suffixes["train_lm_model"]),
    }
    if flags["adv_type"] == "adv":
        generator_model = AdversarialDDGModel(
            init_modules=AdversarialDDGModel.stepA_modules)
    elif flags["adv_type"] == "vir_adv":
        generator_model = VirtualAdversarialDDGModel(
            init_modules=VirtualAdversarialDDGModel.stepA_modules)
    else:
        raise Exception("Unknow adv_type: %s" % flags["adv_type"])
    generator_model.build(stepA=True, restorer_tag_notifier=["EMBEDDING"])
    generator_model.fit(save_model_path=save_model_path,
                        pretrain_model_pathes=pretrained_model_pathes)
示例#2
0
def eval_cl_model():
    eval_from_vals = ["pretrain_cl", "final_cl"]
    assert flags.eval_from in eval_from_vals, "eval_from must be one of %s" % eval_from_vals
    if flags.eval_from == "final_cl":
        model_save_suffix = model_save_suffixes["train_cl_model"]
    else:
        model_save_suffix = model_save_suffixes["pre_train_cl_model"]
    save_model_path = osp.join(flags.save_model_dir, model_save_suffix)
    generator_model = AdversarialDDGModel(
        init_modules=AdversarialDDGModel.eval_cl_modules)
    generator_model.build(eval_cl=True)
    generator_model.eval(save_model_path=save_model_path)
示例#3
0
def eval_generator(eval_batch_size=flags["eval_batch_size"],
                   eval_topic_count=flags["eval_topic_count"],
                   eval_seq_length=flags["eval_seq_length"]):
    eval_from_vals = ["generator", "topic_generator"]
    assert flags.eval_from in eval_from_vals, "eval_from must be one of %s" % eval_from_vals
    if flags.eval_from == "generator":
        model_save_suffix = model_save_suffixes["train_generator"]
    else:
        model_save_suffix = model_save_suffixes["train_topic_generator"]
    save_model_path = osp.join(flags.save_model_dir, model_save_suffix)
    generator_model = AdversarialDDGModel(
        init_modules=AdversarialDDGModel.eval_graph_modules)
    generator_model.build(eval_seq=True,
                          batch_size=eval_batch_size,
                          topic_count=eval_topic_count,
                          seq_length=eval_seq_length)
    generator_model.eval(save_model_path=save_model_path)
示例#4
0
def pre_train_cl_model(
        model_save_suffix=model_save_suffixes["pre_train_cl_model"]):
    assert flags.pretrain_model_dir, "pretrain_model_dir is required"
    save_model_path = osp.join(flags.save_model_dir, model_save_suffix)
    pretrained_model_pathes = {
        "EMBEDDING":
        osp.join(flags.pretrain_model_dir,
                 model_save_suffixes["[no_prefix]train_lm_model"]),
        "T_S":
        osp.join(flags.pretrain_model_dir,
                 model_save_suffixes["[no_prefix]train_lm_model"])
    }
    if flags["adv_type"] == "adv":
        adv_cl_model = AdversarialDDGModel(
            init_modules=AdversarialDDGModel.stepB_modules)
    elif flags["adv_type"] == "vir_adv":
        adv_cl_model = VirtualAdversarialDDGModel(
            init_modules=VirtualAdversarialDDGModel.stepB_modules)
    else:
        raise Exception("Unknow adv_type: %s" % flags["adv_type"])
    adv_cl_model.build(stepB=True, restorer_tag_notifier=[])
    adv_cl_model.fit(save_model_path=save_model_path,
                     pretrain_model_pathes=pretrained_model_pathes)
示例#5
0
import sys

sys.path.insert(0, ".")
from adversarial_net.AdversarialDDGModel import AdversarialDDGModel
from adversarial_net.preprocessing import WordCounter
from adversarial_net import osp
from adversarial_net import arguments as flags

if __name__ == "__main__":
    vocab_freqs = WordCounter().load(
        osp.join(flags["inputs"]["datapath"],
                 "imdb_word_freqs.pickle")).most_common_freqs(
                     flags["lm_sequence"]["vocab_size"])
    flags.add_variable(name="vocab_freqs", value=vocab_freqs)
    adv_model = AdversarialDDGModel(
        init_modules=AdversarialDDGModel.stepA_modules)
    # adv_model.build(eval_seq=True, batch_size=2, topic_count=2, seq_length=200)
    adv_model.build(stepA=True)
    # adv_model.eval(None)
    adv_model.fit()