Exemplo n.º 1
0
def main():
    global opt, model
    opt = parser.parse_args()
    print opt

    #     prefixModelName = 'Regressor_1112_'
    #     prefixPredictedFN = 'preSub1_1112_'
    #     showTrainLossEvery = 100
    #     lr = 1e-4
    #     showTestPerformanceEvery = 2000
    #     saveModelEvery = 2000
    #     decLREvery = 40000
    #     numofIters = 200000
    #     how2normalize = 0

    netD = Discriminator()
    netD.apply(weights_init)
    netD.cuda()

    optimizerD = optim.Adam(netD.parameters(), lr=1e-3)
    criterion_bce = nn.BCELoss()
    criterion_bce.cuda()

    #net=UNet()
    net = UNet(in_channel=5, n_classes=1)
    #     net.apply(weights_init)
    net.cuda()
    params = list(net.parameters())
    print('len of params is ')
    print(len(params))
    print('size of params is ')
    print(params[0].size())

    optimizer = optim.Adam(net.parameters(), lr=opt.lr)
    criterion_L2 = nn.MSELoss()
    criterion_L1 = nn.L1Loss()
    #criterion = nn.CrossEntropyLoss()
    #     criterion = nn.NLLLoss2d()

    given_weight = torch.cuda.FloatTensor([1, 4, 4, 2])

    criterion_3d = CrossEntropy3d(weight=given_weight)

    criterion_3d = criterion_3d.cuda()
    criterion_L2 = criterion_L2.cuda()
    criterion_L1 = criterion_L1.cuda()

    #inputs=Variable(torch.randn(1000,1,32,32)) #here should be tensor instead of variable
    #targets=Variable(torch.randn(1000,10,1,1)) #here should be tensor instead of variable
    #     trainset=data_utils.TensorDataset(inputs, targets)
    #     trainloader = data_utils.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
    #     inputs=torch.randn(1000,1,32,32)
    #     targets=torch.LongTensor(1000)

    path_test = '/shenlab/lab_stor5/dongnie/3T7T'
    path_patients_h5 = '/shenlab/lab_stor5/dongnie/3T7T/histH5Data_64to64'
    path_patients_h5_test = '/shenlab/lab_stor5/dongnie/3T7T/histH5DataTest_64to64'
    #     batch_size=10
    data_generator = Generator_2D_slices(path_patients_h5,
                                         opt.batchSize,
                                         inputKey='data3T',
                                         outputKey='data7T')
    data_generator_test = Generator_2D_slices(path_patients_h5_test,
                                              opt.batchSize,
                                              inputKey='data3T',
                                              outputKey='data7T')

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1
            net.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))
########### We'd better use dataloader to load a lot of data,and we also should train several epoches###############
########### We'd better use dataloader to load a lot of data,and we also should train several epoches###############
    running_loss = 0.0
    start = time.time()
    for iter in range(opt.start_epoch, opt.numofIters + 1):
        #print('iter %d'%iter)

        inputs, labels = data_generator.next()
        #         xx = np.transpose(inputs,(5,64,64))
        inputs = np.transpose(inputs, (0, 3, 1, 2))
        inputs = np.squeeze(inputs)  #5x64x64
        #         print 'shape is ....',inputs.shape
        labels = np.squeeze(labels)  #64x64
        #         labels = labels.astype(int)

        inputs = torch.from_numpy(inputs)
        labels = torch.from_numpy(labels)
        inputs = inputs.cuda()
        labels = labels.cuda()
        #we should consider different data to train

        #wrap them into Variable
        inputs, labels = Variable(inputs), Variable(labels)

        ## (1) update D network: maximize log(D(x)) + log(1 - D(G(z)))
        if opt.isAdLoss:
            outputG = net(inputs)  #5x64x64->1*64x64

            if len(labels.size()) == 3:
                labels = labels.unsqueeze(1)
            outputD_real = netD(labels)

            if len(outputG.size()) == 3:
                outputG = outputG.unsqueeze(1)

            outputD_fake = netD(outputG)
            netD.zero_grad()
            batch_size = inputs.size(0)
            real_label = torch.ones(batch_size, 1)
            real_label = real_label.cuda()
            #print(real_label.size())
            real_label = Variable(real_label)
            #print(outputD_real.size())
            loss_real = criterion_bce(outputD_real, real_label)
            loss_real.backward()
            #train with fake data
            fake_label = torch.zeros(batch_size, 1)
            #         fake_label = torch.FloatTensor(batch_size)
            #         fake_label.data.resize_(batch_size).fill_(0)
            fake_label = fake_label.cuda()
            fake_label = Variable(fake_label)
            loss_fake = criterion_bce(outputD_fake, fake_label)
            loss_fake.backward()

            lossD = loss_real + loss_fake
            #             print 'loss_real is ',loss_real.data[0],'loss_fake is ',loss_fake.data[0],'outputD_real is',outputD_real.data[0]
            #             print('loss for discriminator is %f'%lossD.data[0])
            #update network parameters
            optimizerD.step()

        ## (2) update G network: minimize the L1/L2 loss, maximize the D(G(x))

