def train(model_name, restore=True): import_lib() global config, logger config = Config.config dataset = Dataset.Dataset() dataset.prepare_dataset() logger = utils.get_logger(model_name) model = PHVM.PHVM(len(dataset.vocab.id2featCate), len(dataset.vocab.id2featVal), len(dataset.vocab.id2word), len(dataset.vocab.id2category), key_wordvec=None, val_wordvec=None, tgt_wordvec=dataset.vocab.id2vec, type_vocab_size=len(dataset.vocab.id2type)) init = {'epoch': 0, 'worse_step': 0} if restore: init['epoch'], init['worse_step'], model = model_utils.restore_model( model, config.checkpoint_dir + "/" + model_name + config.tmp_model_dir, config.checkpoint_dir + "/" + model_name + config.best_model_dir) config.check_ckpt(model_name) summary = tf.summary.FileWriter(config.summary_dir, model.graph) _train(model_name, model, dataset, summary, init) logger.info("finish training {}".format(model_name))
def main(): args = get_args() if args.train: train(args.model_name, args.restore) else: import_lib() dataset = Dataset.Dataset() model = PHVM.PHVM(len(dataset.vocab.id2featCate), len(dataset.vocab.id2featVal), len(dataset.vocab.id2word), len(dataset.vocab.id2category), key_wordvec=None, val_wordvec=None, tgt_wordvec=dataset.vocab.id2vec, type_vocab_size=len(dataset.vocab.id2type)) best_checkpoint_dir = config.checkpoint_dir + "/" + args.model_name + config.best_model_dir tmp_checkpoint_dir = config.checkpoint_dir + "/" + args.model_name + config.tmp_model_dir model_utils.restore_model(model, best_checkpoint_dir, tmp_checkpoint_dir) dataset.prepare_dataset() texts = infer(model, dataset, dataset.test) dump(texts, config.result_dir + "/{}.json".format(args.model_name)) utils.print_out("finish file test")
def gen_sen(make_test, hide_strategy, STRATEGY, gen_model=None):#输入手动输入的二元组 / 生成句子模型名称,自命名区分 """ :param make_test: :param gen_model: :return: """ import_lib() from Models.model_utils import to_testfile, load_data dataset = Dataset.EPWDataset() model = PHVM.PHVM(len(dataset.vocab.id2featCate), len(dataset.vocab.id2featVal), len(dataset.vocab.id2word), len(dataset.vocab.id2category), hide_strategy=hide_strategy, STRATEGY=STRATEGY, key_wordvec=None, val_wordvec=None, tgt_wordvec=dataset.vocab.id2vec, type_vocab_size=len(dataset.vocab.id2type)) config = Config.config # to_testfile(load_data(config.data_dir + '/topics/' + make_test)) # 制造训练集 to_testfile(make_test) best_checkpoint_dir = config.checkpoint_dir + "/PHVM" + config.best_model_dir tmp_checkpoint_dir = config.checkpoint_dir + "/PHVM" + config.tmp_model_dir model_utils.restore_model(model, best_checkpoint_dir, tmp_checkpoint_dir) # 当最优模型is not None,选择最优模型进行求解 dataset.prepare_dataset() _bitfile = load_data(config.bit_file) _bitfile.extend(_bitfile) _bitfile.extend(_bitfile) _bitfile.extend(_bitfile) _bitfile.extend(_bitfile) bitfile = [] if config.bit_per_word == 1 or hide_strategy == 'AC': bitfile = _bitfile elif hide_strategy == 'RS': for bit in range(int(len(_bitfile) / config.bit_per_word)): lower = bit * config.bit_per_word upper = lower + config.bit_per_word bitfile.append(bits2int(_bitfile[lower: upper])) texts, texts_id, bit_len = infer(model, dataset, dataset.test, bitfile) # 输出生成文本 # dump(texts, config.result_dir + "/{}.json".format(gen_model)) return texts_id, bit_len