Example #1
0
def main(_):
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)
    if not os.path.exists(FLAGS.summaries_dir):
        os.makedirs(FLAGS.summaries_dir)
    sess = tf.InteractiveSession()

    srcnn = SRCNN(sess,
                  imageSize=FLAGS.image_size,
                  labelSize=FLAGS.label_size,
                  batchSize=FLAGS.batch_size,
                  channel=FLAGS.input_channel,
                  checkpoint_dir=FLAGS.checkpoint_dir,
                  sample_dir=FLAGS.sample_dir)

    srcnn.mode()
Example #2
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        srcnn = SRCNN(sess,
                      image_size=FLAGS.image_size,
                      label_size=FLAGS.label_size,
                      batch_size=FLAGS.batch_size,
                      c_dim=FLAGS.c_dim,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)

        srcnn.train(FLAGS)
Example #3
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        srcnn = SRCNN(sess,
                      image_size=FLAGS.image_size,
                      label_size=FLAGS.label_size,
                      batch_size=FLAGS.batch_size,
                      c_dim=FLAGS.c_dim,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)

        srcnn.train(FLAGS)
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.device('/gpu:0'):
        with tf.Session(config=config) as sess:
            srcnn = SRCNN(sess,
                          image_size=FLAGS.image_size,
                          label_size=FLAGS.label_size,
                          batch_size=FLAGS.batch_size,
                          c_dim=FLAGS.c_dim,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir,
                          config=FLAGS)
            srcnn.train(FLAGS)
Example #5
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    srcnn = SRCNN(image_size=FLAGS.image_size,
                  label_size=FLAGS.label_size,
                  batch_size=FLAGS.batch_size,
                  c_dim=FLAGS.c_dim,
                  checkpoint_dir=FLAGS.checkpoint_dir,
                  sample_dir=FLAGS.sample_dir,
                  FLAGS=FLAGS)
    srcnn.call()
    if FLAGS.is_train == True:
        srcnn.train()
    else:
        srcnn.inference()
Example #6
0
def main(args):
    srcnn = SRCNN(image_size=args.image_size,
                  c_dim=args.c_dim,
                  is_training=False)
    X_pre_test, X_test, Y_test = load_test(scale=args.scale)
    predicted_list = []
    for img in X_test:
        predicted = srcnn.process(img.reshape(1, img.shape[0], img.shape[1],
                                              1))
        predicted_list.append(
            predicted.reshape(predicted.shape[1], predicted.shape[2], 1))
    n_img = len(predicted_list)
    dirname = './result'
    for i in range(n_img):
        imgname = 'image{:02}'.format(i)
        cv2.imwrite(os.path.join(dirname, imgname + '_original.bmp'),
                    X_pre_test[i])
        cv2.imwrite(os.path.join(dirname, imgname + '_input.bmp'), X_test[i])
        cv2.imwrite(os.path.join(dirname, imgname + '_answer.bmp'), Y_test[i])
        cv2.imwrite(os.path.join(dirname, imgname + '_predicted.bmp'),
                    predicted_list[i])
