Esempio n. 1
0
def evaluate_net():
    create_dir(args.result_png_path)
    print('Testing path is %s' % args.blur_src_path)
    blurred_src_file_list = sorted(glob.glob(args.blur_src_path + '/*.png'))
    gt_src_file_list = sorted(glob.glob(args.gt_src_path + '/*.png'))

    if args.gt_src_path:
        psnr = np.zeros(len(gt_src_file_list))
        ssim = np.zeros(len(gt_src_file_list))
        test_time = np.zeros(len(gt_src_file_list) * 96)

    # Build model
    input_channel, output_channel = 5, 3

    model = make_model(input_channel, output_channel, args)

    if torch.cuda.is_available():
        model_dict = torch.load(args.ckpt_dir_test +
                                '/model_%04d_dict.pth' % args.epoch_test)
        model.load_state_dict(model_dict)
        model = model.cuda()
        print('Finish loading the model of the %dth epoch' % args.epoch_test)
    else:
        print('There are not available cuda devices !')

    model.eval()

    #=================#
    for index in range(len(gt_src_file_list)):
        out_patch_list = []
        img_name = os.path.split(gt_src_file_list[index])[-1].split('.')[0]

        # read the image
        gt_img = cv2.imread(gt_src_file_list[index])
        gt_img = gt_img[..., ::-1]
        gt_img = np.asarray(gt_img / 255, np.float64)
        in_img = cv2.imread(blurred_src_file_list[index])
        in_img = in_img[..., ::-1]
        in_img = np.asarray(in_img / 255, np.float64)

        # add noise
        if args.sigma:
            noise = np.random.normal(loc=0,
                                     scale=args.sigma / 255.0,
                                     size=in_img.shape)
            in_img = in_img + noise
            in_img = np.clip(in_img, 0.0, 1.0)

        # compute field
        in_img_wz_fld = compute_fld_info(in_img)
        [h, w, c] = in_img_wz_fld.shape
        padded_in_img_wz_fld = np.pad(in_img_wz_fld,
                                      ((50, 50), (50, 50), (0, 0)), 'edge')
        # crop_patch
        patch_list = crop_patch(padded_in_img_wz_fld,
                                patch_size=500,
                                pad_size=100)
        # concat in and gt, gt->in
        print('process img: %s' % blurred_src_file_list[index])
        for i in range(len(patch_list)):
            in_patch = patch_list[i].copy()
            in_patch = transforms.functional.to_tensor(in_patch)
            in_patch = in_patch.unsqueeze_(0).float()
            if torch.cuda.is_available():
                in_patch = in_patch.cuda()

            torch.cuda.synchronize()
            start_time = time.time()
            with torch.no_grad():
                out_patch = model(in_patch)
            torch.cuda.synchronize()
            test_time[index * 96 + i] = time.time() - start_time

            rgb_patch = out_patch.cpu().detach().numpy().transpose(
                (0, 2, 3, 1))
            rgb_patch = np.clip(rgb_patch[0], 0, 1)
            out_patch_list.append(rgb_patch)

        rgb = sew_up_img(out_patch_list,
                         patch_size=500,
                         pad_size=100,
                         img_size=[3000, 4000])

        # compare psnr and ssim
        psnr[index] = compare_psnr(gt_img, rgb)
        ssim[index] = compare_ssim(gt_img, rgb, multichannel=True)
        # save image
        rgb = rgb[..., ::-1]
        cv2.imwrite(args.result_png_path + '/' + img_name + ".png",
                    np.uint8(rgb * 255))
        print('test image: %s saved!' % img_name)

    test_time_avr = 0
    #===========
    #print psnr,ssim
    for i in range(len(gt_src_file_list)):
        print('src_file: %s: ' %
              (os.path.split(gt_src_file_list[i])[-1].split('.')[0]))
        if args.gt_src_path:
            print('psnr: %f, ssim: %f, average time: %f' %
                  (psnr[i], ssim[i], test_time[i]))

        if i > 0:
            test_time_avr += test_time[i]

    test_time_avr = test_time_avr / (len(gt_src_file_list) - 1)
    print('average time: %f' % (test_time_avr))
    # save the psnr, ssim information
    result_txt_path = args.result_png_path + '/' + "test_result.txt"
    test_info_generator(gt_src_file_list, psnr, ssim, test_time, test_time_avr,
                        result_txt_path)
    return 0
