示例#1
0
def main():    
    print opt    

    # load/define networks
    netG = Generator(opt.finalSize*opt.finalSize, opt.ngf, gpu_ids=opt.gpuID)
#     netG = define_G(opt.fineSize*opt.fineSize, opt.ngf,opt.init_type, opt.gpuID)
    optimizerG = optim.Adam(netG.parameters(),lr=opt.lr_G)                            
    netG.apply(weights_init)
    netG.cuda()
    
    if opt.isAdLoss:
        netD = Discriminator(opt.ndf, opt.gpuID)
        netD.apply(weights_init)
        netD.cuda()
        optimizerD = optim.Adam(netD.parameters(),lr=opt.lr_D)
        
    if opt.isWDist:
        netD = Discriminator(opt.ndf, opt.gpuID)
        netD.apply(weights_init)
        netD.cuda()
        optimizerD = optim.Adam(netD.parameters(),lr=opt.lr_D)
        
    criterion_bce=nn.BCELoss()
    criterion_bce.cuda()
    
    params = list(netG.parameters())
    print('len of params is ')
    print(len(params))
    print('size of params is ')
    print(params[0].size())

    
    
    path_test ='/shenlab/lab_stor5/dongnie/3T7T/3t7tHistData/'
    path_patients_h5 = '/shenlab/lab_stor5/dongnie/3T7T/histH5Data'
    path_patients_h5_test ='/shenlab/lab_stor5/dongnie/3T7T/histH5Data'


    data_generator = Generator_2D_slices_oneKey(path_patients_h5,opt.batchSize, inputKey='data7T')
    data_generator_test = Generator_2D_slices_oneKey(path_patients_h5_test,opt.batchSize, inputKey='data7T')
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            netG.load_state_dict(checkpoint['model'])
            opt.start_epoch = 100000
            opt.start_epoch = checkpoint["epoch"] + 1
            # net.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))
########### We'd better use dataloader to load a lot of data,and we also should train several epoches############### 
########### We'd better use dataloader to load a lot of data,and we also should train several epoches############### 
    running_loss = 0.0
    start = time.time()
    for iter in range(opt.start_epoch, opt.numofIters+1):
        #print('iter %d'%iter)
        
        # we use a 128-dim vector as the noise of the input
        noise = torch.randn(opt.batchSize, 128)
        if opt.gpuID!=None:
            noise = noise.cuda()

        labels = data_generator.next() # labels means real images
        labels = np.squeeze(labels) #64x64
        labels = np.resize(labels, [opt.batchSize, opt.outputSizeOfG,opt.outputSizeOfG]) # here, we take 32 as output size
#         labels = labels.astype(int)


        inputs = noise

        labels = labels.astype(float)
        labels = torch.from_numpy(labels)
        labels = labels.float()

        source = inputs

        source = source.cuda()
#         residual_source = residual_source.cuda()
        labels = labels.cuda()
        #we should consider different data to train
        
        source, labels = Variable(source), Variable(labels)

        
        ## (1) update D network: maximize log(D(x)) + log(1 - D(G(z)))
        if opt.isAdLoss:
#             outputG = net(source,residual_source) #5x64x64->1*64x64

            if len(labels.size())==3:
                labels = labels.unsqueeze(1)
            # print 'labels.shape: ',labels.shape
            outputD_real = netD(labels)
            outputD_real = F.sigmoid(outputD_real)

            outputG = netG(source) #1x64x64->1*64x64
            # print 'outputG.shape: ',outputG.shape
            if len(outputG.size())==3:
                outputG = outputG.unsqueeze(1)
            outputD_fake = netD(outputG)
            outputD_fake = F.sigmoid(outputD_fake)

            ## update D network
            netD.zero_grad()
            batch_size = inputs.size(0)
            real_label = torch.ones(batch_size,1)
            real_label = real_label.cuda()
            #print(real_label.size())
            real_label = Variable(real_label)
            # print 'outputD_real.shape: ',outputD_real.shape,' outputD_fake.size(): ',outputD_fake.shape
            loss_real = criterion_bce(outputD_real,real_label)
            loss_real.backward()
            #train with fake data
            fake_label = torch.zeros(batch_size,1)

            fake_label = fake_label.cuda()
            fake_label = Variable(fake_label)
            loss_fake = criterion_bce(outputD_fake,fake_label)
            loss_fake.backward()
            
            lossD = loss_real + loss_fake

            optimizerD.step()
            
        if opt.isWDist:
            one = torch.FloatTensor([1])
            mone = one * -1
            one = one.cuda()
            mone = mone.cuda()
            
            netD.zero_grad()
            
            outputG = netG(source) #5x64x64->1*64x64
            
            if len(labels.size())==3:
                labels = labels.unsqueeze(1)
                
            outputD_real = netD(labels)
            
            if len(outputG.size())==3:
                outputG = outputG.unsqueeze(1)
                
            outputD_fake = netD(outputG)

            
            batch_size = inputs.size(0)
            
            D_real = outputD_real.mean()
            # print D_real
            D_real.backward(mone)
        
        
            D_fake = outputD_fake.mean()
            D_fake.backward(one)
        
            gradient_penalty = opt.lambda_D_WGAN_GP*calc_gradient_penalty(netD, labels.data, outputG.data)
            gradient_penalty.backward()
            
            D_cost = D_fake - D_real + gradient_penalty
            Wasserstein_D = D_real - D_fake
            
            optimizerD.step()
        
        
        ## (2) update G network: maximize the D(G(x))
        netG.zero_grad()
        
        if opt.isAdLoss:
            #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the 
            #angel of equation (note the max and min difference for generator and discriminator)
            outputG = netG(source) #5x64x64->1*64x64
            
            if len(outputG.size())==3:
                outputG = outputG.unsqueeze(1)
            
            outputD = netD(outputG)
            outputD = F.sigmoid(outputD)
            lossG_D = opt.lambda_AD*criterion_bce(outputD,real_label) #note, for generator, the label for outputG is real, because the G wants to confuse D
            lossG_D.backward()
            
        if opt.isWDist:
            #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the 
            #angel of equation (note the max and min difference for generator and discriminator)
            #outputG = net(inputs)
            outputG = netG(source) #5x64x64->1*64x64
            
            if len(outputG.size())==3:
                outputG = outputG.unsqueeze(1)
            
            outputD_fake = netD(outputG)

            outputD_fake = outputD_fake.mean()
            
            lossG_D = opt.lambda_AD*outputD_fake.mean() #note, for generator, the label for outputG is real, because the G wants to confuse D
            lossG_D.backward(mone)
        
        #for other losses, we can define the loss function following the pytorch tutorial
        
        optimizerG.step() #update network parameters

        
        if iter%opt.showTrainLossEvery==0: #print every 2000 mini-batches
            print '************************************************'
            print 'time now is: ' + time.asctime(time.localtime(time.time()))
#             print 'running loss is ',running_loss
#             print 'average running loss for generator between iter [%d, %d] is: %.5f'%(iter - 100 + 1,iter,running_loss/100)

            if opt.isAdLoss:
                print 'loss_real is ',loss_real.data[0],'loss_fake is ',loss_fake.data[0],'outputD_real is',torch.mean(outputD_real).data[0], 'outputD_fake is',torch.mean(outputD_fake).data[0]
                print 'lossG_D is ', lossG_D.data[0]
                print 'loss for discriminator is %f'%lossD.data[0]
                
            if opt.isWDist:
                print 'loss_real is ',D_real.data[0],'loss_fake is ',D_fake.data[0], 'outputD_real is',torch.mean(outputD_real).data[0], 'outputD_fake is',torch.mean(outputD_fake).data[0]
                print 'lossG_D is ', lossG_D.data[0]
                print 'loss for discriminator is %f'%Wasserstein_D.data[0], ' D cost is %f'%D_cost
            
            print 'cost time for iter [%d, %d] is %.2f'%(iter - 100 + 1,iter, time.time()-start)
            print '************************************************'
            running_loss = 0.0
            start = time.time()
        if iter%opt.saveModelEvery==0: #save the model
            state = {
                'epoch': iter+1,
                'model': netG.state_dict()
            }
            torch.save(state, opt.prefixModelName+'%d.pt'%iter)
            print 'save model: '+opt.prefixModelName+'%d.pt'%iter

            if opt.isAdLoss or opt.isWDist:
                torch.save(netD.state_dict(), opt.prefixModelName+'netD_%d.pt'%iter)
        if iter%opt.decLREvery==0:
            opt.lr_G = opt.lr_G*0.5
            adjust_learning_rate(optimizerG, opt.lr_G)
            if opt.isAdLoss or opt.isWDist:
                opt.lr_D = opt.lr_D*0.2
                adjust_learning_rate(optimizerD, opt.lr_D)
                
        if iter%opt.showValPerformanceEvery==0: #test one subject
            # to test on the validation dataset in the format of h5
            
            # we use a 128-dim vector as the noise of the input
            noise = torch.randn(opt.batchSize, 128)
            if opt.gpuID!=None:
                noise = noise.cuda()
            inputs = noise
             