Example #7
0
def test():
    print("process the image to h5file.....")
    test_dir = flags.test_dir
    test_h5_dir = flags.test_h5_dir
    stride = flags.test_stride
    if not os.path.exists(test_h5_dir):
        os.makedirs(test_h5_dir)

    test_set5 = os.path.join(test_dir, 'Set5')
    test_set14 = os.path.join(test_dir, 'Set14')
    path_set5 = os.path.join(test_h5_dir, 'Set5')
    path_set14 = os.path.join(test_h5_dir, 'Set14')
    data_helper.gen_input_image(test_set5, path_set5, stride)
    data_helper.gen_input_image(test_set14, path_set14, stride)

    print("initialize the model......")
    model_dir = flags.model_dir
    model = SRCNN(flags)
    model.build_graph()
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(model_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(model.sess, ckpt.model_checkpoint_path)
    else:
        print("model info didn't exist!")
        raise ValueError

    print("test in Set5......")
    test_h5_path = os.path.join(path_set5, "data.h5")
    data_set5, label_set5 = data_helper.load_data(test_h5_path)
    accu = model.test(data_set5, label_set5)
    print("the accuracy in Set5 is %.5f", accu)

    print("test in Set14......")
    test_h5_path = os.path.join(path_set14, "data.h5")
    data_set14, label_set14 = data_helper.load_data(test_h5_path)
    accu2 = model.test(data_set14, label_set14)
    print("the accuracy in Set14 is %.5f", accu2)
Example #8
0
def train(training_data, dev_data, args):
    training_gen = data.DataLoader(training_data, batch_size=2)
    dev_gen = data.DataLoader(dev_data, batch_size=2)
    device = torch.device('cuda' if cuda.is_available() else 'cpu')
    print('Initializing model')
    model = SRCNN()
    loss = RMSE()
    if cuda.device_count() > 1:
        print('Using %d CUDA devices' % cuda.device_count())
        model = nn.DataParallel(
            model, device_ids=[i for i in range(cuda.device_count())])
    model.to(device)
    loss.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    def _train(data, opt=True):
        total = 0
        for y, x in data:
            y, x = y.to(device), x.to(device)
            pred_y = model(x)
            l = loss(pred_y, y)
            total += l.item()
            if opt:
                optimizer.zero_grad()
                l.backward()
                optimizer.step()
        cuda.synchronize()
        return total

    print('Training')
    for ep in range(args.ep):
        train_loss = _train(training_gen)
        dev_loss = _train(dev_gen, opt=False)
        print_flush('Epoch %d: Train %.4f Dev %.4f' %
                    (ep, train_loss, dev_loss))
        if ep % 50 == 0:
            save_model(model, args.o)
    return model
Example #9
0
def main(_):
    """3.print configurations"""
    print('tf version:', tf.__version__)
    print('tf setup:')
    for k, v in FLAGS.flag_values_dict().items():
        print(k, v)
    FLAGS.TB_dir += '_' + str(FLAGS.c_dim)
    """4.check/create folders"""
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.TB_dir):
        os.makedirs(FLAGS.TB_dir)
    """5.begin tf session"""
    with tf.Session() as sess:
        """6.init srcnn model"""
        srcnn = SRCNN(sess, FLAGS)
        """7.start to train/test"""
        if (FLAGS.is_train):
            srcnn.train()
        elif FLAGS.patch_test:
            srcnn.test()
        else:
            srcnn.test_whole_img()
Example #10
0
def main():
    dataloaders = myDataloader()
    test_loader = dataloaders.getTestLoader(batch_size)

    model = SRCNN().cuda()
    model.load_state_dict(
        torch.load("./result/train/" + str(epoch) + "srcnnParms.pth"))
    model.eval()
    with torch.no_grad():
        for i, (pic, blurPic, index) in enumerate(test_loader):
            pic = pic.cuda()
            blurPic = blurPic.cuda()
            out = model(blurPic)
            res = torch.cosine_similarity(pic, out, dim=1)
            res = res[0]
            minValue = torch.min(res)
            meanValue = res.mean()
            output = 1 - ((res >= meanValue) + 0)
            plt.figure()
            plt.title(index)
            plt.imshow(output.cpu(), cmap="gray")
            plt.show()
            break