#         print inputs.data.shape
        outputG = net(
            inputs)  #here I am not sure whether we should use twice or not
        net.zero_grad()
        lossG_G = criterion_L1(outputG, torch.squeeze(labels))
        lossG_G.backward()  #compute gradients

        if opt.isAdLoss:
            #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the
            #angel of equation (note the max and min difference for generator and discriminator)
            outputG = net(inputs)

            if len(outputG.size()) == 3:
                outputG = outputG.unsqueeze(1)

            outputD = netD(outputG)
            lossG_D = criterion_bce(
                outputD, real_label
            )  #note, for generator, the label for outputG is real, because the G wants to confuse D
            lossG_D.backward()

        #for other losses, we can define the loss function following the pytorch tutorial

        optimizer.step()  #update network parameters

        #print('loss for generator is %f'%lossG.data[0])
        #print statistics
        running_loss = running_loss + lossG_G.data[0]

        if iter % opt.showTrainLossEvery == 0:  #print every 2000 mini-batches
            print '************************************************'
            print 'time now is: ' + time.asctime(time.localtime(time.time()))
            #             print 'running loss is ',running_loss
            print 'average running loss for generator between iter [%d, %d] is: %.3f' % (
                iter - 100 + 1, iter, running_loss / 100)

            print 'lossG_G is %.2f respectively.' % (lossG_G.data[0])
            if opt.isAdLoss:
                print 'loss_real is ', loss_real.data[
                    0], 'loss_fake is ', loss_fake.data[
                        0], 'outputD_real is', outputD_real.data[0]
                print('loss for discriminator is %f' % lossD.data[0])

            print 'cost time for iter [%d, %d] is %.2f' % (
                iter - 100 + 1, iter, time.time() - start)
            print '************************************************'
            running_loss = 0.0
            start = time.time()
        if iter % opt.saveModelEvery == 0:  #save the model
            torch.save(net.state_dict(), opt.prefixModelName + '%d.pt' % iter)
            print 'save model: ' + opt.prefixModelName + '%d.pt' % iter
        if iter % opt.decLREvery == 0:
            opt.lr = opt.lr * 0.1
            adjust_learning_rate(optimizer, opt.lr)

        if iter % opt.showTestPerformanceEvery == 0:  #test one subject
            # to test on the validation dataset in the format of h5
            inputs, labels = data_generator_test.next()

            inputs = np.transpose(inputs, (0, 3, 1, 2))
            inputs = np.squeeze(inputs)

            labels = np.squeeze(labels)

            inputs = torch.from_numpy(inputs)
            labels = torch.from_numpy(labels)
            inputs = inputs.cuda()
            labels = labels.cuda()
            inputs, labels = Variable(inputs), Variable(labels)
            outputG = net(inputs)
            lossG_G = criterion_L1(outputG, torch.squeeze(labels))

            print '.......come to validation stage: iter {}'.format(
                iter), '........'
            print 'lossG_G is %.2f.' % (lossG_G.data[0])

            mr_test_itk = sitk.ReadImage(
                os.path.join(path_test, 'S1to1_3t.nii.gz'))
            ct_test_itk = sitk.ReadImage(
                os.path.join(path_test, 'S1to1_7t.nii.gz'))

            mrnp = sitk.GetArrayFromImage(mr_test_itk)
            ctnp = sitk.GetArrayFromImage(ct_test_itk)

            ##### specific normalization #####
            mu = np.mean(mrnp)
            maxV, minV = np.percentile(mrnp, [99, 25])
            #mrimg=mrimg
            mrnp = (mrnp - minV) / (maxV - minV)

            #for training data in pelvicSeg
            if opt.how2normalize == 1:
                maxV, minV = np.percentile(mrnp, [99, 1])
                print 'maxV,', maxV, ' minV, ', minV
                mrnp = (mrnp - mu) / (maxV - minV)
                print 'unique value: ', np.unique(ctnp)

            #for training data in pelvicSeg
            if opt.how2normalize == 2:
                maxV, minV = np.percentile(mrnp, [99, 1])
                print 'maxV,', maxV, ' minV, ', minV
                mrnp = (mrnp - mu) / (maxV - minV)
                print 'unique value: ', np.unique(ctnp)

            #for training data in pelvicSegRegH5
            if opt.how2normalize == 3:
                std = np.std(mrnp)
                mrnp = (mrnp - mu) / std
                print 'maxV,', np.ndarray.max(mrnp), ' minV, ', np.ndarray.min(
                    mrnp)

#             full image version with average over the overlapping regions
#             ct_estimated = testOneSubject(mrnp,ctnp,[3,168,112],[1,168,112],[1,8,8],netG,'Segmentor_model_%d.pt'%iter)

#             sz = mrnp.shape
#             matFA = np.zeros(sz[0],3,sz[2],sz[3],sz[4])
            matFA = mrnp
            #note, matFA and matFAOut same size
            matGT = ctnp
            #                 volFA = sitk.GetImageFromArray(matFA)
            #                 sitk.WriteImage(volFA,'volFA'+'.nii.gz')
            #                 volGT = sitk.GetImageFromArray(matGT)
            #                 sitk.WriteImage(volGT,'volGT'+'.nii.gz')
            #             print 'matFA shape: ',matFA.shape
            matOut = test_1_subject(matFA, matGT, [64, 64, 5], [64, 64, 1],
                                    [32, 32, 1], net,
                                    opt.prefixModelName + '%d.pt' % iter)
            print 'matOut shape: ', matOut.shape
            ct_estimated = matOut

            #             ct_estimated = np.rint(ct_estimated)
            #             ct_estimated = denoiseImg(ct_estimated, kernel=np.ones((20,20,20)))
            #             diceBladder = dice(ct_estimated,ctnp,1)
            #             diceProstate = dice(ct_estimated,ctnp,2)
            #             diceRectumm = dice(ct_estimated,ctnp,3)
            itspsnr = psnr(ct_estimated, matGT)

            print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape
            print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape
            print 'psnr = ', itspsnr
            volout = sitk.GetImageFromArray(ct_estimated)
            sitk.WriteImage(
                volout, opt.prefixPredictedFN + '{}'.format(iter) + '.nii.gz')


#             netG.save_state_dict('Segmentor_model_%d.pt'%iter)
#             netD.save_state_dic('Discriminator_model_%d.pt'%iter)

    print('Finished Training')
Exemplo n.º 2
0
def main():    
    print opt    
        
