def get_random_item(): global SR_model global dataset global hr if SR_model is None: SR_model, dataset = load_model_and_dataset() dataset_item_no = random.randint(0, len(dataset)-1) hr = dataset[dataset_item_no].unsqueeze(0) img = to_img(hr/hr.max(), SR_model.opt['mode'], normalize=False) success, return_img = cv2.imencode(".png", cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) return_img = return_img.tobytes() return jsonify({"img":str(base64.b64encode(return_img))})
def sensetivity_adjustment(): global SR_model global dataset global hr global pixel_target global feature_maps if SR_model is None: SR_model, dataset = load_model_and_dataset() if hr is None: _ = get_random_item() change = float(request.args.get('change')) hr = hr.to(SR_model.opt['device']) real_shape = hr.shape size = [] sf_to_use = continuous_scale_factor if SR_model.upscaling_model.continuous \ else (1/SR_model.opt['spatial_downscale_ratio']) for i in range(2, len(hr.shape)): size.append(round(hr.shape[i]/sf_to_use)) lr = F.interpolate(hr, size=size, mode='bilinear' if SR_model.opt['mode'] == "2D" else "trilinear", align_corners=False, recompute_scale_factor=False) lr_pix = [int(pixel_target[0]*lr.shape[2]/2 + lr.shape[2]/2), int(pixel_target[1]*lr.shape[3]/2 + lr.shape[3]/2)] lr[:,:,lr_pix[0],lr_pix[1]] += change lr_up = F.interpolate(lr, size=hr.shape[2:], mode='nearest') lr_im = to_img(lr_up/hr.max(), SR_model.opt['mode'], normalize=False) lr.requires_grad = True if(SR_model.upscaling_model.continuous): hr_coords, _ = to_pixel_samples(hr, flatten=False) cell_sizes = torch.ones_like(hr_coords) for i in range(cell_sizes.shape[-1]): cell_sizes[:,:,i] *= 2 / real_shape[2+i] feature_maps = SR_model.feature_extractor(lr.clone().detach()).clone().detach() lr_upscaled = SR_model(lr, hr_coords, cell_sizes) if(SR_model.opt['mode'] == "2D"): lr_upscaled = lr_upscaled.permute(2, 0, 1).unsqueeze(0) else: lr_upscaled = lr_upscaled.permute(3, 0, 1, 2).unsqueeze(0) #lr_upscaled = torch.flatten(lr_upscaled,start_dim=1, end_dim=-1).permute(1,0) else: feature_maps = SR_model.feature_extractor(lr.clone().detach()).clone().detach() lr_upscaled = SR_model(lr) changed_sr_im = to_img(lr_upscaled/hr.max(), SR_model.opt['mode'], normalize=False) imageio.imwrite("test.png", changed_sr_im) sense = torch.autograd.grad(outputs = [lr_upscaled], inputs = [lr], grad_outputs = torch.ones_like(lr_upscaled) , allow_unused=True, retain_graph=True, create_graph=True)[0] sense = torch.abs(sense) sense *= (1/sense.max()) sense = F.interpolate(sense, size=hr.shape[2:], mode='nearest') sense_img = sense[0,0]*255 sense_img = sense_img.detach().cpu().numpy().astype(np.uint8) success, changed_sr_im = cv2.imencode(".png", cv2.cvtColor(changed_sr_im, cv2.COLOR_BGR2RGB)) changed_sr_im = changed_sr_im.tobytes() success, changed_lr_img = cv2.imencode(".png", cv2.cvtColor(lr_im, cv2.COLOR_BGR2RGB)) changed_lr_img = changed_lr_img.tobytes() success, sense_img = cv2.imencode(".png", cv2.cvtColor(sense_img, cv2.COLOR_BGR2RGB)) sense_img = sense_img.tobytes() return jsonify( { "changed_sr_img":str(base64.b64encode(changed_sr_im)), "changed_lr_img":str(base64.b64encode(changed_lr_img)), "sensitivity_img":str(base64.b64encode(sense_img)) } )
def perform_SR(): global SR_model global dataset global hr global continuous_scale_factor if SR_model is None: SR_model, dataset = load_model_and_dataset() if hr is None: _ = get_random_item() print(hr.shape) hr_im = to_img(hr/hr.max(), SR_model.opt['mode'], normalize=False) hr = hr.to(SR_model.opt['device']) real_shape = hr.shape size = [] sf_to_use = continuous_scale_factor if SR_model.upscaling_model.continuous \ else (1/SR_model.opt['spatial_downscale_ratio']) for i in range(2, len(hr.shape)): size.append(round(hr.shape[i]/sf_to_use)) lr = F.interpolate(hr, size=size, mode='bilinear' if SR_model.opt['mode'] == "2D" else "trilinear", align_corners=False, recompute_scale_factor=False) lr_up = F.interpolate(lr, size=hr.shape[2:], mode='nearest') lr_im = to_img(lr_up/hr.max(), SR_model.opt['mode'], normalize=False) print(hr.device) print(lr.device) if(SR_model.upscaling_model.continuous): hr_coords, _ = to_pixel_samples(hr, flatten=False) cell_sizes = torch.ones_like(hr_coords, device=hr_coords.device) print(hr_coords.device) print(cell_sizes.device) for i in range(cell_sizes.shape[-1]): cell_sizes[:,:,i] *= 2 / real_shape[2+i] lr_upscaled = SR_model(lr, hr_coords, cell_sizes) if(SR_model.opt['mode'] == "2D"): lr_upscaled = lr_upscaled.permute(2, 0, 1).unsqueeze(0) else: lr_upscaled = lr_upscaled.permute(3, 0, 1, 2).unsqueeze(0) #lr_upscaled = torch.flatten(lr_upscaled,start_dim=1, end_dim=-1).permute(1,0) else: lr_upscaled = SR_model(lr) l1 = torch.abs(lr_upscaled-hr).mean().item() psnr = PSNR(lr_upscaled, hr).item() sr_im = to_img(lr_upscaled/hr.max(), SR_model.opt['mode'], normalize=False) success, hr_im = cv2.imencode(".png", cv2.cvtColor(hr_im, cv2.COLOR_BGR2RGB)) hr_im = hr_im.tobytes() success, sr_im = cv2.imencode(".png", cv2.cvtColor(sr_im, cv2.COLOR_BGR2RGB)) sr_im = sr_im.tobytes() success, lr_im = cv2.imencode(".png", cv2.cvtColor(lr_im, cv2.COLOR_BGR2RGB)) lr_im = lr_im.tobytes() return jsonify( { "gt_img":str(base64.b64encode(hr_im)), "sr_img":str(base64.b64encode(sr_im)), "lr_img":str(base64.b64encode(lr_im)), "l1": "%0.04f" % l1, "psnr": "%0.02f" % psnr } )
for k in args.keys(): if args[k] is not None: opt[k] = args[k] opt['cropping_resolution'] = 16 opt['data_folder'] = os.path.join(input_folder, args['data_folder']) model = load_model(opt, args["device"]).to(args['device']) dataset = LocalDataset(opt) with torch.no_grad(): if (args['increasing_size_test']): img_sequence = [] rand_dataset_item = random.randint(0, len(dataset) - 1) hr = dataset[rand_dataset_item].unsqueeze(0).to(args['device']) hr_im = torch.from_numpy( np.transpose(to_img(hr, args['mode']), [2, 0, 1])[0:3]).unsqueeze(0) lr = F.interpolate( hr, scale_factor=(1 / args['max_sf']), mode='bilinear' if args['mode'] == "2D" else "trilinear", align_corners=False) lr_im = torch.from_numpy( np.transpose(to_img(lr, args['mode']), [2, 0, 1])[0:3]).unsqueeze(0) lr_im = F.interpolate(lr_im, size=hr_im.shape[2:], mode='nearest') for i in range(15): img_sequence.append(lr_im[0].permute(1, 2, 0).detach().cpu().numpy())
def train_distributed(self, rank, model, opt, dataset): opt['device'] = "cuda:" + str(rank) dist.init_process_group(backend='nccl', init_method='env://', world_size=self.opt['num_nodes'] * self.opt['gpus_per_node'], rank=rank) model = model.to(rank) model = DDP(model, device_ids=[rank]) model_optim = optim.Adam(model.parameters(), lr=self.opt["g_lr"], betas=(self.opt["beta_1"], self.opt["beta_2"])) optim_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer=model_optim, milestones=[200, 400, 600, 800], gamma=self.opt['gamma']) if (rank == 0): writer = SummaryWriter( os.path.join('tensorboard', opt['save_name'])) start_time = time.time() train_sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=opt['num_nodes'] * opt['gpus_per_node'], rank=rank) dataloader = torch.utils.data.DataLoader( dataset=dataset, shuffle=False, num_workers=opt["num_workers"], pin_memory=True, sampler=train_sampler) L1loss = nn.L1Loss().to(opt["device"]) step = 0 for epoch in range(opt['epoch_number'], opt['epochs']): opt["epoch_number"] = epoch for batch_num, real_hr in enumerate(dataloader): model.zero_grad() real_hr = real_hr.to(self.opt['device']) if (rank == 0): hr_im = torch.from_numpy( np.transpose(to_img(real_hr, self.opt['mode']), [2, 0, 1])[0:3]).unsqueeze(0) real_shape = real_hr.shape #print(real_hr.dtype) #print("Full shape : " + str(real_hr.shape)) if (model.upscaling_model.continuous): scale_factor = torch.rand([1], device=real_hr.device, dtype=real_hr.dtype) * \ (self.opt['scale_factor_end'] - self.opt['scale_factor_start']) + \ self.opt['scale_factor_start'] else: scale_factor = (1 / self.opt['spatial_downscale_ratio']) #scale_factor = 1 #print("Scale factor: " + str(scale_factor)) real_lr = F.interpolate(real_hr, scale_factor=(1 / scale_factor), mode="bilinear" if "2D" in self.opt['mode'] else "trilinear", align_corners=False, recompute_scale_factor=False) if (rank == 0): lr_im = torch.from_numpy( np.transpose(to_img(real_lr, self.opt['mode']), [2, 0, 1])[0:3]).unsqueeze(0) lr_im = F.interpolate(lr_im, size=hr_im.shape[2:], mode='nearest') if (model.upscaling_model.continuous): hr_coords, real_hr = to_pixel_samples(real_hr, flatten=False) cell_sizes = torch.ones_like(hr_coords) for i in range(cell_sizes.shape[-1]): cell_sizes[:, :, i] *= 2 / real_shape[2 + i] lr_upscaled = model(real_lr, hr_coords, cell_sizes) if ("2D" in self.opt['mode']): lr_upscaled = lr_upscaled.permute(2, 0, 1).unsqueeze(0) else: lr_upscaled = lr_upscaled.permute(3, 0, 1, 2).unsqueeze(0) lr_upscaled = torch.flatten(lr_upscaled, start_dim=1, end_dim=-1).permute(1, 0) else: lr_upscaled = model(real_lr) if (rank == 0): sr_im = torch.from_numpy( np.transpose(to_img(lr_upscaled, self.opt['mode']), [2, 0, 1])[0:3]).unsqueeze(0) L1 = L1loss(lr_upscaled, real_hr) L1.backward() model_optim.step() optim_scheduler.step() psnr = PSNR(lr_upscaled, real_hr) if (rank == 0 and step % self.opt['save_every'] == 0): print("Epoch %i batch %i, sf: x%0.02f, L1: %0.04f, PSNR (dB): %0.02f" % \ (epoch, batch_num, scale_factor, L1.item(), psnr.item())) writer.add_scalar('L1', L1.item(), step) writer.add_images("LR, SR, HR", torch.cat([lr_im, sr_im, hr_im]), global_step=step) step += 1 if (rank == 0 and epoch % self.opt['save_every'] == 0): save_model(model, self.opt) print("Saved model") end_time = time.time() total_time = start_time - end_time if (rank == 0): print("Time to train: " + str(total_time)) save_model(model, self.opt) print("Saved model")
def train_single(self, model, dataset): model = model.to(self.opt['device']) if not self.opt['fine_tuning']: model_optim = optim.Adam(model.parameters(), lr=self.opt["g_lr"], betas=(self.opt["beta_1"], self.opt["beta_2"])) else: for param in model.feature_extractor.parameters(): param.requires_grad = False model_optim = optim.Adam(model.upscaling_model.parameters(), lr=self.opt["g_lr"], betas=(self.opt["beta_1"], self.opt["beta_2"])) optim_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer=model_optim, milestones=[200, 400, 600, 800], gamma=self.opt['gamma']) writer = SummaryWriter( os.path.join('tensorboard', self.opt['save_name'])) start_time = time.time() dataloader = torch.utils.data.DataLoader( dataset=dataset, shuffle=True, num_workers=self.opt["num_workers"], pin_memory=True) L1loss = nn.L1Loss().to(self.opt["device"]) step = 0 for epoch in range(self.opt['epoch_number'], self.opt['epochs']): self.opt["epoch_number"] = epoch for batch_num, real_hr in enumerate(dataloader): model.zero_grad() real_hr = real_hr.to(self.opt['device']) hr_im = torch.from_numpy( np.transpose(to_img(real_hr, self.opt['mode']), [2, 0, 1])[0:3]).unsqueeze(0) real_shape = real_hr.shape if (model.upscaling_model.continuous): scale_factor = torch.rand([1], device=real_hr.device, dtype=real_hr.dtype) * \ (self.opt['scale_factor_end'] - self.opt['scale_factor_start']) + \ self.opt['scale_factor_start'] else: scale_factor = (1 / self.opt['spatial_downscale_ratio']) real_lr = F.interpolate(real_hr, scale_factor=(1 / scale_factor), mode="bilinear" if "2D" in self.opt['mode'] else "trilinear", align_corners=False, recompute_scale_factor=False) lr_im = torch.from_numpy( np.transpose(to_img(real_lr, self.opt['mode']), [2, 0, 1])[0:3]).unsqueeze(0) lr_im = F.interpolate(lr_im, mode='nearest', size=hr_im.shape[2:]) if (model.upscaling_model.continuous): hr_coords, real_hr = to_pixel_samples(real_hr, flatten=False) cell_sizes = torch.ones_like(hr_coords) for i in range(cell_sizes.shape[-1]): cell_sizes[:, :, i] *= 2 / real_shape[2 + i] lr_upscaled = model(real_lr, hr_coords, cell_sizes) if ("2D" in self.opt['mode']): lr_upscaled = lr_upscaled.permute(2, 0, 1).unsqueeze(0) else: lr_upscaled = lr_upscaled.permute(3, 0, 1, 2).unsqueeze(0) sr_im = torch.from_numpy( np.transpose(to_img(lr_upscaled, self.opt['mode']), [2, 0, 1])[0:3]).unsqueeze(0) lr_upscaled = torch.flatten(lr_upscaled, start_dim=1, end_dim=-1).permute(1, 0) else: lr_upscaled = model(real_lr) sr_im = torch.from_numpy( np.transpose(to_img(lr_upscaled, self.opt['mode']), [2, 0, 1])[0:3]).unsqueeze(0) L1 = L1loss(lr_upscaled, real_hr) L1.backward() model_optim.step() optim_scheduler.step() psnr = PSNR(lr_upscaled, real_hr) if (step % self.opt['save_every'] == 0): print("Epoch %i batch %i, sf: x%0.02f, L1: %0.04f, PSNR (dB): %0.02f" % \ (epoch, batch_num, scale_factor, L1.item(), psnr.item())) writer.add_scalar('L1', L1.item(), step) writer.add_images("LR, SR, HR", torch.cat([lr_im, sr_im, hr_im]), global_step=step) step += 1 if (epoch % self.opt['save_every'] == 0): save_model(model, self.opt) print("Saved model") end_time = time.time() total_time = start_time - end_time print("Time to train: " + str(total_time)) save_model(model, self.opt) print("Saved model")