Example #11
0
def upscaling():
    data = json.loads(request.get_data(as_text=True))
    times = data['times']
    picname = data['picname']
    # 判断是否重复
    timespic = picname[0:picname.find('.')] + times + "x_"
    picture = Picture.query.filter(Picture.name.like(timespic + "%")).first()
    if picture != None:
        return jsonify(code=400, message="The picture has been processed")
    # 路径
    path = os.path.join(os.getcwd(), FLAGS.sample_dir, picname)
    print(path)
    # 名字
    picture = Picture.query.filter_by(name=picname).first()
    picname = picname[0:picname.find('.')] + times + 'x_.' + picture.suffix
    # url
    url = picture.url
    url = url[0:url.rfind('/')] + '/' + picname[0:picname.find('_') + 1]
    # 放大图片
    with tf.Session() as sess:
        srcnn = SRCNN(sess,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)
        srcnn.upscaling(picname, path, FLAGS, int(times))
    # 保存数据库
    action = 'Upscale_' + times + 'X'
    newpic = Picture(picname, url, action, picture.id)
    db.session.add(newpic)
    db.session.flush()
    db.session.commit()

    return jsonify(code=200,
                   message="success upscaling",
                   name=picname,
                   id=newpic.id,
                   url=url,
                   action=action)
Example #12
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)
    config = tf.ConfigProto(allow_soft_placement=True)

    with tf.device('/device:GPU:0'):
        with tf.Session(config=config) as sess:
            srcnn = SRCNN(sess,
                          image_size=FLAGS.image_size,
                          label_size=FLAGS.label_size,
                          batch_size=FLAGS.batch_size,
                          ci_dim=FLAGS.ci_dim,
                          co_dim=FLAGS.co_dim,
                          scale_factor=FLAGS.scale_factor,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir,
                          model_carac=FLAGS.model_carac,
                          train_dir=FLAGS.train_dir,
                          test_dir=FLAGS.test_dir)
            srcnn.train(FLAGS)
Example #13
0
     os.makedirs('checkpoint_superresolution')
   checkpoint_path = "./checkpoint_superresolution/"+"trial"+str(trialNumber)+"checkpoint.pt"
 elif args.model == 'deep_residual_network_SRCNN':
   model = Net_Superresolution(withRedNet=False,withSRCNN=True)
   dataloaders = get_CUB200_loader()
   if not os.path.exists('checkpoint_superresolution'):
     os.makedirs('checkpoint_superresolution')
   checkpoint_path = "./checkpoint_superresolution/"+"trial"+str(trialNumber)+"checkpoint.pt"
 elif args.model == 'rednet':
   model = REDNet_model()
   if not os.path.exists('checkpoint_rednet'):
     os.makedirs('checkpoint_rednet')
   checkpoint_path = "./checkpoint_rednet/"+"trial"+str(trialNumber)+"checkpoint.pt"
   dataloaders = get_CUB200_loader(three_channel_bayer=True)
 elif args.model == 'SRCNN':
   SRCNN_model = SRCNN(num_channels=3)
   model = model_with_upsampling(SRCNN_model)
   if not os.path.exists('checkpoint_SRCNN'):
     os.makedirs('checkpoint_SRCNN')
   checkpoint_path = "./checkpoint_SRCNN/"+"trial"+str(trialNumber)+"checkpoint.pt"
   if args.not_cropped:
     dataloaders = get_CUB200_loader(three_channel_bayer=True,no_crop=True,batch_size=8)
   else:
     dataloaders = get_CUB200_loader(three_channel_bayer=True)
 
 elif args.model == 'VDSR':
   VDSR_model = VDSR_Net()
   model = model_with_upsampling(VDSR_model)
   if not os.path.exists('checkpoint_VDSR'):
     os.makedirs('checkpoint_VDSR')
   checkpoint_path = "./checkpoint_VDSR/"+"trial"+str(trialNumber)+"checkpoint.pt"
Example #14
0
    parser.add_argument('--input', type=str, required=True)
    parser.add_argument('--scale', type=float, default=2.)
    return vars(parser.parse_args())


