Ejemplo n.º 1
0
    def forward(self, output, target1,target2,need_bpp=False):
        N, _, H, W = target1.size()
        out = {}
        num_pixels = N * H * W

        # 计算误差
        # out['bpp_loss'] = sum(
        #     (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
        #     for likelihoods in output['likelihoods'].values())
        out['mse_loss'] = self.mse(output['x1_hat'], target1) + self.mse(output['x2_hat'], target2)        #end to end
        if need_bpp:
            out['bpp1'] = (torch.log(output['likelihoods']['y1']).sum() / (-math.log(2) * num_pixels)) + (
                    torch.log(output['likelihoods']['z1']).sum() / (-math.log(2) * num_pixels))
            out['bpp2'] = (torch.log(output['likelihoods']['y2']).sum() / (-math.log(2) * num_pixels)) + (
                    torch.log(output['likelihoods']['z2']).sum() / (-math.log(2) * num_pixels))

        out['loss'] = self.lmbda * 255**2 * out['mse_loss'] #+ out['bpp_loss']

        out['ms_ssim1'] = ms_ssim(output['x1_hat'], target1, data_range=1, size_average=False)[0]  # (N,)
        out['ms_ssim2'] = ms_ssim(output['x2_hat'], target2, data_range=1, size_average=False)[0]
        out['ms_ssim'] = (out['ms_ssim1']+out['ms_ssim2'])/2
        out['psnr1'] = mse2psnr(self.mse(output['x1_hat'], target1))
        out['psnr2'] = mse2psnr(self.mse(output['x2_hat'], target2))

        return out
Ejemplo n.º 2
0
    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self.model(inputs)

        loss = criterion(outputs, targets)

        if self.current_epoch is not None and batch_idx + 1 == len(val_loader):
            baseline = F.interpolate(inputs,
                                     scale_factor=scale_factor,
                                     mode='bicubic')
            grid = make_grid(torch.cat([targets, baseline, outputs], dim=-1),
                             nrow=2).clamp(0, 1)
            val_writer.add_image('Images From Last Batch', grid,
                                 self.current_epoch)

        ssim = pytorch_msssim.ms_ssim(outputs, targets, data_range=1)
        psnr = utils.psnr(outputs, targets)

        size = len(targets)
        self.val_loss.update(loss.item(), size)
        self.val_psnr.update(psnr.item(), size)
        self.val_ssim.update(ssim.item(), size)

        return {
            'val_loss': f'{self.val_loss.compute():.4g}',
            'psnr': f'{self.val_psnr.compute():.2f}',
            'ssim': f'{self.val_ssim.compute():.4f}'
        }
Ejemplo n.º 3
0
    def optimize_parameters(self, step):
        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)

        l_g_total = 0
        l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
        l_g_total += l_pix
        if self.cri_CX:
            real_fea = self.netF(self.real_H)
            fake_fea = self.netF(self.fake_H)
            l_CX = self.l_CX_w * self.cri_CX(real_fea, fake_fea)
            l_g_total += l_CX
        if self.cri_ssim:
            if self.cri_ssim == 'ssim':
                ssim_val = ssim(self.fake_H, self.real_H, win_size=self.ssim_window, data_range=1.0, size_average=True)
            elif self.cri_ssim == 'ms-ssim':
                weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363]).to(self.fake_H.device, dtype=self.fake_H.dtype)
                ssim_val = ms_ssim(self.fake_H, self.real_H, win_size=self.ssim_window, data_range=1.0, size_average=True, weights=weights)
            l_ssim = self.l_ssim_w * (1 - ssim_val)
            l_g_total += l_ssim

        l_g_total.backward()
        self.optimizer_G.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()
        if self.cri_CX:
            self.log_dict['l_CX'] = l_CX.item()
        if self.cri_ssim:
            self.log_dict['l_ssim'] = l_ssim.item()
 def forward(self, reconst, original):
     MSSSIM = pytorch_msssim.ms_ssim(original,
                                     reconst,
                                     nonnegative_ssim=True).to(device)
     MSE = (torch.nn.functional.mse_loss(original, reconst)).to(device)
     loss = MSE - MSSSIM + 1
     return loss, MSSSIM
