コード例 #1
0
ファイル: utils.py プロジェクト: anonyr7/Sinogram-Inpainting
class Operators():
    def __init__(self, image_size, n_angles, sample_ratio, device, circle=False):
        self.device = device
        self.image_size = image_size
        self.sample_ratio = sample_ratio
        self.n_angles = n_angles
        
        angles = np.linspace(0, np.pi, self.n_angles, endpoint=False)
        self.radon = Radon(self.image_size, angles, clip_to_circle=circle)
        self.radon_sparse = Radon(self.image_size, angles[::sample_ratio], clip_to_circle=circle)
        self.n_angles_sparse = len(angles[::sample_ratio])
        self.landweber = Landweber(self.radon)
        
        self.mask = torch.zeros((1,1,1,180)).to(device)
        self.mask[:,:,:,::sample_ratio].fill_(1)
        
        
    # $X^\T ()$ inverse radon
    def forward_adjoint(self, input):
        # check dimension
        if input.size()[3] == self.n_angles:
            return self.radon.backprojection(input.permute(0,1,3,2))
        elif input.size()[3] == self.n_angles_sparse:
            return self.radon_sparse.backprojection(input.permute(0,1,3,2))/self.n_angles_sparse*self.n_angles  # scale the angles
        else:
            raise Exception(f'forward_adjoint input dimension wrong! received {input.size()}.')
            
        
    # $X^\T X ()$
    def forward_gramian(self, input):
        # check dimension
        if input.size()[2] != self.image_size:
            raise Exception(f'forward_gramian input dimension wrong! received {input.size()}.')
        
        sinogram = self.radon.forward(input)
        return self.radon.backprojection(sinogram)
    

    # Corruption model: undersample sinogram by 8
    def undersample_model(self, input):
        return input[:,:,:,::self.sample_ratio]
    
    
    # Filtered Backprojection. Input siogram range = (0,1)
    def FBP(self, input):
        # check dimension
        if input.size()[2] != self.image_size or input.size()[3] != self.n_angles:
            raise Exception(f'FBP input dimension wrong! received {input.size()}.')
        filtered_sinogram = self.radon.filter_sinogram(input.permute(0,1,3,2))
        return self.radon.backprojection(filtered_sinogram)
    
    
    # estimate step size eta
    def estimate_eta(self):
        eta = self.landweber.estimate_alpha(self.image_size, self.device)
        return torch.tensor(eta, dtype=torch.float32, device=self.device)
コード例 #2
0
class Predict():
    def __init__(self, args, dataloader):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.args = args
        self.dataloader = dataloader

        self.net = UNet(input_nc=1, output_nc=1).to(self.device)
        self.net = nn.DataParallel(self.net)

        pathG = os.path.join(args.ckpt)
        self.net.load_state_dict(torch.load(pathG, map_location=self.device))
        self.net.eval()

        self.gen_mask()

        angles = np.linspace(0, np.pi, 180, endpoint=False)
        self.radon = Radon(args.height, angles, clip_to_circle=True)

    def gen_mask(self):
        self.mask = torch.zeros(180).to(self.device)
        self.mask[::8].fill_(1)  # 180

    def gen_x(self, y):
        return self.mask * y

    def crop_sinogram(self, x):
        return x[:, :, :, 6:-6]

    def overlay(self, Gx, x):
        result = self.mask * x + (1 - self.mask) * Gx
        return result

    def inpaint(self):
        for i, data in enumerate(self.dataloader):
            y = data[0].to(self.device)  # 320 x 180

            x = self.gen_x(y)  # input, 320 x 23
            Gx = self.net(x)

            Gx = self.overlay(Gx, y)

            # FBP
            Gx = normalize(Gx)  # 0~1
            fbp_Gx = self.radon.backprojection(
                self.radon.filter_sinogram(Gx.permute(0, 1, 3, 2)))

            print(f'Saving images for batch {i}')

            for j in range(y.size()[0]):
                #                 vutils.save_image(Gx[j,0], f'{self.args.outdir}/{class_name}/{fnames[i*self.args.bs+j]}', normalize=True)  # to 0~255
                vutils.save_image(
                    fbp_Gx[j, 0],
                    f'{self.args.outdir}/{class_name}/{fnames[i*self.args.bs+j]}',
                    normalize=True)
コード例 #3
0
ファイル: fbp.py プロジェクト: knxie/torch-radon
from utils import show_images

batch_size = 1
n_angles = 512
image_size = 512

img = np.load("phantom.npy")
device = torch.device('cuda')