#     prefixModelName = 'Regressor_1112_'
#     prefixPredictedFN = 'preSub1_1112_'
#     showTrainLossEvery = 100
#     lr = 1e-4
#     showTestPerformanceEvery = 2000
#     saveModelEvery = 2000
#     decLREvery = 40000
#     numofIters = 200000
#     how2normalize = 0


    netD = Discriminator()
    netD.apply(weights_init)
    netD.cuda()
    
    optimizerD = optim.Adam(netD.parameters(),lr=1e-3)
    criterion_bce=nn.BCELoss()
    criterion_bce.cuda()
    
    #net=UNet()
    if opt.whichNet==1:
        net = UNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
    elif opt.whichNet==2:
        net = ResUNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
    elif opt.whichNet==3:
        net = UNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1)
    elif opt.whichNet==4:
        net = ResUNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1, dp_prob = opt.dropout_rate)
    #net.apply(weights_init)
    net.cuda()
    params = list(net.parameters())
    print('len of params is ')
    print(len(params))
    print('size of params is ')
    print(params[0].size())
    
 
    
    optimizer = optim.Adam(net.parameters(),lr=opt.lr)
    criterion_L2 = nn.MSELoss()
    criterion_L1 = nn.L1Loss()
    criterion_RTL1 = RelativeThreshold_RegLoss(opt.RT_th)
    criterion_gdl = gdl_loss(opt.gdlNorm)
    #criterion = nn.CrossEntropyLoss()
#     criterion = nn.NLLLoss2d()
    
    given_weight = torch.cuda.FloatTensor([1,4,4,2])
    
    criterion_3d = CrossEntropy3d(weight=given_weight)
    
    criterion_3d = criterion_3d.cuda()
    criterion_L2 = criterion_L2.cuda()
    criterion_L1 = criterion_L1.cuda()
    criterion_RTL1 = criterion_RTL1.cuda()
    criterion_gdl = criterion_gdl.cuda()
    
    #inputs=Variable(torch.randn(1000,1,32,32)) #here should be tensor instead of variable
    #targets=Variable(torch.randn(1000,10,1,1)) #here should be tensor instead of variable
#     trainset=data_utils.TensorDataset(inputs, targets)
#     trainloader = data_utils.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
#     inputs=torch.randn(1000,1,32,32)
#     targets=torch.LongTensor(1000)
    
    path_test ='/home/niedong/DataCT/data_niigz/'
    path_patients_h5 = '/home/niedong/DataCT/h5Data_snorm/trainBatch2D_H5'
    path_patients_h5_test ='/home/niedong/DataCT/h5Data_snorm/valBatch2D_H5'
#     batch_size=10
    #data_generator = Generator_2D_slices(path_patients_h5,opt.batchSize,inputKey='data3T',outputKey='data7T')
    #data_generator_test = Generator_2D_slices(path_patients_h5_test,opt.batchSize,inputKey='data3T',outputKey='data7T')

    data_generator = Generator_2D_slicesV1(path_patients_h5,opt.batchSize, inputKey='dataLPET', segKey='dataCT', contourKey='dataHPET')
    data_generator_test = Generator_2D_slicesV1(path_patients_h5_test,opt.batchSize, inputKey='dataLPET', segKey='dataCT', contourKey='dataHPET')
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            net.load_state_dict(checkpoint['model'])
            opt.start_epoch = 100000
            opt.start_epoch = checkpoint["epoch"] + 1
            # net.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))
########### We'd better use dataloader to load a lot of data,and we also should train several epoches############### 
########### We'd better use dataloader to load a lot of data,and we also should train several epoches############### 
    running_loss = 0.0
    start = time.time()
    for iter in range(opt.start_epoch, opt.numofIters+1):
        #print('iter %d'%iter)
        
        inputs, exinputs, labels = data_generator.next()
#         xx = np.transpose(inputs,(5,64,64))
#         inputs = np.transpose(inputs,(0,3,1,2))
        inputs = np.squeeze(inputs) #5x64x64
        # exinputs = np.transpose(exinputs,(0,3,1,2))
        exinputs = np.squeeze(exinputs) #5x64x64
#         print 'shape is ....',inputs.shape
        labels = np.squeeze(labels) #64x64
#         labels = labels.astype(int)

        inputs = inputs.astype(float)
        inputs = torch.from_numpy(inputs)
        inputs = inputs.float()
        exinputs = exinputs.astype(float)
        exinputs = torch.from_numpy(exinputs)
        exinputs = exinputs.float()
        labels = labels.astype(float)
        labels = torch.from_numpy(labels)
        labels = labels.float()
        #print type(inputs), type(exinputs)
        if opt.isMultiSource:
            source = torch.cat((inputs, exinputs),dim=1)
        else:
            source = inputs
        #source = inputs
        mid_slice = opt.numOfChannel_singleSource//2
        residual_source = inputs[:, mid_slice, ...]
        #inputs = inputs.cuda()
        #exinputs = exinputs.cuda()
        source = source.cuda()
        residual_source = residual_source.cuda()
        labels = labels.cuda()
        #we should consider different data to train
        
        #wrap them into Variable
        source, residual_source, labels = Variable(source),Variable(residual_source), Variable(labels)
        #inputs, exinputs, labels = Variable(inputs),Variable(exinputs), Variable(labels)
        
        ## (1) update D network: maximize log(D(x)) + log(1 - D(G(z)))
        if opt.isAdLoss:
            outputG = net(source,residual_source) #5x64x64->1*64x64
            
            if len(labels.size())==3:
                labels = labels.unsqueeze(1)
                
            outputD_real = netD(labels)
            
            if len(outputG.size())==3:
                outputG = outputG.unsqueeze(1)
                
            outputD_fake = netD(outputG)
            netD.zero_grad()
            batch_size = inputs.size(0)
            real_label = torch.ones(batch_size,1)
            real_label = real_label.cuda()
            #print(real_label.size())
            real_label = Variable(real_label)
            #print(outputD_real.size())
            loss_real = criterion_bce(outputD_real,real_label)
            loss_real.backward()
            #train with fake data
            fake_label = torch.zeros(batch_size,1)
    #         fake_label = torch.FloatTensor(batch_size)
    #         fake_label.data.resize_(batch_size).fill_(0)
            fake_label = fake_label.cuda()
            fake_label = Variable(fake_label)
            loss_fake = criterion_bce(outputD_fake,fake_label)
            loss_fake.backward()
            
            lossD = loss_real + loss_fake