Ejemplo n.º 5
0
    def train_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self.model(inputs)

        loss = criterion(outputs, targets)

        with torch.no_grad():
            if batch_idx + 1 == len(train_loader):
                indices = torch.randperm(inputs.shape[0])[:2]
                baseline = F.interpolate(inputs[indices],
                                         scale_factor=scale_factor,
                                         mode='bicubic')
                example = targets[indices], baseline, outputs[indices]
                grid = make_grid(torch.cat(example, dim=-1)).clamp(0, 1)
                train_writer.add_image('Random Images From Last Batch', grid,
                                       self.current_epoch)
            ssim = pytorch_msssim.ms_ssim(outputs, targets, data_range=1)
            psnr = utils.psnr(outputs, targets)

            size = len(targets)
            self.loss.update(loss.item(), size)
            self.psnr.update(psnr.item(), size)
            self.ssim.update(ssim.item(), size)

        return loss, {
            'loss': f'{self.loss.compute():.4g}',
            'psnr': f'{self.psnr.compute():.2f}',
            'ssim': f'{self.ssim.compute():.4f}'
        }
    def inference(self, label, geo, image=None):

        # Encode Inputs
        image = Variable(image) if image is not None else None
        geometry = Variable(geo) if geo is not None else None
        concat_input, real_image = self.encode_input(label,
                                                     image,
                                                     geometry=geometry,
                                                     infer=True)

        print(input_label.data.device)
        # Fake Generation

        if torch.__version__.startswith('0.4'):
            with torch.no_grad():
                fake_image = self.netG.forward(concat_input)
        else:
            fake_image = self.netG.forward(concat_input)

        metrics = {}
        if image is not None:  # metrics
            GT_image = Variable((real_image + 1.) / 2.)
            fake_normalized = (fake_image + 1.) / 2.
            metrics['ssim'] = ssim(fake_normalized,
                                   GT_image,
                                   data_range=1,
                                   size_average=False,
                                   nonnegative_ssim=True)
            metrics['ms_ssim'] = ms_ssim(fake_normalized,
                                         GT_image,
                                         data_range=1,
                                         size_average=False)

        return fake_image, metrics
Ejemplo n.º 7
0
def compute_metrics_for_frame(
    org_frame: Frame,
    dec_frame: Frame,
    bitdepth: int = 8,
) -> Dict[str, Any]:
    org_frame = tuple(p.unsqueeze(0).unsqueeze(0) for p in org_frame)  # type: ignore
    dec_frame = tuple(p.unsqueeze(0).unsqueeze(0) for p in dec_frame)  # type:ignore
    out: Dict[str, Any] = {}

    max_val = 2**bitdepth - 1

    # YCbCr metrics
    for i, component in enumerate("yuv"):
        out[f"mse-{component}"] = (org_frame[i] - dec_frame[i]).pow(2).mean()

    org_rgb = ycbcr2rgb(yuv_420_to_444(org_frame, mode="bicubic").true_divide(max_val))  # type: ignore
    dec_rgb = ycbcr2rgb(yuv_420_to_444(dec_frame, mode="bicubic").true_divide(max_val))  # type: ignore

    org_rgb = (org_rgb * max_val).clamp(0, max_val).round()
    dec_rgb = (dec_rgb * max_val).clamp(0, max_val).round()
    mse_rgb = (org_rgb - dec_rgb).pow(2).mean()

    ms_ssim_rgb = ms_ssim(org_rgb, dec_rgb, data_range=max_val)
    out.update({"ms-ssim-rgb": ms_ssim_rgb, "mse-rgb": mse_rgb})
    return out
Ejemplo n.º 8
0
def compute_metrics_for_frame(
    org_frame: Frame,
    rec_frame: Tensor,
    device: str = "cpu",
    max_val: int = 255,
) -> Dict[str, Any]:
    out: Dict[str, Any] = {}

    # YCbCr metrics
    org_yuv = to_tensors(org_frame, device=str(device), max_value=max_val)
    org_yuv = tuple(p.unsqueeze(0).unsqueeze(0) for p in org_yuv)  # type: ignore
    rec_yuv = convert_rgb_to_yuv420(rec_frame)
    for i, component in enumerate("yuv"):
        org = (org_yuv[i] * max_val).clamp(0, max_val).round()
        rec = (rec_yuv[i] * max_val).clamp(0, max_val).round()
        out[f"psnr-{component}"] = 20 * np.log10(max_val) - 10 * torch.log10(
            (org - rec).pow(2).mean()
        )
    out["psnr-yuv"] = (4 * out["psnr-y"] + out["psnr-u"] + out["psnr-v"]) / 6

    # RGB metrics
    org_rgb = convert_yuv420_to_rgb(
        org_frame, device, max_val
    )  # ycbcr2rgb(yuv_420_to_444(org_frame, mode="bicubic"))  # type: ignore
    org_rgb = (org_rgb * max_val).clamp(0, max_val).round()
    rec_frame = (rec_frame * max_val).clamp(0, max_val).round()
    mse_rgb = (org_rgb - rec_frame).pow(2).mean()
    psnr_rgb = 20 * np.log10(max_val) - 10 * torch.log10(mse_rgb)

    ms_ssim_rgb = ms_ssim(org_rgb, rec_frame, data_range=max_val)
    out.update({"ms-ssim-rgb": ms_ssim_rgb, "mse-rgb": mse_rgb, "psnr-rgb": psnr_rgb})

    return out