Esempio n. 2
0
def train():

    # Load dataset
    if args.real:
        dataset = Dataset_h5_real(args.src_path, patch_size=args.patch_size, train=True)
        dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)
    else:
        dataset = Dataset_from_h5(args.src_path, args.sigma, args.gray,
                            transform=transforms.Compose(
                                         [transforms.RandomCrop((args.patch_size, args.patch_size)),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.RandomVerticalFlip(),
                                         transforms.Lambda(lambda img: RandomRot(img)),
                                         transforms.ToTensor()
                                         #transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
                                         ]),
                                          )
        dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)
    if args.val_path:
        dataset_val = Dataset_h5_real(src_path=args.val_path, patch_size=args.val_patch_size, gray=args.gray, train=False)
        dataloader_val = DataLoader(dataset=dataset_val, batch_size=1, shuffle=False, num_workers=0, drop_last=True)
    # Build model
    if args.gray:
        input_channel, output_channel = 1, 1
    else:
        input_channel, output_channel = 3, 3

    model = make_model(input_channel, output_channel, args)
    model.initialize_weights()

    if args.finetune:
        model_dict = torch.load(args.ckpt_dir+'model_%04d_dict.pth' % args.init_epoch)
        model.load_state_dict(model_dict)


    if args.t_loss == 'L2':
        criterion = torch.nn.MSELoss()
    elif args.t_loss == 'L1':
        criterion = torch.nn.L1Loss()

    if torch.cuda.is_available():
        print(torch.cuda.device_count())
        if torch.cuda.device_count() > 1:
            #model = torch.nn.DataParallel(model, device_ids=[0]).cuda()
            model = torch.nn.DataParallel(model).cuda()
            criterion = criterion.cuda()
        else:
            model = model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=milestone, gamma=0.1)
    #scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2)  # learning rates
    writer = SummaryWriter(args.log_dir)

    for epoch in range(args.init_epoch, args.n_epoch):

        loss_sum = 0
        step_lr_adjust(optimizer, epoch, init_lr=args.lr, step_size=args.milestone, gamma=args.gamma)
        print('Epoch {}, lr {}'.format(epoch+1, optimizer.param_groups[0]['lr']))
        start_time = time.time()
        for i, data in enumerate(dataloader):
            input, label = data
            if torch.cuda.is_available():
                input, label = input.cuda(), label.cuda()
            input, label = Variable(input), Variable(label)

            model.train()
            model.zero_grad()
            optimizer.zero_grad()

            output = model(input)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item()

            if (i % 100 == 0) and (i != 0) :
                loss_avg = loss_sum / 100
                loss_sum = 0.0
                print("Training: Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.8f} Time: {:4.4f}s".format(
                    epoch + 1, args.n_epoch, i + 1, len(dataloader), loss_avg, time.time()-start_time))
                start_time = time.time()
                # Record train loss
                writer.add_scalars('Loss_group', {'train_loss': loss_avg}, epoch)
                # Record learning rate
                #writer.add_scalar('learning rate', scheduler.get_lr()[0], epoch)
                writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], epoch)
        # save model
        if epoch % args.save_epoch == 0:
            if torch.cuda.device_count() > 1:
                torch.save(model.module.state_dict(), os.path.join(args.ckpt_dir, 'model_%04d_dict.pth' % (epoch+1)))
            else:
                torch.save(model.state_dict(), os.path.join(args.ckpt_dir, 'model_%04d_dict.pth' % (epoch+1)))

        # validation
        if args.val_path:
            if epoch % args.val_epoch == 0:
                psnr = 0
                loss_val = 0
                model.eval()
                for i, data in enumerate(dataloader_val):
                    input, label = data
                    if torch.cuda.is_available():
                        input, label = input.cuda(), label.cuda()
                    input, label = Variable(input), Variable(label)

                    test_out = model(input)
                    test_out.detach_()

                    # 计算loss
                    loss_val += criterion(test_out, label).item()
                    rgb_out = test_out.cpu().numpy().transpose((0,2,3,1))
                    clean = label.cpu().numpy().transpose((0,2,3,1))
                    for num in range(rgb_out.shape[0]):
                        denoised = np.clip(rgb_out[num], 0, 1)
                        psnr += compare_psnr(clean[num], denoised)
                img_nums = rgb_out.shape[0] * len(dataloader_val)
                #img_nums = batch_size * len(dataloader_val)
                psnr = psnr / img_nums
                loss_val = loss_val / len(dataloader_val)
                print('Validating: {:0>3} , loss: {:.8f}, PSNR: {:4.4f}'.format(img_nums, loss_val, psnr))
                #mpimg.imsave(ckpt_dir+"img/%04d_denoised.png" % epoch, rgb_out[0])
                writer.add_scalars('Loss_group', {'valid_loss': loss_val}, epoch)
                writer.add_scalar('valid_psnr', psnr, epoch)
                if args.save_val_img:
                    if args.gray:
                        mpimg.imsave(args.ckpt_dir+"img/%04d_denoised.png" % epoch, denoised[:,:,0])
                    else:
                        mpimg.imsave(args.ckpt_dir+"img/%04d_denoised.png" % epoch, denoised)
