Example #1
0
def get_random_item():
    global SR_model
    global dataset
    global hr
    if SR_model is None:
        SR_model, dataset = load_model_and_dataset()

    dataset_item_no = random.randint(0, len(dataset)-1)
    hr = dataset[dataset_item_no].unsqueeze(0)
    img = to_img(hr/hr.max(), SR_model.opt['mode'], normalize=False)

    success, return_img = cv2.imencode(".png", cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    return_img = return_img.tobytes()
    return jsonify({"img":str(base64.b64encode(return_img))})
Example #2
0
def sensetivity_adjustment():
    global SR_model
    global dataset
    global hr
    global pixel_target
    global feature_maps
    if SR_model is None:
        SR_model, dataset = load_model_and_dataset()
    if hr is None:
        _ = get_random_item()
    change = float(request.args.get('change'))
    hr = hr.to(SR_model.opt['device'])
    real_shape = hr.shape
    size = []
    sf_to_use = continuous_scale_factor if SR_model.upscaling_model.continuous \
        else (1/SR_model.opt['spatial_downscale_ratio'])
    for i in range(2, len(hr.shape)):
        size.append(round(hr.shape[i]/sf_to_use))

    lr = F.interpolate(hr, size=size, 
            mode='bilinear' if SR_model.opt['mode'] == "2D" else "trilinear",
            align_corners=False, recompute_scale_factor=False)

    lr_pix = [int(pixel_target[0]*lr.shape[2]/2 + lr.shape[2]/2), 
        int(pixel_target[1]*lr.shape[3]/2 + lr.shape[3]/2)]
    lr[:,:,lr_pix[0],lr_pix[1]] += change

    lr_up = F.interpolate(lr, size=hr.shape[2:], mode='nearest')
    lr_im = to_img(lr_up/hr.max(), SR_model.opt['mode'], normalize=False)

    lr.requires_grad = True
    if(SR_model.upscaling_model.continuous):
        hr_coords, _ = to_pixel_samples(hr, flatten=False)
        cell_sizes = torch.ones_like(hr_coords)

        for i in range(cell_sizes.shape[-1]):
            cell_sizes[:,:,i] *= 2 / real_shape[2+i]
        
        feature_maps = SR_model.feature_extractor(lr.clone().detach()).clone().detach()
        lr_upscaled = SR_model(lr, hr_coords, cell_sizes)
        if(SR_model.opt['mode'] == "2D"):
            lr_upscaled = lr_upscaled.permute(2, 0, 1).unsqueeze(0)
        else:                    
            lr_upscaled = lr_upscaled.permute(3, 0, 1, 2).unsqueeze(0)
        #lr_upscaled = torch.flatten(lr_upscaled,start_dim=1, end_dim=-1).permute(1,0)
    else:
        feature_maps = SR_model.feature_extractor(lr.clone().detach()).clone().detach()
        lr_upscaled = SR_model(lr)
    
    changed_sr_im = to_img(lr_upscaled/hr.max(), SR_model.opt['mode'], normalize=False)
    imageio.imwrite("test.png", changed_sr_im)
    sense = torch.autograd.grad(outputs = [lr_upscaled], 
        inputs = [lr], grad_outputs = torch.ones_like(lr_upscaled) ,
        allow_unused=True, retain_graph=True, create_graph=True)[0]
    
    sense = torch.abs(sense)
    sense *= (1/sense.max())
    sense = F.interpolate(sense, size=hr.shape[2:], mode='nearest')
    sense_img = sense[0,0]*255
    sense_img = sense_img.detach().cpu().numpy().astype(np.uint8)
    

    success, changed_sr_im = cv2.imencode(".png", cv2.cvtColor(changed_sr_im, cv2.COLOR_BGR2RGB))
    changed_sr_im = changed_sr_im.tobytes()
    success, changed_lr_img = cv2.imencode(".png", cv2.cvtColor(lr_im, cv2.COLOR_BGR2RGB))
    changed_lr_img = changed_lr_img.tobytes()
    success, sense_img = cv2.imencode(".png", cv2.cvtColor(sense_img, cv2.COLOR_BGR2RGB))
    sense_img = sense_img.tobytes()

    return jsonify(
            {
                "changed_sr_img":str(base64.b64encode(changed_sr_im)),
                "changed_lr_img":str(base64.b64encode(changed_lr_img)),
                "sensitivity_img":str(base64.b64encode(sense_img))
            }
        )
Example #3
0
def perform_SR():
    global SR_model
    global dataset
    global hr
    global continuous_scale_factor
    if SR_model is None:
        SR_model, dataset = load_model_and_dataset()
    if hr is None:
        _ = get_random_item()
    print(hr.shape)
    hr_im = to_img(hr/hr.max(), SR_model.opt['mode'], normalize=False)
    hr = hr.to(SR_model.opt['device'])
    real_shape = hr.shape
    size = []
    sf_to_use = continuous_scale_factor if SR_model.upscaling_model.continuous \
        else (1/SR_model.opt['spatial_downscale_ratio'])
    for i in range(2, len(hr.shape)):
        size.append(round(hr.shape[i]/sf_to_use))
        
    lr = F.interpolate(hr, size=size, 
            mode='bilinear' if SR_model.opt['mode'] == "2D" else "trilinear",
            align_corners=False, recompute_scale_factor=False)
    lr_up = F.interpolate(lr, size=hr.shape[2:], mode='nearest')
    lr_im = to_img(lr_up/hr.max(), SR_model.opt['mode'], normalize=False)

    print(hr.device)
    print(lr.device)

    if(SR_model.upscaling_model.continuous):
        hr_coords, _ = to_pixel_samples(hr, flatten=False)
        cell_sizes = torch.ones_like(hr_coords, device=hr_coords.device)
        print(hr_coords.device)
        print(cell_sizes.device)
        for i in range(cell_sizes.shape[-1]):
            cell_sizes[:,:,i] *= 2 / real_shape[2+i]
        
        lr_upscaled = SR_model(lr, hr_coords, cell_sizes)
        if(SR_model.opt['mode'] == "2D"):
            lr_upscaled = lr_upscaled.permute(2, 0, 1).unsqueeze(0)
        else:                    
            lr_upscaled = lr_upscaled.permute(3, 0, 1, 2).unsqueeze(0)
        #lr_upscaled = torch.flatten(lr_upscaled,start_dim=1, end_dim=-1).permute(1,0)
    else:
        lr_upscaled = SR_model(lr)
        
    l1 = torch.abs(lr_upscaled-hr).mean().item()
    psnr = PSNR(lr_upscaled, hr).item()
    sr_im = to_img(lr_upscaled/hr.max(), SR_model.opt['mode'], normalize=False)

    success, hr_im = cv2.imencode(".png", cv2.cvtColor(hr_im, cv2.COLOR_BGR2RGB))
    hr_im = hr_im.tobytes()
    success, sr_im = cv2.imencode(".png", cv2.cvtColor(sr_im, cv2.COLOR_BGR2RGB))
    sr_im = sr_im.tobytes()
    success, lr_im = cv2.imencode(".png", cv2.cvtColor(lr_im, cv2.COLOR_BGR2RGB))
    lr_im = lr_im.tobytes()
    return jsonify(
            {
                "gt_img":str(base64.b64encode(hr_im)),
                "sr_img":str(base64.b64encode(sr_im)),
                "lr_img":str(base64.b64encode(lr_im)),
                "l1": "%0.04f" % l1,
                "psnr": "%0.02f" % psnr
            }
        )
Example #4
0
    for k in args.keys():
        if args[k] is not None:
            opt[k] = args[k]

    opt['cropping_resolution'] = 16
    opt['data_folder'] = os.path.join(input_folder, args['data_folder'])
    model = load_model(opt, args["device"]).to(args['device'])
    dataset = LocalDataset(opt)

    with torch.no_grad():
        if (args['increasing_size_test']):
            img_sequence = []
            rand_dataset_item = random.randint(0, len(dataset) - 1)
            hr = dataset[rand_dataset_item].unsqueeze(0).to(args['device'])
            hr_im = torch.from_numpy(
                np.transpose(to_img(hr, args['mode']),
                             [2, 0, 1])[0:3]).unsqueeze(0)

            lr = F.interpolate(
                hr,
                scale_factor=(1 / args['max_sf']),
                mode='bilinear' if args['mode'] == "2D" else "trilinear",
                align_corners=False)
            lr_im = torch.from_numpy(
                np.transpose(to_img(lr, args['mode']),
                             [2, 0, 1])[0:3]).unsqueeze(0)
            lr_im = F.interpolate(lr_im, size=hr_im.shape[2:], mode='nearest')

            for i in range(15):
                img_sequence.append(lr_im[0].permute(1, 2,
                                                     0).detach().cpu().numpy())
Example #5
0
    def train_distributed(self, rank, model, opt, dataset):
        opt['device'] = "cuda:" + str(rank)
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=self.opt['num_nodes'] *
                                self.opt['gpus_per_node'],
                                rank=rank)
        model = model.to(rank)
        model = DDP(model, device_ids=[rank])

        model_optim = optim.Adam(model.parameters(),
                                 lr=self.opt["g_lr"],
                                 betas=(self.opt["beta_1"],
                                        self.opt["beta_2"]))
        optim_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer=model_optim,
            milestones=[200, 400, 600, 800],
            gamma=self.opt['gamma'])

        if (rank == 0):
            writer = SummaryWriter(
                os.path.join('tensorboard', opt['save_name']))

        start_time = time.time()

        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset,
            num_replicas=opt['num_nodes'] * opt['gpus_per_node'],
            rank=rank)
        dataloader = torch.utils.data.DataLoader(
            dataset=dataset,
            shuffle=False,
            num_workers=opt["num_workers"],
            pin_memory=True,
            sampler=train_sampler)
        L1loss = nn.L1Loss().to(opt["device"])
        step = 0
        for epoch in range(opt['epoch_number'], opt['epochs']):
            opt["epoch_number"] = epoch
            for batch_num, real_hr in enumerate(dataloader):
                model.zero_grad()
                real_hr = real_hr.to(self.opt['device'])
                if (rank == 0):
                    hr_im = torch.from_numpy(
                        np.transpose(to_img(real_hr, self.opt['mode']),
                                     [2, 0, 1])[0:3]).unsqueeze(0)
                real_shape = real_hr.shape
                #print(real_hr.dtype)
                #print("Full shape : " + str(real_hr.shape))

                if (model.upscaling_model.continuous):
                    scale_factor = torch.rand([1], device=real_hr.device, dtype=real_hr.dtype) * \
                        (self.opt['scale_factor_end'] - self.opt['scale_factor_start']) + \
                        self.opt['scale_factor_start']
                else:
                    scale_factor = (1 / self.opt['spatial_downscale_ratio'])

                #scale_factor = 1
                #print("Scale factor: " + str(scale_factor))
                real_lr = F.interpolate(real_hr,
                                        scale_factor=(1 / scale_factor),
                                        mode="bilinear" if "2D"
                                        in self.opt['mode'] else "trilinear",
                                        align_corners=False,
                                        recompute_scale_factor=False)
                if (rank == 0):
                    lr_im = torch.from_numpy(
                        np.transpose(to_img(real_lr, self.opt['mode']),
                                     [2, 0, 1])[0:3]).unsqueeze(0)
                    lr_im = F.interpolate(lr_im,
                                          size=hr_im.shape[2:],
                                          mode='nearest')

                if (model.upscaling_model.continuous):
                    hr_coords, real_hr = to_pixel_samples(real_hr,
                                                          flatten=False)
                    cell_sizes = torch.ones_like(hr_coords)

                    for i in range(cell_sizes.shape[-1]):
                        cell_sizes[:, :, i] *= 2 / real_shape[2 + i]

                    lr_upscaled = model(real_lr, hr_coords, cell_sizes)
                    if ("2D" in self.opt['mode']):
                        lr_upscaled = lr_upscaled.permute(2, 0, 1).unsqueeze(0)
                    else:
                        lr_upscaled = lr_upscaled.permute(3, 0, 1,
                                                          2).unsqueeze(0)
                    lr_upscaled = torch.flatten(lr_upscaled,
                                                start_dim=1,
                                                end_dim=-1).permute(1, 0)
                else:
                    lr_upscaled = model(real_lr)

                if (rank == 0):
                    sr_im = torch.from_numpy(
                        np.transpose(to_img(lr_upscaled, self.opt['mode']),
                                     [2, 0, 1])[0:3]).unsqueeze(0)

                L1 = L1loss(lr_upscaled, real_hr)
                L1.backward()
                model_optim.step()
                optim_scheduler.step()

                psnr = PSNR(lr_upscaled, real_hr)
                if (rank == 0 and step % self.opt['save_every'] == 0):
                    print("Epoch %i batch %i, sf: x%0.02f, L1: %0.04f, PSNR (dB): %0.02f" % \
                    (epoch, batch_num, scale_factor, L1.item(), psnr.item()))
                    writer.add_scalar('L1', L1.item(), step)
                    writer.add_images("LR, SR, HR",
                                      torch.cat([lr_im, sr_im, hr_im]),
                                      global_step=step)
                step += 1

            if (rank == 0 and epoch % self.opt['save_every'] == 0):
                save_model(model, self.opt)
                print("Saved model")

        end_time = time.time()
        total_time = start_time - end_time
        if (rank == 0):
            print("Time to train: " + str(total_time))
            save_model(model, self.opt)
            print("Saved model")
