def compute_metrics(net, valLoader, upscaling, multiscale): net.eval() psnr_low = 0 psnr_super = 0 ssim_low = 0 ssim_super = 0 cpu = torch.device("cpu") cuda = torch.device("cuda:0") with torch.no_grad(): for i, batch in enumerate(valLoader): highres, lowres = batch lowres = lowres.to(cuda) highres = highres.to(cuda) if (multiscale): superres = net(lowres, upscaling) else: superres = net(lowres) lowres = lowres.to(cpu) lowres = ToImage()(lowres.view(lowres.size()[1:])) lowres = lowres.resize( (lowres.size[0] * upscaling, lowres.size[1] * upscaling), Image.BICUBIC) lowres = ToTensor()(lowres) lowres = lowres.view([1] + list(lowres.size())) lowres = lowres.to(cuda) psnr_super += compute_psnr(superres, highres) psnr_low += compute_psnr(lowres, highres) ssim_super += compute_msssim(superres, highres) print("|") ssim_low += compute_msssim(lowres, highres) return psnr_low / len(valLoader), psnr_super / len(valLoader), \ ssim_low / len(valLoader), ssim_super / len(valLoader)
def normalized_dicom_pixels(ds): signed = ds.PixelRepresentation == 1 slope = float(ds.RescaleSlope) intercept = float(ds.RescaleIntercept) x = ds.pixel_array if ds.BitsStored == 12 and not signed and int(intercept) > -100: # see: https://www.kaggle.com/jhoward/cleaning-the-data-for-rapid-prototyping-fastai x += 1000 px_mode = 4096 x[x >= px_mode] = x[x >= px_mode] - px_mode intercept -= 1000 x = np.frombuffer(x, dtype='int16' if signed else 'uint16') x = np.array(x, dtype='float32') x = x * slope + intercept x = torch.Tensor(x) if x.numel() != 512 * 512: #dim = torch.sqrt(torch.Tensor([x.numel()])) #if dim.floor() != dim.ceil(): # raise ValueError('Non-square number of input elements ' # f'got {x.numel()} (dcm header reports {ds.Rows}x{ds.Columns})') #dim = dim.int().item() if ds.Columns * ds.Rows != x.numel(): raise ValueError( f'dimensions {ds.Rows}x{ds.Columns} does not match numel {x.numel()}' ) x = x.view(1, ds.Columns, ds.Rows) x = ToPILImage()(x) x = Resize((512, 512))(x) x = ToTensor()(x) #print(f'Successfully resized from {ds.Rows}x{ds.Columns}') x = x.view(1, 512, 512) return x