# instantiate Radon transform
angles = np.linspace(0, np.pi, n_angles, endpoint=False)
radon = Radon(image_size, angles)

with torch.no_grad():
    x = torch.FloatTensor(img).reshape(1, 1, image_size, image_size).to(device)

    sinogram = radon.forward(x)
    filtered_sinogram = radon.filter_sinogram(sinogram)
    fbp = radon.backprojection(filtered_sinogram,
                               extend=False) * np.pi / n_angles

print("FBP Error", torch.norm(x - fbp).item())

titles = [
    "Original Image", "Sinogram", "Filtered Sinogram",
    "Filtered Backprojection"
]
show_images([x, sinogram, filtered_sinogram, fbp], titles, keep_range=False)

plt.show()
コード例 #4
0
class Predict():
    def __init__(self, args, image):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        if args.twoends:
            factor = 192 / (args.angles + 2)  # 7.68
        else:
            factor = 180 / args.angles  # 7.826086956521739

        self.net = UNet(input_nc=1, output_nc=1,
                        scale_factor=factor).to(self.device)
        self.net = nn.DataParallel(self.net)
        pathG = os.path.join(args.ckpt)
        self.net.load_state_dict(torch.load(pathG, map_location=self.device))
        self.net.eval()

        self.image = image.to(self.device)
        self.twoends = args.twoends
        self.mask = self.gen_mask().to(self.device)

        # Radon Operator
        angles = np.linspace(0, np.pi, 180, endpoint=False)
        self.radon = Radon(args.height, angles, clip_to_circle=True)

    def gen_mask(self):
        mask = torch.zeros(180)
        mask[::8].fill_(1)  # 180
        if self.twoends:
            mask = torch.cat((mask[-6:], mask, mask[:6]), 0)  # 192
        return mask

    def append_twoends(self, y):
        front = torch.flip(y[:, :, :, :6], [2])
        back = torch.flip(y[:, :, :, -6:], [2])
        return torch.cat((back, y, front), 3)

    def gen_sparse(self, y):
        return y[:, :, :, self.mask == 1]

    def crop_sinogram(self, x):
        return x[:, :, :, 6:-6]

    def inpaint(self):
        y = self.image  # 320 x 180

        # Two-Ends Preprocessing
        if self.twoends:
            y = self.append_twoends(y)  # 320 x 192

        # Generate Sparse-view Image, forward model
        x = self.gen_sparse(y)
        Gx = self.net(x)

        # Crop Two-Ends
        if self.twoends:
            Gx = self.crop_sinogram(Gx)

        # FBP Reconstruction
        Gx = normalize(Gx)  # 0~1
        fbp_Gx = self.radon.backprojection(
            self.radon.filter_sinogram(Gx.permute(0, 1, 3, 2)))

        # Save Results
        vutils.save_image(fbp_Gx, 'result_reconstruction.png', normalize=True)
        vutils.save_image(Gx, 'result_sinogram.png', normalize=True)