Ejemplo n.º 9
0
def calculate_mssim(minibatch, reconstr_image, size_average=True):
    """ compute the ms-sim between an image and its reconstruction

    :param minibatch: the input minibatch
    :param reconstr_image: the reconstructed image
    :returns: the msssim score
    :rtype: float

    """
    if minibatch.dim() == 5 and reconstr_image.dim(
    ) == 5:  # special case where we have temporal dim
        msssim = torch.cat([
            calculate_mssim(minibatch[:, i, ::],
                            reconstr_image[:, i, ::],
                            size_average=size_average).unsqueeze(-1)
            for i in range(minibatch.shape[1])
        ], -1)
        return torch.mean(msssim, -1)

    smallest_dim = min(minibatch.shape[-1], minibatch.shape[-2])
    if minibatch.dtype != reconstr_image.dtype:
        minibatch = minibatch.type(reconstr_image.dtype)

    if smallest_dim < 160:  # Limitation of ms-ssim library due to 4x downsample
        return 1 - ssim(X=minibatch,
                        Y=reconstr_image,
                        data_range=1,
                        size_average=size_average,
                        nonnegative_ssim=True)

    return 1 - ms_ssim(
        X=minibatch, Y=reconstr_image, data_range=1, size_average=size_average)
Ejemplo n.º 10
0
def calc_baseline(loader,
                  scale_factor=2,
                  use_cuda=torch.cuda.is_available(),
                  method='bicubic',
                  criterion=nn.L1Loss()):
    loss, psnr_, ssim = WeightedAveragedMetric(), WeightedAveragedMetric(
    ), WeightedAveragedMetric()
    with tqdm(loader, desc='Baseline') as bar:
        for inputs, targets in bar:
            if use_cuda:
                inputs = inputs.cuda()
                targets = targets.cuda()
            resized = F.interpolate(inputs,
                                    scale_factor=scale_factor,
                                    mode=method)

            size = len(targets)
            loss.update(criterion(resized, targets).item(), size)
            psnr_.update(psnr(resized, targets).item(), size)
            ssim.update(
                pytorch_msssim.ms_ssim(resized, targets, data_range=1).item(),
                size)

            bar.set_postfix({
                'loss': f'{loss.compute():.4g}',
                'psnr': f'{psnr_.compute():.2f}',
                'ssim': f'{ssim.compute():.4f}'
            })
Ejemplo n.º 11
0
def norm(img1, img2):
    criterion = nn.MSELoss()
    loss_r = criterion(img1,img2)
    print("loss_r:" + str(loss_r.item()))
    SSIM_h=1-pytorch_msssim.ssim(img1,img2)
    print("SSIM_h:"+str(SSIM_h.item()))
    MS_SSIM_h=1-pytorch_msssim.ms_ssim(img1,img2)
    print("MS_SSIM_h:"+str(MS_SSIM_h.item()))
Ejemplo n.º 12
0
def recons_loss(recon_x, x):
    # msssim 多尺度结构相似损失函数:基于多层(图片按照一定规则,由大到小缩放)的SSIM损失函数,相当于考虑了分辨率
    msssim = ((1-pytorch_msssim.ms_ssim(x,recon_x)))/2   #一种优化过的ssim算法
    #作者结合神经科学的研究,认为我们人类衡量两幅图的距离时,
    # 更偏重于两图的结构相似性,而不是逐像素计算两图的差异。因此作者提出了基于 structural similarity 的度量,声称其比 MSE 更能反映人类视觉系统对两幅图相似性的判断。
    f1 =  F.l1_loss(recon_x, x)  #l1损失:基于逐像素比较差异,然后取绝对值  l2损失:逐像素比较差异 取平方
    #L2损失函数会放大最大误差和最小误差之间的差距(比如2*2 和0.1*0.1),另外L2损失函数对异常点也比较敏感
    #论文证明 MS-SSIM+L1损失函数是最好的
    #作者这样组合的原因是,MS-SSIM容易导致亮度的改变和颜色的偏差,但它能保留高频信息(图像的边缘和细节),
    # 而L1损失函数能较好的保持亮度和颜色不变化。公式中α为0.84,是作者试验出来的,而G为高斯分布参数(MS-SSIM里面也要用到这个) Lmix = α*Lmsssim + (1-α)*G*L1  G是高斯分布函数
    return msssim+f1
Ejemplo n.º 13
0
def evaluate_ms_ssim(conf):
    # create DataLoader
    loader = create_databunch(conf["data_path"], conf["fourier"],
                              conf["source_list"], conf["batch_size"])
    model_path = conf["model_path"]
    out_path = Path(model_path).parent / "evaluation"
    out_path.mkdir(parents=True, exist_ok=True)

    img_size = loader.dataset[0][0][0].shape[-1]
    model = load_pretrained_model(conf["arch_name"], conf["model_path"],
                                  img_size)
    if conf["model_path_2"] != "none":
        model_2 = load_pretrained_model(conf["arch_name_2"],
                                        conf["model_path_2"], img_size)

    vals = []

    if img_size < 160:
        click.echo("\nThis is only a placeholder!\
                Images too small for meaningful ms ssim calculations.\n")

    # iterate trough DataLoader
    for i, (img_test, img_true) in enumerate(tqdm(loader)):

        pred = eval_model(img_test, model)
        if conf["model_path_2"] != "none":
            pred_2 = eval_model(img_test, model_2)
            pred = torch.cat((pred, pred_2), dim=1)

        ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"])
        ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"])

        if img_size < 160:
            ifft_truth = pad_unsqueeze(torch.tensor(ifft_truth))
            ifft_pred = pad_unsqueeze(torch.tensor(ifft_pred))

        vals.extend([
            ms_ssim(pred.unsqueeze(0),
                    truth.unsqueeze(0),
                    data_range=truth.max())
            for pred, truth in zip(ifft_pred, ifft_truth)
        ])

    click.echo("\nCreating ms-ssim histogram.\n")
    vals = torch.tensor(vals)
    histogram_ms_ssim(
        vals,
        out_path,
        plot_format=conf["format"],
    )

    click.echo(f"\nThe mean ms-ssim value is {vals.mean()}.\n")
