コード例 #1
0
ファイル: main.py プロジェクト: z123z123d/contk
def main(args):
	if args.debug:
		debug()

	if args.cuda:
		config = tf.ConfigProto()
		config.gpu_options.allow_growth = True
	else:
		config = tf.ConfigProto(device_count={'GPU': 0})
		os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

	data_class = SingleTurnDialog.load_class(args.dataset)
	wordvec_class = WordVector.load_class(args.wvclass)
	if wordvec_class == None:
		wordvec_class = Glove
	if args.cache:
		data = try_cache(data_class, (args.datapath,), args.cache_dir)
		vocab = data.vocab_list
		embed = try_cache(lambda wv, ez, vl: wordvec_class(wv).load(ez, vl),
						  (args.wvpath, args.embedding_size, vocab),
						  args.cache_dir, wordvec_class.__name__)
	else:
		data = data_class(args.datapath)
		wv = wordvec_class(args.wvpath)
		vocab = data.vocab_list
		embed = wv.load(args.embedding_size, vocab)

	embed = np.array(embed, dtype = np.float32)

	with tf.Session(config=config) as sess:
		model = create_model(sess, data, args, embed)
		if args.mode == "train":
			model.train_process(sess, data, args)
		else:
			model.test_process(sess, data, args)
コード例 #2
0
ファイル: main.py プロジェクト: z123z123d/contk
def main(args):
    logging.basicConfig(\
     filename=0,\
     level=logging.DEBUG,\
     format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',\
     datefmt='%H:%M:%S')

    if args.debug:
        debug()
    logging.info(json.dumps(args, indent=2))

    cuda_init(0, args.cuda)

    volatile = Storage()
    data_class = SingleTurnDialog.load_class(args.dataset)
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class is None:
        wordvec_class = Glove
    if args.cache:
        dm = try_cache(data_class, (args.datapath, ), args.cache_dir)
        volatile.wordvec = try_cache(\
         lambda wv, ez, vl: wordvec_class(wv).load(ez, vl), \
         (args.wvpath, args.embedding_size, dm.vocab_list),
         args.cache_dir, wordvec_class.__name__)
    else:
        dm = data_class(args.datapath)
        wv = wordvec_class(args.wvpath)
        volatile.wordvec = wv.load(args.embedding_size, dm.vocab_list)

    volatile.dm = dm

    param = Storage()
    param.args = args
    param.volatile = volatile

    model = Seq2seq(param)
    if args.mode == "train":
        model.train_process()
    elif args.mode == "test":
        model.test_process()
    else:
        raise ValueError("Unknown mode")
コード例 #3
0
def main(args):
    if args.debug:
        debug()
    if args.cuda:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
    else:
        config = tf.ConfigProto(device_count={'GPU': 0})
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    data_class = LanguageGeneration.load_class(args.dataset)
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class == None:
        wordvec_class = Glove
    if args.cache:
        data = try_cache(data_class, (args.datapath, ), args.cache_dir)
        vocab = data.vocab_list
        embed = try_cache(lambda wv, ez, vl: wordvec_class(wv).load(ez, vl),
                          (args.wvpath, args.embedding_size, vocab),
                          args.cache_dir, wordvec_class.__name__)
    else:
        data = data_class(args.datapath)
        wv = wordvec_class(args.wvpath)
        vocab = data.vocab_list
        embed = wv.load(args.embedding_size, vocab)

    embed = np.array(embed, dtype=np.float32)
    with tf.Session(config=config) as sess:
        generator, discriminator, rollout_gen = create_model(
            sess, data, args, embed)
        if args.mode == "train":
            if args.pre_train:
                #Start pretraining
                print('Start pre-training generator...')
                generator.pre_train_process(sess, data)

                print('Start pre-training discriminator...')
                discriminator.train_process(generator, data, sess,
                                            args.dis_pre_epoch_num)

                #Start testing
                generator.test_process(sess, data)

            #Start adversarial training
            for batch in range(args.total_adv_batch):
                print("Adversarial  training %d" % batch)
                print('Start adversarial training generator...')
                generator.adv_train_process(sess, data, rollout_gen,
                                            discriminator)
                testout = generator.pre_evaluate(sess, data, args.batch_size,
                                                 "test")
                if (batch % args.test_per_epoch == 0
                        or batch == args.total_adv_batch - 1) and batch != 0:
                    print('total_batch: ', batch, 'test_loss: ', testout[0])
                    generator.test_process(sess, data)

                print('Start adversarial training discriminator...')
                discriminator.train_process(generator, data, sess,
                                            args.dis_adv_epoch_num)
        else:
            print("Start testing...")
            generator.test_process(sess, data)