コード例 #5
0
class Predict():
    def __init__(self, args, dataloader):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.args = args
        self.dataloader = dataloader

        if args.twoends:
            factor = 192 / (args.angles + 2)  # 7.68
        else:
            factor = 180 / args.angles  # 7.826086956521739

        self.net = UNet(input_nc=1, output_nc=1,
                        scale_factor=factor).to(self.device)
        self.net = nn.DataParallel(self.net)
        pathG = os.path.join(args.ckpt)
        self.net.load_state_dict(torch.load(pathG, map_location=self.device))
        self.net.eval()

        self.gen_mask()

        # Radon Operator for different downsampling factors
        angles = np.linspace(0, np.pi, 180, endpoint=False)
        self.radon = Radon(args.height, angles, clip_to_circle=True)
        self.radon23 = Radon(args.height, angles[::8], clip_to_circle=True)
        self.radon45 = Radon(args.height, angles[::4], clip_to_circle=True)
        self.radon90 = Radon(args.height, angles[::2], clip_to_circle=True)

    def gen_mask(self):
        mask = torch.zeros(180)
        mask[::8].fill_(1)  # 180
        if self.args.twoends:
            self.mask = torch.cat((mask[-6:], mask, mask[:6]),
                                  0).to(self.device)  # 192
        self.mask_sparse = mask

    def append_twoends(self, y):
        front = torch.flip(y[:, :, :, :6], [2])
        back = torch.flip(y[:, :, :, -6:], [2])
        return torch.cat((back, y, front), 3)

    def gen_input(self, y, mask):
        return y[:, :, :, mask == 1]

    def crop_sinogram(self, x):
        return x[:, :, :, 6:-6]

    def inpaint(self):
        for i, data in enumerate(self.dataloader):
            y = data[0].to(self.device)  # 320 x 180

            # Two-Ends Preprocessing
            if self.args.twoends:
                y_TE = self.append_twoends(y)  # 320 x 192

            # Forward Model
            x = self.gen_input(y_TE, self.mask)  # input, 320 x 25
            Gx = self.net(x)  # 320 x 192

            # Crop Two-Ends
            if self.args.twoends:
                Gx = self.crop_sinogram(Gx)  # 320 x 180

            # FBP Reconstruction
            Gx = normalize(Gx)  # 0~1
            fbp_Gx = self.radon.backprojection(
                self.radon.filter_sinogram(Gx.permute(0, 1, 3,
                                                      2)))  # 320 x 320

            # FBP for downsampled sinograms
            Gx1 = Gx[:, :, :, ::2]  # 320 x 90
            Gx1 = normalize(Gx1)  # 0~1
            fbp_Gx1 = self.radon90.backprojection(
                self.radon90.filter_sinogram(Gx1.permute(0, 1, 3, 2)))

            Gx2 = Gx[:, :, :, ::4]  # 320 x 45
            Gx2 = normalize(Gx2)  # 0~1
            fbp_Gx2 = self.radon45.backprojection(
                self.radon45.filter_sinogram(Gx2.permute(0, 1, 3, 2)))

            sparse = y[:, :, :, ::8]  # 320 x 23, original sparse-view sinogram
            sparse = normalize(sparse)  # 0~1
            fbp_sparse = self.radon23.backprojection(
                self.radon23.filter_sinogram(sparse.permute(0, 1, 3, 2)))

            print(f'Saving images for batch {i}')

            for j in range(y.size()[0]):
                #                 vutils.save_image(Gx[j,0], f'{self.args.outdir}/{class_name}/{fnames[i*self.args.bs+j]}', normalize=True)
                vutils.save_image(
                    fbp_Gx[j, 0],
                    f'{self.args.outdir}/{class_name}/{fnames[i*self.args.bs+j]}',
                    normalize=True)
                vutils.save_image(
                    fbp_Gx1[j, 0],
                    f'{self.args.outdir}_90/{class_name}/{fnames[i*self.args.bs+j]}',
                    normalize=True)
                vutils.save_image(
                    fbp_Gx2[j, 0],
                    f'{self.args.outdir}_45/{class_name}/{fnames[i*self.args.bs+j]}',
                    normalize=True)
                vutils.save_image(
                    fbp_sparse[j, 0],
                    f'{self.args.outdir}_23/{class_name}/{fnames[i*self.args.bs+j]}',
                    normalize=True)