Ejemplo n.º 14
0
def run_tests(
  num_samples=32,
):
  l1_losses = []
  l2_losses = []
  psnr_losses = []
  gots = []
  num=100
  with torch.no_grad():
    for i, (c2w, lp) in enumerate(zip(tqdm(cam_to_worlds[:num]), light_locs)):
      exp = exp_imgs[i].clamp(min=0, max=1)
      cameras = NeRFCamera(cam_to_world=c2w.unsqueeze(0), focal=focal, device=device)
      lights = PointLights(intensity=[1,1,1], location=lp[None,...], scale=300, device=device)
      got = None
      for _ in range(num_samples):
        sample = pt.pathtrace(
          density_field,
          size=SIZE, chunk_size=min(SIZE, 100), bundle_size=1, bsdf=learned_bsdf,
          integrator=integrator,
          # 0 is for comparison, 1 is for display
          background=0,
          cameras=cameras, lights=lights, device=device, silent=True,
          w_isect=True,
        )[0]
        if got is None: got = sample
        else: got += sample
      got /= num_samples
      got = got.clamp(min=0, max=1)
      save_plot(
        exp ** (1/2.2), got ** (1/2.2),
        f"outputs/path_nerv_armadillo_{i:03}.png",
      )
      l1_losses.append(F.l1_loss(exp,got).item())
      mse = F.mse_loss(exp,got)
      l2_losses.append(mse.item())
      psnr = mse2psnr(mse).item()
      psnr_losses.append(psnr)
      gots.append(got)
  print("Avg l1 loss", np.mean(l1_losses))
  print("Avg l2 loss", np.mean(l2_losses))
  print("Avg PSNR loss", np.mean(psnr_losses))
  with torch.no_grad():
    gots = torch.stack(gots, dim=0).permute(0, 3, 1, 2)
    exps = torch.stack(exp_imgs[:num], dim=0).permute(0, 3, 1, 2)
    # takes a lot of memory
    torch.cuda.empty_cache()
    ssim_loss = ms_ssim(gots, exps, data_range=1, size_average=True).item()
    print("MS-SSIM loss", ssim_loss)

    ssim_loss = ssim(gots, exps, data_range=1, size_average=True).item()
    print("SSIM loss", ssim_loss)
  return
Ejemplo n.º 15
0
def compute_metrics(a: Union[np.array, Image.Image],
                    b: Union[np.array, Image.Image],
                    max_val: float = 255.) -> Tuple[float, float]:
    """Returns PSNR and MS-SSIM between images `a` and `b`. """
    if isinstance(a, Image.Image):
        a = np.asarray(a)
    if isinstance(b, Image.Image):
        b = np.asarray(b)

    a = torch.from_numpy(a.copy()).float().unsqueeze(0)
    if a.size(3) == 3:
        a = a.permute(0, 3, 1, 2)
    b = torch.from_numpy(b.copy()).float().unsqueeze(0)
    if b.size(3) == 3:
        b = b.permute(0, 3, 1, 2)

    mse = torch.mean((a - b)**2).item()
    p = 20 * np.log10(max_val) - 10 * np.log10(mse)
    m = ms_ssim(a, b, data_range=max_val).item()
    return p, m
Ejemplo n.º 16
0
def calc_metrics(img_gt, img_out, hfen):
    ''' compute vif, mssim, ssim, and psnr of img_out using img_gt as ground-truth reference 
        note: for msssim, made a slight mod to source code in line 200 of 
              /home/vanveen/heck/lib/python3.8/site-packages/pytorch_msssim/ssim.py 
              to compute msssim over images w smallest dim >=160 '''

    img_gt, img_out = norm_imgs(img_gt, img_out)
    img_gt, img_out = np.array(img_gt), np.array(img_out)

    vif_ = vifp_mscale(img_gt, img_out, sigma_nsq=img_out.mean())
    ssim_ = ssim(img_gt, img_out)
    psnr_ = psnr(img_gt, img_out)

    img_out_ = torch.from_numpy(np.array([[img_out]]))
    img_gt_ = torch.from_numpy(np.array([[img_gt]]))
    msssim_ = float(ms_ssim(img_out_, img_gt_, data_range=img_gt_.max()))

    hfen_ = 10000 * float(hfen(img_gt_.float(), img_out_.float()))

    return vif_, msssim_, ssim_, psnr_, hfen_
