def main(_): pp.pprint(flags.FLAGS.__flags) if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) with tf.Session() as sess: srcnn = CGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, c_dim=FLAGS.c_dim, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) srcnn.train(FLAGS)
def test(): if not os.path.exists(conf.output_path): os.makedirs(conf.output_path) data = load_data() model = CGAN() saver = tf.train.Saver() counter = 0 start_time = time.time() with tf.Session() as sess: saver.restore(sess, conf.model_path_test) test_data = data["test"]() for img, cond, name in test_data: pimg, pcond = prepocess_test(img, cond) gen_img = sess.run(model.gen_img, feed_dict={ model.image: pimg, model.cond: pcond }) gen_img = gen_img.reshape(gen_img.shape[1:-1]) gen_img1 = (gen_img + 1.) * 127.5 print(gen_img1) path_save = conf.output_path + "/" + "%s" % (name) print(path_save) scipy.misc.imsave(path_save, gen_img1)
def train(): model = CGAN() d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize( model.d_loss, var_list=model.d_vars) g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize( model.g_loss, var_list=model.g_vars) saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, conf.model_path) np_path = "/home/chenyiru/FluxPreservation/demo_out/test/587724648721678356.npy" all = np.load(np_path) img, cond = all[:, :conf.img_size], all[:, conf.img_size:] pimg, pcond = prepocess_test(img, cond) gen_img = sess.run(model.gen_img, feed_dict={ model.image: pimg, model.cond: pcond }) gen_img = gen_img.reshape(gen_img.shape[1:]) image = np.concatenate((gen_img, cond), axis=1) np.save( "/home/chenyiru/FluxPreservation/demo_out/587724648721678356.npy", image)
def test(): if not os.path.exists("test"): os.makedirs("test") data = load_data() model = CGAN() d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize( model.d_loss, var_list=model.d_vars) g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize( model.g_loss, var_list=model.g_vars) saver = tf.train.Saver() counter = 0 start_time = time.time() with tf.Session() as sess: saver.restore(sess, conf.model_path_test) test_data = data["test"]() for img, cond, name in test_data: pimg, pcond = prepocess_test(img, cond) gen_img = sess.run(model.gen_img, feed_dict={ model.image: pimg, model.cond: pcond }) gen_img = gen_img.reshape(gen_img.shape[1:]) gen_img = (gen_img + 1.) * 127.5 image = np.concatenate((gen_img, cond), axis=1).astype(np.int) imsave(image, "./test" + "/%s" % name)
def __init__(self, flags): run_config = tf.ConfigProto() run_config.gpu_options.allow_growth = True self.sess = tf.Session(config=run_config) self.flags = flags self.dataset = Dataset(self.flags.dataset, self.flags) self.model = CGAN(self.sess, self.flags, self.dataset.image_size) self.best_auc_sum = 0. self._make_folders() self.saver = tf.train.Saver() self.sess.run(tf.global_variables_initializer()) tf_utils.show_all_variables()
def train(args): """ train model """ batch_size = args.batch_size epochs = args.epochs base_lr = args.lr cgan = CGAN(args.name) train_dataset, _ = get_mnist_dataset(batch_size) with tf.Session() as sess: try: cgan.train(sess, train_dataset, base_lr=base_lr, epochs=epochs, save_period=10, reset_logs=args.reset_logs, version=args.version) except KeyboardInterrupt: print_with_time('Interrupted by user.') else: print_with_time('Training finished.') finally: print_with_time('Saving servable..') cgan.export(sess, export_dir=f'{args.name}_export', version=args.version)
def train(): data = load_data() model = CGAN() d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rateD, beta1=conf.beta1).minimize( model.d_loss, var_list=model.d_vars) g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rateG, beta1=conf.beta1).minimize( model.g_loss, var_list=model.g_vars) saver = tf.train.Saver() start_time = time.time() if not os.path.exists(conf.data_path_checkpoint + "/checkpoint"): os.makedirs(conf.data_path_checkpoint + "/checkpoint") if not os.path.exists(conf.output_path): os.makedirs(conf.output_path) config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: if conf.model_path_train == "": sess.run(tf.global_variables_initializer()) else: saver.restore(sess, conf.model_path_train) for epoch in range(conf.max_epoch): counter = 0 train_data = data["train"]() for img, cond, name in train_data: pimg, pcond = prepocess_train(img, cond) _, m = sess.run([d_opt, model.d_loss], feed_dict={ model.image: pimg, model.cond: pcond }) _, M = sess.run([g_opt, model.g_loss], feed_dict={ model.image: pimg, model.cond: pcond }) counter += 1 if counter % 50 == 0: print ("Epoch [%s], Iteration [%s]: time: %s, d_loss: %s, g_loss: %s" \ % (epoch, counter, time.time() - start_time, m, M)) if (epoch + 1) % conf.save_per_epoch == 0: save_path = saver.save( sess, conf.data_path_checkpoint + "/checkpoint/" + "model_%d.ckpt" % (epoch + 1)) print("Model saved in file: %s" % (save_path))
def train(): data = load_data() model = CGAN() d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.d_loss, var_list=model.d_vars) g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.g_loss, var_list=model.g_vars) saver = tf.train.Saver() start_time = time.time() if not os.path.exists(conf.data_path + "/checkpoint"): os.makedirs(conf.data_path + "/checkpoint") if not os.path.exists(conf.output_path): os.makedirs(conf.output_path) config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: if conf.model_path_train == "": #if not os.path.exists(conf.data_path + "/checkpoint"): sess.run(tf.global_variables_initializer()) else: saver.restore(sess, conf.model_path_train) #saver.restore(sess, conf.data_path + "/checkpoint/") for epoch in np.arange(conf.max_epoch): counter = 0 train_data = data["train"]() for img, cond, name in train_data: img, cond = prepocess_train(img, cond) _, m = sess.run([d_opt, model.d_loss], feed_dict={model.image:img, model.cond:cond}) _, m = sess.run([d_opt, model.d_loss], feed_dict={model.image:img, model.cond:cond}) _, M = sess.run([g_opt, model.g_loss], feed_dict={model.image:img, model.cond:cond}) counter += 1 if counter % 50 ==0: print("Epoch [%d], Iteration [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f" % (epoch, counter, time.time() - start_time, m, M)) if (epoch + 1) % conf.save_per_epoch == 0: save_path = saver.save(sess, conf.data_path + "/checkpoint/" + "model_%d.ckpt" % (epoch+1)) print("Model saved in file: %s" % save_path) test_data = data["test"]() for img, cond, name in test_data: pimg, pcond = prepocess_test(img, cond) gen_img = sess.run(model.gen_img, feed_dict={model.image:pimg, model.cond:pcond}) gen_img = gen_img.reshape(gen_img.shape[1:]) gen_img = (gen_img + 1.) * 127.5 image = np.concatenate((gen_img, cond), axis=1).astype(np.int) imsave(image, conf.output_path + "/%s" % name)
def test(mode): data = load_data() model = CGAN() saver = tf.train.Saver() counter = 0 start_time = time.time() out_dir = conf.result_path filter_string = conf.filter_ if not os.path.exists(conf.save_path): os.makedirs(conf.save_path) if not os.path.exists(out_dir): os.makedirs(out_dir) start_epoch = 0 with tf.Session() as sess: saver.restore(sess, conf.model_path) for epoch in xrange(start_epoch, conf.max_epoch): if (epoch + 1) % conf.save_per_epoch == 0: test_data = data[str(mode)]() for img, cond, name in test_data: name = name.replace('-' + filter_string + '.npy', '') pimg, pcond = prepocess_test(img, cond) gen_img = sess.run(model.gen_img, feed_dict={ model.image: pimg, model.cond: pcond }) gen_img = gen_img.reshape(gen_img.shape[1:]) fits_recover = conf.unstretch(gen_img[:, :, 0]) hdu = fits.PrimaryHDU(fits_recover) save_dir = '%s/epoch_%s/fits_output' % (out_dir, epoch + 1) if not os.path.exists(save_dir): os.makedirs(save_dir) filename = '%s/%s-%s.fits' % (save_dir, name, filter_string) if os.path.exists(filename): os.remove(filename) hdu.writeto(filename)
def train(args): train_loader = DataLoader(DonutDataset(root=args.data_root, is_train=True), batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=3) test_loader = DataLoader(DonutDataset(root=args.data_root, is_train=False), batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=3) if args.model == 'regressor': model = Model(args) elif args.model == 'gan': model = CGAN(args) else: raise Exception('Not implemented') for epoch in range(args.epochs): print("EPOCH: ", epoch) model.train_one_epoch(train_loader, epoch) model.test_one_epoch(test_loader, epoch)
class Solver(object): def __init__(self, flags): run_config = tf.ConfigProto() run_config.gpu_options.allow_growth = True self.sess = tf.Session(config=run_config) self.flags = flags self.dataset = Dataset(self.flags.dataset, self.flags) self.model = CGAN(self.sess, self.flags, self.dataset.image_size) self.best_auc_sum = 0. self._make_folders() self.saver = tf.train.Saver() self.sess.run(tf.global_variables_initializer()) tf_utils.show_all_variables() def _make_folders(self): self.model_out_dir = "{}/model_{}_{}_{}".format( self.flags.dataset, self.flags.discriminator, self.flags.train_interval, self.flags.batch_size) if not os.path.isdir(self.model_out_dir): os.makedirs(self.model_out_dir) if self.flags.is_test: self.img_out_dir = "{}/seg_result_{}_{}_{}".format( self.flags.dataset, self.flags.discriminator, self.flags.train_interval, self.flags.batch_size) self.auc_out_dir = "{}/auc_{}_{}_{}".format( self.flags.dataset, self.flags.discriminator, self.flags.train_interval, self.flags.batch_size) if not os.path.isdir(self.img_out_dir): os.makedirs(self.img_out_dir) if not os.path.isdir(self.auc_out_dir): os.makedirs(self.auc_out_dir) elif not self.flags.is_test: self.sample_out_dir = "{}/sample_{}_{}_{}".format( self.flags.dataset, self.flags.discriminator, self.flags.train_interval, self.flags.batch_size) if not os.path.isdir(self.sample_out_dir): os.makedirs(self.sample_out_dir) def train(self): for iter_time in range(0, self.flags.iters + 1, self.flags.train_interval): self.sample(iter_time) # sampling images and save them # train discrminator for iter_ in range(1, self.flags.train_interval + 1): x_imgs, y_imgs = self.dataset.train_next_batch( batch_size=self.flags.batch_size) d_loss = self.model.train_dis(x_imgs, y_imgs) self.print_info(iter_time + iter_, 'd_loss', d_loss) # train generator for iter_ in range(1, self.flags.train_interval + 1): x_imgs, y_imgs = self.dataset.train_next_batch( batch_size=self.flags.batch_size) g_loss = self.model.train_gen(x_imgs, y_imgs) self.print_info(iter_time + iter_, 'g_loss', g_loss) auc_sum = self.eval(iter_time, phase='train') if self.best_auc_sum < auc_sum: self.best_auc_sum = auc_sum self.save_model(iter_time) def test(self): if self.load_model(): print(' [*] Load Success!\n') self.eval(phase='test') else: print(' [!] Load Failed!\n') def sample(self, iter_time): if np.mod(iter_time, self.flags.sample_freq) == 0: idx = np.random.choice(self.dataset.num_val, 2, replace=False) x_imgs, y_imgs = self.dataset.val_imgs[ idx], self.dataset.val_vessels[idx] samples = self.model.sample_imgs(x_imgs) # masking seg_samples = utils.remain_in_mask(samples, self.dataset.val_masks[idx]) # crop to original image shape x_imgs_ = utils.crop_to_original(x_imgs, self.dataset.ori_shape) seg_samples_ = utils.crop_to_original(seg_samples, self.dataset.ori_shape) y_imgs_ = utils.crop_to_original(y_imgs, self.dataset.ori_shape) # sampling self.plot(x_imgs_, seg_samples_, y_imgs_, iter_time, idx=idx, save_file=self.sample_out_dir, phase='train') def plot(self, x_imgs, samples, y_imgs, iter_time, idx=None, save_file=None, phase='train'): # initialize grid size cell_size_h, cell_size_w = self.dataset.ori_shape[ 0] / 100, self.dataset.ori_shape[1] / 100 num_columns, margin = 3, 0.05 width = cell_size_w * num_columns height = cell_size_h * x_imgs.shape[0] fig = plt.figure(figsize=(width, height)) # (column, row) gs = gridspec.GridSpec(x_imgs.shape[0], num_columns) # (row, column) gs.update(wspace=margin, hspace=margin) # convert from normalized to original image x_imgs_norm = np.zeros_like(x_imgs) std, mean = 0., 0. for _ in range(x_imgs.shape[0]): if phase == 'train': std = self.dataset.val_mean_std[idx[_]]['std'] mean = self.dataset.val_mean_std[idx[_]]['mean'] elif phase == 'test': std = self.dataset.test_mean_std[idx[_]]['std'] mean = self.dataset.test_mean_std[idx[_]]['mean'] x_imgs_norm[_] = np.expand_dims(x_imgs[_], axis=0) * std + mean x_imgs_norm = x_imgs_norm.astype(np.uint8) # 1 channel to 3 channels samples_3 = np.stack((samples, samples, samples), axis=3) y_imgs_3 = np.stack((y_imgs, y_imgs, y_imgs), axis=3) imgs = [x_imgs_norm, samples_3, y_imgs_3] for col_index in range(len(imgs)): for row_index in range(x_imgs.shape[0]): ax = plt.subplot(gs[row_index * num_columns + col_index]) plt.axis('off') ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_aspect('equal') plt.imshow(imgs[col_index][row_index].reshape( self.dataset.ori_shape[0], self.dataset.ori_shape[1], 3), cmap='Greys_r') if phase == 'train': plt.savefig(save_file + '/{}_{}.png'.format(str(iter_time), idx[0]), bbox_inches='tight') plt.close(fig) else: # save compared image plt.savefig(os.path.join( save_file, 'compared_{}.png'.format( os.path.basename( self.dataset.test_img_files[idx[0]])[:-4])), bbox_inches='tight') plt.close(fig) # save vessel alone, vessel should be uint8 type Image.fromarray(np.squeeze(samples * 255).astype(np.uint8)).save( os.path.join( save_file, '{}.png'.format( os.path.basename( self.dataset.test_img_files[idx[0]][:-4])))) def print_info(self, iter_time, name, loss): if np.mod(iter_time, self.flags.print_freq) == 0: ord_output = collections.OrderedDict([ (name, loss), ('dataset', self.flags.dataset), ('discriminator', self.flags.discriminator), ('train_interval', np.float32(self.flags.train_interval)), ('gpu_index', self.flags.gpu_index) ]) utils.print_metrics(iter_time, ord_output) def eval(self, iter_time=0, phase='train'): total_time, auc_sum = 0., 0. if np.mod(iter_time, self.flags.eval_freq) == 0: num_data, imgs, vessels, masks = None, None, None, None if phase == 'train': num_data = self.dataset.num_val imgs = self.dataset.val_imgs vessels = self.dataset.val_vessels masks = self.dataset.val_masks elif phase == 'test': num_data = self.dataset.num_test imgs = self.dataset.test_imgs vessels = self.dataset.test_vessels masks = self.dataset.test_masks generated = [] for iter_ in range(num_data): x_img = imgs[iter_] x_img = np.expand_dims(x_img, axis=0) # (H, W, C) to (1, H, W, C) # measure inference time start_time = time.time() generated_vessel = self.model.sample_imgs(x_img) total_time += (time.time() - start_time) generated.append(np.squeeze( generated_vessel, axis=(0, 3))) # (1, H, W, 1) to (H, W) generated = np.asarray(generated) # calculate measurements auc_sum = self.measure(generated, vessels, masks, num_data, iter_time, phase, total_time) if phase == 'test': # save test images segmented_vessel = utils.remain_in_mask(generated, masks) # crop to original image shape imgs_ = utils.crop_to_original(imgs, self.dataset.ori_shape) cropped_vessel = utils.crop_to_original( segmented_vessel, self.dataset.ori_shape) vessels_ = utils.crop_to_original(vessels, self.dataset.ori_shape) for idx in range(num_data): self.plot(np.expand_dims(imgs_[idx], axis=0), np.expand_dims(cropped_vessel[idx], axis=0), np.expand_dims(vessels_[idx], axis=0), 'test', idx=[idx], save_file=self.img_out_dir, phase='test') return auc_sum def measure(self, generated, vessels, masks, num_data, iter_time, phase, total_time): # masking vessels_in_mask, generated_in_mask = utils.pixel_values_in_mask( vessels, generated, masks) # averaging processing time avg_pt = (total_time / num_data) * 1000 # average processing tiem # evaluate Area Under the Curve of ROC and Precision-Recall auc_roc = utils.AUC_ROC(vessels_in_mask, generated_in_mask) auc_pr = utils.AUC_PR(vessels_in_mask, generated_in_mask) # binarize to calculate Dice Coeffient binarys_in_mask = utils.threshold_by_otsu(generated, masks) dice_coeff = utils.dice_coefficient_in_train(vessels_in_mask, binarys_in_mask) acc, sensitivity, specificity = utils.misc_measures( vessels_in_mask, binarys_in_mask) score = auc_pr + auc_roc + dice_coeff + acc + sensitivity + specificity # auc_sum for saving best model in training auc_sum = auc_roc + auc_pr # print information ord_output = collections.OrderedDict([('auc_pr', auc_pr), ('auc_roc', auc_roc), ('dice_coeff', dice_coeff), ('acc', acc), ('sensitivity', sensitivity), ('specificity', specificity), ('score', score), ('auc_sum', auc_sum), ('best_auc_sum', self.best_auc_sum), ('avg_pt', avg_pt)]) utils.print_metrics(iter_time, ord_output) # write in tensorboard when in train mode only if phase == 'train': self.model.measure_assign(auc_pr, auc_roc, dice_coeff, acc, sensitivity, specificity, score, iter_time) elif phase == 'test': # write in npy format for evaluation utils.save_obj(vessels_in_mask, generated_in_mask, os.path.join(self.auc_out_dir, "auc_roc.npy"), os.path.join(self.auc_out_dir, "auc_pr.npy")) return auc_sum def save_model(self, iter_time): self.model.best_auc_sum_assign(self.best_auc_sum) model_name = "iter_{}_auc_sum_{:.3}".format(iter_time, self.best_auc_sum) self.saver.save(self.sess, os.path.join(self.model_out_dir, model_name)) print('===================================================') print(' Model saved! ') print(' Best auc_sum: {:.3}'.format(self.best_auc_sum)) print('===================================================\n') def load_model(self): print(' [*] Reading checkpoint...') ckpt = tf.train.get_checkpoint_state(self.model_out_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) self.saver.restore(self.sess, os.path.join(self.model_out_dir, ckpt_name)) self.best_auc_sum = self.sess.run(self.model.best_auc_sum) print('====================================================') print(' Model saved! ') print(' Best auc_sum: {:.3}'.format(self.best_auc_sum)) print('====================================================') return True else: return False
def run_gan(): (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data() train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') train_images = (train_images - 127.5) / 127.5 # Normalize images to [-1,1] print(train_images.shape) train_labels = to_categorical(train_labels) print(train_labels.shape) # Batch and shuffle the data train_dataset = tf.data.Dataset.from_tensor_slices( (train_images, train_labels)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) gan = CGAN(gen_lr, disc_lr, noise_dim=NOISE_DIM) gan.create_generator() gan.create_discriminator() if model_test: # Test generator random_noise = tf.random.normal([1, NOISE_DIM]) condition = tf.zeros(shape=(1, 10)) generated_image = gan.generator([random_noise, condition]) plt.imshow(generated_image[0, :, :, 0], cmap='gray') plt.show() # Test Discriminator prob = gan.discriminator([generated_image, condition]) print("Probability of image being real: {}".format(sigmoid(prob))) gan.set_noise_seed(num_examples_to_generate) print(gan.label_seed.shape) gan.set_checkpoint(path=save_ckpt_path) gen_loss_array, disc_loss_array = gan.train(train_dataset, epochs=EPOCHS) # Plot Discriminator Loss plt.plot(range(EPOCHS), gen_loss_array) plt.plot(range(EPOCHS), disc_loss_array) plt.show()
def train(evalset): data = load_data() model = CGAN() d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.d_loss, var_list=model.d_vars) g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.g_loss, var_list=model.g_vars) saver = tf.train.Saver() counter = 0 start_time = time.time() out_dir = conf.result_path filter_string = conf.filter_ if not os.path.exists(conf.save_path): os.makedirs(conf.save_path) if not os.path.exists(out_dir): os.makedirs(out_dir) start_epoch = 0 with tf.Session() as sess: if conf.model_path == "": sess.run(tf.global_variables_initializer()) else: saver.restore(sess, conf.model_path) try: log = open(conf.save_path + "/log") start_epoch = int(log.readline()) log.close() except: pass for epoch in xrange(start_epoch, conf.max_epoch): train_data = data["train"]() for img, cond, _ in train_data: img, cond = prepocess_train(img, cond) _, m = sess.run([d_opt, model.d_loss], feed_dict={model.image: img, model.cond: cond}) _, m = sess.run([d_opt, model.d_loss], feed_dict={model.image: img, model.cond: cond}) _, M, flux = sess.run([g_opt, model.g_loss, model.delta], feed_dict={model.image: img, model.cond: cond}) counter += 1 print("Iterate [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f, flux: %.8f" \ % (counter, time.time() - start_time, m, M, flux)) if (epoch + 1) % conf.save_per_epoch == 0: # save_path = saver.save(sess, conf.data_path + "/checkpoint/" + "model_%d.ckpt" % (epoch+1)) save_path = saver.save(sess, conf.save_path + "/model.ckpt") print("Model at epoch %s saved in file: %s" % (epoch + 1, save_path)) log = open(conf.save_path + "/log", "w") log.write(str(epoch + 1)) log.close() test_data = data[str(evalset)]() for img, cond, name in test_data: name = name.replace('-'+filter_string+'.npy', '') pimg, pcond = prepocess_test(img, cond) gen_img = sess.run(model.gen_img, feed_dict={model.image: pimg, model.cond: pcond}) gen_img = gen_img.reshape(gen_img.shape[1:]) fits_recover = conf.unstretch(gen_img[:, :, 0]) hdu = fits.PrimaryHDU(fits_recover) save_dir = '%s/epoch_%s/fits_output' % (out_dir, epoch + 1) if not os.path.exists(save_dir): os.makedirs(save_dir) filename = '%s/%s-%s.fits' % (save_dir, name, filter_string) if os.path.exists(filename): os.remove(filename) hdu.writeto(filename)
def train(): data = load_data() model = CGAN() d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize( model.d_loss, var_list=model.d_vars) g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize( model.g_loss, var_list=model.g_vars) saver = tf.train.Saver() counter = 0 start_time = time.time() if not os.path.exists(conf.data_path + "/checkpoint"): os.makedirs(conf.data_path + "/checkpoint") if not os.path.exists(conf.output_path): os.makedirs(conf.output_path) config = tf.ConfigProto() config.gpu_options.allow_growth = True mpsnr_img = [] mpsnr_cond = [] with tf.Session(config=config) as sess: if conf.model_path == "": sess.run(tf.initialize_all_variables()) else: saver.restore(sess, conf.model_path) print conf.max_epoch for epoch in xrange(conf.max_epoch): train_data = data["train"]() for img, cond, name in train_data: img, cond = prepocess_train(img, cond) _, m = sess.run([d_opt, model.d_loss], feed_dict={ model.image: img, model.cond: cond }) _, m = sess.run([d_opt, model.d_loss], feed_dict={ model.image: img, model.cond: cond }) _, M = sess.run([g_opt, model.g_loss], feed_dict={ model.image: img, model.cond: cond }) counter += 1 print "Iterate [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ % (counter, time.time() - start_time, m, M) if (epoch + 1) % conf.save_per_epoch == 0: save_path = saver.save( sess, conf.data_path + "/checkpoint/" + "model_%d.ckpt" % (epoch + 1)) print "Model saved in file: %s" % save_path mean_psnr_img = 0 mean_psnr_cond = 0 i = 0 test_data = data["test"]() for img, cond, name in test_data: pimg, pcond = prepocess_test(img, cond) gen_img = sess.run(model.gen_img, feed_dict={ model.image: pimg, model.cond: pcond }) gen_img = gen_img.reshape(gen_img.shape[1:]) gen_img = (gen_img + 1.) * 127.5 #print type(img), type(cond), type(gen_img), img.shape, cond.shape, gen_img.shape, img.dtype, cond.dtype, gen_img.dtype mean_psnr_img = mean_psnr_img + skimage.measure.compare_psnr( img, gen_img.astype(np.uint8)) mean_psnr_cond = mean_psnr_cond + skimage.measure.compare_psnr( cond, gen_img.astype(np.uint8)) image = np.concatenate((gen_img, cond), axis=1).astype(np.int) i = i + 1 imsave(image, conf.output_path + "/%s" % name) mean_psnr_img = mean_psnr_img / i mpsnr_img.append(mean_psnr_img) mean_psnr_cond = mean_psnr_cond / i mpsnr_cond.append(mean_psnr_cond) print mpsnr_cond print mpsnr_img plt.plot(mpsnr_cond) plt.show()
def train(): data = load_data() model = CGAN() d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize( model.d_loss, var_list=model.d_vars) g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize( model.g_loss, var_list=model.g_vars) saver = tf.train.Saver() counter = 0 start_time = time.time() if not os.path.exists(conf.save_path): os.makedirs(conf.save_path) if not os.path.exists(conf.output_path): os.makedirs(conf.output_path) start_epoch = 0 try: log = open(conf.save_path + "/log") start_epoch = int(log.readline()) log.close() except: pass with tf.Session() as sess: if conf.model_path == "": sess.run(tf.global_variables_initializer()) else: saver.restore(sess, conf.model_path) for epoch in xrange(start_epoch, conf.max_epoch): train_data = data["train"]() for img, cond, _ in train_data: img, cond = prepocess_train(img, cond) _, m = sess.run([d_opt, model.d_loss], feed_dict={ model.image: img, model.cond: cond }) _, m = sess.run([d_opt, model.d_loss], feed_dict={ model.image: img, model.cond: cond }) _, M, flux = sess.run([g_opt, model.g_loss, model.delta], feed_dict={ model.image: img, model.cond: cond }) counter += 1 print "Iterate [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f, flux: %.8f"\ % (counter, time.time() - start_time, m, M, flux) if (epoch + 1) % conf.save_per_epoch == 0: save_path = saver.save(sess, conf.save_path + "/model.ckpt") print "Model saved in file: %s" % save_path log = open(conf.save_path + "/log", "w") log.write(str(epoch + 1)) log.close() test_data = data["test"]() test_count = 0 for img, cond, name in test_data: test_count += 1 pimg, pcond = prepocess_test(img, cond) gen_img = sess.run(model.gen_img, feed_dict={ model.image: pimg, model.cond: pcond }) gen_img = gen_img.reshape(gen_img.shape[1:]) image = np.concatenate((gen_img, cond), axis=1) np.save(conf.output_path + "/" + name, image)
def cvt_output(model_output): img = model_output.data.numpy()[0] img = np.transpose(img, (1, 2, 0)) img = 0.5 * img + 0.5 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #img = np.transpose(img, (2, 0, 1)) return img state = torch.load('model.tar', map_location=lambda storage, loc: storage) config = state['config'] model = CGAN(config) model.load_state_dict(state['state_dict']) model.eval() with open('vocab.pkl', 'rb') as f: vocab = pickle.load(f) conditions = [] with open(sys.argv[1]) as f: for line in f: line = line.split(',')[1] line = line.split() conditions.append({'hair': [line[0]], 'eyes': [line[2]]}) generated_imgs = []
def train(): data = load_data() model = CGAN() d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize( model.d_loss, var_list=model.d_vars) g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize( model.g_loss, var_list=model.g_vars) saver = tf.train.Saver() counter = 0 start_time = time.time() if not os.path.exists(conf.data_path + "/checkpoint"): os.makedirs(conf.data_path + "/checkpoint") if not os.path.exists(conf.output_path): os.makedirs(conf.output_path) with tf.Session() as sess: if conf.model_path == "": sess.run(tf.initialize_all_variables()) else: saver.restore(sess, conf.model_path) for epoch in xrange(conf.max_epoch): train_data = data["train"] for img, cond in train_data: img, cond = prepocess_train(img, cond) _, m = sess.run([d_opt, model.d_loss], feed_dict={ model.image: img, model.cond: cond }) _, m = sess.run([d_opt, model.d_loss], feed_dict={ model.image: img, model.cond: cond }) _, M = sess.run([g_opt, model.g_loss], feed_dict={ model.image: img, model.cond: cond }) counter += 1 print "Iterate [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ % (counter, time.time() - start_time, m, M) if (epoch + 1) % conf.save_per_epoch == 0: save_path = saver.save( sess, conf.data_path + "/checkpoint/" + "model_%d.ckpt" % (epoch + 1)) print "Model saved in file: %s" % save_path test_data = data["test"] test_count = 0 for img, cond in test_data: test_count += 1 pimg, pcond = prepocess_test(img, cond) gen_img = sess.run(model.gen_img, feed_dict={ model.image: pimg, model.cond: pcond }) gen_img = gen_img.reshape(gen_img.shape[1:]) gen_img = (gen_img + 1.) * 127.5 image = np.concatenate((gen_img, cond), axis=1).astype(np.int) imsave(image, conf.output_path + "/%d.jpg" % test_count)