#             inputs,exinputs,labels = data_generator_test.next()
            labels = data_generator_test.next()
            labels = np.squeeze(labels)

            labels = torch.from_numpy(labels)
            labels = labels.float()

            source = inputs
            source = source.cuda()
            source = Variable(source)
#             residual_source = residual_source.cuda()
            labels = labels.cuda()
            labels = Variable(labels)
            
            if len(labels.size())==3:
                labels = labels.unsqueeze(1)
            outputD_real = netD(labels)
            
            outputG = netG(source) #noise -> img
            if len(outputG.size())==3:
                outputG = outputG.unsqueeze(1)
            outputD_fake = netD(outputG)
            print 'outputD_real is ', torch.mean(outputD_real).data[0], ' outputD_real is ', torch.mean(outputD_real).data[0]


        if iter % opt.showTestPerformanceEvery == 0:  # test one subject
            mr_test_itk=sitk.ReadImage(os.path.join(path_test,'S10to1_3t.nii.gz'))


            spacing = mr_test_itk.GetSpacing()
            origin = mr_test_itk.GetOrigin()
            direction = mr_test_itk.GetDirection()

            mrnp=sitk.GetArrayFromImage(mr_test_itk)

            ##### specific normalization #####
            mu = np.mean(mrnp)

            #for training data in pelvicSeg
            if opt.how2normalize == 1:
                maxV, minV = np.percentile(mrnp, [99 ,1])
                print 'maxV,',maxV,' minV, ',minV
                mrnp = (mrnp-mu)/(maxV-minV)
                print 'unique value: ',np.unique(mrnp)

            #for training data in pelvicSeg
            if opt.how2normalize == 2:
                maxV, minV = np.percentile(mrnp, [99 ,1])
                print 'maxV,',maxV,' minV, ',minV
                mrnp = (mrnp-mu)/(maxV-minV)
                print 'unique value: ',np.unique(mrnp)

            #for training data in pelvicSegRegH5
            if opt.how2normalize== 3:
                std = np.std(mrnp)
                mrnp = (mrnp - mu)/std
                print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp)

            if opt.how2normalize == 4:
                maxLPET = 149.366742
                maxPercentLPET = 7.76
                minLPET = 0.00055037
                meanLPET = 0.27593288
                stdLPET = 0.75747500

                matLPET = (mrnp - minLPET) / (maxPercentLPET - minLPET)

            if opt.how2normalize == 5:
                # for rsCT
                maxCT = 27279
                maxPercentCT = 1320
                minCT = -1023
                meanCT = -601.1929
                stdCT = 475.034

                print 'ct, max: ', np.amax(mrnp), ' ct, min: ', np.amin(mrnp)

                matLPET = mrnp


            if opt.how2normalize == 6:
                maxPercentPET, minPercentPET = np.percentile(mrnp, [99.5, 0])
                matLPET = (mrnp - minPercentPET) / (maxPercentPET - minPercentPET)
 
            matFA = matLPET
#                 matGT = hpetnp

            print 'matFA shape: ',matFA.shape
            pred = testOneSubject4Cla(matFA,[1,64,64],[1,32,32],netD, opt.prefixModelName+'%d.pt'%iter)
            print 'predicted result is ', pred

        
    print('Finished Training')
示例#2
0
def main():
    print(opt)

    netD = Discriminator()
    netD.apply(weights_init)
    netD.cuda()

    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_netD)
    criterion_bce = nn.BCELoss()
    criterion_bce.cuda()

    if opt.whichNet == 1:
        net = UNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
    elif opt.whichNet == 2:
        net = ResUNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
    elif opt.whichNet == 3:
        net = UNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1)
    elif opt.whichNet == 4:
        net = ResUNet_LRes(in_channel=opt.numOfChannel_allSource,
                           n_classes=1,
                           dp_prob=opt.dropout_rate)
    net.cuda()
    params = list(net.parameters())
    print('len of params is ')
    print(len(params))
    print('size of params is ')
    print(params[0].size())

    optimizer = optim.Adam(net.parameters(), lr=opt.lr)
    criterion_L2 = nn.MSELoss()
    criterion_L1 = nn.L1Loss()
    criterion_RTL1 = RelativeThreshold_RegLoss(opt.RT_th)
    criterion_gdl = gdl_loss(opt.gdlNorm)

    given_weight = torch.cuda.FloatTensor([1, 4, 4, 2])

    criterion_3d = CrossEntropy3d(weight=given_weight)

    criterion_3d = criterion_3d.cuda()
    criterion_L2 = criterion_L2.cuda()
    criterion_L1 = criterion_L1.cuda()
    criterion_RTL1 = criterion_RTL1.cuda()
    criterion_gdl = criterion_gdl.cuda()

    path_test = opt.path_test
    path_train = opt.path_train
    path_dev = opt.path_dev
    data_generator = Generator_2D_slices(path_train,
                                         opt.batchSize,
                                         inputKey='noisy',
                                         outputKey='clear')
    data_generator_dev = Generator_2D_slices(path_dev,
                                             opt.batchSize,
                                             inputKey='noisy',
                                             outputKey='noisy')

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            net.load_state_dict(checkpoint['model'])
            opt.start_epoch = checkpoint["epoch"] + 1
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))