if __name__ == '__main__':
    args = get_arguments()

    weight = args.get('weight')
    p = args.get('input')
    upscale = args.get('scale')

    dirpath = os.path.dirname(p)
    input_filename = os.path.basename(p)

    gen = SRCNN()
    gen.load_state_dict(torch.load(weight))
    gen.eval()

    img = Image.open(p)
    new_size = [int(x * upscale) for x in img.size]

    converter = T.Compose([
        T.Resize(size=new_size[::-1], interpolation=Image.BICUBIC),
        T.ToTensor()
    ])

    with torch.no_grad():
        x = converter(img)
        pred = gen(x[None, :, :, :])[0]
Example #15
0
def main():

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--epochs',
                        type=int,
                        default=55,
                        help='Number of epoch [15000]')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        help='The size of batch images [128]')
    parser.add_argument('--image_size',
                        type=int,
                        default=33,
                        help='The size of image to use [33]')
    parser.add_argument('--label_size',
                        type=int,
                        default=33,
                        help='The size of label [33]')
    parser.add_argument(
        '--learning_rate',
        type=int,
        default=1e-4,
        help='The learning rate of gradient descent algorithm [1e-4]')
    parser.add_argument('--c_dim',
                        type=int,
                        default=3,
                        help='Dimension of image color [3]')
    parser.add_argument(
        '--scale',
        type=int,
        default=3,
        help='The size of scale factor for preprocessing input image [3]')
    parser.add_argument('--stride',
                        type=int,
                        default=14,
                        help='The size of stride to apply input image [14]')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='checkpoint',
                        help='Name of checkpoint directory [checkpoint]')
    parser.add_argument('--sample_dir',
                        type=str,
                        default='sample',
                        help='Name of sample directory [sample]')
    parser.add_argument('--is_train',
                        type=bool,
                        default=False,
                        help='True for training, False for testing [True]')

    args = parser.parse_args()

    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    if not os.path.exists(args.sample_dir):
        os.makedirs(args.sample_dir)

    srcnn = SRCNN(image_size=args.image_size,
                  label_size=args.label_size,
                  batch_size=args.batch_size,
                  c_dim=args.c_dim)

    # Stochastic gradient descent optimizer.
    optimizer = tf.keras.optimizers.Adam(args.learning_rate)

    # Optimization process.
    def run_optimization(x, y):
        # Wrap computation inside a GradientTape for automatic differentiation.
        with tf.GradientTape() as g:
            # Forward pass.
            pred = srcnn(x, is_training=True)
            # Compute loss.
            loss = mse(pred, y)
            # Variables to update, i.e. trainable variables.
            trainable_variables = srcnn.trainable_variables
            # Compute gradients.
            gradients = g.gradient(loss, trainable_variables)
            # Update W and b following gradients.
            optimizer.apply_gradients(zip(gradients, trainable_variables))

    def train(args):

        if args.is_train:
            input_setup(args)
        else:
            nx, ny = input_setup(args)

        counter = 0
        start_time = time.time()

        if args.is_train:
            print("Training...")
            data_dir = os.path.join('./{}'.format(args.checkpoint_dir),
                                    "train.h5")
            train_data, train_label = read_data(data_dir)

            display_step = 5
            for step in range(args.epochs):
                batch_idxs = len(train_data) // args.batch_size

                for idx in range(0, batch_idxs):

                    batch_images = train_data[idx * args.batch_size:(idx + 1) *
                                              args.batch_size]
                    batch_labels = train_label[idx *
                                               args.batch_size:(idx + 1) *
                                               args.batch_size]
                    run_optimization(batch_images, batch_labels)

                    if step % display_step == 0:
                        pred = srcnn(batch_images)
                        loss = mse(pred, batch_labels)
                        #psnr_loss = psnr(batch_labels, pred)
                        #acc = accuracy(pred, batch_y)

                        #print("step: %i, loss: %f", "psnr_loss: %f" %(step, loss, psnr_loss))
                        #print("Step:'{0}', Loss:'{1}', PSNR: '{2}'".format(step, loss, psnr_loss))

                        print("step: %i, loss: %f" % (step, loss))

        else:
            print("Testing...")
            data_dir = os.path.join('./{}'.format(args.checkpoint_dir),
                                    "test.h5")
            test_data, test_label = read_data(data_dir)

            result = srcnn(test_data)
            result = merge(result, [nx, ny])
            result = result.squeeze()

            image_path = os.path.join(os.getcwd(), args.sample_dir)
            image_path = os.path.join(image_path, "test_image.png")
            print(result.shape)
            imsave(result, image_path)

    train(args)
