def main(_): if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) if not os.path.exists(FLAGS.summaries_dir): os.makedirs(FLAGS.summaries_dir) sess = tf.InteractiveSession() srcnn = SRCNN(sess, imageSize=FLAGS.image_size, labelSize=FLAGS.label_size, batchSize=FLAGS.batch_size, channel=FLAGS.input_channel, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) srcnn.mode()
def main(_): pp.pprint(flags.FLAGS.__flags) if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) with tf.Session() as sess: srcnn = SRCNN(sess, image_size=FLAGS.image_size, label_size=FLAGS.label_size, batch_size=FLAGS.batch_size, c_dim=FLAGS.c_dim, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) srcnn.train(FLAGS)
def main(_): pp.pprint(flags.FLAGS.__flags) if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: srcnn = SRCNN(sess, image_size=FLAGS.image_size, label_size=FLAGS.label_size, batch_size=FLAGS.batch_size, c_dim=FLAGS.c_dim, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) srcnn.train(FLAGS)
def main(_): pp.pprint(flags.FLAGS.__flags) if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) os.environ["CUDA_VISIBLE_DEVICES"] = "0" config = tf.ConfigProto(allow_soft_placement=True) with tf.device('/gpu:0'): with tf.Session(config=config) as sess: srcnn = SRCNN(sess, image_size=FLAGS.image_size, label_size=FLAGS.label_size, batch_size=FLAGS.batch_size, c_dim=FLAGS.c_dim, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir, config=FLAGS) srcnn.train(FLAGS)
def main(_): pp.pprint(flags.FLAGS.__flags) if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) srcnn = SRCNN(image_size=FLAGS.image_size, label_size=FLAGS.label_size, batch_size=FLAGS.batch_size, c_dim=FLAGS.c_dim, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir, FLAGS=FLAGS) srcnn.call() if FLAGS.is_train == True: srcnn.train() else: srcnn.inference()
def main(args): srcnn = SRCNN(image_size=args.image_size, c_dim=args.c_dim, is_training=False) X_pre_test, X_test, Y_test = load_test(scale=args.scale) predicted_list = [] for img in X_test: predicted = srcnn.process(img.reshape(1, img.shape[0], img.shape[1], 1)) predicted_list.append( predicted.reshape(predicted.shape[1], predicted.shape[2], 1)) n_img = len(predicted_list) dirname = './result' for i in range(n_img): imgname = 'image{:02}'.format(i) cv2.imwrite(os.path.join(dirname, imgname + '_original.bmp'), X_pre_test[i]) cv2.imwrite(os.path.join(dirname, imgname + '_input.bmp'), X_test[i]) cv2.imwrite(os.path.join(dirname, imgname + '_answer.bmp'), Y_test[i]) cv2.imwrite(os.path.join(dirname, imgname + '_predicted.bmp'), predicted_list[i])
def test(): print("process the image to h5file.....") test_dir = flags.test_dir test_h5_dir = flags.test_h5_dir stride = flags.test_stride if not os.path.exists(test_h5_dir): os.makedirs(test_h5_dir) test_set5 = os.path.join(test_dir, 'Set5') test_set14 = os.path.join(test_dir, 'Set14') path_set5 = os.path.join(test_h5_dir, 'Set5') path_set14 = os.path.join(test_h5_dir, 'Set14') data_helper.gen_input_image(test_set5, path_set5, stride) data_helper.gen_input_image(test_set14, path_set14, stride) print("initialize the model......") model_dir = flags.model_dir model = SRCNN(flags) model.build_graph() saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(model_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(model.sess, ckpt.model_checkpoint_path) else: print("model info didn't exist!") raise ValueError print("test in Set5......") test_h5_path = os.path.join(path_set5, "data.h5") data_set5, label_set5 = data_helper.load_data(test_h5_path) accu = model.test(data_set5, label_set5) print("the accuracy in Set5 is %.5f", accu) print("test in Set14......") test_h5_path = os.path.join(path_set14, "data.h5") data_set14, label_set14 = data_helper.load_data(test_h5_path) accu2 = model.test(data_set14, label_set14) print("the accuracy in Set14 is %.5f", accu2)
def train(training_data, dev_data, args): training_gen = data.DataLoader(training_data, batch_size=2) dev_gen = data.DataLoader(dev_data, batch_size=2) device = torch.device('cuda' if cuda.is_available() else 'cpu') print('Initializing model') model = SRCNN() loss = RMSE() if cuda.device_count() > 1: print('Using %d CUDA devices' % cuda.device_count()) model = nn.DataParallel( model, device_ids=[i for i in range(cuda.device_count())]) model.to(device) loss.to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) def _train(data, opt=True): total = 0 for y, x in data: y, x = y.to(device), x.to(device) pred_y = model(x) l = loss(pred_y, y) total += l.item() if opt: optimizer.zero_grad() l.backward() optimizer.step() cuda.synchronize() return total print('Training') for ep in range(args.ep): train_loss = _train(training_gen) dev_loss = _train(dev_gen, opt=False) print_flush('Epoch %d: Train %.4f Dev %.4f' % (ep, train_loss, dev_loss)) if ep % 50 == 0: save_model(model, args.o) return model
def main(_): """3.print configurations""" print('tf version:', tf.__version__) print('tf setup:') for k, v in FLAGS.flag_values_dict().items(): print(k, v) FLAGS.TB_dir += '_' + str(FLAGS.c_dim) """4.check/create folders""" if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.TB_dir): os.makedirs(FLAGS.TB_dir) """5.begin tf session""" with tf.Session() as sess: """6.init srcnn model""" srcnn = SRCNN(sess, FLAGS) """7.start to train/test""" if (FLAGS.is_train): srcnn.train() elif FLAGS.patch_test: srcnn.test() else: srcnn.test_whole_img()
def main(): dataloaders = myDataloader() test_loader = dataloaders.getTestLoader(batch_size) model = SRCNN().cuda() model.load_state_dict( torch.load("./result/train/" + str(epoch) + "srcnnParms.pth")) model.eval() with torch.no_grad(): for i, (pic, blurPic, index) in enumerate(test_loader): pic = pic.cuda() blurPic = blurPic.cuda() out = model(blurPic) res = torch.cosine_similarity(pic, out, dim=1) res = res[0] minValue = torch.min(res) meanValue = res.mean() output = 1 - ((res >= meanValue) + 0) plt.figure() plt.title(index) plt.imshow(output.cpu(), cmap="gray") plt.show() break
def upscaling(): data = json.loads(request.get_data(as_text=True)) times = data['times'] picname = data['picname'] # 判断是否重复 timespic = picname[0:picname.find('.')] + times + "x_" picture = Picture.query.filter(Picture.name.like(timespic + "%")).first() if picture != None: return jsonify(code=400, message="The picture has been processed") # 路径 path = os.path.join(os.getcwd(), FLAGS.sample_dir, picname) print(path) # 名字 picture = Picture.query.filter_by(name=picname).first() picname = picname[0:picname.find('.')] + times + 'x_.' + picture.suffix # url url = picture.url url = url[0:url.rfind('/')] + '/' + picname[0:picname.find('_') + 1] # 放大图片 with tf.Session() as sess: srcnn = SRCNN(sess, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) srcnn.upscaling(picname, path, FLAGS, int(times)) # 保存数据库 action = 'Upscale_' + times + 'X' newpic = Picture(picname, url, action, picture.id) db.session.add(newpic) db.session.flush() db.session.commit() return jsonify(code=200, message="success upscaling", name=picname, id=newpic.id, url=url, action=action)
def main(_): pp.pprint(flags.FLAGS.__flags) if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) config = tf.ConfigProto(allow_soft_placement=True) with tf.device('/device:GPU:0'): with tf.Session(config=config) as sess: srcnn = SRCNN(sess, image_size=FLAGS.image_size, label_size=FLAGS.label_size, batch_size=FLAGS.batch_size, ci_dim=FLAGS.ci_dim, co_dim=FLAGS.co_dim, scale_factor=FLAGS.scale_factor, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir, model_carac=FLAGS.model_carac, train_dir=FLAGS.train_dir, test_dir=FLAGS.test_dir) srcnn.train(FLAGS)
os.makedirs('checkpoint_superresolution') checkpoint_path = "./checkpoint_superresolution/"+"trial"+str(trialNumber)+"checkpoint.pt" elif args.model == 'deep_residual_network_SRCNN': model = Net_Superresolution(withRedNet=False,withSRCNN=True) dataloaders = get_CUB200_loader() if not os.path.exists('checkpoint_superresolution'): os.makedirs('checkpoint_superresolution') checkpoint_path = "./checkpoint_superresolution/"+"trial"+str(trialNumber)+"checkpoint.pt" elif args.model == 'rednet': model = REDNet_model() if not os.path.exists('checkpoint_rednet'): os.makedirs('checkpoint_rednet') checkpoint_path = "./checkpoint_rednet/"+"trial"+str(trialNumber)+"checkpoint.pt" dataloaders = get_CUB200_loader(three_channel_bayer=True) elif args.model == 'SRCNN': SRCNN_model = SRCNN(num_channels=3) model = model_with_upsampling(SRCNN_model) if not os.path.exists('checkpoint_SRCNN'): os.makedirs('checkpoint_SRCNN') checkpoint_path = "./checkpoint_SRCNN/"+"trial"+str(trialNumber)+"checkpoint.pt" if args.not_cropped: dataloaders = get_CUB200_loader(three_channel_bayer=True,no_crop=True,batch_size=8) else: dataloaders = get_CUB200_loader(three_channel_bayer=True) elif args.model == 'VDSR': VDSR_model = VDSR_Net() model = model_with_upsampling(VDSR_model) if not os.path.exists('checkpoint_VDSR'): os.makedirs('checkpoint_VDSR') checkpoint_path = "./checkpoint_VDSR/"+"trial"+str(trialNumber)+"checkpoint.pt"
parser.add_argument('--input', type=str, required=True) parser.add_argument('--scale', type=float, default=2.) return vars(parser.parse_args()) if __name__ == '__main__': args = get_arguments() weight = args.get('weight') p = args.get('input') upscale = args.get('scale') dirpath = os.path.dirname(p) input_filename = os.path.basename(p) gen = SRCNN() gen.load_state_dict(torch.load(weight)) gen.eval() img = Image.open(p) new_size = [int(x * upscale) for x in img.size] converter = T.Compose([ T.Resize(size=new_size[::-1], interpolation=Image.BICUBIC), T.ToTensor() ]) with torch.no_grad(): x = converter(img) pred = gen(x[None, :, :, :])[0]
def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--epochs', type=int, default=55, help='Number of epoch [15000]') parser.add_argument('--batch_size', type=int, default=128, help='The size of batch images [128]') parser.add_argument('--image_size', type=int, default=33, help='The size of image to use [33]') parser.add_argument('--label_size', type=int, default=33, help='The size of label [33]') parser.add_argument( '--learning_rate', type=int, default=1e-4, help='The learning rate of gradient descent algorithm [1e-4]') parser.add_argument('--c_dim', type=int, default=3, help='Dimension of image color [3]') parser.add_argument( '--scale', type=int, default=3, help='The size of scale factor for preprocessing input image [3]') parser.add_argument('--stride', type=int, default=14, help='The size of stride to apply input image [14]') parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', help='Name of checkpoint directory [checkpoint]') parser.add_argument('--sample_dir', type=str, default='sample', help='Name of sample directory [sample]') parser.add_argument('--is_train', type=bool, default=False, help='True for training, False for testing [True]') args = parser.parse_args() if not os.path.exists(args.checkpoint_dir): os.makedirs(args.checkpoint_dir) if not os.path.exists(args.sample_dir): os.makedirs(args.sample_dir) srcnn = SRCNN(image_size=args.image_size, label_size=args.label_size, batch_size=args.batch_size, c_dim=args.c_dim) # Stochastic gradient descent optimizer. optimizer = tf.keras.optimizers.Adam(args.learning_rate) # Optimization process. def run_optimization(x, y): # Wrap computation inside a GradientTape for automatic differentiation. with tf.GradientTape() as g: # Forward pass. pred = srcnn(x, is_training=True) # Compute loss. loss = mse(pred, y) # Variables to update, i.e. trainable variables. trainable_variables = srcnn.trainable_variables # Compute gradients. gradients = g.gradient(loss, trainable_variables) # Update W and b following gradients. optimizer.apply_gradients(zip(gradients, trainable_variables)) def train(args): if args.is_train: input_setup(args) else: nx, ny = input_setup(args) counter = 0 start_time = time.time() if args.is_train: print("Training...") data_dir = os.path.join('./{}'.format(args.checkpoint_dir), "train.h5") train_data, train_label = read_data(data_dir) display_step = 5 for step in range(args.epochs): batch_idxs = len(train_data) // args.batch_size for idx in range(0, batch_idxs): batch_images = train_data[idx * args.batch_size:(idx + 1) * args.batch_size] batch_labels = train_label[idx * args.batch_size:(idx + 1) * args.batch_size] run_optimization(batch_images, batch_labels) if step % display_step == 0: pred = srcnn(batch_images) loss = mse(pred, batch_labels) #psnr_loss = psnr(batch_labels, pred) #acc = accuracy(pred, batch_y) #print("step: %i, loss: %f", "psnr_loss: %f" %(step, loss, psnr_loss)) #print("Step:'{0}', Loss:'{1}', PSNR: '{2}'".format(step, loss, psnr_loss)) print("step: %i, loss: %f" % (step, loss)) else: print("Testing...") data_dir = os.path.join('./{}'.format(args.checkpoint_dir), "test.h5") test_data, test_label = read_data(data_dir) result = srcnn(test_data) result = merge(result, [nx, ny]) result = result.squeeze() image_path = os.path.join(os.getcwd(), args.sample_dir) image_path = os.path.join(image_path, "test_image.png") print(result.shape) imsave(result, image_path) train(args)
args = parser.parse_args() image_size = args.Image_patch_size #image_patch_size stride = args.Stride scale = args.Scale learning_rate = args.lr batch_size = args.BATCH_SIZE epochs = args.Epochs is_training =args.is_training dirname_train = args.dirname_train dirname_test = args.dirname_test if is_training: X_train,Y_train = load_train(image_size = image_size,stride = stride,scale = scale,dirname =dirname_train) srcnn = SRCNN( image_size = image_size, learning_rate=learning_rate) optimizer = Adam(lr = learning_rate) srcnn.compile(optimizer=optimizer, loss='mean_squared_error') history = srcnn.fit(X_train, Y_train, batch_size=batch_size, epochs=epochs, verbose=2,validation_split = 0.1 ) draw_loss_plot(history = history) #Saving trained model_weights in the current workspace . #make a folder named srcnn in your current workspace you are working in.In that folder your weights will be stored. json_string = srcnn.to_json() open(os.path.join('./srcnn/','srcnn_model.json'),'w').write(json_string) srcnn.save_weights(os.path.join('./srcnn/','srcnn_weight.hdf5')) else: dir_list = os.listdir(dirname_test) for img in dir_list:
# torch.manual_seed(opt.seed) # if use_cuda: # torch.cuda.manual_seed(opt.seed) train_set = get_training_set(opt.upscale_factor) test_set = get_test_set(opt.upscale_factor) training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True) testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.test_batch_size, shuffle=False) srcnn = SRCNN() criterion = nn.MSELoss() if (use_cuda): torch.cuda.set_device(opt.gpuid) srcnn.cuda() criterion = criterion.cuda() optimizer = optim.SGD(srcnn.parameters(), lr=opt.lr) #optimizer = optim.Adam(srcnn.parameters(),lr=opt.lr) def train(epoch): epoch_loss = 0 for iteration, batch in enumerate(training_data_loader, 1): input, target = Variable(batch[0]), Variable(batch[1])
transform = transforms.Compose([transforms.ToTensor()]) train_dataset = SR_dataset( lr_path = lr_path, hr_path = hr_path, transform = transform, interpolation_mode=Config.interpolation_mode, interpolation_scale=Config.interpolation_scale ) train_loader = DataLoader( train_dataset, batch_size = Config.batch_size, shuffle =True ) model = SRCNN().to(DEVICE) optimizer =torch.optim.Adam(model.parameters(), lr = Config.lr) epochs = Config.epochs model.train() for epoch in range(epochs): print("{}/{} EPOCHS".format(epoch+1, epochs)) for x,y in tqdm(train_loader): x = x.to(DEVICE) y = y.to(DEVICE) pred = model(x) loss = torch.nn.functional.mse_loss(pred, y)
from model import SRCNN from dataset import DatasetFromFolder, DatasetFromFolderEval import argparse parser = argparse.ArgumentParser(description='predictionCNN Example') parser.add_argument('--cuda', action='store_true', default=False) parser.add_argument('--weight_path', type=str, default=None) parser.add_argument('--save_dir', type=str, default=None) opt = parser.parse_args() test_set = DatasetFromFolderEval(image_dir='./data/General-100/test', scale_factor=4) test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False) model = SRCNN() criterion = nn.MSELoss() if opt.cuda: model = model.cuda() criterion = criterion.cuda() model.load_state_dict( torch.load(opt.weight_path, map_location='cuda' if opt.cuda else 'cpu')) model.eval() total_loss, total_psnr = 0, 0 total_loss_b, total_psnr_b = 0, 0 with torch.no_grad(): for batch in test_loader: inputs, targets = batch[0], batch[1] if opt.cuda:
return nx,ny def merge(images, size): print(images.shape) h, w = images.shape[1], images.shape[2] #觉得下标应该是0,1 #h, w = images.shape[0], images.shape[1] img = np.zeros([h*size[0], w*size[1], 1]) print(img.shape) j = 0 k = 0 for i in range(13): while(j<24): img[(i*21):((i + 1) * 21), (j*21):((j + 1) * 21), :] = images[k, 0:21, 0:21, :] j += 1 k += 1 if(j == 24): j = 0 print(k) return img def make_data(data, label): savepath = os.path.join(os.getcwd(),'checkpoint\\test.h5') with h5py.File(savepath,'w')as hf: hf.create_dataset('data', data=data) hf.create_dataset('label', data=label) if __name__ == '__main__': with tf.Session() as sess: srcnn = SRCNN(sess) test(srcnn,sess)
def test(self, mode, inference): # images = low resolution, labels = high resolution sess = self.sess # for training a particular image(one image) test_label_list = sorted(glob.glob('./dataset/test/gray/*.*')) num_image = len(test_label_list) assert mode == 'SRCNN' or mode == 'VDSR' if mode == 'SRCNN': sr_model = SRCNN(channel_length=self.c_length, image=self.x) _, _, prediction = sr_model.build_model() elif mode == 'VDSR': sr_model = VDSR(channel_length=self.c_length, image=self.x) prediction, residual, _ = sr_model.build_model() with tf.name_scope("PSNR"): psnr = 10 * tf.log(255 * 255 * tf.reciprocal( tf.reduce_mean(tf.square(self.y - prediction)))) / tf.log( tf.constant(10, dtype='float32')) init = tf.global_variables_initializer() sess.run(init) saver = tf.train.Saver() saver.restore(sess, self.save_path) for j in range(2, 5): avg_psnr = 0 for i in range(num_image): test_image_list = sorted( glob.glob('./dataset/test/X{}/*.*'.format(j))) test_image = np.array(Image.open(test_image_list[i])) test_image = test_image[np.newaxis, :, :, np.newaxis] test_label = np.array(Image.open(test_label_list[i])) h = test_label.shape[0] w = test_label.shape[1] h -= h % j w -= w % j test_label = test_label[np.newaxis, 0:h, 0:w, np.newaxis] # print(test_image.shape, test_label.shape) final_psnr = sess.run(psnr, feed_dict={ self.x: test_image, self.y: test_label }) print('X{} : Test PSNR is '.format(j), final_psnr) avg_psnr += final_psnr if inference: pred = sess.run(prediction, feed_dict={ self.x: test_image, self.y: test_label }) pred = np.squeeze(pred).astype(dtype='uint8') pred_image = Image.fromarray(pred) filename = './restored_{0}/{3}/{1}_X{2}.png'.format( mode, i, j, self.date) pred_image.save(filename) if mode == 'VDSR': res = sess.run(residual, feed_dict={ self.x: test_image, self.y: test_label }) res = np.squeeze(res).astype(dtype='uint8') res_image = Image.fromarray(res) filename = './restored_{0}/{3}/{1}_X{2}_res.png'.format( mode, i, j, self.date) res_image.save(filename) print('X{} : Avg PSNR is '.format(j), avg_psnr / 5)
model = FSRCNN(num_channels=3, upscale_factor=4) if opt.model == "FSRCNN" and opt.upscale == 4: model = FSRCNN(num_channels=3, upscale_factor=4) if opt.model == "FALSR_A" or opt.model == "FALSR_B": if opt.upscale is not 2: raise ("ONLY SUPPORT 2X") else: if opt.model == "FALSR_A": model = FALSR_A() if opt.model == "FALSR_B": model = FALSR_B() if opt.model == "SRCNN" and opt.upscale == 4: model = SRCNN(num_channels=3, upscale_factor=4) if opt.model == "VDSR" and opt.upscale == 4: model = VDSR(num_channels=3, base_channels=3, num_residual=20) if opt.model == "ESPCN" and opt.upscale == 4: model = ESPCN(num_channels=3, feature=64, upscale_factor=4) if opt.criterion: if opt.criterion == "l1": criterion = nn.L1Loss() if opt.criterion == "l2": criterion = nn.MSELoss() if opt.criterion == "custom": pass
parser.add_argument('--cuda', action='store_true', help='whether to use cuda') args = parser.parse_args() if args.cuda and not torch.cuda.is_available(): raise Exception('No GPU found') device = torch.device('cuda' if args.cuda else 'cpu') print('Use device:', device) filenames = os.listdir(args.img_dir) image_filenames = [os.path.join(args.img_dir, x) for x in filenames \ if is_image_file(x)] image_filenames = sorted(image_filenames) model = SRCNN().to(device) if args.cuda: ckpt = torch.load(args.model) else: ckpt = torch.load(args.model, map_location='cpu') model.load_state_dict(ckpt['model']) res = {} for i, f in enumerate(image_filenames): # Read test image. img = Image.open(f).convert('RGB') width, height = img.size[0], img.size[1] # Crop test image so that it has size that can be downsampled by the upscale factor. pad_width = width % args.upscale_factor
self.checkpoint_dir = "checkpoint1" self.learning_rate = 1e-4 self.batch_size = 128 self.result_dir = 'result' self.test_img = '' # Do not change this. arg = this_config() print( "Hello TA! We are group 7. Thank you for your work for us. Hope you have a happy day!" ) with tf.Session() as sess: FLAGS = arg srcnn = SRCNN(sess, image_size=FLAGS.image_size, label_size=FLAGS.label_size, c_dim=FLAGS.c_dim) srcnn.train(FLAGS) # Testing files = glob.glob(os.path.join(os.getcwd(), 'train_set', 'LR0', '*.jpg')) test_files = random.sample(files, len(files) // 5) FLAGS.is_train = False count = 1 for f in test_files: FLAGS.test_img = f print('Saving ', count, '/', len(test_files), ': ', FLAGS.test_img, '\n') count += 1 srcnn.test(FLAGS)