########### We'd better use dataloader to load a lot of data,and we also should train several epoches###############
########### We'd better use dataloader to load a lot of data,and we also should train several epoches###############
    running_loss = 0.0
    start = time.time()
    for iter in range(opt.start_epoch, opt.numofIters + 1):
        inputs, labels = next(data_generator)
        exinputs = inputs

        inputs = np.squeeze(inputs)  #5x64x64
        exinputs = np.squeeze(exinputs)  #5x64x64
        labels = np.squeeze(labels)  #64x64

        inputs = inputs.astype(float)
        inputs = torch.from_numpy(inputs)
        inputs = inputs.float()
        exinputs = exinputs.astype(float)
        exinputs = torch.from_numpy(exinputs)
        exinputs = exinputs.float()
        labels = labels.astype(float)
        labels = torch.from_numpy(labels)
        labels = labels.float()
        source = inputs
        mid_slice = opt.numOfChannel_singleSource // 2
        residual_source = inputs[:, mid_slice, ...]
        source = source.cuda()
        residual_source = residual_source.cuda()
        labels = labels.cuda()

        #wrap them into Variable
        source, residual_source, labels = Variable(source), Variable(
            residual_source), Variable(labels)

        ## (1) update D network: maximize log(D(x)) + log(1 - D(G(z)))
        if opt.isAdLoss:
            if opt.whichNet == 3 or opt.whichNet == 4:
                outputG = net(source, residual_source)  # 5x64x64->1*64x64
            else:
                outputG = net(source)  # 5x64x64->1*64x64

            if len(labels.size()) == 3:
                labels = labels.unsqueeze(1)

            outputD_real = netD(labels)
            outputD_real = torch.sigmoid(outputD_real)

            if len(outputG.size()) == 3:
                outputG = outputG.unsqueeze(1)

            outputD_fake = netD(outputG)
            outputD_fake = torch.sigmoid(outputD_fake)
            netD.zero_grad()
            batch_size = inputs.size(0)
            real_label = torch.ones(batch_size, 1)
            real_label = real_label.cuda()
            real_label = Variable(real_label)
            loss_real = criterion_bce(outputD_real, real_label)
            loss_real.backward()
            #train with fake data
            fake_label = torch.zeros(batch_size, 1)
            fake_label = fake_label.cuda()
            fake_label = Variable(fake_label)
            loss_fake = criterion_bce(outputD_fake, fake_label)
            loss_fake.backward()

            lossD = loss_real + loss_fake
            #update network parameters
            optimizerD.step()

        if opt.isWDist:
            one = torch.FloatTensor([1])
            mone = one * -1
            one = one.cuda()
            mone = mone.cuda()

            netD.zero_grad()

            if opt.whichNet == 3 or opt.whichNet == 4:
                outputG = net(source, residual_source)  # 5x64x64->1*64x64
            else:
                outputG = net(source)  # 5x64x64->1*64x64

            if len(labels.size()) == 3:
                labels = labels.unsqueeze(1)

            outputD_real = netD(labels)

            if len(outputG.size()) == 3:
                outputG = outputG.unsqueeze(1)

            outputD_fake = netD(outputG)

            batch_size = inputs.size(0)

            D_real = outputD_real.mean()
            D_real.backward(mone)

            D_fake = outputD_fake.mean()
            D_fake.backward(one)

            gradient_penalty = opt.lambda_D_WGAN_GP * calc_gradient_penalty(
                netD, labels.data, outputG.data)
            gradient_penalty.backward()

            D_cost = D_fake - D_real + gradient_penalty
            Wasserstein_D = D_real - D_fake

            optimizerD.step()

        ## (2) update G network: minimize the L1/L2 loss, maximize the D(G(x))
        if opt.whichNet == 3 or opt.whichNet == 4:
            outputG = net(source, residual_source)  # 5x64x64->1*64x64
        else:
            outputG = net(source)  # 5x64x64->1*64x64
        net.zero_grad()
        if opt.whichLoss == 1:
            lossG_G = criterion_L1(torch.squeeze(outputG),
                                   torch.squeeze(labels))
        elif opt.whichLoss == 2:
            lossG_G = criterion_RTL1(torch.squeeze(outputG),
                                     torch.squeeze(labels))
        else:
            lossG_G = criterion_L2(torch.squeeze(outputG),
                                   torch.squeeze(labels))
        lossG_G = opt.lossBase * lossG_G
        lossG_G.backward(retain_graph=True)  #compute gradients

        if opt.isGDL:
            lossG_gdl = opt.lambda_gdl * criterion_gdl(
                outputG, torch.unsqueeze(torch.squeeze(labels, 1), 1))
            lossG_gdl.backward()  #compute gradients

        if opt.isAdLoss:
            #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the
            #angel of equation (note the max and min difference for generator and discriminator)
            if opt.whichNet == 3 or opt.whichNet == 4:
                outputG = net(source, residual_source)  # 5x64x64->1*64x64
            else:
                outputG = net(source)  # 5x64x64->1*64x64

            if len(outputG.size()) == 3:
                outputG = outputG.unsqueeze(1)

            outputD = netD(outputG)
            outputD = torch.sigmoid(outputD)
            lossG_D = opt.lambda_AD * criterion_bce(
                outputD, real_label
            )  #note, for generator, the label for outputG is real, because the G wants to confuse D
            lossG_D.backward()

        if opt.isWDist:
            #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the
            #angel of equation (note the max and min difference for generator and discriminator)
            if opt.whichNet == 3 or opt.whichNet == 4:
                outputG = net(source, residual_source)  # 5x64x64->1*64x64
            else:
                outputG = net(source)  # 5x64x64->1*64x64
            if len(outputG.size()) == 3:
                outputG = outputG.unsqueeze(1)

            outputD_fake = netD(outputG)

            outputD_fake = outputD_fake.mean()

            lossG_D = opt.lambda_AD * outputD_fake.mean(
            )  #note, for generator, the label for outputG is real, because the G wants to confuse D
            lossG_D.backward(mone)

        #for other losses, we can define the loss function following the pytorch tutorial

        optimizer.step()  #update network parameters

        running_loss = running_loss + lossG_G.data.item()

        if iter % opt.showTrainLossEvery == 0:  #print every 2000 mini-batches
            print('************************************************')
            print('time now is: ' + time.asctime(time.localtime(time.time())))
            print(
                'average running loss for generator between iter [%d, %d] is: %.5f'
                % (iter - 100 + 1, iter, running_loss / 100))

            print('lossG_G is %.5f respectively.' % (lossG_G.data.item()))

            if opt.isGDL:
                print('loss for GDL loss is %f' % lossG_gdl.data.item())

            if opt.isAdLoss:
                print('loss for discriminator is %f' % lossD.data.item())
                print('lossG_D for discriminator is %f' % lossG_D.data.item())

            if opt.isWDist:
                print('loss_real is ',
                      torch.mean(D_real).data.item(), 'loss_fake is ',
                      torch.mean(D_fake).data.item())
                print(
                    'loss for discriminator is %f' % Wasserstein_D.data.item(),
                    ' D cost is %f' % D_cost)
                print('lossG_D for discriminator is %f' % lossG_D.data.item())

            print('cost time for iter [%d, %d] is %.2f' %
                  (iter - 100 + 1, iter, time.time() - start))
            print('************************************************')
            running_loss = 0.0
            start = time.time()
        if iter % opt.saveModelEvery == 0:  #save the model
            state = {'epoch': iter + 1, 'model': net.state_dict()}
            torch.save(state, opt.prefixModelName + '%d.pt' % iter)
            print('save model: ' + opt.prefixModelName + '%d.pt' % iter)

            if opt.isAdLoss or opt.isWDist:
                torch.save(netD.state_dict(),
                           opt.prefixModelName + '_net_D%d.pt' % iter)
        if iter % opt.decLREvery == 0:
            opt.lr = opt.lr * opt.lrDecRate
            adjust_learning_rate(optimizer, opt.lr)
            if opt.isAdLoss or opt.isWDist:
                opt.lr_netD = opt.lr_netD * opt.lrDecRate_netD
                adjust_learning_rate(optimizerD, opt.lr_netD)

        if iter % opt.showValPerformanceEvery == 0:  #test one subject
            # to test on the validation dataset in the format of h5
            inputs, labels = next(data_generator_dev)
            exinputs = inputs

            inputs = np.squeeze(inputs)

            exinputs = np.squeeze(exinputs)  # 5x64x64

            labels = np.squeeze(labels)

            inputs = torch.from_numpy(inputs)
            inputs = inputs.float()
            exinputs = torch.from_numpy(exinputs)
            exinputs = exinputs.float()
            labels = torch.from_numpy(labels)
            labels = labels.float()
            mid_slice = opt.numOfChannel_singleSource // 2
            residual_source = inputs[:, mid_slice, ...]
            source = inputs
            source = source.cuda()
            residual_source = residual_source.cuda()
            labels = labels.cuda()
            source, residual_source, labels = Variable(source), Variable(
                residual_source), Variable(labels)

            if opt.whichNet == 3 or opt.whichNet == 4:
                outputG = net(source, residual_source)  # 5x64x64->1*64x64
            else:
                outputG = net(source)  # 5x64x64->1*64x64
            if opt.whichLoss == 1:
                lossG_G = criterion_L1(torch.squeeze(outputG),
                                       torch.squeeze(labels))
            elif opt.whichLoss == 2:
                lossG_G = criterion_RTL1(torch.squeeze(outputG),
                                         torch.squeeze(labels))
            else:
                lossG_G = criterion_L2(torch.squeeze(outputG),
                                       torch.squeeze(labels))
            lossG_G = opt.lossBase * lossG_G
            print('.......come to validation stage: iter {}'.format(iter),
                  '........')
            print('lossG_G is %.5f.' % (lossG_G.data.item()))

            if opt.isGDL:
                lossG_gdl = criterion_gdl(
                    outputG, torch.unsqueeze(torch.squeeze(labels, 1), 1))
                print('loss for GDL loss is %f' % lossG_gdl.data.item())

        if iter % opt.showTestPerformanceEvery == 0:  # test one subject
            noisy_np = np.load(
                os.path.join(path_test, opt.test_input_file_name))
            noisy_np = noisy_np[0]
            noisy_np = noisy_np.reshape(1, noisy_np.shape[0],
                                        noisy_np.shape[1])

            clear_np = np.load(
                os.path.join(path_test, opt.test_label_file_name))
            clear_np = clear_np[0]
            clear_np = clear_np.reshape(1, clear_np.shape[0],
                                        clear_np.shape[1])

            hpetnp = clear_np

            #for training data in pelvicSeg
            if opt.how2normalize == 1:
                maxV, minV = np.percentile(noisy_np, [99, 1])
                print('maxV,', maxV, ' minV, ', minV)
                noisy_np = (noisy_np - mu) / (maxV - minV)
                print('unique value: ', np.unique(clear_np))

            #for training data in pelvicSeg
            if opt.how2normalize == 2:
                maxV, minV = np.percentile(noisy_np, [99, 1])
                print('maxV,', maxV, ' minV, ', minV)
                noisy_np = (noisy_np - mu) / (maxV - minV)
                print('unique value: ', np.unique(clear_np))

            #for training data in pelvicSegRegH5
            if opt.how2normalize == 3:
                std = np.std(noisy_np)
                noisy_np = (noisy_np - mu) / std
                print('maxV,', np.ndarray.max(noisy_np), ' minV, ',
                      np.ndarray.min(noisy_np))

            if opt.how2normalize == 4:
                maxLPET = 149.366742
                maxPercentLPET = 7.76
                minLPET = 0.00055037
                meanLPET = 0.27593288
                stdLPET = 0.75747500

                # for rsCT
                maxCT = 27279
                maxPercentCT = 1320
                minCT = -1023
                meanCT = -601.1929
                stdCT = 475.034

                # for s-pet
                maxSPET = 156.675962
                maxPercentSPET = 7.79
                minSPET = 0.00055037
                meanSPET = 0.284224789
                stdSPET = 0.7642257

                matLPET = (noisy_np - minLPET) / (maxPercentLPET - minLPET)
                matCT = (clear_np - meanCT) / stdCT
                matSPET = (hpetnp - minSPET) / (maxPercentSPET - minSPET)

            if opt.how2normalize == 5:
                # for rsCT
                maxCT = 27279
                maxPercentCT = 1320
                minCT = -1023
                meanCT = -601.1929
                stdCT = 475.034

                print('ct, max: ', np.amax(clear_np), ' ct, min: ',
                      np.amin(clear_np))

                matLPET = noisy_np
                matCT = (clear_np - meanCT) / stdCT
                matSPET = hpetnp

            if opt.how2normalize == 6:
                maxPercentPET, minPercentPET = np.percentile(
                    noisy_np, [99.5, 0])
                maxPercentCT, minPercentCT = np.percentile(clear_np, [99.5, 0])
                print('maxPercentPET: ', maxPercentPET, ' minPercentPET: ',
                      minPercentPET, ' maxPercentCT: ', maxPercentCT,
                      'minPercentCT: ', minPercentCT)

                matLPET = (noisy_np - minPercentPET) / (maxPercentPET -
                                                        minPercentPET)
                matSPET = (hpetnp - minPercentPET) / (maxPercentPET -
                                                      minPercentPET)

                matCT = (clear_np - minPercentCT) / (maxPercentCT -
                                                     minPercentCT)

            matFA = matLPET
            matGT = hpetnp

            print('matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape)
            matOut = testOneSubject_aver_res(
                matFA, matGT, [2, 64, 64], [1, 64, 64], [1, 8, 8], net,
                opt.prefixModelName + '%d.pt' % iter)
            print('matOut shape: ', matOut.shape)
            if opt.how2normalize == 6:
                clear_estimated = matOut * (maxPercentPET -
                                            minPercentPET) + minPercentPET
            else:
                clear_estimated = matOut

            itspsnr = psnr(clear_estimated, matGT)
            clear_estimated = clear_estimated.reshape(clear_estimated.shape[1],
                                                      clear_estimated.shape[2])

            print('pred: ', clear_estimated.dtype, ' shape: ',
                  clear_estimated.shape)
            print('gt: ', clear_np.dtype, ' shape: ', clear_estimated.shape)
            print('psnr = ', itspsnr)
            volout = sitk.GetImageFromArray(clear_estimated)
            volout = sitk.Cast(
                sitk.RescaleIntensity(volout,
                                      outputMinimum=0,
                                      outputMaximum=65535), sitk.sitkUInt16)
            sitk.WriteImage(
                volout, opt.prefixPredictedFN + '{}'.format(iter) + '.tiff')
            np.save(opt.prefixPredictedFN + '{}'.format(iter) + '.npy',
                    clear_estimated)

    print('Finished Training')
示例#3
0
def main():    
    print opt    
        
#     prefixModelName = 'Regressor_1112_'
#     prefixPredictedFN = 'preSub1_1112_'
#     showTrainLossEvery = 100
#     lr = 1e-4
#     showTestPerformanceEvery = 2000
#     saveModelEvery = 2000
#     decLREvery = 40000
#     numofIters = 200000
#     how2normalize = 0


    netD = Discriminator()
    netD.apply(weights_init)
    netD.cuda()
    
    optimizerD = optim.Adam(netD.parameters(),lr=1e-3)
    criterion_bce=nn.BCELoss()
    criterion_bce.cuda()
    
    #net=UNet()
    if opt.whichNet==1:
        net = UNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
    elif opt.whichNet==2:
        net = ResUNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
    elif opt.whichNet==3:
        net = UNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1)
    elif opt.whichNet==4:
        net = ResUNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1, dp_prob = opt.dropout_rate)
    #net.apply(weights_init)
    net.cuda()
    params = list(net.parameters())
    print('len of params is ')
    print(len(params))
    print('size of params is ')
    print(params[0].size())
    
 
    
    optimizer = optim.Adam(net.parameters(),lr=opt.lr)
    criterion_L2 = nn.MSELoss()
    criterion_L1 = nn.L1Loss()
    criterion_RTL1 = RelativeThreshold_RegLoss(opt.RT_th)
    criterion_gdl = gdl_loss(opt.gdlNorm)
    #criterion = nn.CrossEntropyLoss()