Example #16
0
args = parser.parse_args()

image_size = args.Image_patch_size      #image_patch_size
stride = args.Stride
scale = args.Scale
learning_rate = args.lr
batch_size = args.BATCH_SIZE
epochs = args.Epochs
is_training =args.is_training
dirname_train = args.dirname_train
dirname_test = args.dirname_test

if is_training:
   
    X_train,Y_train = load_train(image_size = image_size,stride = stride,scale = scale,dirname =dirname_train)
    srcnn = SRCNN(  image_size = image_size, learning_rate=learning_rate)
    optimizer = Adam(lr = learning_rate)
    srcnn.compile(optimizer=optimizer, loss='mean_squared_error')
    history = srcnn.fit(X_train, Y_train, batch_size=batch_size, epochs=epochs, verbose=2,validation_split = 0.1 )
    draw_loss_plot(history = history)

    #Saving trained model_weights in the current workspace .
    #make a folder named srcnn in your current workspace you are working in.In that folder your weights will be stored.
    json_string = srcnn.to_json()
    open(os.path.join('./srcnn/','srcnn_model.json'),'w').write(json_string)
    srcnn.save_weights(os.path.join('./srcnn/','srcnn_weight.hdf5'))

else:

    dir_list = os.listdir(dirname_test)
    for img in dir_list:
Example #17
0
# torch.manual_seed(opt.seed)
# if use_cuda:
#     torch.cuda.manual_seed(opt.seed)

train_set = get_training_set(opt.upscale_factor)
test_set = get_test_set(opt.upscale_factor)
training_data_loader = DataLoader(dataset=train_set,
                                  num_workers=opt.threads,
                                  batch_size=opt.batch_size,
                                  shuffle=True)
testing_data_loader = DataLoader(dataset=test_set,
                                 num_workers=opt.threads,
                                 batch_size=opt.test_batch_size,
                                 shuffle=False)

srcnn = SRCNN()
criterion = nn.MSELoss()

if (use_cuda):
    torch.cuda.set_device(opt.gpuid)
    srcnn.cuda()
    criterion = criterion.cuda()

optimizer = optim.SGD(srcnn.parameters(), lr=opt.lr)
#optimizer = optim.Adam(srcnn.parameters(),lr=opt.lr)


def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        input, target = Variable(batch[0]), Variable(batch[1])
Example #18
0
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = SR_dataset(
  lr_path = lr_path,
  hr_path = hr_path,
  transform = transform,
  interpolation_mode=Config.interpolation_mode,
  interpolation_scale=Config.interpolation_scale
)

train_loader = DataLoader(
  train_dataset,
  batch_size = Config.batch_size,
  shuffle =True
)

model = SRCNN().to(DEVICE)
optimizer =torch.optim.Adam(model.parameters(), lr = Config.lr)


epochs = Config.epochs
model.train()
for epoch in range(epochs):
  print("{}/{} EPOCHS".format(epoch+1, epochs))
  for x,y in tqdm(train_loader):
    x = x.to(DEVICE)

    y = y.to(DEVICE)
    pred = model(x)

    loss = torch.nn.functional.mse_loss(pred, y)
Example #19
0
from model import SRCNN
from dataset import DatasetFromFolder, DatasetFromFolderEval

import argparse
parser = argparse.ArgumentParser(description='predictionCNN Example')
parser.add_argument('--cuda', action='store_true', default=False)
parser.add_argument('--weight_path', type=str, default=None)
parser.add_argument('--save_dir', type=str, default=None)
opt = parser.parse_args()

