def forward(self, output, target1,target2,need_bpp=False): N, _, H, W = target1.size() out = {} num_pixels = N * H * W # 计算误差 # out['bpp_loss'] = sum( # (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) # for likelihoods in output['likelihoods'].values()) out['mse_loss'] = self.mse(output['x1_hat'], target1) + self.mse(output['x2_hat'], target2) #end to end if need_bpp: out['bpp1'] = (torch.log(output['likelihoods']['y1']).sum() / (-math.log(2) * num_pixels)) + ( torch.log(output['likelihoods']['z1']).sum() / (-math.log(2) * num_pixels)) out['bpp2'] = (torch.log(output['likelihoods']['y2']).sum() / (-math.log(2) * num_pixels)) + ( torch.log(output['likelihoods']['z2']).sum() / (-math.log(2) * num_pixels)) out['loss'] = self.lmbda * 255**2 * out['mse_loss'] #+ out['bpp_loss'] out['ms_ssim1'] = ms_ssim(output['x1_hat'], target1, data_range=1, size_average=False)[0] # (N,) out['ms_ssim2'] = ms_ssim(output['x2_hat'], target2, data_range=1, size_average=False)[0] out['ms_ssim'] = (out['ms_ssim1']+out['ms_ssim2'])/2 out['psnr1'] = mse2psnr(self.mse(output['x1_hat'], target1)) out['psnr2'] = mse2psnr(self.mse(output['x2_hat'], target2)) return out
def validation_step(self, batch, batch_idx): inputs, targets = batch outputs = self.model(inputs) loss = criterion(outputs, targets) if self.current_epoch is not None and batch_idx + 1 == len(val_loader): baseline = F.interpolate(inputs, scale_factor=scale_factor, mode='bicubic') grid = make_grid(torch.cat([targets, baseline, outputs], dim=-1), nrow=2).clamp(0, 1) val_writer.add_image('Images From Last Batch', grid, self.current_epoch) ssim = pytorch_msssim.ms_ssim(outputs, targets, data_range=1) psnr = utils.psnr(outputs, targets) size = len(targets) self.val_loss.update(loss.item(), size) self.val_psnr.update(psnr.item(), size) self.val_ssim.update(ssim.item(), size) return { 'val_loss': f'{self.val_loss.compute():.4g}', 'psnr': f'{self.val_psnr.compute():.2f}', 'ssim': f'{self.val_ssim.compute():.4f}' }
def optimize_parameters(self, step): self.optimizer_G.zero_grad() self.fake_H = self.netG(self.var_L) l_g_total = 0 l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) l_g_total += l_pix if self.cri_CX: real_fea = self.netF(self.real_H) fake_fea = self.netF(self.fake_H) l_CX = self.l_CX_w * self.cri_CX(real_fea, fake_fea) l_g_total += l_CX if self.cri_ssim: if self.cri_ssim == 'ssim': ssim_val = ssim(self.fake_H, self.real_H, win_size=self.ssim_window, data_range=1.0, size_average=True) elif self.cri_ssim == 'ms-ssim': weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363]).to(self.fake_H.device, dtype=self.fake_H.dtype) ssim_val = ms_ssim(self.fake_H, self.real_H, win_size=self.ssim_window, data_range=1.0, size_average=True, weights=weights) l_ssim = self.l_ssim_w * (1 - ssim_val) l_g_total += l_ssim l_g_total.backward() self.optimizer_G.step() # set log self.log_dict['l_pix'] = l_pix.item() if self.cri_CX: self.log_dict['l_CX'] = l_CX.item() if self.cri_ssim: self.log_dict['l_ssim'] = l_ssim.item()
def forward(self, reconst, original): MSSSIM = pytorch_msssim.ms_ssim(original, reconst, nonnegative_ssim=True).to(device) MSE = (torch.nn.functional.mse_loss(original, reconst)).to(device) loss = MSE - MSSSIM + 1 return loss, MSSSIM
def train_step(self, batch, batch_idx): inputs, targets = batch outputs = self.model(inputs) loss = criterion(outputs, targets) with torch.no_grad(): if batch_idx + 1 == len(train_loader): indices = torch.randperm(inputs.shape[0])[:2] baseline = F.interpolate(inputs[indices], scale_factor=scale_factor, mode='bicubic') example = targets[indices], baseline, outputs[indices] grid = make_grid(torch.cat(example, dim=-1)).clamp(0, 1) train_writer.add_image('Random Images From Last Batch', grid, self.current_epoch) ssim = pytorch_msssim.ms_ssim(outputs, targets, data_range=1) psnr = utils.psnr(outputs, targets) size = len(targets) self.loss.update(loss.item(), size) self.psnr.update(psnr.item(), size) self.ssim.update(ssim.item(), size) return loss, { 'loss': f'{self.loss.compute():.4g}', 'psnr': f'{self.psnr.compute():.2f}', 'ssim': f'{self.ssim.compute():.4f}' }
def inference(self, label, geo, image=None): # Encode Inputs image = Variable(image) if image is not None else None geometry = Variable(geo) if geo is not None else None concat_input, real_image = self.encode_input(label, image, geometry=geometry, infer=True) print(input_label.data.device) # Fake Generation if torch.__version__.startswith('0.4'): with torch.no_grad(): fake_image = self.netG.forward(concat_input) else: fake_image = self.netG.forward(concat_input) metrics = {} if image is not None: # metrics GT_image = Variable((real_image + 1.) / 2.) fake_normalized = (fake_image + 1.) / 2. metrics['ssim'] = ssim(fake_normalized, GT_image, data_range=1, size_average=False, nonnegative_ssim=True) metrics['ms_ssim'] = ms_ssim(fake_normalized, GT_image, data_range=1, size_average=False) return fake_image, metrics
def compute_metrics_for_frame( org_frame: Frame, dec_frame: Frame, bitdepth: int = 8, ) -> Dict[str, Any]: org_frame = tuple(p.unsqueeze(0).unsqueeze(0) for p in org_frame) # type: ignore dec_frame = tuple(p.unsqueeze(0).unsqueeze(0) for p in dec_frame) # type:ignore out: Dict[str, Any] = {} max_val = 2**bitdepth - 1 # YCbCr metrics for i, component in enumerate("yuv"): out[f"mse-{component}"] = (org_frame[i] - dec_frame[i]).pow(2).mean() org_rgb = ycbcr2rgb(yuv_420_to_444(org_frame, mode="bicubic").true_divide(max_val)) # type: ignore dec_rgb = ycbcr2rgb(yuv_420_to_444(dec_frame, mode="bicubic").true_divide(max_val)) # type: ignore org_rgb = (org_rgb * max_val).clamp(0, max_val).round() dec_rgb = (dec_rgb * max_val).clamp(0, max_val).round() mse_rgb = (org_rgb - dec_rgb).pow(2).mean() ms_ssim_rgb = ms_ssim(org_rgb, dec_rgb, data_range=max_val) out.update({"ms-ssim-rgb": ms_ssim_rgb, "mse-rgb": mse_rgb}) return out
def compute_metrics_for_frame( org_frame: Frame, rec_frame: Tensor, device: str = "cpu", max_val: int = 255, ) -> Dict[str, Any]: out: Dict[str, Any] = {} # YCbCr metrics org_yuv = to_tensors(org_frame, device=str(device), max_value=max_val) org_yuv = tuple(p.unsqueeze(0).unsqueeze(0) for p in org_yuv) # type: ignore rec_yuv = convert_rgb_to_yuv420(rec_frame) for i, component in enumerate("yuv"): org = (org_yuv[i] * max_val).clamp(0, max_val).round() rec = (rec_yuv[i] * max_val).clamp(0, max_val).round() out[f"psnr-{component}"] = 20 * np.log10(max_val) - 10 * torch.log10( (org - rec).pow(2).mean() ) out["psnr-yuv"] = (4 * out["psnr-y"] + out["psnr-u"] + out["psnr-v"]) / 6 # RGB metrics org_rgb = convert_yuv420_to_rgb( org_frame, device, max_val ) # ycbcr2rgb(yuv_420_to_444(org_frame, mode="bicubic")) # type: ignore org_rgb = (org_rgb * max_val).clamp(0, max_val).round() rec_frame = (rec_frame * max_val).clamp(0, max_val).round() mse_rgb = (org_rgb - rec_frame).pow(2).mean() psnr_rgb = 20 * np.log10(max_val) - 10 * torch.log10(mse_rgb) ms_ssim_rgb = ms_ssim(org_rgb, rec_frame, data_range=max_val) out.update({"ms-ssim-rgb": ms_ssim_rgb, "mse-rgb": mse_rgb, "psnr-rgb": psnr_rgb}) return out
def calculate_mssim(minibatch, reconstr_image, size_average=True): """ compute the ms-sim between an image and its reconstruction :param minibatch: the input minibatch :param reconstr_image: the reconstructed image :returns: the msssim score :rtype: float """ if minibatch.dim() == 5 and reconstr_image.dim( ) == 5: # special case where we have temporal dim msssim = torch.cat([ calculate_mssim(minibatch[:, i, ::], reconstr_image[:, i, ::], size_average=size_average).unsqueeze(-1) for i in range(minibatch.shape[1]) ], -1) return torch.mean(msssim, -1) smallest_dim = min(minibatch.shape[-1], minibatch.shape[-2]) if minibatch.dtype != reconstr_image.dtype: minibatch = minibatch.type(reconstr_image.dtype) if smallest_dim < 160: # Limitation of ms-ssim library due to 4x downsample return 1 - ssim(X=minibatch, Y=reconstr_image, data_range=1, size_average=size_average, nonnegative_ssim=True) return 1 - ms_ssim( X=minibatch, Y=reconstr_image, data_range=1, size_average=size_average)
def calc_baseline(loader, scale_factor=2, use_cuda=torch.cuda.is_available(), method='bicubic', criterion=nn.L1Loss()): loss, psnr_, ssim = WeightedAveragedMetric(), WeightedAveragedMetric( ), WeightedAveragedMetric() with tqdm(loader, desc='Baseline') as bar: for inputs, targets in bar: if use_cuda: inputs = inputs.cuda() targets = targets.cuda() resized = F.interpolate(inputs, scale_factor=scale_factor, mode=method) size = len(targets) loss.update(criterion(resized, targets).item(), size) psnr_.update(psnr(resized, targets).item(), size) ssim.update( pytorch_msssim.ms_ssim(resized, targets, data_range=1).item(), size) bar.set_postfix({ 'loss': f'{loss.compute():.4g}', 'psnr': f'{psnr_.compute():.2f}', 'ssim': f'{ssim.compute():.4f}' })
def norm(img1, img2): criterion = nn.MSELoss() loss_r = criterion(img1,img2) print("loss_r:" + str(loss_r.item())) SSIM_h=1-pytorch_msssim.ssim(img1,img2) print("SSIM_h:"+str(SSIM_h.item())) MS_SSIM_h=1-pytorch_msssim.ms_ssim(img1,img2) print("MS_SSIM_h:"+str(MS_SSIM_h.item()))
def recons_loss(recon_x, x): # msssim 多尺度结构相似损失函数:基于多层(图片按照一定规则,由大到小缩放)的SSIM损失函数,相当于考虑了分辨率 msssim = ((1-pytorch_msssim.ms_ssim(x,recon_x)))/2 #一种优化过的ssim算法 #作者结合神经科学的研究,认为我们人类衡量两幅图的距离时, # 更偏重于两图的结构相似性,而不是逐像素计算两图的差异。因此作者提出了基于 structural similarity 的度量,声称其比 MSE 更能反映人类视觉系统对两幅图相似性的判断。 f1 = F.l1_loss(recon_x, x) #l1损失:基于逐像素比较差异,然后取绝对值 l2损失:逐像素比较差异 取平方 #L2损失函数会放大最大误差和最小误差之间的差距(比如2*2 和0.1*0.1),另外L2损失函数对异常点也比较敏感 #论文证明 MS-SSIM+L1损失函数是最好的 #作者这样组合的原因是,MS-SSIM容易导致亮度的改变和颜色的偏差,但它能保留高频信息(图像的边缘和细节), # 而L1损失函数能较好的保持亮度和颜色不变化。公式中α为0.84,是作者试验出来的,而G为高斯分布参数(MS-SSIM里面也要用到这个) Lmix = α*Lmsssim + (1-α)*G*L1 G是高斯分布函数 return msssim+f1
def evaluate_ms_ssim(conf): # create DataLoader loader = create_databunch(conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"]) model_path = conf["model_path"] out_path = Path(model_path).parent / "evaluation" out_path.mkdir(parents=True, exist_ok=True) img_size = loader.dataset[0][0][0].shape[-1] model = load_pretrained_model(conf["arch_name"], conf["model_path"], img_size) if conf["model_path_2"] != "none": model_2 = load_pretrained_model(conf["arch_name_2"], conf["model_path_2"], img_size) vals = [] if img_size < 160: click.echo("\nThis is only a placeholder!\ Images too small for meaningful ms ssim calculations.\n") # iterate trough DataLoader for i, (img_test, img_true) in enumerate(tqdm(loader)): pred = eval_model(img_test, model) if conf["model_path_2"] != "none": pred_2 = eval_model(img_test, model_2) pred = torch.cat((pred, pred_2), dim=1) ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) if img_size < 160: ifft_truth = pad_unsqueeze(torch.tensor(ifft_truth)) ifft_pred = pad_unsqueeze(torch.tensor(ifft_pred)) vals.extend([ ms_ssim(pred.unsqueeze(0), truth.unsqueeze(0), data_range=truth.max()) for pred, truth in zip(ifft_pred, ifft_truth) ]) click.echo("\nCreating ms-ssim histogram.\n") vals = torch.tensor(vals) histogram_ms_ssim( vals, out_path, plot_format=conf["format"], ) click.echo(f"\nThe mean ms-ssim value is {vals.mean()}.\n")
def run_tests( num_samples=32, ): l1_losses = [] l2_losses = [] psnr_losses = [] gots = [] num=100 with torch.no_grad(): for i, (c2w, lp) in enumerate(zip(tqdm(cam_to_worlds[:num]), light_locs)): exp = exp_imgs[i].clamp(min=0, max=1) cameras = NeRFCamera(cam_to_world=c2w.unsqueeze(0), focal=focal, device=device) lights = PointLights(intensity=[1,1,1], location=lp[None,...], scale=300, device=device) got = None for _ in range(num_samples): sample = pt.pathtrace( density_field, size=SIZE, chunk_size=min(SIZE, 100), bundle_size=1, bsdf=learned_bsdf, integrator=integrator, # 0 is for comparison, 1 is for display background=0, cameras=cameras, lights=lights, device=device, silent=True, w_isect=True, )[0] if got is None: got = sample else: got += sample got /= num_samples got = got.clamp(min=0, max=1) save_plot( exp ** (1/2.2), got ** (1/2.2), f"outputs/path_nerv_armadillo_{i:03}.png", ) l1_losses.append(F.l1_loss(exp,got).item()) mse = F.mse_loss(exp,got) l2_losses.append(mse.item()) psnr = mse2psnr(mse).item() psnr_losses.append(psnr) gots.append(got) print("Avg l1 loss", np.mean(l1_losses)) print("Avg l2 loss", np.mean(l2_losses)) print("Avg PSNR loss", np.mean(psnr_losses)) with torch.no_grad(): gots = torch.stack(gots, dim=0).permute(0, 3, 1, 2) exps = torch.stack(exp_imgs[:num], dim=0).permute(0, 3, 1, 2) # takes a lot of memory torch.cuda.empty_cache() ssim_loss = ms_ssim(gots, exps, data_range=1, size_average=True).item() print("MS-SSIM loss", ssim_loss) ssim_loss = ssim(gots, exps, data_range=1, size_average=True).item() print("SSIM loss", ssim_loss) return
def compute_metrics(a: Union[np.array, Image.Image], b: Union[np.array, Image.Image], max_val: float = 255.) -> Tuple[float, float]: """Returns PSNR and MS-SSIM between images `a` and `b`. """ if isinstance(a, Image.Image): a = np.asarray(a) if isinstance(b, Image.Image): b = np.asarray(b) a = torch.from_numpy(a.copy()).float().unsqueeze(0) if a.size(3) == 3: a = a.permute(0, 3, 1, 2) b = torch.from_numpy(b.copy()).float().unsqueeze(0) if b.size(3) == 3: b = b.permute(0, 3, 1, 2) mse = torch.mean((a - b)**2).item() p = 20 * np.log10(max_val) - 10 * np.log10(mse) m = ms_ssim(a, b, data_range=max_val).item() return p, m
def calc_metrics(img_gt, img_out, hfen): ''' compute vif, mssim, ssim, and psnr of img_out using img_gt as ground-truth reference note: for msssim, made a slight mod to source code in line 200 of /home/vanveen/heck/lib/python3.8/site-packages/pytorch_msssim/ssim.py to compute msssim over images w smallest dim >=160 ''' img_gt, img_out = norm_imgs(img_gt, img_out) img_gt, img_out = np.array(img_gt), np.array(img_out) vif_ = vifp_mscale(img_gt, img_out, sigma_nsq=img_out.mean()) ssim_ = ssim(img_gt, img_out) psnr_ = psnr(img_gt, img_out) img_out_ = torch.from_numpy(np.array([[img_out]])) img_gt_ = torch.from_numpy(np.array([[img_gt]])) msssim_ = float(ms_ssim(img_out_, img_gt_, data_range=img_gt_.max())) hfen_ = 10000 * float(hfen(img_gt_.float(), img_out_.float())) return vif_, msssim_, ssim_, psnr_, hfen_
def inference(model, x): x = x.unsqueeze(0) h, w = x.size(2), x.size(3) p = 64 # maximum 6 strides of 2 new_h = (h + p - 1) // p * p new_w = (w + p - 1) // p * p padding_left = (new_w - w) // 2 padding_right = new_w - w - padding_left padding_top = (new_h - h) // 2 padding_bottom = new_h - h - padding_top x_padded = F.pad( x, (padding_left, padding_right, padding_top, padding_bottom), mode="constant", value=0, ) with torch.no_grad(): start = time.time() out_enc = model.compress(x_padded) enc_time = time.time() - start start = time.time() out_dec = model.decompress(out_enc["strings"], out_enc["shape"]) dec_time = time.time() - start out_dec["x_hat"] = F.pad( out_dec["x_hat"], (-padding_left, -padding_right, -padding_top, -padding_bottom)) num_pixels = x.size(0) * x.size(2) * x.size(3) bpp = sum(len(s[0]) for s in out_enc["strings"]) * 8.0 / num_pixels return { "psnr": psnr(x, out_dec["x_hat"]), "ms-ssim": ms_ssim(x, out_dec["x_hat"], data_range=1.0).item(), "bpp": bpp, "encoding_time": enc_time, "decoding_time": dec_time, }
def inference(model, x): x = x.unsqueeze(0) h, w = x.size(2), x.size(3) p = 64 # maximum 6 strides of 2 new_h = (h + p - 1) // p * p new_w = (w + p - 1) // p * p padding_left = (new_w - w) // 2 padding_right = new_w - w - padding_left padding_top = (new_h - h) // 2 padding_bottom = new_h - h - padding_top x_padded = F.pad( x, (padding_left, padding_right, padding_top, padding_bottom), mode='constant', value=0) with torch.no_grad(): start = time.time() out_enc = model.compress(x_padded) enc_time = time.time() - start start = time.time() out_dec = model.decompress(out_enc['strings'], out_enc['shape']) dec_time = time.time() - start out_dec['x_hat'] = F.pad( out_dec['x_hat'], (-padding_left, -padding_right, -padding_top, -padding_bottom)) num_pixels = x.size(0) * x.size(2) * x.size(3) bpp = sum(len(s[0]) for s in out_enc['strings']) * 8. / num_pixels return { 'psnr': psnr(x, out_dec['x_hat']), 'msssim': ms_ssim(x, out_dec['x_hat'], data_range=1.).item(), 'bpp': bpp, 'encoding_time': enc_time, 'decoding_time': dec_time, }
def process(self, inputs, outputs): """ Args: inputs: the inputs to a CFPN model. input dicts must contain an 'image' key outputs: the outputs of a CFPN model. It is a list of dicts with key "instances" that contains :class:`Instances`. The :class:`Instances` object needs to have `densepose` field. """ assert (len(inputs) == 1) # reshape the input image to have a maximum length of 512 as the model preprocesses # Much shuffling of data, but also the dataset is only 24 images orig_image = inputs[0]['image'].permute(1, 2, 0) orig_image = self.transform.get_transform(orig_image).apply_image( orig_image.numpy()) orig_image = torch.tensor(orig_image).permute(2, 0, 1) orig_image = torch.unsqueeze(orig_image, dim=0).float().to( outputs[self.eval_img].get_device()) reconstruct_image = outputs[self.eval_img].float() reconstruct_image = reconstruct_image[:, :, 0:orig_image.shape[2], 0:orig_image.shape[3]] assert orig_image.shape[1] == 3, "original image must have 3 channels" assert reconstruct_image.shape[ 1] == 3, "reconstructed image must have 3 channels" with torch.no_grad(): ssim_val = ssim(reconstruct_image, orig_image, data_range=255, size_average=False) ms_ssim_val = ms_ssim(reconstruct_image, orig_image, data_range=255, size_average=False) self.ssim_vals.extend(ssim_val) self.ms_ssim_vals.extend(ms_ssim_val)
def validation(NORMAL_NUM, iteration, generator, discriminator, real_data, fake_data, is_end, AUC_LIST, END_ITER): generator.eval() discriminator.eval() # resnet.eval() y = [] score = [] normal_gsvdd = [] abnormal_gsvdd = [] normal_recon = [] abnormal_recon = [] test_root = '/home/user/Documents/Public_Dataset/MVTec_AD/{}/test/'.format( NORMAL_NUM) list_test = os.listdir(test_root) with torch.no_grad(): for i in range(len(list_test)): current_defect = list_test[i] # print(current_defect) test_path = test_root + "{}".format(current_defect) valid_dataset_loader = load_test(test_path, sample_rate=1.) for index, (images, label) in enumerate(valid_dataset_loader): # img = five_crop_ready(images) img_tmp = images.to(device) img_tmp = img_tmp.unfold(2, 32, p_stride).unfold(3, 32, p_stride) # img = extract_patch(img_tmp) img_tmp = img_tmp.permute(0, 2, 3, 1, 4, 5) img = img_tmp.reshape(-1, 3, 32, 32) latent_z = generator.encoder(img) generate_result = generator(img) weight = 0.85 ms_ssim_batch_wise = 1 - ms_ssim(img, generate_result, data_range=data_range, size_average=False, win_size=3, weights=ssim_weights) l1_batch_wise = (img - generate_result) / data_range l1_batch_wise = l1_batch_wise.mean(1).mean(1).mean(1) ms_ssim_l1 = weight * ms_ssim_batch_wise + ( 1 - weight) * l1_batch_wise diff = (latent_z - generator.c)**2 dist = -1 * torch.sum(diff, dim=1) / generator.sigma guass_svdd_loss = 1 - torch.exp(dist) anormaly_score = ( (0.5 * ms_ssim_l1 + 0.5 * guass_svdd_loss).max()).cpu().detach().numpy() score.append(float(anormaly_score)) if label[0] == "good": # if is_end: # normal_gsvdd.append(float(guass_svdd_loss.max().cpu().detach().numpy())) # normal_recon.append(float(ms_ssim_l1.max().cpu().detach().numpy())) y.append(0) else: # if is_end: # abnormal_gsvdd.append(float(guass_svdd_loss.max().cpu().detach().numpy())) # abnormal_recon.append(float(ms_ssim_l1.max().cpu().detach().numpy())) y.append(1) ################################################### # if is_end: # Helper.plot_2d_chart(x1=numpy.arange(0, len(normal_gsvdd)), y1=normal_gsvdd, label1='normal_loss', # x2=numpy.arange(len(normal_gsvdd), len(normal_gsvdd) + len(abnormal_gsvdd)), # y2=abnormal_gsvdd, label2='abnormal_loss', # title="{}: {}".format(NORMAL_NUM, "gsvdd loss"), # save_path="./plot/{}_gsvdd".format(NORMAL_NUM)) # # if True: # Helper.plot_2d_chart(x1=numpy.arange(0, len(normal_recon)), y1=normal_recon, label1='normal_loss', # x2=numpy.arange(len(normal_recon), len(normal_recon) + len(abnormal_recon)), # y2=abnormal_recon, label2='abnormals_loss', # title="{}: {}".format(NORMAL_NUM, "recon loss"), # save_path="./plot/{}_gsvdd_{}".format(NORMAL_NUM, iteration)) fpr, tpr, thresholds = metrics.roc_curve(y, score, pos_label=1) auc_result = auc(fpr, tpr) AUC_LIST.append(auc_result) # tqdm.write(str(auc_result), end='.....................') auc_file = open(ckpt_path + "/auc.txt", "a") auc_file.write('Iter {}: {}\r\n'.format(str(iteration), str(auc_result))) auc_file.close() if iteration == END_ITER - 1: auc_file = open(ckpt_path + "/auc.txt", "a") auc_file.write('BEST AUC -> {}\r\n'.format(max(AUC_LIST))) auc_file.close() return auc_result, AUC_LIST
def train(args, NORMAL_NUM, generator, discriminator, optimizer_g, optimizer_d): generator = generator.to(device) discriminator = discriminator.to(device) generator.train() discriminator.train() AUC_LIST = [] global test_auc test_auc = 0 BEST_AUC = 0 train_path = '/home/user/Documents/Public_Dataset/MVTec_AD/{}/train/good'.format( NORMAL_NUM) START_ITER = 0 train_size = len(os.listdir(train_path)) END_ITER = int((train_size / BATCH_SIZE) * MAX_EPOCH) # END_ITER = 40500 train_dataset_loader, train_size = load_train(train_path, args.sample_rate) generator.c = None generator.sigma = None generator.c = init_c(train_dataset_loader, generator) generator.sigma = init_sigma(train_dataset_loader, generator) print("gsvdd_sigma: {}".format(generator.sigma)) train_data = iter(train_dataset_loader) process = tqdm(range(START_ITER, END_ITER), desc='{AUC: }') for iteration in process: poly_lr_scheduler(optimizer_d, init_lr=LR, iter=iteration, max_iter=END_ITER) poly_lr_scheduler(optimizer_g, init_lr=LR, iter=iteration, max_iter=END_ITER) # --------------------- Loader ------------------------ batch = next(train_data, None) if batch is None: # train_dataset_loader = load_train(train_path) train_data = iter(train_dataset_loader) batch = train_data.next() batch = batch[0] # batch[1] contains labels batch_data = batch.to(device) data_tmp = batch_data.unfold(2, 32, p_stride).unfold(3, 32, p_stride) data_tmp = data_tmp.permute(0, 2, 3, 1, 4, 5) real_data = data_tmp.reshape(-1, 3, 32, 32) # --------------------- TRAIN E ------------------------ optimizer_g.zero_grad() latent_z = generator.encoder(real_data) fake_data = generator(real_data) b, _ = latent_z.shape # Reconstruction loss weight = 0.85 ms_ssim_batch_wise = 1 - ms_ssim(real_data, fake_data, data_range=data_range, size_average=True, win_size=3, weights=ssim_weights) l1_batch_wise = l1_criterion(real_data, fake_data) / data_range ms_ssim_l1 = weight * ms_ssim_batch_wise + (1 - weight) * l1_batch_wise ############ Interplote ############ e1 = torch.flip(latent_z, dims=[0]) alpha = torch.FloatTensor(b, 1).uniform_(0, 0.5).to(device) e2 = alpha * latent_z + (1 - alpha) * e1 g2 = generator.generate(e2) reg_inter = torch.mean(discriminator(g2)**2) ############ DSVDD ############ diff = (latent_z - generator.c)**2 dist = -1 * (torch.sum(diff, dim=1) / generator.sigma) svdd_loss = torch.mean(1 - torch.exp(dist)) encoder_loss = ms_ssim_l1 + svdd_loss + 0.1 * reg_inter encoder_loss.backward() optimizer_g.step() ############ Discriminator ############ optimizer_d.zero_grad() g2 = generator.generate(e2).detach() fake_data = generator(real_data).detach() d_loss_front = torch.mean((discriminator(g2) - alpha)**2) gamma = 0.2 tmp = fake_data + gamma * (real_data - fake_data) d_loss_back = torch.mean(discriminator(tmp)**2) d_loss = d_loss_front + d_loss_back d_loss.backward() optimizer_d.step() if iteration % int( (train_size / BATCH_SIZE) * 10) == 0 and iteration != 0: generator.sigma = init_sigma(train_dataset_loader, generator) generator.c = init_c(train_dataset_loader, generator) if recorder is not None: recorder.record(loss=svdd_loss, epoch=int(iteration / BATCH_SIZE), num_batches=len(train_data), n_batch=iteration, loss_name='DSVDD') recorder.record(loss=torch.mean(dist), epoch=int(iteration / BATCH_SIZE), num_batches=len(train_data), n_batch=iteration, loss_name='DIST') recorder.record(loss=ms_ssim_batch_wise, epoch=int(iteration / BATCH_SIZE), num_batches=len(train_data), n_batch=iteration, loss_name='MS-SSIM') recorder.record(loss=l1_batch_wise, epoch=int(iteration / BATCH_SIZE), num_batches=len(train_data), n_batch=iteration, loss_name='L1') if iteration % int( (train_size / BATCH_SIZE) * 10) == 0 or iteration == END_ITER - 1: is_end = True if iteration == END_ITER - 1 else False test_auc, AUC_LIST = validation(NORMAL_NUM, iteration, generator, discriminator, real_data, fake_data, is_end, AUC_LIST, END_ITER) process.set_description("{AUC: %.5f}" % test_auc) torch.save(optimizer_g.state_dict(), ckpt_path + '/optimizer/g_opt.pth') torch.save(optimizer_d.state_dict(), ckpt_path + '/optimizer/d_opt.pth')
for i in range(999999999): for k in range(batchSize): trainData[k] = torch.from_numpy(dReader.readImg()).float().cuda() if (i == 0): trainDataRgbFilter = rgbCompress.createRGGB121Filter(trainData).cuda() testDataRgbFilter = rgbCompress.createRGGB121Filter(testData).cuda() optimizer.zero_grad() decData = recNet(trainData * trainDataRgbFilter) currentMSEL = F.mse_loss(decData, trainData) currentMS_SSIM = pytorch_msssim.ms_ssim(trainData, decData, data_range=1, size_average=True) if (torch.isnan(currentMS_SSIM)): currentMS_SSIM.zero_() loss = - currentMS_SSIM.item() print('%.3f' % currentMS_SSIM.item(), '%.3f' % currentMSEL.item()) if (currentMS_SSIM.item() > -0.7): currentMSEL.backward() else: currentMS_SSIM.backward() if (i == 0): minLoss = loss else:
encNet = torch.load('./models/encNet_16.pkl', map_location='cuda:0').cuda() decNet = torch.load('./models/decNet_16.pkl', map_location='cuda:0').cuda() MSELoss = nn.MSELoss() img = Image.open('./test.bmp').convert('L') inputData = torch.from_numpy( numpy.asarray(img).astype(float).reshape([1, 1, 256, 256])).float().cuda() encData = encNet(inputData) qEncData = quantize(encData, 4) decData = decNet(qEncData / 3) MSEL = MSELoss(inputData, decData) img1 = inputData.clone() img2 = decData.clone() img1.detach_() img2.detach_() img2[img2 < 0] = 0 img2[img2 > 255] = 255 MS_SSIM = pytorch_msssim.ms_ssim(img1, img2, data_range=255, size_average=True) print('MSEL=', '%.3f' % MSEL.item(), 'MS_SSIM=', '%.3f' % MS_SSIM) img2 = img2.cpu().numpy().astype(int).reshape([256, 256]) img2 = Image.fromarray(img2.astype('uint8')).convert('L') img2.save('./output/output.bmp') img1 = img1.cpu().numpy().astype(int).reshape([256, 256]) img1 = Image.fromarray(img1.astype('uint8')).convert('L') img1.save('./output/input.bmp') numpy.save('./output/output.npy', qEncData.detach().cpu().numpy().astype(int))
def pytorch_ms_ssim(preds, target, data_range, kernel_size): return ms_ssim(preds, target, data_range=data_range, win_size=kernel_size)
def msssim(pred_batch, gt_batch): ms_ssim_loss = 1 - ms_ssim(pred_batch, gt_batch, data_range=1, win_size=7) return ms_ssim_loss
def compute_msssim(image1, image2): return ms_ssim(image1, image2, data_range=1, size_average=True)
def rec_loss(attr_images, generated_images, a): ms_ssim_loss = 1 - ms_ssim( attr_images, generated_images, data_range=1, size_average=True) l1_loss_value = l1_criterion(attr_images, generated_images) return a * ms_ssim_loss + (1 - a) * l1_loss_value
defMaxLossOfTrainData = 0 for j in range(16): # 每16批 当作一个训练单元 统计这16批数据的表现 for k in range(batchSize): trainData[k] = torch.from_numpy(dReader.readImg()).float().cuda() optimizer.zero_grad() encData = encNet(trainData) qEncData = quantize(encData) decData = decNet(qEncData) currentMSEL = MSELoss(trainData, decData) currentMS_SSIM = pytorch_msssim.ms_ssim(trainData, decData, data_range=255, size_average=True) #currentEdgeMSEL = extendMSE.EdgeMSELoss(trainData, decData) if (currentMSEL > 500): loss = currentMSEL else: loss = -currentMS_SSIM if (defMaxLossOfTrainData == 0): maxLossOfTrainData = loss maxLossTrainMSEL = currentMSEL maxLossTrainMS_SSIM = currentMS_SSIM #maxLossTrainEdgeMSEL = currentEdgeMSEL defMaxLossOfTrainData = 1
optimizer.zero_grad() trainGrayData = rgbCompress.RGB444DetachToRGGBTensor( trainData, rgbKernelList, True) encData = encNet(trainGrayData / 255) qEncData = torch.zeros_like(encData) for ii in range(encData.shape[0]): qEncData[ii] = softToHardQuantize.sthQuantize(encData[ii], C, sigma) decData = decNet(qEncData) * 255 currentMSEL = F.mse_loss(trainGrayData, decData) currentMS_SSIM = pytorch_msssim.ms_ssim(trainGrayData, decData, win_size=7, data_range=255, size_average=True) if (torch.isnan(currentMS_SSIM)): currentMS_SSIM.zero_() loss = -currentMS_SSIM.item() print('%.3f' % currentMS_SSIM.item()) if (currentMS_SSIM.item() > -0.7): currentMSEL.backward() else: currentMS_SSIM.backward() if (i == 0):
def compute_msssim(a, b): return ms_ssim(a, b, data_range=1.).item()