#     criterion = nn.NLLLoss2d()
    
    given_weight = torch.cuda.FloatTensor([1,4,4,2])
    
    criterion_3d = CrossEntropy3d(weight=given_weight)
    
    criterion_3d = criterion_3d.cuda()
    criterion_L2 = criterion_L2.cuda()
    criterion_L1 = criterion_L1.cuda()
    criterion_RTL1 = criterion_RTL1.cuda()
    criterion_gdl = criterion_gdl.cuda()
    
    #inputs=Variable(torch.randn(1000,1,32,32)) #here should be tensor instead of variable
    #targets=Variable(torch.randn(1000,10,1,1)) #here should be tensor instead of variable
#     trainset=data_utils.TensorDataset(inputs, targets)
#     trainloader = data_utils.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
#     inputs=torch.randn(1000,1,32,32)
#     targets=torch.LongTensor(1000)
    
    path_test ='/home/niedong/DataCT/data_niigz/'
    path_patients_h5 = '/home/niedong/DataCT/h5Data_snorm/trainBatch2D_H5'
    path_patients_h5_test ='/home/niedong/DataCT/h5Data_snorm/valBatch2D_H5'
#     batch_size=10
    #data_generator = Generator_2D_slices(path_patients_h5,opt.batchSize,inputKey='data3T',outputKey='data7T')
    #data_generator_test = Generator_2D_slices(path_patients_h5_test,opt.batchSize,inputKey='data3T',outputKey='data7T')

    data_generator = Generator_2D_slicesV1(path_patients_h5,opt.batchSize, inputKey='dataLPET', segKey='dataCT', contourKey='dataHPET')
    data_generator_test = Generator_2D_slicesV1(path_patients_h5_test,opt.batchSize, inputKey='dataLPET', segKey='dataCT', contourKey='dataHPET')
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            net.load_state_dict(checkpoint['model'])
            opt.start_epoch = 100000
            opt.start_epoch = checkpoint["epoch"] + 1
            # net.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))
########### We'd better use dataloader to load a lot of data,and we also should train several epoches############### 
########### We'd better use dataloader to load a lot of data,and we also should train several epoches############### 
    running_loss = 0.0
    start = time.time()
    for iter in range(opt.start_epoch, opt.numofIters+1):
        #print('iter %d'%iter)
        
        inputs, exinputs, labels = data_generator.next()
#         xx = np.transpose(inputs,(5,64,64))
#         inputs = np.transpose(inputs,(0,3,1,2))
        inputs = np.squeeze(inputs) #5x64x64
        # exinputs = np.transpose(exinputs,(0,3,1,2))
        exinputs = np.squeeze(exinputs) #5x64x64
#         print 'shape is ....',inputs.shape
        labels = np.squeeze(labels) #64x64
#         labels = labels.astype(int)

        inputs = inputs.astype(float)
        inputs = torch.from_numpy(inputs)
        inputs = inputs.float()
        exinputs = exinputs.astype(float)
        exinputs = torch.from_numpy(exinputs)
        exinputs = exinputs.float()
        labels = labels.astype(float)
        labels = torch.from_numpy(labels)
        labels = labels.float()
        #print type(inputs), type(exinputs)
        if opt.isMultiSource:
            source = torch.cat((inputs, exinputs),dim=1)
        else:
            source = inputs
        #source = inputs
        mid_slice = opt.numOfChannel_singleSource//2
        residual_source = inputs[:, mid_slice, ...]
        #inputs = inputs.cuda()
        #exinputs = exinputs.cuda()
        source = source.cuda()
        residual_source = residual_source.cuda()
        labels = labels.cuda()
        #we should consider different data to train
        
        #wrap them into Variable
        source, residual_source, labels = Variable(source),Variable(residual_source), Variable(labels)
        #inputs, exinputs, labels = Variable(inputs),Variable(exinputs), Variable(labels)
        
        ## (1) update D network: maximize log(D(x)) + log(1 - D(G(z)))
        if opt.isAdLoss:
            outputG = net(source,residual_source) #5x64x64->1*64x64
            
            if len(labels.size())==3:
                labels = labels.unsqueeze(1)
                
            outputD_real = netD(labels)
            
            if len(outputG.size())==3:
                outputG = outputG.unsqueeze(1)
                
            outputD_fake = netD(outputG)
            netD.zero_grad()
            batch_size = inputs.size(0)
            real_label = torch.ones(batch_size,1)
            real_label = real_label.cuda()
            #print(real_label.size())
            real_label = Variable(real_label)
            #print(outputD_real.size())
            loss_real = criterion_bce(outputD_real,real_label)
            loss_real.backward()
            #train with fake data
            fake_label = torch.zeros(batch_size,1)
    #         fake_label = torch.FloatTensor(batch_size)
    #         fake_label.data.resize_(batch_size).fill_(0)
            fake_label = fake_label.cuda()
            fake_label = Variable(fake_label)
            loss_fake = criterion_bce(outputD_fake,fake_label)
            loss_fake.backward()
            
            lossD = loss_real + loss_fake