#             print 'loss_real is ',loss_real.data[0],'loss_fake is ',loss_fake.data[0],'outputD_real is',outputD_real.data[0]
#             print('loss for discriminator is %f'%lossD.data[0])
            #update network parameters
            optimizerD.step()
            
        if opt.isWDist:
            one = torch.FloatTensor([1])
            mone = one * -1
            one = one.cuda()
            mone = mone.cuda()
            
            netD.zero_grad()
            
            outputG = net(source,residual_source) #5x64x64->1*64x64
            
            if len(labels.size())==3:
                labels = labels.unsqueeze(1)
                
            outputD_real = netD(labels)
            
            if len(outputG.size())==3:
                outputG = outputG.unsqueeze(1)
                
            outputD_fake = netD(outputG)

            
            batch_size = inputs.size(0)
            
            D_real = outputD_real.mean()
            # print D_real
            D_real.backward(mone)
        
        
            D_fake = outputD_fake.mean()
            D_fake.backward(one)
        
            gradient_penalty = opt.lambda_D_WGAN_GP*calc_gradient_penalty(netD, labels.data, outputG.data)
            gradient_penalty.backward()
            
            D_cost = D_fake - D_real + gradient_penalty
            Wasserstein_D = D_real - D_fake
            
            optimizerD.step()
        
        
        ## (2) update G network: minimize the L1/L2 loss, maximize the D(G(x))
        
#         print inputs.data.shape
        #outputG = net(source) #here I am not sure whether we should use twice or not
        outputG = net(source,residual_source) #5x64x64->1*64x64
        net.zero_grad()
        if opt.whichLoss==1:
            lossG_G = criterion_L1(torch.squeeze(outputG), torch.squeeze(labels))
        elif opt.whichLoss==2:
            lossG_G = criterion_RTL1(torch.squeeze(outputG), torch.squeeze(labels))
        else:
            lossG_G = criterion_L2(torch.squeeze(outputG), torch.squeeze(labels))
        lossG_G = opt.lossBase * lossG_G
        lossG_G.backward() #compute gradients

        if opt.isGDL:
            lossG_gdl = opt.lambda_gdl * criterion_gdl(outputG,torch.unsqueeze(labels,1))
            lossG_gdl.backward() #compute gradients

        if opt.isAdLoss:
            #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the 
            #angel of equation (note the max and min difference for generator and discriminator)
            #outputG = net(inputs)
            outputG = net(source,residual_source) #5x64x64->1*64x64
            
            if len(outputG.size())==3:
                outputG = outputG.unsqueeze(1)
            
            outputD = netD(outputG)
            lossG_D = opt.lambda_AD*criterion_bce(outputD,real_label) #note, for generator, the label for outputG is real, because the G wants to confuse D
            lossG_D.backward()
            
        if opt.isWDist:
            #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the 
            #angel of equation (note the max and min difference for generator and discriminator)
            #outputG = net(inputs)
            outputG = net(source,residual_source) #5x64x64->1*64x64
            
            if len(outputG.size())==3:
                outputG = outputG.unsqueeze(1)
            
            outputD_fake = netD(outputG)

            outputD_fake = outputD_fake.mean()
            
            lossG_D = opt.lambda_AD*outputD_fake.mean() #note, for generator, the label for outputG is real, because the G wants to confuse D
            lossG_D.backward(mone)
        
        #for other losses, we can define the loss function following the pytorch tutorial
        
        optimizer.step() #update network parameters

        #print('loss for generator is %f'%lossG.data[0])
        #print statistics
        running_loss = running_loss + lossG_G.data[0]

        
        if iter%opt.showTrainLossEvery==0: #print every 2000 mini-batches
            print '************************************************'
            print 'time now is: ' + time.asctime(time.localtime(time.time()))
