def train_generator( model_save_suffix=model_save_suffixes["[no_prefix]train_generator"]): assert flags.pretrain_model_dir, "pretrain_model_dir is required" save_model_path = osp.join(flags.save_model_dir, model_save_suffix) pretrained_model_pathes = { "EMBEDDING": osp.join(flags.pretrain_model_dir, model_save_suffixes["[no_prefix]pre_train_cl_model"]), "FG_S": osp.join(flags.pretrain_model_dir, model_save_suffixes["[no_prefix]train_lm_model"]), "SEQ_G_LSTM_1": osp.join(flags.pretrain_model_dir, model_save_suffixes["train_lm_model"]), "SEQ_G_LSTM_2": osp.join(flags.pretrain_model_dir, model_save_suffixes["train_ae_model"]), "RNN_TO_EMBEDDING": osp.join(flags.pretrain_model_dir, model_save_suffixes["train_lm_model"]), } if flags["adv_type"] == "adv": generator_model = AdversarialDDGModel( init_modules=AdversarialDDGModel.stepA_modules) elif flags["adv_type"] == "vir_adv": generator_model = VirtualAdversarialDDGModel( init_modules=VirtualAdversarialDDGModel.stepA_modules) else: raise Exception("Unknow adv_type: %s" % flags["adv_type"]) generator_model.build(stepA=True, restorer_tag_notifier=["EMBEDDING"]) generator_model.fit(save_model_path=save_model_path, pretrain_model_pathes=pretrained_model_pathes)
def eval_cl_model(): eval_from_vals = ["pretrain_cl", "final_cl"] assert flags.eval_from in eval_from_vals, "eval_from must be one of %s" % eval_from_vals if flags.eval_from == "final_cl": model_save_suffix = model_save_suffixes["train_cl_model"] else: model_save_suffix = model_save_suffixes["pre_train_cl_model"] save_model_path = osp.join(flags.save_model_dir, model_save_suffix) generator_model = AdversarialDDGModel( init_modules=AdversarialDDGModel.eval_cl_modules) generator_model.build(eval_cl=True) generator_model.eval(save_model_path=save_model_path)
def eval_generator(eval_batch_size=flags["eval_batch_size"], eval_topic_count=flags["eval_topic_count"], eval_seq_length=flags["eval_seq_length"]): eval_from_vals = ["generator", "topic_generator"] assert flags.eval_from in eval_from_vals, "eval_from must be one of %s" % eval_from_vals if flags.eval_from == "generator": model_save_suffix = model_save_suffixes["train_generator"] else: model_save_suffix = model_save_suffixes["train_topic_generator"] save_model_path = osp.join(flags.save_model_dir, model_save_suffix) generator_model = AdversarialDDGModel( init_modules=AdversarialDDGModel.eval_graph_modules) generator_model.build(eval_seq=True, batch_size=eval_batch_size, topic_count=eval_topic_count, seq_length=eval_seq_length) generator_model.eval(save_model_path=save_model_path)
def pre_train_cl_model( model_save_suffix=model_save_suffixes["pre_train_cl_model"]): assert flags.pretrain_model_dir, "pretrain_model_dir is required" save_model_path = osp.join(flags.save_model_dir, model_save_suffix) pretrained_model_pathes = { "EMBEDDING": osp.join(flags.pretrain_model_dir, model_save_suffixes["[no_prefix]train_lm_model"]), "T_S": osp.join(flags.pretrain_model_dir, model_save_suffixes["[no_prefix]train_lm_model"]) } if flags["adv_type"] == "adv": adv_cl_model = AdversarialDDGModel( init_modules=AdversarialDDGModel.stepB_modules) elif flags["adv_type"] == "vir_adv": adv_cl_model = VirtualAdversarialDDGModel( init_modules=VirtualAdversarialDDGModel.stepB_modules) else: raise Exception("Unknow adv_type: %s" % flags["adv_type"]) adv_cl_model.build(stepB=True, restorer_tag_notifier=[]) adv_cl_model.fit(save_model_path=save_model_path, pretrain_model_pathes=pretrained_model_pathes)
import sys sys.path.insert(0, ".") from adversarial_net.AdversarialDDGModel import AdversarialDDGModel from adversarial_net.preprocessing import WordCounter from adversarial_net import osp from adversarial_net import arguments as flags if __name__ == "__main__": vocab_freqs = WordCounter().load( osp.join(flags["inputs"]["datapath"], "imdb_word_freqs.pickle")).most_common_freqs( flags["lm_sequence"]["vocab_size"]) flags.add_variable(name="vocab_freqs", value=vocab_freqs) adv_model = AdversarialDDGModel( init_modules=AdversarialDDGModel.stepA_modules) # adv_model.build(eval_seq=True, batch_size=2, topic_count=2, seq_length=200) adv_model.build(stepA=True) # adv_model.eval(None) adv_model.fit()