#             print 'loss_real is ',loss_real.data[0],'loss_fake is ',loss_fake.data[0],'outputD_real is',outputD_real.data[0]
#             print('loss for discriminator is %f'%lossD.data[0])
            #update network parameters
            optimizerD.step()
            
        if opt.isWDist:
            one = torch.FloatTensor([1])
            mone = one * -1
            one = one.cuda()
            mone = mone.cuda()
            
            netD.zero_grad()
            
            outputG = net(source,residual_source) #5x64x64->1*64x64
            
            if len(labels.size())==3:
                labels = labels.unsqueeze(1)
                
            outputD_real = netD(labels)
            
            if len(outputG.size())==3:
                outputG = outputG.unsqueeze(1)
                
            outputD_fake = netD(outputG)

            
            batch_size = inputs.size(0)
            
            D_real = outputD_real.mean()
            # print D_real
            D_real.backward(mone)
        
        
            D_fake = outputD_fake.mean()
            D_fake.backward(one)
        
            gradient_penalty = opt.lambda_D_WGAN_GP*calc_gradient_penalty(netD, labels.data, outputG.data)
            gradient_penalty.backward()
            
            D_cost = D_fake - D_real + gradient_penalty
            Wasserstein_D = D_real - D_fake
            
            optimizerD.step()
        
        
        ## (2) update G network: minimize the L1/L2 loss, maximize the D(G(x))
        
#         print inputs.data.shape
        #outputG = net(source) #here I am not sure whether we should use twice or not
        outputG = net(source,residual_source) #5x64x64->1*64x64
        net.zero_grad()
        if opt.whichLoss==1:
            lossG_G = criterion_L1(torch.squeeze(outputG), torch.squeeze(labels))
        elif opt.whichLoss==2:
            lossG_G = criterion_RTL1(torch.squeeze(outputG), torch.squeeze(labels))
        else:
            lossG_G = criterion_L2(torch.squeeze(outputG), torch.squeeze(labels))
        lossG_G = opt.lossBase * lossG_G
        lossG_G.backward() #compute gradients

        if opt.isGDL:
            lossG_gdl = opt.lambda_gdl * criterion_gdl(outputG,torch.unsqueeze(labels,1))
            lossG_gdl.backward() #compute gradients

        if opt.isAdLoss:
            #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the 
            #angel of equation (note the max and min difference for generator and discriminator)
            #outputG = net(inputs)
            outputG = net(source,residual_source) #5x64x64->1*64x64
            
            if len(outputG.size())==3:
                outputG = outputG.unsqueeze(1)
            
            outputD = netD(outputG)
            lossG_D = opt.lambda_AD*criterion_bce(outputD,real_label) #note, for generator, the label for outputG is real, because the G wants to confuse D
            lossG_D.backward()
            
        if opt.isWDist:
            #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the 
            #angel of equation (note the max and min difference for generator and discriminator)
            #outputG = net(inputs)
            outputG = net(source,residual_source) #5x64x64->1*64x64
            
            if len(outputG.size())==3:
                outputG = outputG.unsqueeze(1)
            
            outputD_fake = netD(outputG)

            outputD_fake = outputD_fake.mean()
            
            lossG_D = opt.lambda_AD*outputD_fake.mean() #note, for generator, the label for outputG is real, because the G wants to confuse D
            lossG_D.backward(mone)
        
        #for other losses, we can define the loss function following the pytorch tutorial
        
        optimizer.step() #update network parameters

        #print('loss for generator is %f'%lossG.data[0])
        #print statistics
        running_loss = running_loss + lossG_G.data[0]

        
        if iter%opt.showTrainLossEvery==0: #print every 2000 mini-batches
            print '************************************************'
            print 'time now is: ' + time.asctime(time.localtime(time.time()))
#             print 'running loss is ',running_loss
            print 'average running loss for generator between iter [%d, %d] is: %.5f'%(iter - 100 + 1,iter,running_loss/100)
            
            print 'lossG_G is %.5f respectively.'%(lossG_G.data[0])

            if opt.isGDL:
                print('loss for GDL loss is %f'%lossG_gdl.data[0])

            if opt.isAdLoss:
                print 'loss_real is ',loss_real.data[0],'loss_fake is ',loss_fake.data[0],'outputD_real is',outputD_real.data[0]
                print('loss for discriminator is %f'%lossD.data[0])  
                
            if opt.isWDist:
                print 'loss_real is ',D_real.data[0],'loss_fake is ',D_fake.data[0]
                print('loss for discriminator is %f'%Wasserstein_D.data[0], ' D cost is %f'%D_cost)                
            
            print 'cost time for iter [%d, %d] is %.2f'%(iter - 100 + 1,iter, time.time()-start)
            print '************************************************'
            running_loss = 0.0
            start = time.time()
        if iter%opt.saveModelEvery==0: #save the model
            state = {
                'epoch': iter+1,
                'model': net.state_dict()
            }
            torch.save(state, opt.prefixModelName+'%d.pt'%iter)
            print 'save model: '+opt.prefixModelName+'%d.pt'%iter

            if opt.isAdLoss:
                torch.save(netD.state_dict(), opt.prefixModelName+'_net_D%d.pt'%iter)
        if iter%opt.decLREvery==0:
            opt.lr = opt.lr*0.5
            adjust_learning_rate(optimizer, opt.lr)
                
        if iter%opt.showValPerformanceEvery==0: #test one subject
            # to test on the validation dataset in the format of h5 
            inputs,exinputs,labels = data_generator_test.next()

            # inputs = np.transpose(inputs,(0,3,1,2))
            inputs = np.squeeze(inputs)

            # exinputs = np.transpose(exinputs, (0, 3, 1, 2))
            exinputs = np.squeeze(exinputs)  # 5x64x64

            labels = np.squeeze(labels)

            inputs = torch.from_numpy(inputs)
            inputs = inputs.float()
            exinputs = torch.from_numpy(exinputs)
            exinputs = exinputs.float()
            labels = torch.from_numpy(labels)
            labels = labels.float()
            mid_slice = opt.numOfChannel_singleSource // 2
            residual_source = inputs[:, mid_slice, ...]
            if opt.isMultiSource:
                source = torch.cat((inputs, exinputs), dim=1)
            else:
                source = inputs
            source = source.cuda()
            residual_source = residual_source.cuda()
            labels = labels.cuda()
            source,residual_source,labels = Variable(source),Variable(residual_source), Variable(labels)

            # source = inputs
            #outputG = net(inputs)
            outputG = net(source,residual_source) #5x64x64->1*64x64

            if opt.whichLoss == 1:
                lossG_G = criterion_L1(torch.squeeze(outputG), torch.squeeze(labels))
            elif opt.whichLoss == 2:
                lossG_G = criterion_RTL1(torch.squeeze(outputG), torch.squeeze(labels))
            else:
                lossG_G = criterion_L2(torch.squeeze(outputG), torch.squeeze(labels))
            lossG_G = opt.lossBase * lossG_G
            print '.......come to validation stage: iter {}'.format(iter),'........'
            print 'lossG_G is %.5f.'%(lossG_G.data[0])

            if opt.isGDL:
                lossG_gdl = criterion_gdl(outputG, labels)
                print('loss for GDL loss is %f'%lossG_gdl.data[0])

        if iter % opt.showTestPerformanceEvery == 0:  # test one subject
            mr_test_itk=sitk.ReadImage(os.path.join(path_test,'sub1_sourceCT.nii.gz'))
            ct_test_itk=sitk.ReadImage(os.path.join(path_test,'sub1_extraCT.nii.gz'))
            hpet_test_itk = sitk.ReadImage(os.path.join(path_test, 'sub1_targetCT.nii.gz'))

            spacing = hpet_test_itk.GetSpacing()
            origin = hpet_test_itk.GetOrigin()
            direction = hpet_test_itk.GetDirection()

            mrnp=sitk.GetArrayFromImage(mr_test_itk)
            ctnp=sitk.GetArrayFromImage(ct_test_itk)
            hpetnp=sitk.GetArrayFromImage(hpet_test_itk)

            ##### specific normalization #####
            # mu = np.mean(mrnp)
            # maxV, minV = np.percentile(mrnp, [99 ,25])
            # #mrimg=mrimg
            # mrnp = (mrnp-minV)/(maxV-minV)



            #for training data in pelvicSeg
            if opt.how2normalize == 1:
                maxV, minV = np.percentile(mrnp, [99 ,1])
                print 'maxV,',maxV,' minV, ',minV
                mrnp = (mrnp-mu)/(maxV-minV)
                print 'unique value: ',np.unique(ctnp)

            #for training data in pelvicSeg
            if opt.how2normalize == 2:
                maxV, minV = np.percentile(mrnp, [99 ,1])
                print 'maxV,',maxV,' minV, ',minV
                mrnp = (mrnp-mu)/(maxV-minV)
                print 'unique value: ',np.unique(ctnp)

            #for training data in pelvicSegRegH5
            if opt.how2normalize== 3:
                std = np.std(mrnp)
                mrnp = (mrnp - mu)/std
                print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp)

            if opt.how2normalize == 4:
                maxLPET = 149.366742
                maxPercentLPET = 7.76
                minLPET = 0.00055037
                meanLPET = 0.27593288
                stdLPET = 0.75747500

                # for rsCT
                maxCT = 27279
                maxPercentCT = 1320
                minCT = -1023
                meanCT = -601.1929
                stdCT = 475.034

                # for s-pet
                maxSPET = 156.675962
                maxPercentSPET = 7.79
                minSPET = 0.00055037
                meanSPET = 0.284224789
                stdSPET = 0.7642257

                #matLPET = (mrnp - meanLPET) / (stdLPET)
                matLPET = (mrnp - minLPET) / (maxPercentLPET - minLPET)
                matCT = (ctnp - meanCT) / stdCT
                matSPET = (hpetnp - minSPET) / (maxPercentSPET - minSPET)

            if opt.how2normalize == 5:
                # for rsCT
                maxCT = 27279
                maxPercentCT = 1320
                minCT = -1023
                meanCT = -601.1929
                stdCT = 475.034

                print
                'ct, max: ', np.amax(ctnp), ' ct, min: ', np.amin(ctnp)

                # matLPET = (mrnp - meanLPET) / (stdLPET)
                matLPET = mrnp
                matCT = (ctnp - meanCT) / stdCT
                matSPET = hpetnp

            if opt.how2normalize == 6:
                maxPercentPET, minPercentPET = np.percentile(mrnp, [99.5, 0])
                maxPercentCT, minPercentCT = np.percentile(ctnp, [99.5, 0])
                print 'maxPercentPET: ', maxPercentPET, ' minPercentPET: ', minPercentPET, ' maxPercentCT: ', maxPercentCT, 'minPercentCT: ', minPercentCT

                matLPET = (mrnp - minPercentPET) / (maxPercentPET - minPercentPET)
                matSPET = (hpetnp - minPercentPET) / (maxPercentPET - minPercentPET)

                matCT = (ctnp - minPercentCT) / (maxPercentCT - minPercentCT)


            if not opt.isMultiSource:
                matFA = matLPET
                matGT = hpetnp

                print 'matFA shape: ',matFA.shape, ' matGT shape: ', matGT.shape
                matOut = testOneSubject_aver_res(matFA,matGT,[5,64,64],[1,64,64],[1,32,32],net,opt.prefixModelName+'%d.pt'%iter)
                print 'matOut shape: ',matOut.shape
                if opt.how2normalize==6:
                    ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET
                else:
                    ct_estimated = matOut


                itspsnr = psnr(ct_estimated, matGT)

                print 'pred: ',ct_estimated.dtype, ' shape: ',ct_estimated.shape
                print 'gt: ',ctnp.dtype,' shape: ',ct_estimated.shape
                print 'psnr = ',itspsnr
                volout = sitk.GetImageFromArray(ct_estimated)
                volout.SetSpacing(spacing)
                volout.SetOrigin(origin)
                volout.SetDirection(direction)
                sitk.WriteImage(volout,opt.prefixPredictedFN+'{}'.format(iter)+'.nii.gz')
            else:
                matFA = matLPET
                matGT = hpetnp
                print 'matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape
                matOut = testOneSubject_aver_res_multiModal(matFA, matCT, matGT, [5, 64, 64], [1, 64, 64], [1, 32, 32], net,
                                                            opt.prefixModelName + '%d.pt' % iter)
                print 'matOut shape: ', matOut.shape
                if opt.how2normalize==6:
                    ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET
                else:
                    ct_estimated = matOut

                itspsnr = psnr(ct_estimated, matGT)

                print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape
                print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape
                print 'psnr = ', itspsnr
                volout = sitk.GetImageFromArray(ct_estimated)
                volout.SetSpacing(spacing)
                volout.SetOrigin(origin)
                volout.SetDirection(direction)
                sitk.WriteImage(volout, opt.prefixPredictedFN + '{}'.format(iter) + '.nii.gz')
        
    print('Finished Training')