#             print 'running loss is ',running_loss
            print 'average running loss for generator between iter [%d, %d] is: %.5f'%(iter - 100 + 1,iter,running_loss/100)
            
            print 'lossG_G is %.5f respectively.'%(lossG_G.data[0])

            if opt.isGDL:
                print('loss for GDL loss is %f'%lossG_gdl.data[0])

            if opt.isAdLoss:
                print 'loss_real is ',loss_real.data[0],'loss_fake is ',loss_fake.data[0],'outputD_real is',outputD_real.data[0]
                print('loss for discriminator is %f'%lossD.data[0])  
                
            if opt.isWDist:
                print 'loss_real is ',D_real.data[0],'loss_fake is ',D_fake.data[0]
                print('loss for discriminator is %f'%Wasserstein_D.data[0], ' D cost is %f'%D_cost)                
            
            print 'cost time for iter [%d, %d] is %.2f'%(iter - 100 + 1,iter, time.time()-start)
            print '************************************************'
            running_loss = 0.0
            start = time.time()
        if iter%opt.saveModelEvery==0: #save the model
            state = {
                'epoch': iter+1,
                'model': net.state_dict()
            }
            torch.save(state, opt.prefixModelName+'%d.pt'%iter)
            print 'save model: '+opt.prefixModelName+'%d.pt'%iter

            if opt.isAdLoss:
                torch.save(netD.state_dict(), opt.prefixModelName+'_net_D%d.pt'%iter)
        if iter%opt.decLREvery==0:
            opt.lr = opt.lr*0.5
            adjust_learning_rate(optimizer, opt.lr)
                
        if iter%opt.showValPerformanceEvery==0: #test one subject
            # to test on the validation dataset in the format of h5 
            inputs,exinputs,labels = data_generator_test.next()

            # inputs = np.transpose(inputs,(0,3,1,2))
            inputs = np.squeeze(inputs)

            # exinputs = np.transpose(exinputs, (0, 3, 1, 2))
            exinputs = np.squeeze(exinputs)  # 5x64x64

            labels = np.squeeze(labels)

            inputs = torch.from_numpy(inputs)
            inputs = inputs.float()
            exinputs = torch.from_numpy(exinputs)
            exinputs = exinputs.float()
            labels = torch.from_numpy(labels)
            labels = labels.float()
            mid_slice = opt.numOfChannel_singleSource // 2
            residual_source = inputs[:, mid_slice, ...]
            if opt.isMultiSource:
                source = torch.cat((inputs, exinputs), dim=1)
            else:
                source = inputs
            source = source.cuda()
            residual_source = residual_source.cuda()
            labels = labels.cuda()
            source,residual_source,labels = Variable(source),Variable(residual_source), Variable(labels)

            # source = inputs
            #outputG = net(inputs)
            outputG = net(source,residual_source) #5x64x64->1*64x64

            if opt.whichLoss == 1:
                lossG_G = criterion_L1(torch.squeeze(outputG), torch.squeeze(labels))
            elif opt.whichLoss == 2:
                lossG_G = criterion_RTL1(torch.squeeze(outputG), torch.squeeze(labels))
            else:
                lossG_G = criterion_L2(torch.squeeze(outputG), torch.squeeze(labels))
            lossG_G = opt.lossBase * lossG_G
            print '.......come to validation stage: iter {}'.format(iter),'........'
            print 'lossG_G is %.5f.'%(lossG_G.data[0])

            if opt.isGDL:
                lossG_gdl = criterion_gdl(outputG, labels)
                print('loss for GDL loss is %f'%lossG_gdl.data[0])

        if iter % opt.showTestPerformanceEvery == 0:  # test one subject
            mr_test_itk=sitk.ReadImage(os.path.join(path_test,'sub1_sourceCT.nii.gz'))
            ct_test_itk=sitk.ReadImage(os.path.join(path_test,'sub1_extraCT.nii.gz'))
            hpet_test_itk = sitk.ReadImage(os.path.join(path_test, 'sub1_targetCT.nii.gz'))

            spacing = hpet_test_itk.GetSpacing()
            origin = hpet_test_itk.GetOrigin()
            direction = hpet_test_itk.GetDirection()

            mrnp=sitk.GetArrayFromImage(mr_test_itk)
            ctnp=sitk.GetArrayFromImage(ct_test_itk)
            hpetnp=sitk.GetArrayFromImage(hpet_test_itk)

            ##### specific normalization #####
            # mu = np.mean(mrnp)
            # maxV, minV = np.percentile(mrnp, [99 ,25])
            # #mrimg=mrimg
            # mrnp = (mrnp-minV)/(maxV-minV)



            #for training data in pelvicSeg
            if opt.how2normalize == 1:
                maxV, minV = np.percentile(mrnp, [99 ,1])
                print 'maxV,',maxV,' minV, ',minV
                mrnp = (mrnp-mu)/(maxV-minV)
                print 'unique value: ',np.unique(ctnp)

            #for training data in pelvicSeg
            if opt.how2normalize == 2:
                maxV, minV = np.percentile(mrnp, [99 ,1])
                print 'maxV,',maxV,' minV, ',minV
                mrnp = (mrnp-mu)/(maxV-minV)
                print 'unique value: ',np.unique(ctnp)

            #for training data in pelvicSegRegH5
            if opt.how2normalize== 3:
                std = np.std(mrnp)
                mrnp = (mrnp - mu)/std
                print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp)

            if opt.how2normalize == 4:
                maxLPET = 149.366742
                maxPercentLPET = 7.76
                minLPET = 0.00055037
                meanLPET = 0.27593288
                stdLPET = 0.75747500

                # for rsCT
                maxCT = 27279
                maxPercentCT = 1320
                minCT = -1023
                meanCT = -601.1929
                stdCT = 475.034

                # for s-pet
                maxSPET = 156.675962
                maxPercentSPET = 7.79
                minSPET = 0.00055037
                meanSPET = 0.284224789
                stdSPET = 0.7642257

                #matLPET = (mrnp - meanLPET) / (stdLPET)
                matLPET = (mrnp - minLPET) / (maxPercentLPET - minLPET)
                matCT = (ctnp - meanCT) / stdCT
                matSPET = (hpetnp - minSPET) / (maxPercentSPET - minSPET)

            if opt.how2normalize == 5:
                # for rsCT
                maxCT = 27279
                maxPercentCT = 1320
                minCT = -1023
                meanCT = -601.1929
                stdCT = 475.034

                print
                'ct, max: ', np.amax(ctnp), ' ct, min: ', np.amin(ctnp)

                # matLPET = (mrnp - meanLPET) / (stdLPET)
                matLPET = mrnp
                matCT = (ctnp - meanCT) / stdCT
                matSPET = hpetnp

            if opt.how2normalize == 6:
                maxPercentPET, minPercentPET = np.percentile(mrnp, [99.5, 0])
                maxPercentCT, minPercentCT = np.percentile(ctnp, [99.5, 0])
                print 'maxPercentPET: ', maxPercentPET, ' minPercentPET: ', minPercentPET, ' maxPercentCT: ', maxPercentCT, 'minPercentCT: ', minPercentCT

                matLPET = (mrnp - minPercentPET) / (maxPercentPET - minPercentPET)
                matSPET = (hpetnp - minPercentPET) / (maxPercentPET - minPercentPET)

                matCT = (ctnp - minPercentCT) / (maxPercentCT - minPercentCT)


            if not opt.isMultiSource:
                matFA = matLPET
                matGT = hpetnp

                print 'matFA shape: ',matFA.shape, ' matGT shape: ', matGT.shape
                matOut = testOneSubject_aver_res(matFA,matGT,[5,64,64],[1,64,64],[1,32,32],net,opt.prefixModelName+'%d.pt'%iter)
                print 'matOut shape: ',matOut.shape
                if opt.how2normalize==6:
                    ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET
                else:
                    ct_estimated = matOut


                itspsnr = psnr(ct_estimated, matGT)

                print 'pred: ',ct_estimated.dtype, ' shape: ',ct_estimated.shape
                print 'gt: ',ctnp.dtype,' shape: ',ct_estimated.shape
                print 'psnr = ',itspsnr
                volout = sitk.GetImageFromArray(ct_estimated)
                volout.SetSpacing(spacing)
                volout.SetOrigin(origin)
                volout.SetDirection(direction)
                sitk.WriteImage(volout,opt.prefixPredictedFN+'{}'.format(iter)+'.nii.gz')
            else:
                matFA = matLPET
                matGT = hpetnp
                print 'matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape
                matOut = testOneSubject_aver_res_multiModal(matFA, matCT, matGT, [5, 64, 64], [1, 64, 64], [1, 32, 32], net,
                                                            opt.prefixModelName + '%d.pt' % iter)
                print 'matOut shape: ', matOut.shape
                if opt.how2normalize==6:
                    ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET
                else:
                    ct_estimated = matOut

                itspsnr = psnr(ct_estimated, matGT)

                print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape
                print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape
                print 'psnr = ', itspsnr
                volout = sitk.GetImageFromArray(ct_estimated)
                volout.SetSpacing(spacing)
                volout.SetOrigin(origin)
                volout.SetDirection(direction)
                sitk.WriteImage(volout, opt.prefixPredictedFN + '{}'.format(iter) + '.nii.gz')
        
    print('Finished Training')
