Beispiel #1
0
# testData = data.load(opt,test=True)

# prepare model saver/summary writer
saver_D = tf.train.Saver(var_list=varsD,max_to_keep=20)
summaryWriter = tf.summary.FileWriter(main_folder + "summary_{0}/{1}".format(opt.group,opt.name))

print(util.toYellow("======= TRAINING START ======="))
timeStart = time.time()
# start session
tfConfig = tf.ConfigProto(allow_soft_placement=True)
tfConfig.gpu_options.allow_growth = True
with tf.Session(config=tfConfig) as sess:
	sess.run(tf.global_variables_initializer())
	summaryWriter.add_graph(sess.graph)
	if opt.fromIt!=0:
		util.restoreModelFromIt(opt,sess,saver_D,"D",opt.fromIt)
		print(util.toMagenta("resuming from iteration {0}...".format(opt.fromIt)))
	elif opt.loadD:
		util.restoreModel(opt,sess,saver_D,opt.loadD,"D")
		print(util.toMagenta("loading pretrained D {0}...".format(opt.loadD)))
	print(util.toMagenta("start training..."))

	# training loop
	for i in range(opt.fromIt,opt.toIt):
		lrD = opt.lrD*opt.lrDdecay**(i//opt.lrDstep)
		# make training batch
		batch = data.makeBatch(opt,trainData,PH)
		batch[lrD_PH] = lrD
		# update discriminator
		runList = [optimD,loss_D,grad_D_norm_mean]
		for u in range(opt.updateD):
# prepare model saver/summary writer
saver_GP = tf.train.Saver(var_list=varsGP, max_to_keep=10)
summaryWriter = tf.summary.FileWriter("summary_{0}/{1}".format(
    opt.group, opt.name))

print(util.toYellow("======= TRAINING START ======="))
timeStart = time.time()
# start session
tfConfig = tf.ConfigProto(allow_soft_placement=True)
tfConfig.gpu_options.allow_growth = True
with tf.Session(config=tfConfig) as sess:
    sess.run(tf.global_variables_initializer())
    summaryWriter.add_graph(sess.graph)
    if opt.fromIt != 0:
        util.restoreModelFromIt(opt, sess, saver_GP, "GP", opt.fromIt)
        print(
            util.toMagenta("resuming from iteration {0}...".format(
                opt.fromIt)))
    elif opt.loadGP:
        util.restoreModel(opt, sess, saver_GP, opt.loadGP, "GP")
        print(util.toMagenta("loading pretrained GP {0}...".format(
            opt.loadGP)))
    print(util.toMagenta("start training..."))

    # training loop
    for i in range(opt.fromIt, opt.toIt):
        lrGP = opt.lrGP * opt.lrGPdecay**(i // opt.lrGPstep)
        # make training batch
        batch = data.makeBatch_homo(opt, trainData, PH)
        batch[lrGP_PH] = lrGP