示例#4
0
def main():

    global opt, model
    opt = parser.parse_args()
    print opt

    ##########configs########
    #     isSegReg = False
    #     isDiceLoss = True # for dice loss
    #     isSoftmaxLoss = True #for softmax loss
    #     isResidualEnhancement = False #using ensemble learning to enhance residual learning
    #     isViewExpansion = True #using dilation to expand receptive filed
    #     isAdLoss = True #for adverarail loss
    #     isSpatialDropOut = False
    #     isFocalLoss = False #we use focal loss to escale training dominated by easy examples
    #     isSampleImportanceFromAd = False #we set batch sample importance using the loss from adversarial training
    #     dropoutRate = 0.25
    #     lambdaAD = 0 # the coefficients before the Adversarial training
    #     adImportance = 0 # attention from adversarial training
    #     how2normalize = 3 #1. mu/(max-min); 2. mu/(percent_99 - percent_1); 3. mu/std
    #     lr = 1e-4 #one of the most important hyper-parameters:
    #     prefixModelName = 'Segmentor_wdice_wce_lrdce_viewExpansion_1111_'
    #     prefixPredictedFN = 'preSub_wdice_wce_lrdce_viewExpansion_1111_'
    #     showTrainLossEvery = 100
    #     showTestPerformanceEvery = 2000
    #     decLREvery = 25000 #decrease learning rate every xxx iterations
    #     saveModelEvery = 2000
    #     numofIters = 200000

    ##########configs########

    if opt.isSegReg:
        netG = ResSegRegNet()
    elif opt.isContourLoss:
        netG = ResSegContourNet(isRandomConnection=opt.isResidualEnhancement,
                                isSmallDilation=opt.isViewExpansion,
                                isSpatialDropOut=opt.isSpatialDropOut,
                                dropoutRate=opt.dropoutRate)
    else:
        netG = ResSegNet(isRandomConnection=opt.isResidualEnhancement,
                         isSmallDilation=opt.isViewExpansion,
                         isSpatialDropOut=opt.isSpatialDropOut,
                         dropoutRate=opt.dropoutRate)

#     netG =PSPNet(num_classes=4)
    netG = RefineNet4Cascade(input_shape=(3, 64),
                             num_classes=4,
                             features=256,
                             pretrained=False)
    #netG.apply(weights_init)
    netG = netG.cuda()

    netD = Discriminator()
    netD.apply(weights_init)
    netD.cuda()

    params = list(netG.parameters())
    print('len of params is ')
    print(len(params))
    print('size of params is ')
    print(params[0].size())

    #     optimizerG =optim.SGD(netG.parameters(),lr=1e-2)
    optimizerG = optim.Adam(filter(lambda p: p.requires_grad,
                                   netG.parameters()),
                            lr=opt.lr)

    #     optimizerD =optim.SGD(netD.parameters(),lr=1e-4)
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr)

    criterion_MSE = nn.MSELoss()
    given_weight = torch.FloatTensor([1, 4, 8, 8])
    given_weight = given_weight.cuda()
    #     criterion_NLL2D = nn.NLLLoss2d(weight=given_weight)
    criterion_CE2D = CrossEntropy2d(weight=given_weight)

    criterion_BCE2D = CrossEntropy2d()  #for contours

    #     criterion_dice = DiceLoss4Organs(organIDs=[1,2,3], organWeights=[1,1,1])
    #     criterion_dice = WeightedDiceLoss4Organs()
    criterion_dice = myWeightedDiceLoss4Organs(organIDs=[0, 1, 2, 3],
                                               organWeights=given_weight)

    criterion_focal = myFocalLoss(4, alpha=given_weight, gamma=2)

    criterion = nn.BCELoss()
    criterion = criterion.cuda()
    criterion_dice = criterion_dice.cuda()
    criterion_MSE = criterion_MSE.cuda()
    criterion_CE2D = criterion_CE2D.cuda()
    criterion_BCE2D = criterion_BCE2D.cuda()
    criterion_focal = criterion_focal.cuda()
    softmax2d = nn.Softmax2d()
    #     inputs=Variable(torch.randn(1000,1,32,32)) #here should be tensor instead of variable
    #     targets=Variable(torch.randn(1000,10,1,1)) #here should be tensor instead of variable
    #     trainset=data_utils.TensorDataset(inputs, targets)
    #     trainloader = data_utils.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
    #     inputs=torch.randn(1000,1,32,32)
    #     targets=torch.LongTensor(1000)

    path_test = '/home/dongnie/warehouse/mrs_data'
    path_test = '/shenlab/lab_stor5/dongnie/pelvic'
    path_patients_h5 = '/home/dongnie/warehouse/BrainEstimation/brainH5'
    path_patients_h5 = '/home/dongnie/warehouse/pelvicSeg/pelvicH5'
    path_patients_h5 = '/home/dongnie/warehouse/pelvicSeg/pelvicSegRegH5'
    path_patients_h5 = '/home/dongnie/warehouse/pelvicSeg/pelvicSegRegContourBatchH5'
    path_patients_h5 = '/shenlab/lab_stor5/dongnie/pelvic/pelvicSeg2D64H5'
    #     path_patients_h5 = '/shenlab/lab_stor5/dongnie/pelvic/pelvicSegRegH5'
    #     path_patients_h5 = '/home/dongnie/warehouse/pelvicSeg/pelvicSegRegPartH5/' #only contains 1-15
    path_patients_h5_test = '/home/dongnie/warehouse/pelvicSeg/pelvicSegRegContourH5Test'
    path_patients_h5_test = '/shenlab/lab_stor5/dongnie/pelvic/pelvicSeg2D64H5Test'
    #     path_patients_h5_test ='/shenlab/lab_stor5/dongnie/pelvic/pelvicSegRegH5Test'

    #     batch_size = 10
    if opt.isSegReg:
        data_generator = Generator_2D_slices_variousKeys(
            path_patients_h5,
            opt.batchSize,
            inputKey='dataMR2D',
            outputKey='dataSeg2D',
            regKey1='dataBladder2D',
            regKey2='dataProstate2D',
            regKey3='dataRectum2D')
    elif opt.isContourLoss:
        data_generator = Generator_2D_slicesV1(path_patients_h5,
                                               opt.batchSize,
                                               inputKey='dataMR2D',
                                               segKey='dataSeg2D',
                                               contourKey='dataContour2D')
    else:
        data_generator = Generator_2D_slices(path_patients_h5,
                                             opt.batchSize,
                                             inputKey='dataMR2D',
                                             outputKey='dataSeg2D')

    data_generator_test = Generator_2D_slices(path_patients_h5_test,
                                              opt.batchSize,
                                              inputKey='dataMR2D',
                                              outputKey='dataSeg2D')

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            #             opt.start_epoch = checkpoint["epoch"] + 1
            #             netG.load_state_dict(checkpoint["model"].state_dict())
            netG.load_state_dict(checkpoint)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