Esempio n. 3
0
def evaluate_net():

    noise_src_folder_list = []
    dst_png_path_list = []
    if args.gt_src_path:
        gt_src_folder_list = []

    for item in args.test_items:
        noise_tmp = sorted(glob.glob(args.noise_src_path + item + '/'))
        noise_src_folder_list.extend(noise_tmp)
        dst_png_path_list.append(args.result_png_path + item + '/')
        if args.gt_src_path:
            gt_src_folder_list.extend(
                sorted(glob.glob(args.gt_src_path + item + '/')))

    if args.gt_src_path:
        psnr = np.zeros(len(gt_src_folder_list))
        ssim = np.zeros(len(gt_src_folder_list))
    test_time = np.zeros(len(noise_src_folder_list))

    # Build model
    if args.gray:
        input_channel, output_channel = 1, 1
    else:
        input_channel, output_channel = 3, 3

    model = make_model(input_channel, output_channel, args)

    if torch.cuda.is_available():
        model_dict = torch.load(args.ckpt_dir_test +
                                'model_%04d_dict.pth' % args.epoch_test)
        model.load_state_dict(model_dict)
        model = model.cuda()
    else:
        print('There are not available cuda devices !')

    model.eval()

    #=================#
    for i in range(len(gt_src_folder_list)):
        in_files = glob.glob(noise_src_folder_list[i] + '*')
        in_files.sort()
        if args.gt_src_path:
            gt_files = glob.glob(gt_src_folder_list[i] + '*')
            gt_files.sort()
        create_dir(dst_png_path_list[i])

        for ind in range(len(in_files)):
            if args.gt_src_path:
                clean = imread(gt_files[ind]).astype(np.float32) / 255
                clean = clean[0:(clean.shape[0] // 8) * 8,
                              0:(clean.shape[1] // 8) * 8]

            img_name = os.path.split(in_files[ind])[-1].split('.')[0]
            noisy = imread(in_files[ind]).astype(np.float32) / 255
            noisy = noisy[0:(noisy.shape[0] // 8) * 8,
                          0:(noisy.shape[1] // 8) * 8]

            img_test = transforms.functional.to_tensor(noisy)
            img_test = img_test.unsqueeze_(0).float()
            if torch.cuda.is_available():
                img_test = img_test.cuda()

            torch.cuda.synchronize()
            start_time = time.time()
            with torch.no_grad():
                out_image = model(img_test)
            torch.cuda.synchronize()
            if ind > 0:
                test_time[i] += (time.time() - start_time)
            print("took: %4.4fs" % (time.time() - start_time))
            print("process folder:%s" % noise_src_folder_list[i])
            print("[*] save images")

            rgb = out_image.cpu().detach().numpy().transpose((0, 2, 3, 1))
            if noisy.ndim == 3:
                rgb = np.clip(rgb[0], 0, 1)
            elif noisy.ndim == 2:
                rgb = np.clip(rgb[0, :, :, 0], 0, 1)

            # save image
            imwrite(dst_png_path_list[i] + img_name + ".png",
                    np.uint8(rgb * 255))

            if args.gt_src_path:

                psnr[i] += compare_psnr(clean, rgb)

                if clean.ndim == 2:
                    ssim[i] += compare_ssim(clean, rgb)
                elif clean.ndim == 3:
                    ssim[i] += compare_ssim(clean, rgb, multichannel=True)

        test_time[i] = test_time[i] / (len(in_files) - 1)
        if args.gt_src_path:
            psnr[i] = psnr[i] / len(in_files)
            ssim[i] = ssim[i] / len(in_files)
        #===========

    #print psnr,ssim
    for i in range(len(gt_src_folder_list)):
        print('src_folder: %s: ' % (gt_src_folder_list[i]))
        if args.gt_src_path:
            print('psnr: %f, ssim: %f, average time: %f' %
                  (psnr[i], ssim[i], test_time[i]))
        else:
            print('average time: %f' % (test_time[i]))

    return 0
Esempio n. 4
0
def train():

    # Load dataset
    dataset = Dataset_from_h5(src_path=args.src_path,
                              recrop_patch_size=args.patch_size,
                              sigma=args.sigma,
                              train=True)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=8,
                            drop_last=True)
    dataset_val = Dataset_from_h5(src_path=args.val_path,
                                  recrop_patch_size=args.val_patch_size,
                                  sigma=args.sigma,
                                  train=False)
    dataloader_val = DataLoader(dataset=dataset_val,
                                batch_size=args.val_batch_size,
                                shuffle=False,
                                num_workers=8,
                                drop_last=True)
    print('Training path of {:s};\nValidation path of {:s};'.format(
        args.src_path, args.val_path))
    # Build model
    input_channel, output_channel = 5, 3
    model = make_model(input_channel, output_channel, args)
    model.initialize_weights()

    if args.finetune:
        model_dict = torch.load(args.ckpt_dir +
                                'model_%04d_dict.pth' % args.init_epoch)
        model.load_state_dict(model_dict)

    if args.t_loss == 'L2':
        criterion = torch.nn.MSELoss()
        print('Training with L2Loss!')
    elif args.t_loss == 'L1':
        criterion = torch.nn.L1Loss()
        print('Training with L1Loss!')
    elif args.t_loss == 'L2_wz_TV':
        criterion = L2_wz_TV(args)
        print('Training with L2 and TV Loss!')
    elif args.t_loss == 'L2_wz_Perceptual':
        criterion = L2_wz_Perceptual(args)
        print('Training with L2 and Perceptual Loss!')

    if torch.cuda.is_available():
        print('Use {} GPU, which order is {:s}th'.format(
            torch.cuda.device_count(), args.gpu))
        if torch.cuda.device_count() > 1:
            #model = torch.nn.DataParallel(model, device_ids=[0]).cuda()
            model = torch.nn.DataParallel(model).cuda()
        else:
            model = model.cuda()

        criterion = criterion.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    writer = SummaryWriter(args.log_dir)

    for epoch in range(args.init_epoch, args.n_epoch):
        loss_sum = 0
        step_lr_adjust(optimizer,
                       epoch,
                       init_lr=args.lr,
                       step_size=args.milestone,
                       gamma=args.gamma)
        print('Epoch {}, lr {}'.format(epoch + 1,
                                       optimizer.param_groups[0]['lr']))
        start_time = time.time()
        for i, data in enumerate(dataloader):
            input, label = data
            if torch.cuda.is_available():
                input, label = input.cuda(), label.cuda()
            input, label = Variable(input), Variable(label)

            model.train()
            model.zero_grad()
            optimizer.zero_grad()

            output = model(input)

            # calculate loss
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item()

            if (i % 100 == 0) and (i != 0):
                loss_avg = loss_sum / 100
                loss_sum = 0.0
                print(
                    "Training: Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.8f} Time: {:4.4f}s"
                    .format(epoch + 1, args.n_epoch, i + 1, len(dataloader),
                            loss_avg,
                            time.time() - start_time))
                start_time = time.time()
                # Record train loss
                writer.add_scalars('Loss_group', {'train_loss': loss_avg},
                                   epoch)
                # Record learning rate
                writer.add_scalar('learning rate',
                                  optimizer.param_groups[0]['lr'], epoch)
        # save model
        if epoch % args.save_epoch == 0:
            if torch.cuda.device_count() > 1:
                torch.save(
                    model.module.state_dict(),
                    os.path.join(args.ckpt_dir,
                                 'model_%04d_dict.pth' % (epoch + 1)))
            else:
                torch.save(
                    model.state_dict(),
                    os.path.join(args.ckpt_dir,
                                 'model_%04d_dict.pth' % (epoch + 1)))

        # validation
        if epoch % args.val_epoch == 0:
            psnr = 0
            loss_val = 0
            model.eval()
            for i, data in enumerate(dataloader_val):
                input, label = data
                if torch.cuda.is_available():
                    input, label = input.cuda(), label.cuda()
                input, label = Variable(input), Variable(label)

                test_out = model(input)
                test_out.detach_()

                # compute loss
                loss_val += criterion(test_out, label).item()
                rgb_out = test_out.cpu().numpy().transpose((0, 2, 3, 1))
                clean = label.cpu().numpy().transpose((0, 2, 3, 1))
                for num in range(rgb_out.shape[0]):
                    deblurred = np.clip(rgb_out[num], 0, 1)
                    psnr += compare_psnr(clean[num], deblurred)
            img_nums = rgb_out.shape[0] * len(dataloader_val)
            psnr = psnr / img_nums
            loss_val = loss_val / len(dataloader_val)
            print('Validating: {:0>3} , loss: {:.8f}, PSNR: {:4.4f}'.format(
                img_nums, loss_val, psnr))
            writer.add_scalars('Loss_group', {'valid_loss': loss_val}, epoch)
            writer.add_scalar('valid_psnr', psnr, epoch)
            if args.save_val_img:
                cv2.imwrite(args.ckpt_dir + "img/%04d_deblurred.png" % epoch,
                            deblurred[..., ::-1])

    writer.close()