# testData = data.load(opt,test=True) # prepare model saver/summary writer saver_D = tf.train.Saver(var_list=varsD,max_to_keep=20) summaryWriter = tf.summary.FileWriter(main_folder + "summary_{0}/{1}".format(opt.group,opt.name)) print(util.toYellow("======= TRAINING START =======")) timeStart = time.time() # start session tfConfig = tf.ConfigProto(allow_soft_placement=True) tfConfig.gpu_options.allow_growth = True with tf.Session(config=tfConfig) as sess: sess.run(tf.global_variables_initializer()) summaryWriter.add_graph(sess.graph) if opt.fromIt!=0: util.restoreModelFromIt(opt,sess,saver_D,"D",opt.fromIt) print(util.toMagenta("resuming from iteration {0}...".format(opt.fromIt))) elif opt.loadD: util.restoreModel(opt,sess,saver_D,opt.loadD,"D") print(util.toMagenta("loading pretrained D {0}...".format(opt.loadD))) print(util.toMagenta("start training...")) # training loop for i in range(opt.fromIt,opt.toIt): lrD = opt.lrD*opt.lrDdecay**(i//opt.lrDstep) # make training batch batch = data.makeBatch(opt,trainData,PH) batch[lrD_PH] = lrD # update discriminator runList = [optimD,loss_D,grad_D_norm_mean] for u in range(opt.updateD):
# prepare model saver/summary writer saver_GP = tf.train.Saver(var_list=varsGP, max_to_keep=10) summaryWriter = tf.summary.FileWriter("summary_{0}/{1}".format( opt.group, opt.name)) print(util.toYellow("======= TRAINING START =======")) timeStart = time.time() # start session tfConfig = tf.ConfigProto(allow_soft_placement=True) tfConfig.gpu_options.allow_growth = True with tf.Session(config=tfConfig) as sess: sess.run(tf.global_variables_initializer()) summaryWriter.add_graph(sess.graph) if opt.fromIt != 0: util.restoreModelFromIt(opt, sess, saver_GP, "GP", opt.fromIt) print( util.toMagenta("resuming from iteration {0}...".format( opt.fromIt))) elif opt.loadGP: util.restoreModel(opt, sess, saver_GP, opt.loadGP, "GP") print(util.toMagenta("loading pretrained GP {0}...".format( opt.loadGP))) print(util.toMagenta("start training...")) # training loop for i in range(opt.fromIt, opt.toIt): lrGP = opt.lrGP * opt.lrGPdecay**(i // opt.lrGPstep) # make training batch batch = data.makeBatch_homo(opt, trainData, PH) batch[lrGP_PH] = lrGP