Exemplo n.º 3
0
def main():
    print(opt)

    netD = Discriminator()
    netD.apply(weights_init)
    netD.cuda()

    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_netD)
    criterion_bce = nn.BCELoss()
    criterion_bce.cuda()

    if opt.whichNet == 1:
        net = UNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
    elif opt.whichNet == 2:
        net = ResUNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
    elif opt.whichNet == 3:
        net = UNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1)
    elif opt.whichNet == 4:
        net = ResUNet_LRes(in_channel=opt.numOfChannel_allSource,
                           n_classes=1,
                           dp_prob=opt.dropout_rate)
    net.cuda()
    params = list(net.parameters())
    print('len of params is ')
    print(len(params))
    print('size of params is ')
    print(params[0].size())

    optimizer = optim.Adam(net.parameters(), lr=opt.lr)
    criterion_L2 = nn.MSELoss()
    criterion_L1 = nn.L1Loss()
    criterion_RTL1 = RelativeThreshold_RegLoss(opt.RT_th)
    criterion_gdl = gdl_loss(opt.gdlNorm)

    given_weight = torch.cuda.FloatTensor([1, 4, 4, 2])

    criterion_3d = CrossEntropy3d(weight=given_weight)

    criterion_3d = criterion_3d.cuda()
    criterion_L2 = criterion_L2.cuda()
    criterion_L1 = criterion_L1.cuda()
    criterion_RTL1 = criterion_RTL1.cuda()
    criterion_gdl = criterion_gdl.cuda()

    path_test = opt.path_test
    path_train = opt.path_train
    path_dev = opt.path_dev
    data_generator = Generator_2D_slices(path_train,
                                         opt.batchSize,
                                         inputKey='noisy',
                                         outputKey='clear')
    data_generator_dev = Generator_2D_slices(path_dev,
                                             opt.batchSize,
                                             inputKey='noisy',
                                             outputKey='noisy')

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            net.load_state_dict(checkpoint['model'])
            opt.start_epoch = checkpoint["epoch"] + 1
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))


