Exemple #1
0
def compute_metrics(net, valLoader, upscaling, multiscale):
    net.eval()
    psnr_low = 0
    psnr_super = 0
    ssim_low = 0
    ssim_super = 0
    cpu = torch.device("cpu")
    cuda = torch.device("cuda:0")
    with torch.no_grad():
        for i, batch in enumerate(valLoader):
            highres, lowres = batch
            lowres = lowres.to(cuda)
            highres = highres.to(cuda)
            if (multiscale):
                superres = net(lowres, upscaling)
            else:
                superres = net(lowres)
            lowres = lowres.to(cpu)
            lowres = ToImage()(lowres.view(lowres.size()[1:]))
            lowres = lowres.resize(
                (lowres.size[0] * upscaling, lowres.size[1] * upscaling),
                Image.BICUBIC)
            lowres = ToTensor()(lowres)
            lowres = lowres.view([1] + list(lowres.size()))
            lowres = lowres.to(cuda)
            psnr_super += compute_psnr(superres, highres)
            psnr_low += compute_psnr(lowres, highres)
            ssim_super += compute_msssim(superres, highres)
            print("|")
            ssim_low += compute_msssim(lowres, highres)
    return psnr_low / len(valLoader), psnr_super / len(valLoader), \
           ssim_low / len(valLoader), ssim_super / len(valLoader)
Exemple #2
0
def normalized_dicom_pixels(ds):
    signed = ds.PixelRepresentation == 1
    slope = float(ds.RescaleSlope)
    intercept = float(ds.RescaleIntercept)
    x = ds.pixel_array
    if ds.BitsStored == 12 and not signed and int(intercept) > -100:
        # see: https://www.kaggle.com/jhoward/cleaning-the-data-for-rapid-prototyping-fastai
        x += 1000
        px_mode = 4096
        x[x >= px_mode] = x[x >= px_mode] - px_mode
        intercept -= 1000
    x = np.frombuffer(x, dtype='int16' if signed else 'uint16')
    x = np.array(x, dtype='float32')
    x = x * slope + intercept
    x = torch.Tensor(x)
    if x.numel() != 512 * 512:
        #dim = torch.sqrt(torch.Tensor([x.numel()]))
        #if dim.floor() != dim.ceil():
        #    raise ValueError('Non-square number of input elements '
        #                     f'got {x.numel()} (dcm header reports {ds.Rows}x{ds.Columns})')
        #dim = dim.int().item()
        if ds.Columns * ds.Rows != x.numel():
            raise ValueError(
                f'dimensions {ds.Rows}x{ds.Columns} does not match numel {x.numel()}'
            )
        x = x.view(1, ds.Columns, ds.Rows)
        x = ToPILImage()(x)
        x = Resize((512, 512))(x)
        x = ToTensor()(x)
        #print(f'Successfully resized from {ds.Rows}x{ds.Columns}')
    x = x.view(1, 512, 512)
    return x