Ejemplo n.º 17
0
def inference(model, x):
    x = x.unsqueeze(0)

    h, w = x.size(2), x.size(3)
    p = 64  # maximum 6 strides of 2
    new_h = (h + p - 1) // p * p
    new_w = (w + p - 1) // p * p
    padding_left = (new_w - w) // 2
    padding_right = new_w - w - padding_left
    padding_top = (new_h - h) // 2
    padding_bottom = new_h - h - padding_top
    x_padded = F.pad(
        x,
        (padding_left, padding_right, padding_top, padding_bottom),
        mode="constant",
        value=0,
    )

    with torch.no_grad():
        start = time.time()
        out_enc = model.compress(x_padded)
        enc_time = time.time() - start

        start = time.time()
        out_dec = model.decompress(out_enc["strings"], out_enc["shape"])
        dec_time = time.time() - start

    out_dec["x_hat"] = F.pad(
        out_dec["x_hat"],
        (-padding_left, -padding_right, -padding_top, -padding_bottom))

    num_pixels = x.size(0) * x.size(2) * x.size(3)
    bpp = sum(len(s[0]) for s in out_enc["strings"]) * 8.0 / num_pixels

    return {
        "psnr": psnr(x, out_dec["x_hat"]),
        "ms-ssim": ms_ssim(x, out_dec["x_hat"], data_range=1.0).item(),
        "bpp": bpp,
        "encoding_time": enc_time,
        "decoding_time": dec_time,
    }
Ejemplo n.º 18
0
def inference(model, x):
    x = x.unsqueeze(0)

    h, w = x.size(2), x.size(3)
    p = 64  # maximum 6 strides of 2
    new_h = (h + p - 1) // p * p
    new_w = (w + p - 1) // p * p
    padding_left = (new_w - w) // 2
    padding_right = new_w - w - padding_left
    padding_top = (new_h - h) // 2
    padding_bottom = new_h - h - padding_top
    x_padded = F.pad(
        x, (padding_left, padding_right, padding_top, padding_bottom),
        mode='constant',
        value=0)

    with torch.no_grad():
        start = time.time()
        out_enc = model.compress(x_padded)
        enc_time = time.time() - start

        start = time.time()
        out_dec = model.decompress(out_enc['strings'], out_enc['shape'])
        dec_time = time.time() - start

    out_dec['x_hat'] = F.pad(
        out_dec['x_hat'],
        (-padding_left, -padding_right, -padding_top, -padding_bottom))

    num_pixels = x.size(0) * x.size(2) * x.size(3)
    bpp = sum(len(s[0]) for s in out_enc['strings']) * 8. / num_pixels

    return {
        'psnr': psnr(x, out_dec['x_hat']),
        'msssim': ms_ssim(x, out_dec['x_hat'], data_range=1.).item(),
        'bpp': bpp,
        'encoding_time': enc_time,
        'decoding_time': dec_time,
    }
Ejemplo n.º 19
0
    def process(self, inputs, outputs):
        """
        Args:
            inputs: the inputs to a CFPN model. input dicts must contain an 'image' key

            outputs: the outputs of a CFPN model. It is a list of dicts with key
                "instances" that contains :class:`Instances`.
                The :class:`Instances` object needs to have `densepose` field.
        """
        assert (len(inputs) == 1)
        # reshape the input image to have a maximum length of 512 as the model preprocesses
        # Much shuffling of data, but also the dataset is only 24 images
        orig_image = inputs[0]['image'].permute(1, 2, 0)
        orig_image = self.transform.get_transform(orig_image).apply_image(
            orig_image.numpy())
        orig_image = torch.tensor(orig_image).permute(2, 0, 1)
        orig_image = torch.unsqueeze(orig_image, dim=0).float().to(
            outputs[self.eval_img].get_device())

        reconstruct_image = outputs[self.eval_img].float()
        reconstruct_image = reconstruct_image[:, :, 0:orig_image.shape[2],
                                              0:orig_image.shape[3]]
        assert orig_image.shape[1] == 3, "original image must have 3 channels"
        assert reconstruct_image.shape[
            1] == 3, "reconstructed image must have 3 channels"
        with torch.no_grad():
            ssim_val = ssim(reconstruct_image,
                            orig_image,
                            data_range=255,
                            size_average=False)
            ms_ssim_val = ms_ssim(reconstruct_image,
                                  orig_image,
                                  data_range=255,
                                  size_average=False)
            self.ssim_vals.extend(ssim_val)
            self.ms_ssim_vals.extend(ms_ssim_val)