コード例 #6
0
class Inpaint():
    def __init__(self, net, args, dataloader, device):
        self.netG = net[0]
        self.netDG = net[1]
        self.netDL = net[2]
        if args.mode == 'vgg':
            self.netLoss = net[3]
        self.optimizerG = optim.Adam(self.netG.parameters(), lr=args.lr, betas=(0.5, 0.999))
        self.optimizerDG = optim.Adam(self.netDG.parameters(), lr=args.lr, betas=(0.5, 0.999))
        self.optimizerDL = optim.Adam(self.netDL.parameters(), lr=args.lr, betas=(0.5, 0.999))
        
        self.dataloader = dataloader
        self.device = device
        self.args = args
        self.save_cp = True
        self.start_epoch = args.load+1 if args.load>=0 else 0
        self.mask = self.gen_mask().to(self.device)
        
        self.criterionL1 = torch.nn.L1Loss().to(self.device)
        self.criterionL2 = torch.nn.MSELoss().to(self.device)
        self.criterionGAN = GANLoss('vanilla').to(self.device)
        
        err_list = ["errDG", "errDL", 
                    "errGG_GAN", "errGG_C", "errGG_F", "errGG_P",
                    "errGL_GAN", "errGL_C", "errGL_F", "errGL_P"]
        self.err = dict.fromkeys(err_list, None) 
                    
        if self.save_cp:
            try:
                if not os.path.exists(os.path.join(args.outdir, 'ckpt')):
                    os.makedirs(os.path.join(args.outdir, 'ckpt'))
                    print('Created checkpoint directory')
                if args.load < 0:  # New log file
                    with open(os.path.join(args.outdir, args.log_fn+'.csv'), 'w', newline='') as f:
                        csvwriter = writer(f)
                        csvwriter.writerow(["epoch", "runtime"] + err_list)
            except OSError:
                pass
        
        angles = np.linspace(0, np.pi, 180, endpoint=False)
        self.radon = Radon(args.height, angles, clip_to_circle=True)

        
    def gen_mask(self):
        mask = torch.zeros(180)
        mask[::8].fill_(1)  # 180/23
        if self.args.twoends:
            mask = torch.cat((mask[-6:], mask, mask[:6]), 0) # 192/25
        return mask
            
    
    def gen_sparse(self, y):
        return y[:,:,:,self.mask==1]
    
    
    def append_twoends(self, y):
        front = torch.flip(y[:,:,:,:6], [2])
        back = torch.flip(y[:,:,:,-6:], [2])
        return torch.cat((back, y, front), 3)
    
    
    def ramp_module(self, sinogram):
        '''
            Sinogram has dimension: bs x c x height x angle. 
            Ramp is 1D but angle number affects normalization for filter_sinogram. Use with caution.
        '''
        normalized_sinogram = normalize(sinogram, rto=(0,1))
        if sinogram.size()[2] == self.args.height:
            filtered_sinogram = self.radon.filter_sinogram(normalized_sinogram.permute(0,1,3,2)).permute(0,1,3,2)  # 320 x 192
        else:
            print('sinogram dimension wrong for filter!')
        return normalize(filtered_sinogram, rto=(-1,1))
    
    
    def criterionP(self, Gx, y):
        # calculate feature loss
        y_features = self.netLoss(y)
        Gx_features = self.netLoss(Gx)
        
        loss = 0.0
        for j in range(len(y_features)):
            loss += self.criterionL2(Gx_features[j], y_features[j][:y.shape[0]])
        return loss
    
    
    def criterionDP(self, Gx_features, y_features):
        loss = 0.0
        for j in range(len(y_features)):
            loss += self.criterionL2(Gx_features[j], y_features[j])
        return loss
        
        
    def train_D(self, Gx, y, mode):
        '''
            mode is G/L.
        '''
        if mode == 'G':
            netD = self.netDG
            optimizer = self.optimizerDG
        elif mode == 'L':
            netD = self.netDL
            optimizer = self.optimizerDL
        else:
            print('wrong mode!')
            
        netD.zero_grad()
        
        ############################
        # Loss_D: L_D = -(log(D(y) + log(1 - D(G(x))))
        ###########################
        # train with fake
        D_Gx = netD(Gx.detach())[-1]
        errD_fake = self.criterionGAN(D_Gx, False)

        # train with real
        D_y = netD(y)[-1]
        errD_real = self.criterionGAN(D_y, True)

        # backprop
        errD = (errD_real + errD_fake) * 0.5
        errD.backward()
        optimizer.step()
        self.err['errD'+mode] = errD.item()
    
    
    def train_G(self, Gx, y, filtered_Gx, filtered_y, mode):
        '''
            mode is G/L.
        '''
        if mode == 'G':
            netD = self.netDG
        elif mode == 'L':
            netD = self.netDL
        else:
            print('wrong mode!')
            
        self.netG.zero_grad()
        
        ############################
        # Loss_G_GAN: L_G = -log(D(G(x))  # Fake the D
        ###########################
        Gx_features = netD(Gx)
        errG_GAN = self.criterionGAN(Gx_features[-1], True)
        
        ############################
        # Loss_G_C: L_C = ||y - G(x)||_1
        ###########################
        errG_C = self.criterionL1(Gx, y)*50

        ############################
        # Loss_G_DP: Discriminator perceptual feature loss
        ###########################
        if self.args.mode == 'vgg':
            errG_P = self.criterionP(Gx, y)*20
        elif self.args.mode == 'DP':
            y_features = netD(y)
            errG_P = self.criterionDP(Gx_features[:-1], y_features[:-1])*20