########### We'd better use dataloader to load a lot of data,and we also should train several epoches###############
    running_loss = 0.0
    start = time.time()
    for iter in range(opt.start_epoch, opt.numofIters + 1):
        #print('iter %d'%iter)

        if opt.isSegReg:
            inputs, labels, regGT1, regGT2, regGT3 = data_generator.next()
        elif opt.isContourLoss:
            inputs, labels, contours = data_generator.next()
        else:
            inputs, labels = data_generator.next()

        labels = np.squeeze(labels)
        labels = zoomImages(labels, rate=opt.zoomRate)

        labels = labels.astype(int)

        if opt.isContourLoss:
            contours = np.squeeze(contours)
            contours = contours.astype(int)
            contours = torch.from_numpy(contours)
            contours = contours.cuda()
            contours = Variable(contours)

        inputs = torch.from_numpy(inputs)
        labels = torch.from_numpy(labels)
        inputs = inputs.cuda()
        labels = labels.cuda()
        #we should consider different data to train

        #wrap them into Variable
        inputs, labels = Variable(inputs), Variable(labels)

        #zero the parameter gradients
        #netD.zero_grad()
        if opt.isAdLoss:
            #forward + backward +optimizer
            if opt.isSegReg:
                outputG, outputReg1, outputReg2, outputReg3 = netG(inputs)
            elif opt.isContourLoss:
                outputG, _ = netG(inputs)
            else:
                outputG = netG(inputs)
            outputG = softmax2d(outputG)  #batach
            #         print 'outputG: ',outputG.size(),'labels: ',labels.size()
            #         print 'outputG: ', outputG.data[0].size()
            outputG = outputG.data.max(1)[1]
            #outputG = torch.squeeze(outputG) #[N,C,W,H]
            labels = labels.unsqueeze(1)  #expand the 1st dim
            #         print 'outputG: ',outputG.size(),'labels: ',labels.size()
            outputR = labels.type(torch.FloatTensor).cuda()  #output_Real
            outputG = Variable(outputG.type(torch.FloatTensor).cuda())
            outputD_real = netD(outputR)
            #         print 'size outputG: ',outputG.unsqueeze(1).size()
            outputD_fake = netD(outputG.unsqueeze(1))

            ## (1) update D network: maximize log(D(x)) + log(1 - D(G(z)))
            netD.zero_grad()
            batch_size = inputs.size(0)
            #print(inputs.size())
            #train with real data
            #         real_label = torch.FloatTensor(batch_size)
            #         real_label.data.resize_(batch_size).fill_(1)
            real_label = torch.ones(batch_size, 1)
            real_label = real_label.cuda()
            #print(real_label.size())
            real_label = Variable(real_label)
            #print(outputD_real.size())
            loss_real = criterion(outputD_real, real_label)
            loss_real.backward()
            #train with fake data
            fake_label = torch.zeros(batch_size, 1)
            #         fake_label = torch.FloatTensor(batch_size)
            #         fake_label.data.resize_(batch_size).fill_(0)
            fake_label = fake_label.cuda()
            fake_label = Variable(fake_label)
            loss_fake = criterion(outputD_fake, fake_label)
            loss_fake.backward()

            lossD = loss_real + loss_fake

            optimizerD.step()

        ## (2) update G network: minimize the L1/L2 loss, maximize the D(G(x))

        #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the
        #angel of equation (note the max and min difference for generator and discriminator)
        if opt.isAdLoss:
            if opt.isSegReg:
                outputG, outputReg1, outputReg2, outputReg3 = netG(inputs)
            elif opt.isContourLoss:
                outputG, _ = netG(inputs)
            else:
                outputG = netG(inputs)
            outputG = outputG.data.max(1)[1]
            outputG = Variable(outputG.type(torch.FloatTensor).cuda())
            #         print 'outputG shape, ',outputG.size()

            outputD = netD(outputG.unsqueeze(1))
            averProb = outputD.data.cpu().mean()
            #             print 'prob: ',averProb
            #             adImportance = computeAttentionWeight(averProb)
            adImportance = computeSampleAttentionWeight(averProb)
            lossG_D = opt.lambdaAD * criterion(
                outputD, real_label
            )  #note, for generator, the label for outputG is real
            lossG_D.backward(retain_graph=True)

        if opt.isSegReg:
            outputG, outputReg1, outputReg2, outputReg3 = netG(inputs)
        elif opt.isContourLoss:
            outputG, outputContour = netG(inputs)
        else:
            outputG = netG(
                inputs)  #here I am not sure whether we should use twice or not
        netG.zero_grad()

        if opt.isFocalLoss:
            lossG_focal = criterion_focal(outputG, torch.squeeze(labels))
            lossG_focal.backward(retain_graph=True)  #compute gradients

        if opt.isSoftmaxLoss:
            if opt.isSampleImportanceFromAd:
                lossG_G = (1 + adImportance) * criterion_CE2D(
                    outputG, torch.squeeze(labels))
            else:
                lossG_G = criterion_CE2D(outputG, torch.squeeze(labels))

            lossG_G.backward(retain_graph=True)  #compute gradients

        if opt.isContourLoss:
            lossG_contour = criterion_BCE2D(outputContour, contours)
            lossG_contour.backward(retain_graph=True)
#         criterion_dice(outputG,torch.squeeze(labels))
#         print 'hahaN'
        if opt.isSegReg:
            lossG_Reg1 = criterion_MSE(outputReg1, regGT1)
            lossG_Reg2 = criterion_MSE(outputReg2, regGT2)
            lossG_Reg3 = criterion_MSE(outputReg3, regGT3)
            lossG_Reg = lossG_Reg1 + lossG_Reg2 + lossG_Reg3
            lossG_Reg.backward()

        if opt.isDiceLoss:
            #             print 'isDiceLoss line278'
            #             criterion_dice = myWeightedDiceLoss4Organs(organIDs=[0,1,2,3], organWeights=[1,4,8,6])
            if opt.isSampleImportanceFromAd:
                loss_dice = (1 + adImportance) * criterion_dice(
                    outputG, torch.squeeze(labels))
            else:
                loss_dice = criterion_dice(outputG, torch.squeeze(labels))