Ejemplo n.º 20
0
def validation(NORMAL_NUM, iteration, generator, discriminator, real_data,
               fake_data, is_end, AUC_LIST, END_ITER):
    generator.eval()
    discriminator.eval()
    # resnet.eval()
    y = []
    score = []
    normal_gsvdd = []
    abnormal_gsvdd = []
    normal_recon = []
    abnormal_recon = []
    test_root = '/home/user/Documents/Public_Dataset/MVTec_AD/{}/test/'.format(
        NORMAL_NUM)
    list_test = os.listdir(test_root)

    with torch.no_grad():
        for i in range(len(list_test)):
            current_defect = list_test[i]
            # print(current_defect)
            test_path = test_root + "{}".format(current_defect)
            valid_dataset_loader = load_test(test_path, sample_rate=1.)

            for index, (images, label) in enumerate(valid_dataset_loader):
                # img = five_crop_ready(images)
                img_tmp = images.to(device)
                img_tmp = img_tmp.unfold(2, 32,
                                         p_stride).unfold(3, 32, p_stride)
                # img     = extract_patch(img_tmp)
                img_tmp = img_tmp.permute(0, 2, 3, 1, 4, 5)
                img = img_tmp.reshape(-1, 3, 32, 32)

                latent_z = generator.encoder(img)
                generate_result = generator(img)

                weight = 0.85

                ms_ssim_batch_wise = 1 - ms_ssim(img,
                                                 generate_result,
                                                 data_range=data_range,
                                                 size_average=False,
                                                 win_size=3,
                                                 weights=ssim_weights)
                l1_batch_wise = (img - generate_result) / data_range
                l1_batch_wise = l1_batch_wise.mean(1).mean(1).mean(1)
                ms_ssim_l1 = weight * ms_ssim_batch_wise + (
                    1 - weight) * l1_batch_wise

                diff = (latent_z - generator.c)**2
                dist = -1 * torch.sum(diff, dim=1) / generator.sigma
                guass_svdd_loss = 1 - torch.exp(dist)

                anormaly_score = (
                    (0.5 * ms_ssim_l1 +
                     0.5 * guass_svdd_loss).max()).cpu().detach().numpy()
                score.append(float(anormaly_score))

                if label[0] == "good":
                    # if is_end:
                    # normal_gsvdd.append(float(guass_svdd_loss.max().cpu().detach().numpy()))
                    # normal_recon.append(float(ms_ssim_l1.max().cpu().detach().numpy()))
                    y.append(0)
                else:
                    # if is_end:
                    # abnormal_gsvdd.append(float(guass_svdd_loss.max().cpu().detach().numpy()))
                    # abnormal_recon.append(float(ms_ssim_l1.max().cpu().detach().numpy()))
                    y.append(1)

            ###################################################
    # if is_end:
    #     Helper.plot_2d_chart(x1=numpy.arange(0, len(normal_gsvdd)), y1=normal_gsvdd, label1='normal_loss',
    #                          x2=numpy.arange(len(normal_gsvdd), len(normal_gsvdd) + len(abnormal_gsvdd)),
    #                          y2=abnormal_gsvdd, label2='abnormal_loss',
    #                          title="{}: {}".format(NORMAL_NUM, "gsvdd loss"),
    #                          save_path="./plot/{}_gsvdd".format(NORMAL_NUM))
    # # if True:
    #     Helper.plot_2d_chart(x1=numpy.arange(0, len(normal_recon)), y1=normal_recon, label1='normal_loss',
    #                          x2=numpy.arange(len(normal_recon), len(normal_recon) + len(abnormal_recon)),
    #                          y2=abnormal_recon, label2='abnormals_loss',
    #                          title="{}: {}".format(NORMAL_NUM, "recon loss"),
    #                          save_path="./plot/{}_gsvdd_{}".format(NORMAL_NUM, iteration))
    fpr, tpr, thresholds = metrics.roc_curve(y, score, pos_label=1)
    auc_result = auc(fpr, tpr)
    AUC_LIST.append(auc_result)
    # tqdm.write(str(auc_result), end='.....................')
    auc_file = open(ckpt_path + "/auc.txt", "a")
    auc_file.write('Iter {}:            {}\r\n'.format(str(iteration),
                                                       str(auc_result)))
    auc_file.close()
    if iteration == END_ITER - 1:
        auc_file = open(ckpt_path + "/auc.txt", "a")
        auc_file.write('BEST AUC -> {}\r\n'.format(max(AUC_LIST)))
        auc_file.close()
    return auc_result, AUC_LIST
