Пример #1
0
def main():
    args = parser.parse_args()

    device = f'cuda:{args.gpu}'

    unet = UNet(in_channels=1, out_channels=1, compact=4, residual=True,
                circular_padding=True, cat=True).to(device)
    forw = CT(img_width=128, radon_view=50, circle=False, device=device)
    dataloader = torch.utils.data.DataLoader(dataset=CTData(mode='test'),batch_size=1, shuffle=False)

    def test(net, ckp, fbp, adv=False):
        checkpoint = torch.load(ckp, map_location=device)
        net.load_state_dict(checkpoint['state_dict_G' if adv else 'state_dict'])
        net.to(device).eval()
        return net(fbp)

    for i, x in enumerate(dataloader):
        if i in args.sample_to_show:
            if len(x.shape) == 3:
                x = x.unsqueeze(1)
            x = x.type(torch.float).to(device)

            y = forw.A(x)
            fbp = forw.A_dagger(y)
            x_hat = test(unet, args.ckp_net, fbp)

            plt.subplot(1,4,1)
            plt.imshow(y[0].detach().permute(1, 2, 0).cpu().numpy())
            plt.title('y')
            
            plt.subplot(1,4,2)
            plt.imshow(fbp[0].detach().permute(1, 2, 0).cpu().numpy())
            plt.title('FBP ({:.2f})'.format(cal_psnr(x, fbp)))
            
            plt.subplot(1,4,3)
            plt.imshow(x_hat[0].detach().permute(1, 2, 0).cpu().numpy())
            plt.title('{} ({:.2f})'.format(args.model_name, cal_psnr(x, x_hat)))
            
            plt.subplot(1,4,4)
            plt.imshow(x[0].detach().permute(1, 2, 0).cpu().numpy())
            plt.title('x (GT)')
            
            ax = plt.gca()
            ax.set_xticks([]), ax.set_yticks([])
            plt.subplots_adjust(left=0.1, bottom=0.1, top=0.9, right=0.9, hspace=0.02, wspace=0.02)
            plt.show()
        else:
            continue
Пример #2
0
def closure_sup_ei(net,
                   dataloader,
                   physics,
                   transform,
                   optimizer,
                   criterion_fc,
                   criterion_ei,
                   alpha,
                   dtype,
                   device,
                   reportpsnr=False):
    loss_x_seq, loss_ei_seq, loss_seq, psnr_seq, mse_seq = [], [], [], [], []
    for i, x in enumerate(dataloader):
        x = x[0] if isinstance(x, list) else x
        if len(x.shape) == 3:
            x = x.unsqueeze(1)
        x = x.type(dtype).to(device)

        y0 = physics.A(x.type(dtype).to(device))
        x0 = physics.A_dagger(y0)  #range input (pr)

        x1 = net(x0)
        y1 = physics.A(x1)

        # EI: x2, x3
        x2 = transform.apply(x1)
        x3 = net(physics.A_dagger(physics.A(x2)))

        loss_x = criterion_fc(x1, x)
        loss_ei = criterion_ei(x3, x2)

        loss = loss_x + alpha['ei'] * loss_ei

        loss_x_seq.append(loss_x.item())
        loss_ei_seq.append(loss_ei.item())
        loss_seq.append(loss.item())

        if reportpsnr:
            psnr_seq.append(cal_psnr(x1, x))
            mse_seq.append(cal_mse(x1, x))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    loss_closure = [
        np.mean(loss_x_seq),
        np.mean(loss_ei_seq),
        np.mean(loss_seq)
    ]

    if reportpsnr:
        loss_closure.append(np.mean(psnr_seq))
        loss_closure.append(np.mean(mse_seq))

    return loss_closure
Пример #3
0
def closure_ei(net, dataloader, physics, transform,
                    optimizer, criterion_mc, criterion_ei,
                    alpha, dtype, device, reportpsnr=False):
    loss_mc_seq, loss_ei_seq, loss_seq, psnr_seq, mse_seq = [], [], [], [], []
    for i, x in enumerate(dataloader):
        x = x[0] if isinstance(x, list) else x
        if len(x.shape)==3:
            x = x.unsqueeze(1)
        x = x.type(dtype).to(device)# ground-truth signal x

        y0 = physics.A(x.type(dtype).to(device)) # generate measurement input y
        x0 = physics.A_dagger(y0) # range input (A^+y)

        x1 = net(x0)
        y1 = physics.A(x1)

        # equivariant imaging: x2, x3
        x2 = transform.apply(x1)
        x3 = net(physics.A_dagger(physics.A(x2)))

        loss_mc = criterion_mc(y1, y0)
        loss_ei = criterion_ei(x3, x2)

        loss = loss_mc + alpha['ei'] * loss_ei

        loss_mc_seq.append(loss_mc.item())
        loss_ei_seq.append(loss_ei.item())
        loss_seq.append(loss.item())

        if reportpsnr:
            psnr_seq.append(cal_psnr(x1, x))
            mse_seq.append(cal_mse(x1, x))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    loss_closure = [np.mean(loss_mc_seq), np.mean(loss_ei_seq), np.mean(loss_seq)]

    if reportpsnr:
        loss_closure.append(np.mean(psnr_seq))
        loss_closure.append(np.mean(mse_seq))

    return loss_closure