#             errG_P = self.criterionDP(Gx_features[-2], y_features[-2])*50
        else:
            errG_P = torch.tensor(0).to(self.device)
            
        ############################
        # Loss_G_F: Ramp filtered sinogram loss
        ###########################
        errG_F = self.criterionL1(filtered_Gx, filtered_y)*50

        # backprop
        errG = errG_GAN + errG_C + errG_F + errG_P
        errG.backward()
        self.optimizerG.step()
        
        self.err['errG'+mode+'_GAN'] = errG_GAN.item()
        self.err['errG'+mode+'_C'] = errG_C.item()
        self.err['errG'+mode+'_F'] = errG_F.item()
        self.err['errG'+mode+'_P'] = errG_P.item()
    
    
    def log(self, epoch, i):
        print(f'[{epoch}/{self.args.epochs}][{i}/{len(self.dataloader)}] ' \
              f'LossDG: {self.err["errDG"]:.4f} ' \
              f'LossGG_GAN: {self.err["errGG_GAN"]:.4f} ' \
              f'LossGG_C: {self.err["errGG_C"]:.4f} ' \
              f'LossGG_F: {self.err["errGG_F"]:.4f} ' \
              f'LossGG_P: {self.err["errGG_P"]:.4f} ' \
              
              f'LossDL: {self.err["errDL"]:.4f} ' \
              f'LossGL_GAN: {self.err["errGL_GAN"]:.4f} ' \
              f'LossGL_C: {self.err["errGL_C"]:.4f} ' \
              f'LossGL_F: {self.err["errGL_F"]:.4f} ' \
              f'LossGL_P: {self.err["errGL_P"]:.4f} ' \
        )
    
    
    def log2file(self, fn, epoch, runtime):
        new_row = [epoch, runtime]+ list[self.err.values()]
        with open(fn, 'a+', newline='') as write_obj:
            csv_writer = writer(write_obj)
            csv_writer.writerow(new_row)
        
        
    def train(self):
        print(f'''Starting training:
            Epochs:          {self.args.epochs}
            Batch size:      {self.args.batchSize}
            Learning rate:   {self.args.lr}
            Checkpoints:     {self.save_cp}
            Device:          {self.device.type}
        ''')
        
        for epoch in range(self.start_epoch, self.args.epochs):
            self.D_epochs = 1 # Adjust if you want
            print('D is trained ', str(self.D_epochs), 'times in this epoch.')
            
            start = time.time()  # log start time
            for i, data in enumerate(self.dataloader):
                y = data[0].to(self.device)  # 320 x 180

                # forward
                if self.args.twoends:
                    y = self.append_twoends(y)  # 320 x 192
                
                filtered_y = self.ramp_module(y)  # 320 x 192, normalized to -1~1
                x = self.gen_sparse(y)  # 320 x 25
                
                # Train Global
                Gx = self.netG(x)
                filtered_Gx = self.ramp_module(Gx)  # 320 x 192, normalized to -1~1

                ###### Train D
                set_requires_grad(self.netDG, True)
                for _ in range(self.D_epochs):  # increase D epoch gradually. FOR DP LOSS training
                    self.train_D(Gx, y, mode='G')
                ###### Train G
                set_requires_grad(self.netDG, False)  # D requires no gradients when optimizing G
                self.train_G(Gx, y, filtered_Gx, filtered_y, mode='G')
                
                # Train Local
                Gx = self.netG(x)
                filtered_Gx = self.ramp_module(Gx)  # 320 x 192, normalized to -1~1
                patch_area = gen_hole_area((y.shape[3]//4, y.shape[2]//4), (y.shape[3], y.shape[2]))
                Gx_patch = crop(Gx, patch_area)
                y_patch = crop(y, patch_area)
                filtered_y_patch = crop(filtered_y, patch_area)
                filtered_Gx_patch = crop(filtered_Gx, patch_area)
                
                ###### Train D
                set_requires_grad(self.netDL, True)
                for _ in range(self.D_epochs):  # increase D epoch gradually. FOR DP LOSS training
                    self.train_D(Gx_patch, y_patch, mode='L')
                ###### Train G
                set_requires_grad(self.netDL, False)  # D requires no gradients when optimizing G
                self.train_G(Gx_patch, y_patch, filtered_Gx_patch, filtered_y_patch, mode='L')
                
                if i % 100 == 0:
                    self.log(epoch, i)
                    
            end = time.time()  # log end time
#             self.log2file(os.path.join(self.args.outdir, self.args.log_fn+'.csv'), epoch , str(end-start))
            
            # Log
            self.log(epoch, i)
            if self.save_cp:
                torch.save(self.netG.state_dict(), f'{self.args.outdir}/ckpt/G_epoch{epoch}.pth')
                torch.save(self.netDG.state_dict(), f'{self.args.outdir}/ckpt/DG_epoch{epoch}.pth')
                torch.save(self.netDL.state_dict(), f'{self.args.outdir}/ckpt/DL_epoch{epoch}.pth')
            vutils.save_image(Gx.detach(), '%s/impainted_samples_epoch_%03d.png' % (self.args.outdir, epoch), normalize=True)