#             loss_dice = myDiceLoss4Organs(outputG,torch.squeeze(labels)) #succeed
#             loss_dice.backward(retain_graph=True) #compute gradients for dice loss
            loss_dice.backward()  #compute gradients for dice loss

        #lossG_D.backward()

        #for other losses, we can define the loss function following the pytorch tutorial

        optimizerG.step()  #update network parameters
        #         print 'gradients of parameters****************************'
        #         [x.grad.data for x in netG.parameters()]
        #         print x.grad.data[0]
        #         print '****************************'
        if opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss and opt.isSegReg and opt.isFocalLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_G + loss_dice.data[
                0] + lossG_Reg + lossG_focal
        if opt.isDiceLoss and opt.isFocalLoss and opt.isAdLoss and opt.isSegReg:
            lossG = opt.lambdaAD * lossG_D + lossG_focal + loss_dice.data[
                0] + lossG_Reg
        if opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss and opt.isSegReg:
            lossG = opt.lambdaAD * lossG_D + lossG_G + loss_dice.data[
                0] + lossG_Reg
        elif opt.isSoftmaxLoss and opt.isAdLoss and opt.isSegReg:
            lossG = opt.lambdaAD * lossG_D + lossG_G + lossG_Reg
        elif opt.isDiceLoss and opt.isAdLoss and opt.isSegReg:
            lossG = opt.lambdaAD * lossG_D + loss_dice.data[0] + lossG_Reg
        elif opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_G + loss_dice.data[0]
        elif opt.isDiceLoss and opt.isFocalLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_focal + loss_dice.data[0]
        elif opt.isSoftmaxLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_G
        elif opt.isFocalLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_focal
        elif opt.isDiceLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + loss_dice.data[0]
        elif opt.isSoftmaxLoss:
            lossG = lossG_G
        #print('loss for generator is %f'%lossG.data[0])
        #print statistics
        running_loss = running_loss + lossG.data[0]
        #         print 'running_loss is ',running_loss,' type: ',type(running_loss)

        #         print type(outputD_fake.cpu().data[0].numpy())

        if iter % opt.showTrainLossEvery == 0:  #print every 2000 mini-batches
            print '************************************************'
            print 'time now is: ' + time.asctime(time.localtime(time.time()))
            if opt.isAdLoss:
                print 'the outputD_real for iter {}'.format(
                    iter), ' is ', outputD_real.cpu().data[0].numpy()[0]
                print 'the outputD_fake for iter {}'.format(
                    iter), ' is ', outputD_fake.cpu().data[0].numpy()[0]
                print 'loss for discriminator at iter ', iter, ' is %f' % lossD.data[
                    0]
#             print 'running loss is ',running_loss
            print 'average running loss for generator between iter [%d, %d] is: %.3f' % (
                iter - 100 + 1, iter, running_loss / 100)

            print 'total loss for generator at iter ', iter, ' is %f' % lossG.data[
                0]
            if opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss and opt.isSegReg:
                print 'lossG_D, lossG_G and loss_dice loss_Reg are %.2f, %.2f and %.2f respectively.' % (
                    lossG_D.data[0], lossG_G.data[0], loss_dice.data[0],
                    lossG_Reg.data[0])
            elif opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss:
                print 'lossG_D, lossG_G and loss_dice are %.2f, %.2f and %.2f respectively.' % (
                    lossG_D.data[0], lossG_G.data[0], loss_dice.data[0])
            elif opt.isDiceLoss and opt.isFocalLoss and opt.isAdLoss:
                print 'lossG_D, lossG_focal and loss_dice are %.2f, %.2f and %.2f respectively.' % (
                    lossG_D.data[0], lossG_focal.data[0], loss_dice.data[0])
            elif opt.isSoftmaxLoss and opt.isAdLoss:
                print 'lossG_D and lossG_G are %.2f and %.2f respectively.' % (
                    lossG_D.data[0], lossG_G.data[0])
            elif opt.isFocalLoss and opt.isAdLoss:
                print 'lossG_D and lossG_focal are %.2f and %.2f respectively.' % (
                    lossG_D.data[0], lossG_focal.data[0])
            elif opt.isDiceLoss and opt.isAdLoss:
                print 'lossG_D and loss_dice are %.2f and %.2f respectively.' % (
                    lossG_D.data[0], loss_dice.data[0])
            elif opt.isSoftmaxLoss:
                print ' lossG_G are %.2f respectively.' % (lossG_G.data[0])

            if opt.isContourLoss:
                print 'lossG_contour is {}'.format(lossG_contour.data[0])

            print 'cost time for iter [%d, %d] is %.2f' % (
                iter - 100 + 1, iter, time.time() - start)
            print '************************************************'
            running_loss = 0.0
            start = time.time()
        if iter % opt.saveModelEvery == 0:  #save the model
            torch.save(netG.state_dict(), opt.prefixModelName + '%d.pt' % iter)
            print 'save model: ' + opt.prefixModelName + '%d.pt' % iter

        if iter % opt.decLREvery == 0 and iter > 0:
            opt.lr = opt.lr * 0.1
            adjust_learning_rate(optimizerG, opt.lr)
            print 'now the learning rate is {}'.format(opt.lr)

        if iter % opt.showTestPerformanceEvery == 0:  #test one subject
            # to test on the validation dataset in the format of h5
            inputs, labels = data_generator_test.next()
            labels = np.squeeze(labels)
            labels = zoomImages(labels, rate=opt.zoomRate)
            labels = labels.astype(int)
            inputs = torch.from_numpy(inputs)
            labels = torch.from_numpy(labels)
            inputs = inputs.cuda()
            labels = labels.cuda()
            inputs, labels = Variable(inputs), Variable(labels)
            if opt.isSegReg:
                outputG, outputReg1, outputReg2, outputReg3 = netG(inputs)
            elif opt.isContourLoss:
                outputG, _ = netG(inputs)
            else:
                outputG = netG(
                    inputs
                )  #here I am not sure whether we should use twice or not

            lossG_G = criterion_CE2D(outputG, torch.squeeze(labels))
            loss_dice = criterion_dice(outputG, torch.squeeze(labels))
            print '.......come to validation stage: iter {}'.format(
                iter), '........'
            print 'lossG_G and loss_dice are %.2f and %.2f respectively.' % (
                lossG_G.data[0], loss_dice.data[0])

            ####release all the unoccupied memory####
            torch.cuda.empty_cache()

            mr_test_itk = sitk.ReadImage(
                os.path.join(path_test, 'img50_nocrop.nii.gz'))
            ct_test_itk = sitk.ReadImage(
                os.path.join(path_test, 'img50_label_nie_nocrop.nii.gz'))

            mrnp = sitk.GetArrayFromImage(mr_test_itk)
            mu = np.mean(mrnp)

            ctnp = sitk.GetArrayFromImage(ct_test_itk)

            #for training data in pelvicSeg
            if opt.how2normalize == 1:
                maxV, minV = np.percentile(mrnp, [99, 1])
                print 'maxV,', maxV, ' minV, ', minV
                mrnp = (mrnp - mu) / (maxV - minV)
                print 'unique value: ', np.unique(ctnp)

            #for training data in pelvicSeg
            if opt.how2normalize == 2:
                maxV, minV = np.percentile(mrnp, [99, 1])
                print 'maxV,', maxV, ' minV, ', minV
                mrnp = (mrnp - mu) / (maxV - minV)
                print 'unique value: ', np.unique(ctnp)

            #for training data in pelvicSegRegH5
            if opt.how2normalize == 3:
                std = np.std(mrnp)
                mrnp = (mrnp - mu) / std
                print 'maxV,', np.ndarray.max(mrnp), ' minV, ', np.ndarray.min(
                    mrnp)

#             full image version with average over the overlapping regions
#             ct_estimated = testOneSubject(mrnp,ctnp,[3,168,112],[1,168,112],[1,8,8],netG,'Segmentor_model_%d.pt'%iter)

            if opt.how2normalize == 4:
                maxV, minV = np.percentile(mrnp, [99.2, 1])
                print 'maxV is: ', np.ndarray.max(mrnp)
                mrnp[np.where(mrnp > maxV)] = maxV
                print 'maxV is: ', np.ndarray.max(mrnp)
                mu = np.mean(mrnp)
                std = np.std(mrnp)
                mrnp = (mrnp - mu) / std
                print 'maxV,', np.ndarray.max(mrnp), ' minV, ', np.ndarray.min(
                    mrnp)

            matFA = mrnp
            matGT = ctnp
            matGT = zoomImages(matGT, rate=opt.zoomRate)

            matOut, _ = testOneSubject(matFA, matGT, 4, [3, 64, 64],
                                       [1, 16, 16], [1, 8, 8], netG,
                                       opt.prefixModelName + '%d.pt' % iter)
            ct_estimated = np.zeros(
                [ctnp.shape[0], ctnp.shape[1], ctnp.shape[2]])
            print 'matOut shape: ', matOut.shape
            #             ct_estimated[:,y1:y2,x1:x2] = matOut
            ct_estimated = matOut

            ct_estimated = np.rint(ct_estimated)
            ct_estimated = denoiseImg(ct_estimated,
                                      kernel=np.ones((20, 20, 20)))
            diceBladder = dice(ct_estimated, ctnp, 1)
            diceProstate = dice(ct_estimated, ctnp, 2)
            diceRectumm = dice(ct_estimated, ctnp, 3)

            print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape
            print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape
            print 'dice1 = ', diceBladder, ' dice2= ', diceProstate, ' dice3= ', diceRectumm
            volout = sitk.GetImageFromArray(ct_estimated)
            sitk.WriteImage(
                volout, opt.prefixPredictedFN + '{}'.format(iter) + '.nii.gz')


#             netG.save_state_dict('Segmentor_model_%d.pt'%iter)
#             netD.save_state_dic('Discriminator_model_%d.pt'%iter)

    print('Finished Training')