Пример #4
0
def closure_dip(net,
                dataloader,
                z,
                physics,
                optimizer,
                criterion_mc,
                dtype,
                device,
                reportpsnr=False):
    loss_dip_seq = []
    for i, x in enumerate(dataloader):
        x = x[0] if isinstance(x, list) else x
        if len(x.shape) == 3:
            x = x.unsqueeze(1)
        x = x.type(dtype).to(device)

        y0 = physics.A(x.type(dtype).to(device))

        # z = torch.rand_like(x)

        x1 = net(z)
        y1 = physics.A(x1)

        if reportpsnr:
            psnr = cal_psnr(x1, x)
            mse = torch.nn.MSELoss()(x1, x).item()

        loss_mc = criterion_mc(y1, y0)

        loss_dip_seq.append(loss_mc.item())

        optimizer.zero_grad()
        loss_fc.backward()
        optimizer.step()

    loss_closure = [np.mean(loss_dip_seq)]

    if reportpsnr:
        loss_closure.append(psnr)
        loss_closure.append(mse)

    return loss_closure
Пример #5
0
def closure_mc(net,
               dataloader,
               physics,
               optimizer,
               criterion_mc,
               dtype,
               device,
               reportpsnr=False):
    loss_mc_seq, psnr_seq, mse_seq = [], [], []
    for i, x in enumerate(dataloader):
        x = x[0] if isinstance(x, list) else x
        if len(x.shape) == 3:
            x = x.unsqueeze(1)
        x = x.type(dtype).to(device)  # ground-truth

        y0 = physics.A(x.type(dtype).to(device))  # measurement
        x0 = physics.A_dagger(y0)  # range input

        x1 = net(x0)
        y1 = physics.A(x1)

        loss_mc = criterion_mc(y1, y0)

        loss_mc_seq.append(loss_mc.item())

        if reportpsnr:
            psnr_seq.append(cal_psnr(x1, x))
            mse_seq.append(cal_mse(x1, x))

        optimizer.zero_grad()
        loss_mc.backward()
        optimizer.step()

    loss_closure = [np.mean(loss_mc_seq)]

    if reportpsnr:
        loss_closure.append(np.mean(psnr_seq))
        loss_closure.append(np.mean(mse_seq))

    return loss_closure
Пример #6
0
def main():
    args = parser.parse_args()

    device = f'cuda:{args.gpu}'

    # define the dataloader (i.e. 'urban100', first 90 imgs for training, last 10 for testing)
    dataloader = CVDB_ICCV(dataset_name=args.dataset_name,
                           mode='test',
                           batch_size=1,
                           shuffle=False)

    # define the forward oeprator (i.e. physics)
    forw = Inpainting(img_heigth=256,
                      img_width=256,
                      mask_rate=0.3,
                      device=device)

    # define the network G (i.e. residual unet in the paper)
    unet = UNet(in_channels=3,
                out_channels=3,
                compact=4,
                residual=True,
                circular_padding=True,
                cat=True).to(device)

    psnr_fbp, psnr_net = [], []

    def test(net, ckp, fbp, adv=False):
        checkpoint = torch.load(ckp, map_location=device)
        net.load_state_dict(
            checkpoint['state_dict_G' if adv else 'state_dict'])
        net.to(device).eval()
        return net(fbp)

    for i, x in enumerate(dataloader):
        x = x[0] if isinstance(x, list) else x
        if len(x.shape) == 3:
            x = x.unsqueeze(1)
        # groundtruth
        x = x.type(torch.float).to(device)
        # compute measurement
        y = forw.A(x)
        # compute the A^+y or FBP
        fbp = forw.A_dagger(y)

        x_hat = test(unet, args.ckp, fbp)

        if i in args.sample_to_show:
            plt.subplot(1, 4, 1)
            plt.imshow(y.squeeze().detach().permute(1, 2, 0).cpu().numpy())
            plt.title('y')

            plt.subplot(1, 4, 2)
            plt.imshow(fbp.squeeze().detach().permute(1, 2, 0).cpu().numpy())
            plt.title('FBP ({:.2f})'.format(cal_psnr(x, fbp)))

            plt.subplot(1, 4, 3)
            plt.imshow(x_hat.squeeze().detach().permute(1, 2, 0).cpu().numpy())
            plt.title('{} ({:.2f})'.format(args.model_name, cal_psnr(x, fbp)))

            plt.subplot(1, 4, 4)
            plt.imshow(x.squeeze().detach().permute(1, 2, 0).cpu().numpy())
            plt.title('x (GT)')

            ax = plt.gca()
            ax.set_xticks([]), ax.set_yticks([])
            plt.subplots_adjust(left=0.1,
                                bottom=0.1,
                                top=0.9,
                                right=0.9,
                                hspace=0.02,
                                wspace=0.02)
            plt.show()

    print('Inpainting (0.3) AVG-PSNR: A^+y={:.2f}\t{}={:.2f}'.format(
        np.mean(psnr_fbp), args.model_name, np.mean(psnr_ei)))
