示例#1
0
def train(model_name, restore=True):
    import_lib()
    global config, logger
    config = Config.config
    dataset = Dataset.Dataset()
    dataset.prepare_dataset()
    logger = utils.get_logger(model_name)

    model = PHVM.PHVM(len(dataset.vocab.id2featCate),
                      len(dataset.vocab.id2featVal),
                      len(dataset.vocab.id2word),
                      len(dataset.vocab.id2category),
                      key_wordvec=None,
                      val_wordvec=None,
                      tgt_wordvec=dataset.vocab.id2vec,
                      type_vocab_size=len(dataset.vocab.id2type))
    init = {'epoch': 0, 'worse_step': 0}
    if restore:
        init['epoch'], init['worse_step'], model = model_utils.restore_model(
            model,
            config.checkpoint_dir + "/" + model_name + config.tmp_model_dir,
            config.checkpoint_dir + "/" + model_name + config.best_model_dir)
    config.check_ckpt(model_name)
    summary = tf.summary.FileWriter(config.summary_dir, model.graph)
    _train(model_name, model, dataset, summary, init)
    logger.info("finish training {}".format(model_name))
示例#2
0
def main():
    args = get_args()

    if args.train:
        train(args.model_name, args.restore)
    else:
        import_lib()
        dataset = Dataset.Dataset()
        model = PHVM.PHVM(len(dataset.vocab.id2featCate),
                          len(dataset.vocab.id2featVal),
                          len(dataset.vocab.id2word),
                          len(dataset.vocab.id2category),
                          key_wordvec=None,
                          val_wordvec=None,
                          tgt_wordvec=dataset.vocab.id2vec,
                          type_vocab_size=len(dataset.vocab.id2type))

        best_checkpoint_dir = config.checkpoint_dir + "/" + args.model_name + config.best_model_dir
        tmp_checkpoint_dir = config.checkpoint_dir + "/" + args.model_name + config.tmp_model_dir
        model_utils.restore_model(model, best_checkpoint_dir,
                                  tmp_checkpoint_dir)

        dataset.prepare_dataset()
        texts = infer(model, dataset, dataset.test)
        dump(texts, config.result_dir + "/{}.json".format(args.model_name))
        utils.print_out("finish file test")
def gen_sen(make_test, hide_strategy, STRATEGY, gen_model=None):#输入手动输入的二元组 / 生成句子模型名称,自命名区分
	"""
	:param make_test:
	:param gen_model:
	:return:
	"""
	import_lib()
	from Models.model_utils import to_testfile, load_data
	dataset = Dataset.EPWDataset()
	model = PHVM.PHVM(len(dataset.vocab.id2featCate), len(dataset.vocab.id2featVal), len(dataset.vocab.id2word),
					  len(dataset.vocab.id2category), hide_strategy=hide_strategy, STRATEGY=STRATEGY,
					  key_wordvec=None, val_wordvec=None, tgt_wordvec=dataset.vocab.id2vec,
					  type_vocab_size=len(dataset.vocab.id2type))
	config = Config.config
	# to_testfile(load_data(config.data_dir + '/topics/' + make_test))  # 制造训练集
	to_testfile(make_test)
	best_checkpoint_dir = config.checkpoint_dir + "/PHVM" + config.best_model_dir
	tmp_checkpoint_dir = config.checkpoint_dir + "/PHVM" + config.tmp_model_dir
	model_utils.restore_model(model, best_checkpoint_dir, tmp_checkpoint_dir)  # 当最优模型is not None,选择最优模型进行求解

	dataset.prepare_dataset()
	_bitfile = load_data(config.bit_file)

	_bitfile.extend(_bitfile)
	_bitfile.extend(_bitfile)
	_bitfile.extend(_bitfile)
	_bitfile.extend(_bitfile)

	bitfile = []
	if config.bit_per_word == 1 or hide_strategy == 'AC':
		bitfile = _bitfile
	elif hide_strategy == 'RS':
		for bit in range(int(len(_bitfile) / config.bit_per_word)):
			lower = bit * config.bit_per_word
			upper = lower + config.bit_per_word
			bitfile.append(bits2int(_bitfile[lower: upper]))
	texts, texts_id, bit_len = infer(model, dataset, dataset.test, bitfile)  # 输出生成文本
	# dump(texts, config.result_dir + "/{}.json".format(gen_model))
	return texts_id, bit_len