test_set = DatasetFromFolderEval(image_dir='./data/General-100/test',
                                 scale_factor=4)
test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)

model = SRCNN()
criterion = nn.MSELoss()
if opt.cuda:
    model = model.cuda()
    criterion = criterion.cuda()

model.load_state_dict(
    torch.load(opt.weight_path, map_location='cuda' if opt.cuda else 'cpu'))

model.eval()
total_loss, total_psnr = 0, 0
total_loss_b, total_psnr_b = 0, 0
with torch.no_grad():
    for batch in test_loader:
        inputs, targets = batch[0], batch[1]
        if opt.cuda:
Example #20
0
    return nx,ny

def merge(images, size):
  print(images.shape)
  h, w = images.shape[1], images.shape[2] #觉得下标应该是0,1
  #h, w = images.shape[0], images.shape[1]
  img = np.zeros([h*size[0], w*size[1], 1])
  print(img.shape)
  j = 0
  k = 0
  for i in range(13):
        while(j<24):
            img[(i*21):((i + 1) * 21), (j*21):((j + 1) * 21), :] = images[k, 0:21, 0:21, :]
            j += 1
            k += 1
        if(j == 24):
            j = 0
  print(k)
  return img

def make_data(data, label):
    savepath = os.path.join(os.getcwd(),'checkpoint\\test.h5')
    with h5py.File(savepath,'w')as hf:
        hf.create_dataset('data', data=data)
        hf.create_dataset('label', data=label)

if __name__ == '__main__':
    with tf.Session() as sess:
        srcnn = SRCNN(sess)
        test(srcnn,sess)
Example #21
0
    def test(self, mode, inference):
        # images = low resolution, labels = high resolution
        sess = self.sess

        # for training a particular image(one image)
        test_label_list = sorted(glob.glob('./dataset/test/gray/*.*'))

        num_image = len(test_label_list)

        assert mode == 'SRCNN' or mode == 'VDSR'
        if mode == 'SRCNN':
            sr_model = SRCNN(channel_length=self.c_length, image=self.x)
            _, _, prediction = sr_model.build_model()
        elif mode == 'VDSR':
            sr_model = VDSR(channel_length=self.c_length, image=self.x)
            prediction, residual, _ = sr_model.build_model()

        with tf.name_scope("PSNR"):
            psnr = 10 * tf.log(255 * 255 * tf.reciprocal(
                tf.reduce_mean(tf.square(self.y - prediction)))) / tf.log(
                    tf.constant(10, dtype='float32'))

        init = tf.global_variables_initializer()
        sess.run(init)

        saver = tf.train.Saver()

        saver.restore(sess, self.save_path)

        for j in range(2, 5):
            avg_psnr = 0
            for i in range(num_image):
                test_image_list = sorted(
                    glob.glob('./dataset/test/X{}/*.*'.format(j)))
                test_image = np.array(Image.open(test_image_list[i]))
                test_image = test_image[np.newaxis, :, :, np.newaxis]
                test_label = np.array(Image.open(test_label_list[i]))
                h = test_label.shape[0]
                w = test_label.shape[1]
                h -= h % j
                w -= w % j
                test_label = test_label[np.newaxis, 0:h, 0:w, np.newaxis]
                # print(test_image.shape, test_label.shape)

                final_psnr = sess.run(psnr,
                                      feed_dict={
                                          self.x: test_image,
                                          self.y: test_label
                                      })

                print('X{} : Test PSNR is '.format(j), final_psnr)
                avg_psnr += final_psnr

                if inference:
                    pred = sess.run(prediction,
                                    feed_dict={
                                        self.x: test_image,
                                        self.y: test_label
                                    })
                    pred = np.squeeze(pred).astype(dtype='uint8')
                    pred_image = Image.fromarray(pred)
                    filename = './restored_{0}/{3}/{1}_X{2}.png'.format(
                        mode, i, j, self.date)
                    pred_image.save(filename)
                    if mode == 'VDSR':
                        res = sess.run(residual,
                                       feed_dict={
                                           self.x: test_image,
                                           self.y: test_label
                                       })
                        res = np.squeeze(res).astype(dtype='uint8')
                        res_image = Image.fromarray(res)
                        filename = './restored_{0}/{3}/{1}_X{2}_res.png'.format(
                            mode, i, j, self.date)
                        res_image.save(filename)

            print('X{} : Avg PSNR is '.format(j), avg_psnr / 5)