Пример #7
0
def closure_ei_adv(generator,
                   discriminator,
                   dataloader,
                   physics,
                   transform,
                   optimizer_G,
                   optimizer_D,
                   criterion_mc,
                   criterion_ei,
                   criterion_gan,
                   alpha,
                   dtype,
                   device,
                   reportpsnr=False):
    loss_mc_seq, loss_ei_seq, loss_g_seq, loss_G_seq, loss_D_seq, psnr_seq, mse_seq = [], [], [], [], [], [], []

    for i, x in enumerate(dataloader):
        x = x[0] if isinstance(x, list) else x
        if len(x.shape) == 3:
            x = x.unsqueeze(1)
        x = x.type(dtype).to(device)

        # Measurements
        y0 = physics.A(x)

        # Model range inputs
        x0 = Variable(physics.A_dagger(y0))  # range input (pr)

        # Adversarial ground truths
        valid = torch.ones(x.shape[0],
                           *discriminator.output_shape).type(dtype).to(device)
        valid_ei = torch.ones(
            x.shape[0] * transform.n_trans,
            *discriminator.output_shape).type(dtype).to(device)
        fake_ei = torch.zeros(
            x.shape[0] * transform.n_trans,
            *discriminator.output_shape).type(dtype).to(device)

        valid = Variable(valid, requires_grad=False)
        valid_ei = Variable(valid_ei, requires_grad=False)
        fake_ei = Variable(fake_ei, requires_grad=False)
        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Generate a batch of images from range input A^+y
        x1 = generator(x0)
        y1 = physics.A(x1)

        # EI: x2, x3
        x2 = transform.apply(x1)
        x3 = generator(physics.A_dagger(physics.A(x2)))

        # Loss measures generator's ability to measurement consistency and ei
        loss_fc = criterion_mc(y1, y0)
        loss_ei = criterion_ei(x3, x2)

        # Loss measures generator's ability to fool the discriminator
        loss_g = criterion_gan(discriminator(x2), valid_ei)

        loss_G = loss_fc + alpha['ei'] * loss_ei + alpha['adv'] * loss_g

        loss_G.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = criterion_gan(discriminator(x1.detach()), valid)
        fake_loss = criterion_gan(discriminator(x2.detach()), fake_ei)
        loss_D = 0.5 * alpha['adv'] * (real_loss + fake_loss)

        loss_D.backward()
        optimizer_D.step()

        if reportpsnr:
            psnr_seq.append(cal_psnr(x1, x))
            mse_seq.append(cal_mse(x1, x))

        # --------------
        #  Log Progress
        # --------------

        loss_mc_seq.append(loss_fc.item())
        loss_ei_seq.append(loss_ei.item())
        loss_g_seq.append(loss_g.item())
        loss_G_seq.append(loss_G.item())  # total loss for G
        loss_D_seq.append(loss_D.item())  # total loss for D
    #loss: loss_fc, loss_ti, loss_g, loss_G, loss_D

    loss_closure = [np.mean(loss_mc_seq), np.mean(loss_ei_seq), np.mean(loss_g_seq),\
           np.mean(loss_G_seq), np.mean(loss_D_seq)]

    if reportpsnr:
        loss_closure.append(np.mean(psnr_seq))
        loss_closure.append(np.mean(mse_seq))

    return loss_closure