########### We'd better use dataloader to load a lot of data,and we also should train several epoches###############
########### We'd better use dataloader to load a lot of data,and we also should train several epoches###############
    running_loss = 0.0
    start = time.time()
    for iter in range(opt.start_epoch, opt.numofIters + 1):
        inputs, labels = next(data_generator)
        exinputs = inputs

        inputs = np.squeeze(inputs)  #5x64x64
        exinputs = np.squeeze(exinputs)  #5x64x64
        labels = np.squeeze(labels)  #64x64

        inputs = inputs.astype(float)
        inputs = torch.from_numpy(inputs)
        inputs = inputs.float()
        exinputs = exinputs.astype(float)
        exinputs = torch.from_numpy(exinputs)
        exinputs = exinputs.float()
        labels = labels.astype(float)
        labels = torch.from_numpy(labels)
        labels = labels.float()
        source = inputs
        mid_slice = opt.numOfChannel_singleSource // 2
        residual_source = inputs[:, mid_slice, ...]
        source = source.cuda()
        residual_source = residual_source.cuda()
        labels = labels.cuda()

        #wrap them into Variable
        source, residual_source, labels = Variable(source), Variable(
            residual_source), Variable(labels)

        ## (1) update D network: maximize log(D(x)) + log(1 - D(G(z)))
        if opt.isAdLoss:
            if opt.whichNet == 3 or opt.whichNet == 4:
                outputG = net(source, residual_source)  # 5x64x64->1*64x64
            else:
                outputG = net(source)  # 5x64x64->1*64x64

            if len(labels.size()) == 3:
                labels = labels.unsqueeze(1)

            outputD_real = netD(labels)
            outputD_real = torch.sigmoid(outputD_real)

            if len(outputG.size()) == 3:
                outputG = outputG.unsqueeze(1)

            outputD_fake = netD(outputG)
            outputD_fake = torch.sigmoid(outputD_fake)
            netD.zero_grad()
            batch_size = inputs.size(0)
            real_label = torch.ones(batch_size, 1)
            real_label = real_label.cuda()
            real_label = Variable(real_label)
            loss_real = criterion_bce(outputD_real, real_label)
            loss_real.backward()
            #train with fake data
            fake_label = torch.zeros(batch_size, 1)
            fake_label = fake_label.cuda()
            fake_label = Variable(fake_label)
            loss_fake = criterion_bce(outputD_fake, fake_label)
            loss_fake.backward()

            lossD = loss_real + loss_fake
            #update network parameters
            optimizerD.step()

        if opt.isWDist:
            one = torch.FloatTensor([1])
            mone = one * -1
            one = one.cuda()
            mone = mone.cuda()

            netD.zero_grad()

            if opt.whichNet == 3 or opt.whichNet == 4:
                outputG = net(source, residual_source)  # 5x64x64->1*64x64
            else:
                outputG = net(source)  # 5x64x64->1*64x64

            if len(labels.size()) == 3:
                labels = labels.unsqueeze(1)

            outputD_real = netD(labels)

            if len(outputG.size()) == 3:
                outputG = outputG.unsqueeze(1)

            outputD_fake = netD(outputG)

            batch_size = inputs.size(0)

            D_real = outputD_real.mean()
            D_real.backward(mone)

            D_fake = outputD_fake.mean()
            D_fake.backward(one)

            gradient_penalty = opt.lambda_D_WGAN_GP * calc_gradient_penalty(
                netD, labels.data, outputG.data)
            gradient_penalty.backward()

            D_cost = D_fake - D_real + gradient_penalty
            Wasserstein_D = D_real - D_fake

            optimizerD.step()

        ## (2) update G network: minimize the L1/L2 loss, maximize the D(G(x))
        if opt.whichNet == 3 or opt.whichNet == 4:
            outputG = net(source, residual_source)  # 5x64x64->1*64x64
        else:
            outputG = net(source)  # 5x64x64->1*64x64
        net.zero_grad()
        if opt.whichLoss == 1:
            lossG_G = criterion_L1(torch.squeeze(outputG),
                                   torch.squeeze(labels))
        elif opt.whichLoss == 2:
            lossG_G = criterion_RTL1(torch.squeeze(outputG),
                                     torch.squeeze(labels))
        else:
            lossG_G = criterion_L2(torch.squeeze(outputG),
                                   torch.squeeze(labels))
        lossG_G = opt.lossBase * lossG_G
        lossG_G.backward(retain_graph=True)  #compute gradients

        if opt.isGDL:
            lossG_gdl = opt.lambda_gdl * criterion_gdl(
                outputG, torch.unsqueeze(torch.squeeze(labels, 1), 1))
            lossG_gdl.backward()  #compute gradients

        if opt.isAdLoss:
            #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the
            #angel of equation (note the max and min difference for generator and discriminator)
            if opt.whichNet == 3 or opt.whichNet == 4:
                outputG = net(source, residual_source)  # 5x64x64->1*64x64
            else:
                outputG = net(source)  # 5x64x64->1*64x64

            if len(outputG.size()) == 3:
                outputG = outputG.unsqueeze(1)

            outputD = netD(outputG)
            outputD = torch.sigmoid(outputD)
            lossG_D = opt.lambda_AD * criterion_bce(
                outputD, real_label
            )  #note, for generator, the label for outputG is real, because the G wants to confuse D
            lossG_D.backward()

        if opt.isWDist:
            #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the
            #angel of equation (note the max and min difference for generator and discriminator)
            if opt.whichNet == 3 or opt.whichNet == 4:
                outputG = net(source, residual_source)  # 5x64x64->1*64x64
            else:
                outputG = net(source)  # 5x64x64->1*64x64
            if len(outputG.size()) == 3:
                outputG = outputG.unsqueeze(1)

            outputD_fake = netD(outputG)

            outputD_fake = outputD_fake.mean()

            lossG_D = opt.lambda_AD * outputD_fake.mean(
            )  #note, for generator, the label for outputG is real, because the G wants to confuse D
            lossG_D.backward(mone)

        #for other losses, we can define the loss function following the pytorch tutorial

        optimizer.step()  #update network parameters

        running_loss = running_loss + lossG_G.data.item()

        if iter % opt.showTrainLossEvery == 0:  #print every 2000 mini-batches
            print('************************************************')
            print('time now is: ' + time.asctime(time.localtime(time.time())))
            print(
                'average running loss for generator between iter [%d, %d] is: %.5f'
                % (iter - 100 + 1, iter, running_loss / 100))

            print('lossG_G is %.5f respectively.' % (lossG_G.data.item()))

            if opt.isGDL:
                print('loss for GDL loss is %f' % lossG_gdl.data.item())

            if opt.isAdLoss:
                print('loss for discriminator is %f' % lossD.data.item())
                print('lossG_D for discriminator is %f' % lossG_D.data.item())

            if opt.isWDist:
                print('loss_real is ',
                      torch.mean(D_real).data.item(), 'loss_fake is ',
                      torch.mean(D_fake).data.item())
                print(
                    'loss for discriminator is %f' % Wasserstein_D.data.item(),
                    ' D cost is %f' % D_cost)
                print('lossG_D for discriminator is %f' % lossG_D.data.item())

            print('cost time for iter [%d, %d] is %.2f' %
                  (iter - 100 + 1, iter, time.time() - start))
            print('************************************************')
            running_loss = 0.0
            start = time.time()
        if iter % opt.saveModelEvery == 0:  #save the model
            state = {'epoch': iter + 1, 'model': net.state_dict()}
            torch.save(state, opt.prefixModelName + '%d.pt' % iter)
            print('save model: ' + opt.prefixModelName + '%d.pt' % iter)

            if opt.isAdLoss or opt.isWDist:
                torch.save(netD.state_dict(),
                           opt.prefixModelName + '_net_D%d.pt' % iter)
        if iter % opt.decLREvery == 0:
            opt.lr = opt.lr * opt.lrDecRate
            adjust_learning_rate(optimizer, opt.lr)
            if opt.isAdLoss or opt.isWDist:
                opt.lr_netD = opt.lr_netD * opt.lrDecRate_netD
                adjust_learning_rate(optimizerD, opt.lr_netD)

        if iter % opt.showValPerformanceEvery == 0:  #test one subject
            # to test on the validation dataset in the format of h5
            inputs, labels = next(data_generator_dev)
            exinputs = inputs

            inputs = np.squeeze(inputs)

            exinputs = np.squeeze(exinputs)  # 5x64x64

            labels = np.squeeze(labels)

            inputs = torch.from_numpy(inputs)
            inputs = inputs.float()
            exinputs = torch.from_numpy(exinputs)
            exinputs = exinputs.float()
            labels = torch.from_numpy(labels)
            labels = labels.float()
            mid_slice = opt.numOfChannel_singleSource // 2
            residual_source = inputs[:, mid_slice, ...]
            source = inputs
            source = source.cuda()
            residual_source = residual_source.cuda()
            labels = labels.cuda()
            source, residual_source, labels = Variable(source), Variable(
                residual_source), Variable(labels)

            if opt.whichNet == 3 or opt.whichNet == 4:
                outputG = net(source, residual_source)  # 5x64x64->1*64x64
            else:
                outputG = net(source)  # 5x64x64->1*64x64
            if opt.whichLoss == 1:
                lossG_G = criterion_L1(torch.squeeze(outputG),
                                       torch.squeeze(labels))
            elif opt.whichLoss == 2:
                lossG_G = criterion_RTL1(torch.squeeze(outputG),
                                         torch.squeeze(labels))
            else:
                lossG_G = criterion_L2(torch.squeeze(outputG),
                                       torch.squeeze(labels))
            lossG_G = opt.lossBase * lossG_G
            print('.......come to validation stage: iter {}'.format(iter),
                  '........')
            print('lossG_G is %.5f.' % (lossG_G.data.item()))

            if opt.isGDL:
                lossG_gdl = criterion_gdl(
                    outputG, torch.unsqueeze(torch.squeeze(labels, 1), 1))
                print('loss for GDL loss is %f' % lossG_gdl.data.item())

        if iter % opt.showTestPerformanceEvery == 0:  # test one subject
            noisy_np = np.load(
                os.path.join(path_test, opt.test_input_file_name))
            noisy_np = noisy_np[0]
            noisy_np = noisy_np.reshape(1, noisy_np.shape[0],
                                        noisy_np.shape[1])

            clear_np = np.load(
                os.path.join(path_test, opt.test_label_file_name))
            clear_np = clear_np[0]
            clear_np = clear_np.reshape(1, clear_np.shape[0],
                                        clear_np.shape[1])

            hpetnp = clear_np

            #for training data in pelvicSeg
            if opt.how2normalize == 1:
                maxV, minV = np.percentile(noisy_np, [99, 1])
                print('maxV,', maxV, ' minV, ', minV)
                noisy_np = (noisy_np - mu) / (maxV - minV)
                print('unique value: ', np.unique(clear_np))

            #for training data in pelvicSeg
            if opt.how2normalize == 2:
                maxV, minV = np.percentile(noisy_np, [99, 1])
                print('maxV,', maxV, ' minV, ', minV)
                noisy_np = (noisy_np - mu) / (maxV - minV)
                print('unique value: ', np.unique(clear_np))

            #for training data in pelvicSegRegH5
            if opt.how2normalize == 3:
                std = np.std(noisy_np)
                noisy_np = (noisy_np - mu) / std
                print('maxV,', np.ndarray.max(noisy_np), ' minV, ',
                      np.ndarray.min(noisy_np))

            if opt.how2normalize == 4:
                maxLPET = 149.366742
                maxPercentLPET = 7.76
                minLPET = 0.00055037
                meanLPET = 0.27593288
                stdLPET = 0.75747500

                # for rsCT
                maxCT = 27279
                maxPercentCT = 1320
                minCT = -1023
                meanCT = -601.1929
                stdCT = 475.034

                # for s-pet
                maxSPET = 156.675962
                maxPercentSPET = 7.79
                minSPET = 0.00055037
                meanSPET = 0.284224789
                stdSPET = 0.7642257

                matLPET = (noisy_np - minLPET) / (maxPercentLPET - minLPET)
                matCT = (clear_np - meanCT) / stdCT
                matSPET = (hpetnp - minSPET) / (maxPercentSPET - minSPET)

            if opt.how2normalize == 5:
                # for rsCT
                maxCT = 27279
                maxPercentCT = 1320
                minCT = -1023
                meanCT = -601.1929
                stdCT = 475.034

                print('ct, max: ', np.amax(clear_np), ' ct, min: ',
                      np.amin(clear_np))

                matLPET = noisy_np
                matCT = (clear_np - meanCT) / stdCT
                matSPET = hpetnp

            if opt.how2normalize == 6:
                maxPercentPET, minPercentPET = np.percentile(
                    noisy_np, [99.5, 0])
                maxPercentCT, minPercentCT = np.percentile(clear_np, [99.5, 0])
                print('maxPercentPET: ', maxPercentPET, ' minPercentPET: ',
                      minPercentPET, ' maxPercentCT: ', maxPercentCT,
                      'minPercentCT: ', minPercentCT)

                matLPET = (noisy_np - minPercentPET) / (maxPercentPET -
                                                        minPercentPET)
                matSPET = (hpetnp - minPercentPET) / (maxPercentPET -
                                                      minPercentPET)

                matCT = (clear_np - minPercentCT) / (maxPercentCT -
                                                     minPercentCT)

            matFA = matLPET
            matGT = hpetnp

            print('matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape)
            matOut = testOneSubject_aver_res(
                matFA, matGT, [2, 64, 64], [1, 64, 64], [1, 8, 8], net,
                opt.prefixModelName + '%d.pt' % iter)
            print('matOut shape: ', matOut.shape)
            if opt.how2normalize == 6:
                clear_estimated = matOut * (maxPercentPET -
                                            minPercentPET) + minPercentPET
            else:
                clear_estimated = matOut

            itspsnr = psnr(clear_estimated, matGT)
            clear_estimated = clear_estimated.reshape(clear_estimated.shape[1],
                                                      clear_estimated.shape[2])

            print('pred: ', clear_estimated.dtype, ' shape: ',
                  clear_estimated.shape)
            print('gt: ', clear_np.dtype, ' shape: ', clear_estimated.shape)
            print('psnr = ', itspsnr)
            volout = sitk.GetImageFromArray(clear_estimated)
            volout = sitk.Cast(
                sitk.RescaleIntensity(volout,
                                      outputMinimum=0,
                                      outputMaximum=65535), sitk.sitkUInt16)
            sitk.WriteImage(
                volout, opt.prefixPredictedFN + '{}'.format(iter) + '.tiff')
            np.save(opt.prefixPredictedFN + '{}'.format(iter) + '.npy',
                    clear_estimated)

    print('Finished Training')