Ejemplo n.º 21
0
def train(args, NORMAL_NUM, generator, discriminator, optimizer_g,
          optimizer_d):
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    generator.train()
    discriminator.train()

    AUC_LIST = []
    global test_auc
    test_auc = 0
    BEST_AUC = 0
    train_path = '/home/user/Documents/Public_Dataset/MVTec_AD/{}/train/good'.format(
        NORMAL_NUM)
    START_ITER = 0
    train_size = len(os.listdir(train_path))
    END_ITER = int((train_size / BATCH_SIZE) * MAX_EPOCH)
    # END_ITER = 40500

    train_dataset_loader, train_size = load_train(train_path, args.sample_rate)

    generator.c = None
    generator.sigma = None
    generator.c = init_c(train_dataset_loader, generator)
    generator.sigma = init_sigma(train_dataset_loader, generator)
    print("gsvdd_sigma: {}".format(generator.sigma))

    train_data = iter(train_dataset_loader)
    process = tqdm(range(START_ITER, END_ITER), desc='{AUC: }')

    for iteration in process:
        poly_lr_scheduler(optimizer_d,
                          init_lr=LR,
                          iter=iteration,
                          max_iter=END_ITER)
        poly_lr_scheduler(optimizer_g,
                          init_lr=LR,
                          iter=iteration,
                          max_iter=END_ITER)
        # --------------------- Loader ------------------------
        batch = next(train_data, None)
        if batch is None:
            # train_dataset_loader = load_train(train_path)
            train_data = iter(train_dataset_loader)
            batch = train_data.next()
        batch = batch[0]  # batch[1] contains labels
        batch_data = batch.to(device)
        data_tmp = batch_data.unfold(2, 32, p_stride).unfold(3, 32, p_stride)
        data_tmp = data_tmp.permute(0, 2, 3, 1, 4, 5)
        real_data = data_tmp.reshape(-1, 3, 32, 32)

        # --------------------- TRAIN E ------------------------
        optimizer_g.zero_grad()
        latent_z = generator.encoder(real_data)
        fake_data = generator(real_data)

        b, _ = latent_z.shape
        # Reconstruction loss
        weight = 0.85
        ms_ssim_batch_wise = 1 - ms_ssim(real_data,
                                         fake_data,
                                         data_range=data_range,
                                         size_average=True,
                                         win_size=3,
                                         weights=ssim_weights)
        l1_batch_wise = l1_criterion(real_data, fake_data) / data_range
        ms_ssim_l1 = weight * ms_ssim_batch_wise + (1 - weight) * l1_batch_wise

        ############ Interplote ############
        e1 = torch.flip(latent_z, dims=[0])
        alpha = torch.FloatTensor(b, 1).uniform_(0, 0.5).to(device)
        e2 = alpha * latent_z + (1 - alpha) * e1
        g2 = generator.generate(e2)
        reg_inter = torch.mean(discriminator(g2)**2)

        ############ DSVDD ############
        diff = (latent_z - generator.c)**2
        dist = -1 * (torch.sum(diff, dim=1) / generator.sigma)
        svdd_loss = torch.mean(1 - torch.exp(dist))

        encoder_loss = ms_ssim_l1 + svdd_loss + 0.1 * reg_inter

        encoder_loss.backward()
        optimizer_g.step()

        ############ Discriminator ############
        optimizer_d.zero_grad()
        g2 = generator.generate(e2).detach()
        fake_data = generator(real_data).detach()
        d_loss_front = torch.mean((discriminator(g2) - alpha)**2)
        gamma = 0.2
        tmp = fake_data + gamma * (real_data - fake_data)
        d_loss_back = torch.mean(discriminator(tmp)**2)
        d_loss = d_loss_front + d_loss_back
        d_loss.backward()
        optimizer_d.step()

        if iteration % int(
            (train_size / BATCH_SIZE) * 10) == 0 and iteration != 0:
            generator.sigma = init_sigma(train_dataset_loader, generator)
            generator.c = init_c(train_dataset_loader, generator)

        if recorder is not None:
            recorder.record(loss=svdd_loss,
                            epoch=int(iteration / BATCH_SIZE),
                            num_batches=len(train_data),
                            n_batch=iteration,
                            loss_name='DSVDD')

            recorder.record(loss=torch.mean(dist),
                            epoch=int(iteration / BATCH_SIZE),
                            num_batches=len(train_data),
                            n_batch=iteration,
                            loss_name='DIST')

            recorder.record(loss=ms_ssim_batch_wise,
                            epoch=int(iteration / BATCH_SIZE),
                            num_batches=len(train_data),
                            n_batch=iteration,
                            loss_name='MS-SSIM')

            recorder.record(loss=l1_batch_wise,
                            epoch=int(iteration / BATCH_SIZE),
                            num_batches=len(train_data),
                            n_batch=iteration,
                            loss_name='L1')

        if iteration % int(
            (train_size / BATCH_SIZE) * 10) == 0 or iteration == END_ITER - 1:
            is_end = True if iteration == END_ITER - 1 else False
            test_auc, AUC_LIST = validation(NORMAL_NUM, iteration, generator,
                                            discriminator, real_data,
                                            fake_data, is_end, AUC_LIST,
                                            END_ITER)
            process.set_description("{AUC: %.5f}" % test_auc)
            torch.save(optimizer_g.state_dict(),
                       ckpt_path + '/optimizer/g_opt.pth')
            torch.save(optimizer_d.state_dict(),
                       ckpt_path + '/optimizer/d_opt.pth')
Ejemplo n.º 22
0
for i in range(999999999):

    for k in range(batchSize):
        trainData[k] = torch.from_numpy(dReader.readImg()).float().cuda()

    if (i == 0):
        trainDataRgbFilter = rgbCompress.createRGGB121Filter(trainData).cuda()
        testDataRgbFilter = rgbCompress.createRGGB121Filter(testData).cuda()


    optimizer.zero_grad()
    decData = recNet(trainData * trainDataRgbFilter)

    currentMSEL = F.mse_loss(decData, trainData)

    currentMS_SSIM = pytorch_msssim.ms_ssim(trainData, decData, data_range=1, size_average=True)

    if (torch.isnan(currentMS_SSIM)):
        currentMS_SSIM.zero_()
    loss = - currentMS_SSIM.item()

    print('%.3f' % currentMS_SSIM.item(), '%.3f' % currentMSEL.item())

    if (currentMS_SSIM.item() > -0.7):
        currentMSEL.backward()
    else:
        currentMS_SSIM.backward()

    if (i == 0):
        minLoss = loss
    else:
