decompose_data = torch.cat([y_, High_noise, Low_noise], dim=1)
                # x = torch.cat([High_origin.unsqueeze(1), Low_origin.unsqueeze(1)], dim=1)
                model = model.cpu()
                decom_output = decompose_model(decompose_data.float()).squeeze(
                    0)  # inference

                dncnn_output = model(High_noise[:, 0].unsqueeze(1)).squeeze(0)
                # dncnn_high, dncnn_low = Decomposition(dncnn_output, 0.10)

                # x_ = output[0].cpu().detach().numpy().astype(np.float32)

                output = ComplexTensor(dncnn_output[1] + decom_output[3],
                                       decom_output[2] +
                                       decom_output[4]).abs()
                # output = ComplexTensor(dncnn_high[0,0] + decom_output[3], decom_output[2] + decom_output[4]).abs()
                x_ = output.cpu().detach().numpy().astype(np.float32)

                # x_ = ComplexTensor(output[:, 0] + output[:, 2], Low_noise[:, 0] + Low_noise[:, 1]).abs().squeeze(0)
                # x_ = ComplexTensor(output[:, 0] + output[:, 2], output[:, 1] + output[:, 3]).abs().squeeze(0)
                # x_ = torch.add(output[:,0], output[:,1]).squeeze(0)

                # x_ = x_.cpu().detach().numpy().astype(np.float32)

                # x_ = x_.view(y.shape[0], y.shape[1])
                # x_ = x_.cpu()
                # x_ = x_.detach().numpy().astype(np.float32)
                elapsed_time = time.time() - start_time

                psnr_x_ = compare_psnr(x, x_)
                ssim_x_ = compare_ssim(x, x_)
示例#2
0
                #  0.14
                High_noise, Low_noise = Decomposition(y_.squeeze(0), 0.125)
                High_noise = High_noise.cuda()
                Low_noise = Low_noise.cuda()
                y_ = y_.cuda()

                y_ = torch.cat([High_noise, Low_noise], dim=1)
                # x = torch.cat([High_origin.unsqueeze(1), Low_origin.unsqueeze(1)], dim=1)

                output = model(y_.cuda().float())  # inference

                # x_ = ComplexTensor(output[:, 0] + output[:, 2], Low_noise[:, 0] + Low_noise[:, 1]).abs().squeeze(0)
                x_ = ComplexTensor(output[:, 0] + output[:, 2], output[:, 1] + output[:, 3]).abs().squeeze(0)
                # x_ = torch.add(output[:,0], output[:,1]).squeeze(0)

                x_ = x_.cpu().detach().numpy().astype(np.float32)

                # x_ = x_.view(y.shape[0], y.shape[1])
                # x_ = x_.cpu()
                # x_ = x_.detach().numpy().astype(np.float32)
                torch.cuda.synchronize()
                elapsed_time = time.time() - start_time

                psnr_x_ = compare_psnr(x, x_)
                ssim_x_ = compare_ssim(x, x_)

                print('%10s : %10s : %2.4f second %2.2f PSNR' % (set_cur, im, elapsed_time, psnr_x_))

                if args.save_result:
                    name, ext = os.path.splitext(im)
                    show(np.hstack((y, x_)))  # show the image