Exemple #1
0
    def __init__(self,
                 data_list,
                 patch_size=64,
                 scale=2,
                 denoise=False,
                 max_noise=0.0748,
                 min_noise=0.0,
                 downsampler='avg'):
        super(LoadRawSR, self).__init__()
        self.scale = scale
        self.patch_size = patch_size
        self.data_lists = []
        self.denoise = denoise
        self.max_noise = max_noise
        self.min_noise = min_noise
        self.downsampler = downsampler
        self.raw_stack = DownsamplingShuffle(2)

        # read image list from txt
        fin = open(data_list)
        lines = fin.readlines()
        for line in lines:
            line = line.strip().split()
            self.data_lists.append(line[0])
        fin.close()
Exemple #2
0
def main():
    ##############################################################################
    # args parse
    parser = argparse.ArgumentParser(
        description='PyTorch implementation of demosaicking')
    parsers = TestArgs()
    args = parsers.initialize(parser)
    if args.show_info:
        parsers.print_args()

    ##############################################################################
    # load model architecture
    print('===> Loading the network ...')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    module = importlib.import_module("model.{}".format(args.model))
    model = module.NET(args).to(device)

    if args.show_info:
        print(model)
        print_model_parm_nums(model)

    ##############################################################################
    # load pre-trained
    if os.path.isfile(args.pretrained_model):
        print("=====> loading checkpoint '{}".format(args.pretrained_model))
        checkpoint = torch.load(args.pretrained_model)
        best_psnr = checkpoint['best_psnr']
        model.load_state_dict(checkpoint['state_dict'])
        print(
            "The pretrained_model is at checkpoint {}k, and it's best loss is {}."
            .format(checkpoint['iter'] / 1000, best_psnr))
    else:
        print("=====> no checkpoint found at '{}'".format(
            args.pretrained_model))

    ##############################################################################
    # in channels, out channels
    if args.model == 'denorgb':
        in_channels = 3
        out_channels = 3
    elif args.model == 'denoraw':
        in_channels = 4
        out_channels = 1
    elif args.model == 'demo':
        in_channels = 4
        out_channels = 3
    elif args.model == 'srraw':
        in_channels = 4
        out_channels = 1
    elif args.model == 'srrgb':
        in_channels = 3
        out_channels = 3
    elif args.model == 'tenet1':
        in_channels = 4
        out_channels = 3
    elif args.model == 'tenet2':
        in_channels = 4
        out_channels = 3
    else:
        raise ValueError('not supported model')

    ##############################################################################
    # test
    model.eval()
    raw_down_sample = DownsamplingShuffle(2)
    demosaic_layer = nn.PixelShuffle(2)

    # for dataset in os.listdir(args.test_path):
    img_path = os.path.join(args.test_path)
    dst_path = os.path.join(args.save_path)
    if not os.path.exists(dst_path):
        os.makedirs(dst_path)

    im_files = _all_images(img_path)

    with torch.no_grad():
        for i in range(len(im_files)):
            im_file = im_files[i]
            paths = im_file.split('/')
            im_name = paths[-1]

            img = _read_image(im_file)

            # shift images. assure that bayer pattern is: rggb
            if args.shift_x > 0:
                img = np.concatenate((img[:, 1:], img[:, -2:-1]), 1)
            if args.shift_y > 0:
                img = np.concatenate((img[1:], img[-2:-1]), 0)

            h = img.shape[0]
            w = img.shape[1]
            # if input is raw, assure img size is multipliers of 2
            if in_channels == 4:
                if h % 2 != 0 or w % 2 != 0:
                    h = h - h % 2
                    w = w - w % 2
                    img = img[:, :, 0:h, 0:w]

                img = torch.from_numpy(img).float().contiguous().view(
                    -1, 1, h, w)
                img = raw_down_sample(img)
            else:
                img = torch.from_numpy(np.transpose(img, [2, 0, 1])).float()
                img = img.contiguous().view(-1, 3, h, w)

            if args.denoise:
                noise_map = torch.ones([1, 1, img.shape[-2], img.shape[-1]
                                        ]) * args.noise_level
                img = torch.cat((img, noise_map), 1)

            im_inputs = crop_imgs(img, args.crop_scale)
            del img

            im_inputs = im_inputs.to(device)
            h = im_inputs.shape[-2]
            w = im_inputs.shape[-1]

            output = torch.zeros((args.crop_scale**2, 1, out_channels,
                                  h * args.scale * 2, w * args.scale * 2))

            for j in range(args.crop_scale**2):
                if args.model == 'tenet2':
                    sr = model(im_inputs[j].unsqueeze(0))[1]
                else:
                    sr = model(im_inputs[j].unsqueeze(0))

                sr = torch.clamp(sr.cpu(), min=0., max=1.)
                output[j] = sr

            if args.crop_scale > 1:
                rgb_output = binning_imgs(output, args.crop_scale)
            else:
                rgb_output = output.view(1, output.shape[-3], output.shape[-2],
                                         output.shape[-1])

            if out_channels == 4:
                rgb_output = demosaic_layer(rgb_output)

            rgb_output = _tensor2cvimage(rgb_output, np.uint8)

            im_name = im_name.split('.')[0] + '-' + args.post
            # pdb.set_trace()
            cv2.imwrite(os.path.join(dst_path, im_name), rgb_output)
            if args.show_info:
                print('saving: {}, size: {} [{}]/[{}]'.format(
                    os.path.join(dst_path, im_name), rgb_output.shape, i,
                    len(im_files) - 1))