Ejemplo n.º 23
0
encNet = torch.load('./models/encNet_16.pkl', map_location='cuda:0').cuda()
decNet = torch.load('./models/decNet_16.pkl', map_location='cuda:0').cuda()

MSELoss = nn.MSELoss()

img = Image.open('./test.bmp').convert('L')
inputData = torch.from_numpy(
    numpy.asarray(img).astype(float).reshape([1, 1, 256, 256])).float().cuda()

encData = encNet(inputData)
qEncData = quantize(encData, 4)
decData = decNet(qEncData / 3)

MSEL = MSELoss(inputData, decData)
img1 = inputData.clone()
img2 = decData.clone()
img1.detach_()
img2.detach_()
img2[img2 < 0] = 0
img2[img2 > 255] = 255
MS_SSIM = pytorch_msssim.ms_ssim(img1, img2, data_range=255, size_average=True)
print('MSEL=', '%.3f' % MSEL.item(), 'MS_SSIM=', '%.3f' % MS_SSIM)
img2 = img2.cpu().numpy().astype(int).reshape([256, 256])
img2 = Image.fromarray(img2.astype('uint8')).convert('L')
img2.save('./output/output.bmp')
img1 = img1.cpu().numpy().astype(int).reshape([256, 256])
img1 = Image.fromarray(img1.astype('uint8')).convert('L')
img1.save('./output/input.bmp')

numpy.save('./output/output.npy', qEncData.detach().cpu().numpy().astype(int))
Ejemplo n.º 24
0
def pytorch_ms_ssim(preds, target, data_range, kernel_size):
    return ms_ssim(preds, target, data_range=data_range, win_size=kernel_size)
Ejemplo n.º 25
0
def msssim(pred_batch, gt_batch):
    ms_ssim_loss = 1 - ms_ssim(pred_batch, gt_batch, data_range=1, win_size=7)
    return ms_ssim_loss
Ejemplo n.º 26
0
def compute_msssim(image1, image2):
    return ms_ssim(image1, image2, data_range=1, size_average=True)
def rec_loss(attr_images, generated_images, a):
    ms_ssim_loss = 1 - ms_ssim(
        attr_images, generated_images, data_range=1, size_average=True)
    l1_loss_value = l1_criterion(attr_images, generated_images)
    return a * ms_ssim_loss + (1 - a) * l1_loss_value
Ejemplo n.º 28
0
    defMaxLossOfTrainData = 0

    for j in range(16):  # 每16批 当作一个训练单元 统计这16批数据的表现
        for k in range(batchSize):
            trainData[k] = torch.from_numpy(dReader.readImg()).float().cuda()

        optimizer.zero_grad()
        encData = encNet(trainData)
        qEncData = quantize(encData)
        decData = decNet(qEncData)

        currentMSEL = MSELoss(trainData, decData)

        currentMS_SSIM = pytorch_msssim.ms_ssim(trainData,
                                                decData,
                                                data_range=255,
                                                size_average=True)

        #currentEdgeMSEL = extendMSE.EdgeMSELoss(trainData, decData)

        if (currentMSEL > 500):
            loss = currentMSEL
        else:
            loss = -currentMS_SSIM

        if (defMaxLossOfTrainData == 0):
            maxLossOfTrainData = loss
            maxLossTrainMSEL = currentMSEL
            maxLossTrainMS_SSIM = currentMS_SSIM
            #maxLossTrainEdgeMSEL = currentEdgeMSEL
            defMaxLossOfTrainData = 1
Ejemplo n.º 29
0
    optimizer.zero_grad()
    trainGrayData = rgbCompress.RGB444DetachToRGGBTensor(
        trainData, rgbKernelList, True)
    encData = encNet(trainGrayData / 255)

    qEncData = torch.zeros_like(encData)
    for ii in range(encData.shape[0]):
        qEncData[ii] = softToHardQuantize.sthQuantize(encData[ii], C, sigma)

    decData = decNet(qEncData) * 255
    currentMSEL = F.mse_loss(trainGrayData, decData)

    currentMS_SSIM = pytorch_msssim.ms_ssim(trainGrayData,
                                            decData,
                                            win_size=7,
                                            data_range=255,
                                            size_average=True)

    if (torch.isnan(currentMS_SSIM)):
        currentMS_SSIM.zero_()
    loss = -currentMS_SSIM.item()

    print('%.3f' % currentMS_SSIM.item())

    if (currentMS_SSIM.item() > -0.7):
        currentMSEL.backward()
    else:
        currentMS_SSIM.backward()

    if (i == 0):
Ejemplo n.º 30
0
def compute_msssim(a, b):
    return ms_ssim(a, b, data_range=1.).item()