Example #6
0
    def train_single(self, model, dataset):
        model = model.to(self.opt['device'])

        if not self.opt['fine_tuning']:
            model_optim = optim.Adam(model.parameters(),
                                     lr=self.opt["g_lr"],
                                     betas=(self.opt["beta_1"],
                                            self.opt["beta_2"]))
        else:
            for param in model.feature_extractor.parameters():
                param.requires_grad = False
            model_optim = optim.Adam(model.upscaling_model.parameters(),
                                     lr=self.opt["g_lr"],
                                     betas=(self.opt["beta_1"],
                                            self.opt["beta_2"]))
        optim_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer=model_optim,
            milestones=[200, 400, 600, 800],
            gamma=self.opt['gamma'])

        writer = SummaryWriter(
            os.path.join('tensorboard', self.opt['save_name']))

        start_time = time.time()

        dataloader = torch.utils.data.DataLoader(
            dataset=dataset,
            shuffle=True,
            num_workers=self.opt["num_workers"],
            pin_memory=True)
        L1loss = nn.L1Loss().to(self.opt["device"])
        step = 0
        for epoch in range(self.opt['epoch_number'], self.opt['epochs']):
            self.opt["epoch_number"] = epoch
            for batch_num, real_hr in enumerate(dataloader):
                model.zero_grad()
                real_hr = real_hr.to(self.opt['device'])
                hr_im = torch.from_numpy(
                    np.transpose(to_img(real_hr, self.opt['mode']),
                                 [2, 0, 1])[0:3]).unsqueeze(0)
                real_shape = real_hr.shape

                if (model.upscaling_model.continuous):
                    scale_factor = torch.rand([1], device=real_hr.device, dtype=real_hr.dtype) * \
                        (self.opt['scale_factor_end'] - self.opt['scale_factor_start']) + \
                        self.opt['scale_factor_start']
                else:
                    scale_factor = (1 / self.opt['spatial_downscale_ratio'])

                real_lr = F.interpolate(real_hr,
                                        scale_factor=(1 / scale_factor),
                                        mode="bilinear" if "2D"
                                        in self.opt['mode'] else "trilinear",
                                        align_corners=False,
                                        recompute_scale_factor=False)
                lr_im = torch.from_numpy(
                    np.transpose(to_img(real_lr, self.opt['mode']),
                                 [2, 0, 1])[0:3]).unsqueeze(0)
                lr_im = F.interpolate(lr_im,
                                      mode='nearest',
                                      size=hr_im.shape[2:])

                if (model.upscaling_model.continuous):
                    hr_coords, real_hr = to_pixel_samples(real_hr,
                                                          flatten=False)
                    cell_sizes = torch.ones_like(hr_coords)

                    for i in range(cell_sizes.shape[-1]):
                        cell_sizes[:, :, i] *= 2 / real_shape[2 + i]

                    lr_upscaled = model(real_lr, hr_coords, cell_sizes)
                    if ("2D" in self.opt['mode']):
                        lr_upscaled = lr_upscaled.permute(2, 0, 1).unsqueeze(0)
                    else:
                        lr_upscaled = lr_upscaled.permute(3, 0, 1,
                                                          2).unsqueeze(0)
                    sr_im = torch.from_numpy(
                        np.transpose(to_img(lr_upscaled, self.opt['mode']),
                                     [2, 0, 1])[0:3]).unsqueeze(0)
                    lr_upscaled = torch.flatten(lr_upscaled,
                                                start_dim=1,
                                                end_dim=-1).permute(1, 0)
                else:
                    lr_upscaled = model(real_lr)
                    sr_im = torch.from_numpy(
                        np.transpose(to_img(lr_upscaled, self.opt['mode']),
                                     [2, 0, 1])[0:3]).unsqueeze(0)

                L1 = L1loss(lr_upscaled, real_hr)
                L1.backward()
                model_optim.step()
                optim_scheduler.step()
                psnr = PSNR(lr_upscaled, real_hr)

                if (step % self.opt['save_every'] == 0):
                    print("Epoch %i batch %i, sf: x%0.02f, L1: %0.04f, PSNR (dB): %0.02f" % \
                        (epoch, batch_num, scale_factor, L1.item(), psnr.item()))
                    writer.add_scalar('L1', L1.item(), step)
                    writer.add_images("LR, SR, HR",
                                      torch.cat([lr_im, sr_im, hr_im]),
                                      global_step=step)
                step += 1

            if (epoch % self.opt['save_every'] == 0):
                save_model(model, self.opt)
                print("Saved model")

        end_time = time.time()
        total_time = start_time - end_time
        print("Time to train: " + str(total_time))
        save_model(model, self.opt)
        print("Saved model")