Example #22
0
        model = FSRCNN(num_channels=3, upscale_factor=4)

    if opt.model == "FSRCNN" and opt.upscale == 4:
        model = FSRCNN(num_channels=3, upscale_factor=4)

    if opt.model == "FALSR_A" or opt.model == "FALSR_B":
        if opt.upscale is not 2:
            raise ("ONLY SUPPORT 2X")
        else:
            if opt.model == "FALSR_A":
                model = FALSR_A()
            if opt.model == "FALSR_B":
                model = FALSR_B()

    if opt.model == "SRCNN" and opt.upscale == 4:
        model = SRCNN(num_channels=3, upscale_factor=4)

    if opt.model == "VDSR" and opt.upscale == 4:
        model = VDSR(num_channels=3, base_channels=3, num_residual=20)

    if opt.model == "ESPCN" and opt.upscale == 4:
        model = ESPCN(num_channels=3, feature=64, upscale_factor=4)

if opt.criterion:
    if opt.criterion == "l1":
        criterion = nn.L1Loss()
    if opt.criterion == "l2":
        criterion = nn.MSELoss()
    if opt.criterion == "custom":
        pass
    parser.add_argument('--cuda',
                        action='store_true',
                        help='whether to use cuda')
    args = parser.parse_args()

    if args.cuda and not torch.cuda.is_available():
        raise Exception('No GPU found')
    device = torch.device('cuda' if args.cuda else 'cpu')
    print('Use device:', device)

    filenames = os.listdir(args.img_dir)
    image_filenames = [os.path.join(args.img_dir, x) for x in filenames \
                       if is_image_file(x)]
    image_filenames = sorted(image_filenames)

    model = SRCNN().to(device)
    if args.cuda:
        ckpt = torch.load(args.model)
    else:
        ckpt = torch.load(args.model, map_location='cpu')
    model.load_state_dict(ckpt['model'])

    res = {}

    for i, f in enumerate(image_filenames):
        # Read test image.
        img = Image.open(f).convert('RGB')
        width, height = img.size[0], img.size[1]

        # Crop test image so that it has size that can be downsampled by the upscale factor.
        pad_width = width % args.upscale_factor
Example #24
0
        self.checkpoint_dir = "checkpoint1"
        self.learning_rate = 1e-4
        self.batch_size = 128
        self.result_dir = 'result'
        self.test_img = ''  # Do not change this.


arg = this_config()
print(
    "Hello TA!  We are group 7. Thank you for your work for us. Hope you have a happy day!"
)

with tf.Session() as sess:
    FLAGS = arg
    srcnn = SRCNN(sess,
                  image_size=FLAGS.image_size,
                  label_size=FLAGS.label_size,
                  c_dim=FLAGS.c_dim)
    srcnn.train(FLAGS)

    # Testing
    files = glob.glob(os.path.join(os.getcwd(), 'train_set', 'LR0', '*.jpg'))
    test_files = random.sample(files, len(files) // 5)

    FLAGS.is_train = False
    count = 1
    for f in test_files:
        FLAGS.test_img = f
        print('Saving ', count, '/', len(test_files), ': ', FLAGS.test_img,
              '\n')
        count += 1
        srcnn.test(FLAGS)