def train(): current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) os.makedirs(checkpoints_dir, exist_ok=True) graph = tf.Graph() with graph.as_default(): cycle_gan = CycleGAN() G_loss, D_Y_loss, F_loss, D_X_loss = cycle_gan.model() optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) train_writer = tf.summary.FileWriter(checkpoints_dir, graph) with tf.Session(graph=graph) as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: step = 0 while not coord.should_stop(): _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run([ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, cycle_gan.summary ])) train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: print('-----------Step %d:-------------' % step) print(' G_loss : {}'.format(G_loss_val)) print(' D_Y_loss : {}'.format(D_Y_loss_val)) print(' F_loss : {}'.format(F_loss_val)) print(' D_X_loss : {}'.format(D_X_loss_val)) if step % 1000 == 0: save_path = cycle_gan.saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) print("Model saved in file: %s" % save_path) step += 1 except KeyboardInterrupt: print('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = cycle_gan.saver.save(sess, checkpoints_dir + "/model.ckpt") print("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
def train(): if FLAGS.load_model is not None: checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip( "checkpoints/") else: current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) except os.error: pass graph = tf.Graph() with graph.as_default(): cycle_gan = CycleGAN(X_train_file=FLAGS.X, Y_train_file=FLAGS.Y, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda2, learning_rate=FLAGS.learning_rate, beta1=FLAGS.beta1, ngf=FLAGS.ngf) G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model() optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: fake_Y_pool = ImagePool(FLAGS.pool_size) fake_X_pool = ImagePool(FLAGS.pool_size) while not coord.should_stop(): # get previously generated images fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) # train _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op ], feed_dict={ cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val) })) train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: logging.info('-----------Step %d:-------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logging.info(' F_loss : {}'.format(F_loss_val)) logging.info(' D_X_loss : {}'.format(D_X_loss_val)) if step % 5000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
def train(): if FLAGS.load_model is not None: checkpoints_dir = "checkpoints/" + FLAGS.load_model else: current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) except os.error: pass graph = tf.Graph() with graph.as_default(): cycle_gan = CycleGAN( X_train_file=FLAGS.X, Y_train_file=FLAGS.Y, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda1, learning_rate=FLAGS.learning_rate, beta1=FLAGS.beta1, ngf=FLAGS.ngf ) G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model() optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: fake_Y_pool = ImagePool(FLAGS.pool_size) fake_X_pool = ImagePool(FLAGS.pool_size) while not coord.should_stop(): # get previously generated images fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) # train _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op], feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val)} ) ) if step % 100 == 0: train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: logging.info('-----------Step %d:-------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logging.info(' F_loss : {}'.format(F_loss_val)) logging.info(' D_X_loss : {}'.format(D_X_loss_val)) if step % 10000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
def train(): max_accuracy = 0.98 learning_loss_set = 4.0 if FLAGS.load_model is not None: checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip( "checkpoints/") else: current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) except os.error: pass graph = tf.Graph() with graph.as_default(): cycle_gan = CycleGAN( #X_train_file=FLAGS.X, #Y_train_file=FLAGS.Y, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda2, learning_rate=FLAGS.learning_rate, beta1=FLAGS.beta1, ngf=FLAGS.ngf) #G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss, x_correct, y_correct, fake_y_correct, fake_y_pre G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss, x_correct, y_correct, fake_y_correct = cycle_gan.model( ) optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = "checkpoints/20190224-1130/model.ckpt-7792.meta" print('meta_graph_path', meta_graph_path) restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, "checkpoints/20190224-1130/model.ckpt-7792") step = 7792 #meta_graph_path = checkpoint.model_checkpoint_path + ".meta" #restore = tf.train.import_meta_graph(meta_graph_path) #restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) #step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) last_1 = 0.0 last_2 = 0.0 best_1 = 0.0 best_2 = 0.0 try: while not coord.should_stop(): #x_image, x_label = get_batch_images(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, FLAGS.X) #y_image, y_label = get_batch_images(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, FLAGS.Y) x_image, x_label = get_train_batch("X", FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, "./dataset/") #print('x_label',x_label) y_image, y_label = get_train_batch("Y", FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, "./dataset/") #print('y_label', y_label) # get previously generated images # fake_y_val, fake_x_val = sess.run([fake_y, fake_x],feed_dict={cycle_gan.x: x_image, cycle_gan.y: y_image}) # train _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, teacher_loss_eval, student_loss_eval, learning_loss_eval, summary = ( sess.run( [ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss, summary_op ], feed_dict={ cycle_gan.x: x_image, cycle_gan.y: y_image, cycle_gan.x_label: x_label, cycle_gan.y_label: y_label })) train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: print('-----------Step %d:-------------' % step) print(' G_loss : {}'.format(G_loss_val)) print(' D_Y_loss : {}'.format(D_Y_loss_val)) print(' F_loss : {}'.format(F_loss_val)) print(' D_X_loss : {}'.format(D_X_loss_val)) print('teacher_loss: {}'.format(teacher_loss_eval)) print('student_loss: {}'.format(student_loss_eval)) print('learning_loss: {}'.format(learning_loss_eval)) if step % 100 == 0 and step >= 10: print('Now is in testing! Please wait result...') test_images_y, test_labels_y = get_test_batch( "Y", FLAGS.image_size, FLAGS.image_size, "./dataset/") fake_y_correct_cout = 0 for i in range((len(test_images_y))): y_imgs = [] y_lbs = [] y_imgs.append(test_images_y[i]) y_lbs.append(test_labels_y[i]) y_correct_eval, fake_y_correct_eval = (sess.run( [y_correct, fake_y_correct], feed_dict={ cycle_gan.y: y_imgs, cycle_gan.y_label: y_lbs })) if fake_y_correct_eval: fake_y_correct_cout = fake_y_correct_cout + 1 print('fake_y_accuracy: {}'.format(fake_y_correct_cout / len(test_labels_y))) # print('Now is in testing! Please wait result...') # #save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) # #print("Model saved in file: %s" % save_path) # test_images_y,test_labels_y= get_test_batch("Y",FLAGS.image_size,FLAGS.image_size,"./dataset/") # test_images_x,test_labels_x= get_test_batch("X",FLAGS.image_size,FLAGS.image_size,"./dataset/") # y_correct_cout=0 # fake_y_correct_cout=0 # for i in range(min(len(test_images_y),len(test_images_x))): # y_imgs=[] # y_lbs=[] # y_imgs.append(test_images_y[i]) # y_lbs.append(test_labels_y[i]) # x_imgs=[] # x_lbs=[] # x_imgs.append(test_images_x[i]) # x_lbs.append(test_labels_x[i]) # y_correct_eval,fake_y_correct_eval = ( # sess.run( # [y_correct,fake_y_correct], # feed_dict={cycle_gan.x: x_imgs, cycle_gan.y: y_imgs, # cycle_gan.x_label: x_lbs,cycle_gan.y_label: y_lbs} # ) # ) # #print('y_correct_eval', y_correct_eval) # #print('y_correct_cout',y_correct_cout) # #print('fake_y_correct_eval', fake_y_correct_eval) # #print('fake_y_correct_cout',fake_y_correct_cout) # #if y_correct_eval[0][0]: # if y_correct_eval: # y_correct_cout=y_correct_cout+1 # #if fake_y_correct_eval[0][0]: # if fake_y_correct_eval: # fake_y_correct_cout=fake_y_correct_cout+1 # # # print('y_accuracy: {}'.format(y_correct_cout/(min(len(test_labels_y),len(test_labels_x))))) # print('fake_y_accuracy: {}'.format(fake_y_correct_cout/(min(len(test_labels_y),len(test_labels_x))))) # y_accuracy_1 = format(y_correct_cout / (min(len(test_labels_y), len(test_labels_x)))) # fake_y_accuracy_1 = format(fake_y_correct_cout / (min(len(test_labels_y), len(test_labels_x)))) # # #print('test_images_len:',len(test_images_y)) # #print('test_labels_len:', len(test_labels_y)) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) print("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
def train(): if FLAGS.load_model is not None: checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip("checkpoints/") else: current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) except os.error: pass graph = tf.Graph() with graph.as_default(): cycle_gan = CycleGAN( batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda2, learning_rate=FLAGS.learning_rate, learning_rate2=FLAGS.learning_rate2, beta1=FLAGS.beta1, ngf=FLAGS.ngf ) G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, Disperse_loss, Fuzzy_loss,feature_x,feature_y,_,_ = cycle_gan.model() optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss, Disperse_loss) optimizers2 = cycle_gan.optimize2(Fuzzy_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 1 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: x_path = FLAGS.X + FLAGS.UC_name print('now is in FCM initializing!') if FLAGS.load_model is None: x_images, x_id_list, x_len, x_labels,_ ,_= get_source_batch(0, 256, 256, source_dir=x_path) y_images, y_id_list, y_len, y_labels,_,_ = get_target_batch(0, 256, 256, target_dir=FLAGS.Y) print('x_len',len(x_images)) print('y_len',len(y_images)) x_data=[] y_data=[] for x in x_images: feature_x_eval = ( sess.run( feature_x, feed_dict={cycle_gan.x: [x]} )) x_data.append(feature_x_eval[0]) for y in y_images: feature_y_eval = (sess.run( feature_y, feed_dict={cycle_gan.y: [y]} )) y_data.append(feature_y_eval[0]) Ux, Uy, Cx, Cy= fuzzy.initialize_UC_test(x_len,x_data,y_len,y_data, FLAGS.UC_name,checkpoints_dir) np.savetxt(checkpoints_dir + "/Ux" + FLAGS.UC_name + '.txt', Ux, fmt="%.20f", delimiter=",") np.savetxt(checkpoints_dir + "/Uy" + FLAGS.UC_name + '.txt', Uy, fmt="%.20f", delimiter=",") np.savetxt(checkpoints_dir + "/Cx" + FLAGS.UC_name + '.txt', Cx, fmt="%.20f", delimiter=",") np.savetxt(checkpoints_dir + "/Cy" + FLAGS.UC_name + '.txt', Cy, fmt="%.20f", delimiter=",") else: Ux = np.loadtxt(checkpoints_dir + "/Ux" + FLAGS.UC_name + '.txt', delimiter=",") Ux = [[x] for x in Ux] Uy = np.loadtxt(checkpoints_dir + "/Uy" + FLAGS.UC_name + '.txt', delimiter=",") Cx = np.loadtxt(checkpoints_dir + "/Cx" + FLAGS.UC_name + '.txt', delimiter=",") Cx = [Cx] Cy = np.loadtxt(checkpoints_dir + "/Cy" + FLAGS.UC_name + '.txt', delimiter=",") print('FCM initialization is ended! Go to train') max_accuracy = 0 while not coord.should_stop(): images_x, idx_list, len_x, labels_x,_ ,_= get_source_batch(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, source_dir=x_path) subUx = fuzzy.getSubU(Ux, idx_list) label_x = [x[0] for x in subUx] images_y, idy_list, len_y, labels_y,_,_ = get_target_batch(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, target_dir=FLAGS.Y) subUy = fuzzy.getSubU(Uy, idy_list) label_y = [x[0] for x in subUy] _,_, Fuzzy_loss_val, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary, Disperse_loss_val,feature_x_eval,feature_y_eval = ( sess.run( [optimizers,optimizers2,Fuzzy_loss, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op, Disperse_loss, feature_x, feature_y], feed_dict={cycle_gan.x: images_x, cycle_gan.y: images_y, cycle_gan.Uy2x: subUy, cycle_gan.Ux2y: subUx, cycle_gan.x_label: label_x, cycle_gan.y_label: label_y, cycle_gan.ClusterX: Cx, cycle_gan.ClusterY: Cy} ) ) train_writer.add_summary(summary, step) train_writer.flush() ''' Optimize Networks if step % 10 == 0: print('-----------Step %d:-------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logging.info(' F_loss : {}'.format(F_loss_val)) logging.info(' D_X_loss : {}'.format(D_X_loss_val)) logging.info(' Disperse_loss : {}'.format(Disperse_loss_val)) logging.info(' Fuzzy_loss : {}'.format(Fuzzy_loss_val)) Optimize FCM algorithm ''' if step % 100== 0: print('Now is in FCM training!') y_images, y_id_list, y_len, y_labels,_ ,_= get_target_batch(0, 256, 256, target_dir=FLAGS.Y) print('y_len', len(y_images)) #x_data = [] y_data = [] for y in y_images: feature_y_eval = (sess.run( feature_y, feed_dict={cycle_gan.y: [y]} )) y_data.append(feature_y_eval[0]) #print('y_data:',np.sum(y_data,1)) Uy, Cy = fuzzy.updata_U(checkpoints_dir, y_data, Uy, FLAGS.UC_name) accuracy, tp, tn, fp, fn, f1_score, recall, precision, specificity=computeAccuracy(Uy, y_labels) print("accuracy:%.4f\ttp:%.4f\ttn:%.4f\tfp %d\tfn:%d" % (accuracy, tp, tn, fp, fn)) if accuracy==1: break if accuracy >= max_accuracy: max_accuracy = accuracy if not os.path.exists(checkpoints_dir + "/max"): os.makedirs(checkpoints_dir + "/max") f = open(checkpoints_dir + "/max/step.txt", 'w') f.seek(0) f.truncate() f.write(str(step) + '\n') f.write(str(accuracy) + '\taccuracy\n') f.close() np.save(checkpoints_dir + "/max/feature_fcgan.npy",y_data) np.savetxt(checkpoints_dir + "/max/"+ "/Uy" + FLAGS.UC_name + '.txt', Uy, fmt="%.20f", delimiter=",") np.savetxt(checkpoints_dir + "/max/"+ "/Cy" + FLAGS.UC_name + '.txt', Cy, fmt="%.20f", delimiter=",") np.savetxt(checkpoints_dir + "/max/"+ "/Ux" + FLAGS.UC_name + '.txt', Ux, fmt="%.20f", delimiter=",") np.savetxt(checkpoints_dir + "/max/"+ "/Cx" + FLAGS.UC_name + '.txt', Cx, fmt="%.20f", delimiter=",") save_path = saver.save(sess, checkpoints_dir + "/max/model.ckpt",global_step=step) print("Max model saved in file: %s" % save_path) print('max_accuracy:', max_accuracy) print('mean_U',np.min(Uy,0)) step += 1 if step>10000: logging.info('train stop!') break except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) np.savetxt(checkpoints_dir + "/Uy" + FLAGS.UC_name + '.txt', Uy, fmt="%.20f", delimiter=",") np.savetxt(checkpoints_dir + "/Cy" + FLAGS.UC_name + '.txt', Cy, fmt="%.20f", delimiter=",") np.savetxt(checkpoints_dir + "/Ux" + FLAGS.UC_name + '.txt', Ux, fmt="%.20f", delimiter=",") np.savetxt(checkpoints_dir + "/Cx" + FLAGS.UC_name + '.txt', Cx, fmt="%.20f", delimiter=",") logging.info("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
def train(): if FLAGS.load_model is not None: #如果该命令行参数不为空,则据此给出checkpoint_dir checkpoints_dir = "checkpoints/" + FLAGS.load_model else: #否则,根据当前时间,创建一个checkpoint_dir current_time = datetime.now().strftime("%Y%m%d - %H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) except os.error: pass graph = tf.Graph() #创建计算图 with graph.as_default(): cycle_gan = CycleGAN(X_train_file=FLAGS.X, Y_train_file=FLAGS.Y, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda1, learning_rate=FLAGS.learning_rate, beta1=FLAGS.beta1, ngf=FLAGS.ngf) #引入CycleGAN网络 G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model( ) #返回值分别是:反向生成网络损失,正向判别函数损失,生成网络损失,逆向判别函数损失,正向生成的y,反向生成的x optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) #四个损失的优化器 summary_op = tf.summary.merge_all() #将一些信息显示在stdoutput中 train_writer = tf.summary.FileWriter(checkpoints_dir, graph) #将图保存在checkpoints_dir中 saver = tf.train.Saver() with tf.Session(graph=graph) as sess: if FLAGS.load_model is not None: #如果已存在训练模型,则加载继续训练 checkpoint = tf.train.get_checkpoint_state( checkpoints_dir) #将最新的model加载进来 meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) #加载model结构 restore.restore( sess, tf.train.latest_checkpoint(checkpoints_dir)) #加载最新的model模型参数 step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) #初始化全局变量 step = 0 coord = tf.train.Coordinator() #进行线程管理 threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: fake_Y_pool = ImagePool(FLASG.pool_size) #设定image缓冲大小 fake_X_pool = ImagePool(FLAGS.pool_size) while not coord.should_stop(): fake_y_val, fake_x_val = sess.run( [fake_y, fake_x]) #先得出generated image x,y??? #train _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op ], feed_dict={ cycle_gan.fake_y: fake_Y_pool.query( fake_y_val ), #将上述得到的fake_x,fake_y输入到optimizers,G_loss,...,中,优化; 假设,初始化F,D_y,然后根据x得到fake_y,然后根据G,D_x,y,得到fake_x,根据这些value:x,y,fake_x,fake_y,求上述的几个loss,利用优化器对其进行优化 cycle_gan.fake_x: fake_X_pool.query(fake_x_val) } #还是没太弄明白 为什么一会儿fake_y,一会儿self.fake_y;是要缓冲若干个fake_y??? )) #进行训练 if step % 100 == 0: #到100步时,将信息输出到stdout train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: logging.info('----------step %d:--------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logging.info(' F_loss : {}'.format(F_loss_val)) logging.info(' D_X_loss : {}'.format(D_X_loss_val)) if step % 10000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save( sess, checkpoints_dir + "/model.ckpt", global_step=step) #训练完成后,将训练好的model保存起来.ckpt; logging.info("Model saved in file: %s" % save_path) coord.request_stop() coord.join(threads)
def main(): num_epoch = 40000 pool_size = 20 batch_size = 1 oldpath = FLAGS.buckets picFpath = 'picF' picGpath = 'picG' useCopyfile = True if useCopyfile: trainfiles = ['picf1.zip', 'picf2.zip', 'picg1.zip'] # trainfiles.extend(['picf3.zip','picf4.zip','picg2.zip']) print(trainfiles) for f in trainfiles: fn = utils.pai_copy(f, oldpath) utils.Unzip(fn) picFpath = os.path.join('temp', picFpath) picGpath = os.path.join('temp', picGpath) print(picFpath) print(picGpath) sess = tf.InteractiveSession(config=tf.ConfigProto( allow_soft_placement=True)) cycle_gan = CycleGAN(X_train_file=picGpath, Y_train_file=picFpath, batch_size=batch_size, image_size=(270, 480), use_lsgan=True, lossfunc='wgan', norm='instance', learning_rate=3e-3, start_decay_step=5000, decay_steps=350000 #optimizer = 'RMSProp' ) G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.build() optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.checkpointDir) saver = tf.train.Saver(max_to_keep=0) sess.run( [tf.global_variables_initializer(), tf.local_variables_initializer()]) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # save_path = saver.save(sess,os.path.join(FLAGS.checkpointDir,"model_pre.ckpt")) # print("Model saved in file: %s" % save_path) fake_Y_pool = ImagePool(pool_size) fake_X_pool = ImagePool(pool_size) print('start train') start_time = time.time() for step in range(1, num_epoch + 1): # get previously generated images fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) # train _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op], feed_dict={ cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val) })) elapsed_time = time.time() - start_time start_time = time.time() if step % 25 == 0: print('G_loss : %s--D_Y_loss : %s--F_loss : %s--D_X_loss : %s--' % (G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val)) print('step : %s --elapsed_time : %s' % (step, elapsed_time)) print('adding summary...') train_writer.add_summary(summary, step) train_writer.flush() # if step % 100 == 0: # print('-----------Step %d:-------------' % step) # print(' G_loss : {}'.format(G_loss_val)) # print(' D_Y_loss : {}'.format(D_Y_loss_val)) # print(' F_loss : {}'.format(F_loss_val)) # print(' D_X_loss : {}'.format(D_X_loss_val)) if step % 1000 == 0: save_path = saver.save(sess, os.path.join(FLAGS.checkpointDir, "model.ckpt"), global_step=step, write_meta_graph=False) print("Model saved in file: %s" % save_path) coord.request_stop() coord.join(threads)
def main(unused_argv): total_step = 0 checkpoints_dir = './models/real2cartoon' summary_dir = './summary' graph = tf.Graph() with graph.as_default(): cycle_gan = CycleGAN(batch_size=FLAGS.batch_size, image_size=256, use_mse=FLAGS.use_mse, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda2, learning_rate=FLAGS.learning_rate, filters=FLAGS.filters, beta1=FLAGS.beta1, mse_label=FLAGS.mse_label, file_x=FLAGS.file_x, file_y=FLAGS.file_y) G_loss, F_loss, D_X_loss, D_Y_loss, fake_y, fake_x = cycle_gan.model() optimizers = cycle_gan.optimize(G_loss, F_loss, D_X_loss, D_Y_loss) summarys = tf.summary.merge_all() train_writer = tf.summary.FileWriter(summary_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: ckpt = tf.train.get_checkpoint_state(checkpoints_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path) total_step = int( next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) logger.info('load model success' + ckpt.model_checkpoint_path) else: sess.run(tf.global_variables_initializer()) logger.info('start new model') # img_x = utils.get_img(FLAGS.file_x, FLAGS.output_height, FLAGS.output_width, FLAGS.batch_size) # img_y = utils.get_img(FLAGS.file_y, FLAGS.output_height, FLAGS.output_width, FLAGS.batch_size) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: fake_X_pool = utils.ImagePool(FLAGS.pool_size) fake_Y_pool = utils.ImagePool(FLAGS.pool_size) while not coord.should_stop(): # img_x, img_y = read_file() fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summarys ], feed_dict={ cycle_gan.x: fake_X_pool.query(fake_x_val), cycle_gan.y: fake_Y_pool.query(fake_y_val) })) train_writer.add_summary(summary, total_step) train_writer.flush() logger.info('step: {}'.format(total_step)) if total_step > 1e5: sess.run(cycle_gan.learning_rate_decay_op()) if total_step % 100 == 0: logger.info('-----------Step %d:-------------' % total_step) logger.info(' G_loss : {}'.format(G_loss_val)) logger.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logger.info(' F_loss : {}'.format(F_loss_val)) logger.info(' D_X_loss : {}'.format(D_X_loss_val)) logger.info(' learning_rate : {}'.format( cycle_gan.learning_rate)) if total_step % 10000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=total_step) logger.info("Model saved in file: %s" % save_path) total_step += 1 except KeyboardInterrupt: logger.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=total_step) logger.info("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
def train(): # 如果存储中间训练结果的路径设置不为None 就从路径中读取数据继续训练,如果为None则建立一个新的,以时间命名的文件夹存储训练结果 if FLAGS.load_model is not None: checkpoints_dir = "checkpoints/" + FLAGS.load_model else: current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) os.makedirs(FLAGS.res_im_path) except os.error: pass graph = tf.Graph() with graph.as_default(): # 初始化 cyclegan 类 cycle_gan = CycleGAN(FLAGS) # 构建图 G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, real_y, real_x = cycle_gan.model( ) optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) # 初始化summary summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver(max_to_keep=10) with tf.Session(graph=graph) as sess: # 如果存储中间训练结果的路径设置不为None 就从路径中读取数据继续训练 if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 0 # 初始化样本队列 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: # 初始化在线样本池 fake_Y_pool = ImagePool(FLAGS.pool_size) fake_X_pool = ImagePool(FLAGS.pool_size) while not coord.should_stop(): # get previously generated images fake_y_val, fake_x_val, real_y_in, real_x_in = sess.run( [fake_y, fake_x, real_y, real_x]) # train _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op ], feed_dict={ cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val) })) train_writer.add_summary(summary, step) train_writer.flush() # 输出当前状态 if step % 1 == 0: logging.info('-----------Step %d:-------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logging.info(' F_loss : {}'.format(F_loss_val)) logging.info(' D_X_loss : {}'.format(D_X_loss_val)) if step % 1000 == 0: ops.save_img_result(fake_y_val, fake_x_val, real_y_in, real_x_in, FLAGS.res_im_path, step) if step % 1000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) step += 1 if step == FLAGS.epho: coord.request_stop() # 发出停止训练信号 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) ops.save_img_result(fake_y_val, fake_x_val, real_y_in, real_x_in, FLAGS.res_im_path, step) logging.info("Model saved in file: %s" % save_path) coord.request_stop() # 停止训练 coord.join(threads)
def train(): max_accuracy = 0.90 learning_loss_set = 4.0 if FLAGS.load_model is not None: checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip( "checkpoints/") else: current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) except os.error: pass graph = tf.Graph() with graph.as_default(): cycle_gan = CycleGAN( #X_train_file=FLAGS.X, #Y_train_file=FLAGS.Y, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda2, learning_rate=FLAGS.learning_rate, beta1=FLAGS.beta1, ngf=FLAGS.ngf) G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss, \ x_correct, y_correct, fake_x_correct, softmax3, fake_x_pre, f_fakeX, fake_x, fake_y_= cycle_gan.model() optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = "checkpoints/20190611-1650/model.ckpt-90000.meta" print('meta_graph_path', meta_graph_path) restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, "checkpoints/20190611-1650/model.ckpt-90000") step = 90000 #meta_graph_path = checkpoint.model_checkpoint_path + ".meta" #restore = tf.train.import_meta_graph(meta_graph_path) #restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) #step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) last_1 = 0.0 last_2 = 0.0 best_1 = 0.0 best_2 = 0.0 try: while not coord.should_stop(): x_image, x_label = get_train_batch("X", FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, "./dataset/") y_image, y_label = get_train_batch("Y", FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, "./dataset/") _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, teacher_loss_eval, student_loss_eval, learning_loss_eval, summary = ( sess.run( [ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss, summary_op ], feed_dict={ cycle_gan.x: x_image, cycle_gan.y: y_image, cycle_gan.x_label: x_label, cycle_gan.y_label: y_label })) train_writer.add_summary(summary, step) train_writer.flush() if step % 500 == 0: print('-----------Step %d:-------------' % step) print(' G_loss : {}'.format(G_loss_val)) print(' D_Y_loss : {}'.format(D_Y_loss_val)) print(' F_loss : {}'.format(F_loss_val)) print(' D_X_loss vb: {}'.format(D_X_loss_val)) print('teacher_loss: {}'.format(teacher_loss_eval)) print('student_loss: {}'.format(student_loss_eval)) print('learning_loss: {}'.format(learning_loss_eval)) if step % 2000 == 0 and step > 0: print('Now is in testing! Please wait result...') test_images_x, test_labels_x, _ = get_test_batch1( 'X', 1000, FLAGS.image_size, FLAGS.image_size, "./dataset/") test_images_y, test_labels_y = get_roc_batch( FLAGS.image_size, FLAGS.image_size, "./dataset/Y") y_correct_cout = 0 fake_x_correct_cout = 0 print(len(test_images_y)) print(len(test_images_x)) for i in range(min(len(test_images_y), len(test_images_x))): y_imgs = [] y_lbs = [] y_imgs.append(test_images_y[i]) y_lbs.append(test_labels_y[i]) y_correct_eval, fake_x_correct_eval = (sess.run( [y_correct, fake_x_correct], feed_dict={ cycle_gan.y: y_imgs, cycle_gan.y_label: y_lbs })) # for i in range(min(len(test_images_y),len(test_images_x))): # y_imgs=[] # y_lbs=[] # y_imgs.append(test_images_y[i]) # y_lbs.append(test_labels_y[i]) # x_imgs=[] # x_lbs=[] # x_imgs.append(test_images_x[i]) # x_lbs.append(test_labels_x[i]) # y_correct_eval,fake_x_correct_eval = ( # sess.run( # [y_correct,fake_x_correct], # feed_dict={cycle_gan.x: x_imgs, cycle_gan.y: y_imgs, # cycle_gan.x_label: x_lbs,cycle_gan.y_label: y_lbs} # ) # ) if y_correct_eval: y_correct_cout = y_correct_cout + 1 if fake_x_correct_eval: fake_x_correct_cout = fake_x_correct_cout + 1 print('fake_x_correct_cout', fake_x_correct_cout) print('x_accuracy: {}'.format( y_correct_cout / (min(len(test_labels_y), len(test_labels_x))))) print('fake_x_accuracy: {}'.format( fake_x_correct_cout / (min(len(test_labels_y), len(test_labels_x))))) y_accuracy_1 = format( y_correct_cout / (min(len(test_labels_y), len(test_labels_x)))) fake_y_accuracy_1 = format( fake_x_correct_cout / (min(len(test_labels_y), len(test_labels_x)))) save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) print("Model saved in file: %s" % save_path) if float(fake_y_accuracy_1) > max_accuracy: max_accuracy = float(fake_y_accuracy_1) if not os.path.exists(checkpoints_dir): os.makedirs(checkpoints_dir) f = open(checkpoints_dir + "/step.txt", 'w') f.seek(0) f.truncate() f.write(str(step) + '\n') f.write((fake_y_accuracy_1 + '\n')) f.close() save_path = saver.save(sess, checkpoints_dir + "/bestmodel/model.ckpt", global_step=step) print("Model saved in file: %s" % save_path) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) print("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)