Ejemplo n.º 1
0
def main():
    ids = [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
    ids = range(0, 30)
    for id in ids:
        #         datafn = os.path.join(path,'Case%02d.mhd'%id)
        #         outdatafn = os.path.join(path,'Case%02d.nii.gz'%id)
        #
        #         dataOrg = sitk.ReadImage(datafn)
        #         dataMat = sitk.GetArrayFromImage(dataOrg)
        #         #gtMat=np.transpose(gtMat,(2,1,0))
        #         dataVol = sitk.GetImageFromArray(dataMat)
        #         sitk.WriteImage(dataVol,outdatafn)

        datafn = os.path.join(path, 'TestData_mhd/Case%02d.mhd' % id)
        dataOrg = sitk.ReadImage(datafn)
        spacing = dataOrg.GetSpacing()
        origin = dataOrg.GetOrigin()
        direction = dataOrg.GetDirection()
        dataMat = sitk.GetArrayFromImage(dataOrg)

        gtfn = os.path.join(
            path,
            'submission_niigz/preTestCha_model0110_iter14w_sub%02d.nii.gz' %
            id)
        gtOrg = sitk.ReadImage(gtfn)
        gtMat = sitk.GetArrayFromImage(gtOrg)
        #gtMat=np.transpose(gtMat,(2,1,0))

        gtMat1 = denoiseImg_closing(gtMat, kernel=np.ones((20, 20, 20)))
        gtMat2 = gtMat + gtMat1
        gtMat2[np.where(gtMat2 > 1)] = 1
        gtMat = gtMat2
        gtMat = denoiseImg_isolation(gtMat, struct=np.ones((3, 3, 3)))

        gtMat = gtMat.astype(np.uint8)

        outgtfn = os.path.join(path,
                               'submission_mhd/Case%02d_segmentation.mhd' % id)
        gtVol = sitk.GetImageFromArray(gtMat)
        gtVol.SetSpacing(spacing)
        gtVol.SetOrigin(origin)
        gtVol.SetDirection(direction)
        sitk.WriteImage(gtVol, outgtfn)
def main():
    global opt
    opt = parser.parse_args()

    print opt
    ids = range(0,30)

 
    for ind in ids:
        datafilename = 'preTestCha_model0110_iter14w_prob_Prostate_sub%02d.nii.gz'%ind #provide a sample name of your filename of data here
        datafn = os.path.join(opt.basePath+opt.dataFolder, datafilename)
#         labelfilename='Case%02d_segmentation.nii.gz'%ind  # provide a sample name of your filename of ground truth here
#         labelfn=os.path.join(opt.basePath+opt.dataFolder, labelfilename)
        imgOrg = sitk.ReadImage(datafn)
        mrimg = sitk.GetArrayFromImage(imgOrg)
        
#         labelOrg=sitk.ReadImage(labelfn)
#         labelimg=sitk.GetArrayFromImage(labelOrg)
        
        inds = np.where(mrimg>opt.threshold)
        
        tmat_prob = np.zeros(mrimg.shape)
        tmat_prob[inds] = mrimg[inds]
        
        tmat = np.zeros(mrimg.shape)
        tmat[inds] = 1
        tmat = denoiseImg_closing(tmat, kernel=np.ones((20,20,20))) 
        tmat = denoiseImg_isolation(tmat, struct=np.ones((3,3,3)))   
#         tmat = dice(tmat,ctnp,1)       
#         diceBladder = dice(tmat,ctnp,1)
#         print 'sub%d'%ind,'dice1 = ',diceBladder
                               
        volout = sitk.GetImageFromArray(tmat)
        sitk.WriteImage(volout, opt.basePath+opt.dataFolder+'threshold_seg_model0110_sub{:02d}'.format(ind)+'.nii.gz')  

        volout = sitk.GetImageFromArray(tmat_prob)
        sitk.WriteImage(volout, opt.basePath+opt.dataFolder+'threshold_prob_model0110_sub{:02d}'.format(ind)+'.nii.gz')  
Ejemplo n.º 3
0
def main():
    global opt
    opt = parser.parse_args()
    print opt
    
    path_test = '/home/dongnie/warehouse/pelvicSeg/prostateChallenge/data/'
    path_test = '/shenlab/lab_stor5/dongnie/challengeData/testdata/'
    
    if opt.isSegReg:
        negG = ResSegRegNet(opt.in_channels, opt.out_channels, nd=opt.NDim)
    elif opt.isContourLoss:
        netG = ResSegContourNet(opt.in_channels, opt.out_channels, nd=opt.NDim, isRandomConnection=opt.isResidualEnhancement,isSmallDilation=opt.isViewExpansion, isSpatialDropOut=opt.isSpatialDropOut,dropoutRate=opt.dropoutRate)
    elif opt.isDeeplySupervised and opt.isHighResolution:
        netG = HRResSegNet_DS(opt.in_channels, opt.out_channels, nd=opt.NDim, isRandomConnection=opt.isResidualEnhancement,isSmallDilation=opt.isViewExpansion, isSpatialDropOut=opt.isSpatialDropOut,dropoutRate=opt.dropoutRate)
    elif opt.isDeeplySupervised:
        netG = ResSegNet_DS(opt.in_channels, opt.out_channels, nd=opt.NDim, isRandomConnection=opt.isResidualEnhancement,isSmallDilation=opt.isViewExpansion, isSpatialDropOut=opt.isSpatialDropOut,dropoutRate=opt.dropoutRate)
    elif opt.isHighResolution:
        netG = HRResSegNet(opt.in_channels, opt.out_channels, nd=opt.NDim, isRandomConnection=opt.isResidualEnhancement, isSmallDilation=opt.isViewExpansion, isSpatialDropOut=opt.isSpatialDropOut,dropoutRate=opt.dropoutRate)
    else:
        netG = ResSegNet(opt.in_channels, opt.out_channels, nd=opt.NDim, isRandomConnection=opt.isResidualEnhancement,isSmallDilation=opt.isViewExpansion, isSpatialDropOut=opt.isSpatialDropOut,dropoutRate=opt.dropoutRate)
    #netG.apply(weights_init)
    netG = netG.cuda()
    
    checkpoint = torch.load(opt.modelPath)
    #print checkpoint.items()
    netG.load_state_dict(checkpoint["state_dict"])
    #netG.load_state_dict(torch.load(opt.modelPath))
    
    
    ids = [1,2,3,4,6,7,8,10,11,12,13]
    ids = [45,46,47,48,49]
    ids = range(0,30) 
#     ids = [6,20,23]
#     ids = [23]
    for ind in ids:
        print 'come to ind: ',ind
        print 'time now is: ' + time.asctime(time.localtime(time.time()))
        mr_test_itk=sitk.ReadImage(os.path.join(path_test,'Case%02d.nii.gz'%ind))
        #ct_test_itk=sitk.ReadImage(os.path.join(path_test,'Case%02d_segmentation.nii.gz'%ind))
        
        mrnp=sitk.GetArrayFromImage(mr_test_itk)
        mu=np.mean(mrnp)
    
        #ctnp=sitk.GetArrayFromImage(ct_test_itk)
        
        #for training data in pelvicSeg
        if opt.how2normalize == 1:
            maxV, minV=np.percentile(mrnp, [99 ,1])
            print 'maxV,',maxV,' minV, ',minV
            mrnp=(mrnp-mu)/(maxV-minV)
            #print 'unique value: ',np.unique(ctnp)
    
        #for training data in pelvicSeg
        elif opt.how2normalize == 2:
            maxV, minV = np.percentile(mrnp, [99 ,1])
            print 'maxV,',maxV,' minV, ',minV
            mrnp = (mrnp-mu)/(maxV-minV)
            #print 'unique value: ',np.unique(ctnp)
        
        #for training data in pelvicSegRegH5
        elif opt.how2normalize== 3:
            std = np.std(mrnp)
            mrnp = (mrnp - mu)/std
            print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp)
            
        elif opt.how2normalize== 4:
                maxV, minV = np.percentile(mrnp, [99.2 ,1])
                print 'maxV is: ',np.ndarray.max(mrnp)
                mrnp[np.where(mrnp>maxV)] = maxV
                print 'maxV is: ',np.ndarray.max(mrnp)
                mu=np.mean(mrnp)
                std = np.std(mrnp)
                mrnp = (mrnp - mu)/std
                print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp)
    
    #             full image version with average over the overlapping regions
    #             ct_estimated = testOneSubject(mrnp,ctnp,[3,168,112],[1,168,112],[1,8,8],netG,'Segmentor_model_%d.pt'%iter)
        
        # the attention regions
        row,col,leng = mrnp.shape
        y1 = int (leng * 0.25)
        y2 = int (leng * 0.75)
        x1 = int (col * 0.25)
        x2 = int(col * 0.75)
#         x1 = 120
#         x2 = 350
#         y1 = 120
#         y2 = 350
        matFA = mrnp[:,y1:y2,x1:x2] #note, matFA and matFAOut same size 
#         matGT = ctnp[:,y1:y2,x1:x2]
#         matFA = mrnp
        #matGT = ctnp
        
        if opt.resType==2:
            matOut, matProb, _ = testOneSubject(matFA,matFA,opt.out_channels,opt.input_sz,opt.output_sz,opt.test_step_sz,netG,opt.modelPath,resType=opt.resType, nd = opt.NDim)
        else:
            matOut,_ = testOneSubject(matFA,matFA,opt.out_channels,opt.input_sz,opt.output_sz,opt.test_step_sz,netG,opt.modelPath,resType=opt.resType, nd = opt.NDim)
                                      
        #matOut,_ = testOneSubject(matFA,matGT,opt.out_channels,opt.input_sz, opt.output_sz, opt.test_step_sz,netG,opt.modelPath, nd = opt.NDim)
        ct_estimated = np.zeros([mrnp.shape[0],mrnp.shape[1],mrnp.shape[2]])
        ct_prob = np.zeros([opt.out_channels, mrnp.shape[0],mrnp.shape[1],mrnp.shape[2]])
        
#         print 'matOut shape: ',matOut.shape
        ct_estimated[:,y1:y2,x1:x2] = matOut
#         ct_estimated = matOut

        ct_prob[:,:,y1:y2,x1:x2] = matProb
#         ct_prob = matProb
        matProb_Bladder = np.squeeze(ct_prob[1,:,:,:])
                
#         volout = sitk.GetImageFromArray(matProb_Bladder)
#         sitk.WriteImage(volout,opt.prefixPredictedFN+'prob_Prostate_sub%02d'%ind+'.nii.gz')  
        
        threshold = 0.9
        inds = np.where(matProb_Bladder>threshold) 
        tmat = np.zeros(matProb_Bladder.shape)
        tmat[inds] = 1
        tmat = denoiseImg_closing(tmat, kernel=np.ones((20,20,20))) 
        tmat = denoiseImg_isolation(tmat, struct=np.ones((3,3,3)))   
#         diceBladder = dice(tmat,ctnp,1)       
#         diceBladder = dice(tmat,ctnp,1)
#         print 'sub%d'%ind,'dice1 = ',diceBladder
        volout = sitk.GetImageFromArray(tmat)
        sitk.WriteImage(volout,opt.prefixPredictedFN+'threshSeg_sub{:02d}'.format(ind)+'.nii.gz')        
        
        ct_estimated = np.rint(ct_estimated) 
        ct_estimated = denoiseImg_closing(ct_estimated, kernel=np.ones((20,20,20))) 
        ct_estimated = denoiseImg_isolation(ct_estimated, struct=np.ones((3,3,3)))   
        #diceBladder = dice(ct_estimated,ctnp,1)
    
#         diceProstate = dice(ct_estimated,ctnp,2)
#         diceRectumm = dice(ct_estimated,ctnp,3)
        
#         print 'pred: ',ct_estimated.dtype, ' shape: ',ct_estimated.shape
#         print 'gt: ',ctnp.dtype,' shape: ',ctnp.shape
#         print 'sub%d'%ind,'dice1 = ',diceBladder,' dice2= ',diceProstate,' dice3= ',diceRectumm
        #print 'sub%d'%ind,'dice1 = ',diceBladder
        volout = sitk.GetImageFromArray(ct_estimated)
        sitk.WriteImage(volout,opt.prefixPredictedFN+'sub%02d'%ind+'.nii.gz')
Ejemplo n.º 4
0
def main():
    global opt
    opt = parser.parse_args()
    print opt

    path_test = '/home/dongnie/warehouse/pelvicSeg/prostateChallenge/data/'
    path_test = '/shenlab/lab_stor5/dongnie/challengeData/data/'

    if opt.isSegReg:
        netG = ResSegRegNet(opt.in_channels, opt.out_channels, nd=opt.NDim)
    elif opt.isContourLoss and opt.isDeeplySupervised and opt.isDeeplySupervised:
        HRResSegContourNet_DS(opt.in_channels,
                              opt.out_channels,
                              isRandomConnection=opt.isResidualEnhancement,
                              isSmallDilation=opt.isViewExpansion,
                              isSpatialDropOut=opt.isSpatialDropOut,
                              dropoutRate=opt.dropoutRate,
                              TModule='HR',
                              FModule=opt.isAttentionConcat,
                              nd=opt.NDim)
    elif opt.isContourLoss:
        netG = ResSegContourNet(opt.in_channels,
                                opt.out_channels,
                                nd=opt.NDim,
                                isRandomConnection=opt.isResidualEnhancement,
                                isSmallDilation=opt.isViewExpansion,
                                isSpatialDropOut=opt.isSpatialDropOut,
                                dropoutRate=opt.dropoutRate)
    elif opt.isLongConcatConnection and opt.isDeeplySupervised and opt.isHighResolution:
        netG = HRResUNet_DS(opt.in_channels,
                            opt.out_channels,
                            isRandomConnection=opt.isResidualEnhancement,
                            isSmallDilation=opt.isViewExpansion,
                            isSpatialDropOut=opt.isSpatialDropOut,
                            dropoutRate=opt.dropoutRate,
                            TModule='HR',
                            FModule=opt.isAttentionConcat,
                            nd=opt.NDim)
    elif opt.isDeeplySupervised and opt.isHighResolution:
        netG = HRResSegNet_DS(opt.in_channels,
                              opt.out_channels,
                              nd=opt.NDim,
                              isRandomConnection=opt.isResidualEnhancement,
                              isSmallDilation=opt.isViewExpansion,
                              isSpatialDropOut=opt.isSpatialDropOut,
                              dropoutRate=opt.dropoutRate)
    elif opt.isLongConcatConnection and opt.isHighResolution:
        netG = UNet(opt.in_channels,
                    opt.out_channels,
                    TModule='HR',
                    FModule=opt.isAttentionConcat,
                    nd=opt.NDim)
    elif opt.isDeeplySupervised:
        netG = ResSegNet_DS(opt.in_channels,
                            opt.out_channels,
                            nd=opt.NDim,
                            isRandomConnection=opt.isResidualEnhancement,
                            isSmallDilation=opt.isViewExpansion,
                            isSpatialDropOut=opt.isSpatialDropOut,
                            dropoutRate=opt.dropoutRate)
    elif opt.isHighResolution:
        netG = HRResSegNet(opt.in_channels,
                           opt.out_channels,
                           nd=opt.NDim,
                           isRandomConnection=opt.isResidualEnhancement,
                           isSmallDilation=opt.isViewExpansion,
                           isSpatialDropOut=opt.isSpatialDropOut,
                           dropoutRate=opt.dropoutRate)
    elif opt.isLongConcatConnection:
        netG = UNet(opt.in_channels,
                    opt.out_channels,
                    TModule=None,
                    FModule=opt.isAttentionConcat,
                    nd=opt.NDim)
    else:
        netG = ResSegNet(opt.in_channels,
                         opt.out_channels,
                         nd=opt.NDim,
                         isRandomConnection=opt.isResidualEnhancement,
                         isSmallDilation=opt.isViewExpansion,
                         isSpatialDropOut=opt.isSpatialDropOut,
                         dropoutRate=opt.dropoutRate)
        # netG.apply(weights_init)
    netG = netG.cuda()
    checkpoint = torch.load(opt.modelPath)
    #     netG.load_state_dict(checkpoint["model"].state_dict())
    netG.load_state_dict(checkpoint["model"])
    # netG.load_state_dict(torch.load(opt.modelPath)["state_dict"])

    if opt.isAdLoss:
        if opt.isNetDFullyConv:
            # netD = Discriminator_my23DLRResFCN(opt.in_channels_netD, opt.out_channels_netD, nd=opt.NDim)
            netD = Discriminator_my23dFCNv4(opt.in_channels_netD,
                                            opt.out_channels_netD,
                                            nd=opt.NDim)
        else:
            netD = Discriminator(opt.in_channels_netD,
                                 opt.out_channels_netD,
                                 nd=opt.NDim)
        netD = netD.cuda()
        netD.load_state_dict(torch.load(opt.netDModelPath))

    ids = [1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13]
    ids = [45, 46, 47, 48, 49]
    # ids = range(0,50)
    for ind in ids:
        mr_test_itk = sitk.ReadImage(
            os.path.join(path_test, 'Case%02d.nii.gz' % ind))
        ct_test_itk = sitk.ReadImage(
            os.path.join(path_test, 'Case%02d_segmentation.nii.gz' % ind))

        mrnp = sitk.GetArrayFromImage(mr_test_itk)
        mu = np.mean(mrnp)

        ctnp = sitk.GetArrayFromImage(ct_test_itk)

        #for training data in pelvicSeg
        if opt.how2normalize == 1:
            maxV, minV = np.percentile(mrnp, [99, 1])
            print 'maxV,', maxV, ' minV, ', minV
            mrnp = (mrnp - mu) / (maxV - minV)
            print 'unique value: ', np.unique(ctnp)

        #for training data in pelvicSeg
        elif opt.how2normalize == 2:
            maxV, minV = np.percentile(mrnp, [99, 1])
            print 'maxV,', maxV, ' minV, ', minV
            mrnp = (mrnp - mu) / (maxV - minV)
            print 'unique value: ', np.unique(ctnp)

        #for training data in pelvicSegRegH5
        elif opt.how2normalize == 3:
            std = np.std(mrnp)
            mrnp = (mrnp - mu) / std
            print 'maxV,', np.ndarray.max(mrnp), ' minV, ', np.ndarray.min(
                mrnp)

        elif opt.how2normalize == 4:
            maxV, minV = np.percentile(mrnp, [99.2, 1])
            print 'maxV is: ', np.ndarray.max(mrnp)
            mrnp[np.where(mrnp > maxV)] = maxV
            print 'maxV is: ', np.ndarray.max(mrnp)
            mu = np.mean(mrnp)
            std = np.std(mrnp)
            mrnp = (mrnp - mu) / std
            print 'maxV,', np.ndarray.max(mrnp), ' minV, ', np.ndarray.min(
                mrnp)

    #             full image version with average over the overlapping regions
    #             ct_estimated = testOneSubject(mrnp,ctnp,[3,168,112],[1,168,112],[1,8,8],netG,'Segmentor_model_%d.pt'%iter)

    # the attention regions
        row, col, leng = mrnp.shape
        if opt.isTestonAttentionRegion:
            # the attention regions
            y1 = int(leng * 0.25)
            y2 = int(leng * 0.75)
            x1 = int(col * 0.25)
            x2 = int(col * 0.75)
            matFA = mrnp[:, y1:y2, x1:x2]  # note, matFA and matFAOut same size
            matGT = ctnp[:, y1:y2, x1:x2]
        else:
            matFA = mrnp
            matGT = ctnp

        if opt.resType == 2:
            matOut, matProb, _ = testOneSubject(matFA,
                                                matFA,
                                                opt.out_channels,
                                                opt.input_sz,
                                                opt.output_sz,
                                                opt.test_step_sz,
                                                netG,
                                                opt.modelPath,
                                                resType=opt.resType,
                                                nd=opt.NDim)
        else:
            matOut, _ = testOneSubject(matFA,
                                       matFA,
                                       opt.out_channels,
                                       opt.input_sz,
                                       opt.output_sz,
                                       opt.test_step_sz,
                                       netG,
                                       opt.modelPath,
                                       resType=opt.resType,
                                       nd=opt.NDim)

        #matOut,_ = testOneSubject(matFA,matGT,opt.out_channels,opt.input_sz, opt.output_sz, opt.test_step_sz,netG,opt.modelPath, nd = opt.NDim)
        ct_estimated = np.zeros([mrnp.shape[0], mrnp.shape[1], mrnp.shape[2]])
        ct_prob = np.zeros(
            [opt.out_channels, mrnp.shape[0], mrnp.shape[1], mrnp.shape[2]])

        if opt.isTestonAttentionRegion:
            ct_estimated[:, y1:y2, x1:x2] = matOut
            ct_prob[:, :, y1:y2, x1:x2] = matProb
        else:
            ct_estimated = matOut
            ct_prob = matProb

        matProb_Bladder = np.squeeze(ct_prob[1, :, :, :])

        #         volout = sitk.GetImageFromArray(matProb_Bladder)
        #         sitk.WriteImage(volout,opt.prefixPredictedFN+'prob_Prostate_sub%02d'%ind+'.nii.gz')

        threshold = 0.9
        inds = np.where(matProb_Bladder > threshold)
        tmat = np.zeros(matProb_Bladder.shape)
        tmat[inds] = 1
        tmat = denoiseImg_closing(tmat, kernel=np.ones((20, 20, 20)))
        tmat = denoiseImg_isolation(tmat, struct=np.ones((3, 3, 3)))
        diceBladder = dice(tmat, ctnp, 1)
        #         diceBladder = dice(tmat,ctnp,1)
        print 'sub%d' % ind, 'dice1 = ', diceBladder
        volout = sitk.GetImageFromArray(tmat)
        sitk.WriteImage(
            volout, opt.prefixPredictedFN + 'threshSeg_sub{:02d}'.format(ind) +
            '.nii.gz')

        ct_estimated = np.rint(ct_estimated)
        ct_estimated = denoiseImg_closing(ct_estimated,
                                          kernel=np.ones((20, 20, 20)))
        ct_estimated = denoiseImg_isolation(ct_estimated,
                                            struct=np.ones((3, 3, 3)))
        diceBladder = dice(ct_estimated, ctnp, 1)

        #         diceProstate = dice(ct_estimated,ctnp,2)
        #         diceRectumm = dice(ct_estimated,ctnp,3)

        #         print 'pred: ',ct_estimated.dtype, ' shape: ',ct_estimated.shape
        #         print 'gt: ',ctnp.dtype,' shape: ',ctnp.shape
        #         print 'sub%d'%ind,'dice1 = ',diceBladder,' dice2= ',diceProstate,' dice3= ',diceRectumm
        diceBladder = dice(ct_estimated, ctnp, 1)
        print 'sub%d' % ind, 'dice1 = ', diceBladder
        volout = sitk.GetImageFromArray(ct_estimated)
        sitk.WriteImage(
            volout,
            opt.prefixPredictedFN + 'sub{:02d}'.format(ind) + '.nii.gz')
        volgt = sitk.GetImageFromArray(ctnp)
        sitk.WriteImage(volgt, 'gt_sub{:02d}'.format(ind) + '.nii.gz')

        ### for Discriminator network
        if opt.isNetDInputIncludeSource:
            input = np.expand_dims(mrnp, axis=0)
            matInput_netD = np.concatenate((input, ct_prob), axis=0)

        else:
            matInput_netD = matProb

        if opt.resType == 2:
            matConfLabel, matConfProb, _ = testOneSubjectWith4DInput(
                matInput_netD,
                matInput_netD,
                opt.out_channels_netD,
                opt.input_sz,
                opt.output_sz,
                opt.test_step_sz,
                netD,
                opt.netDModelPath,
                resType=opt.resType,
                nd=opt.NDim)
        else:
            matConfLabel, _ = testOneSubjectWith4DInput(matInput_netD,
                                                        matInput_netD,
                                                        opt.out_channels_netD,
                                                        opt.input_sz,
                                                        opt.output_sz,
                                                        opt.test_step_sz,
                                                        netD,
                                                        opt.netDModelPath,
                                                        resType=opt.resType,
                                                        nd=opt.NDim)

        matConfFGProb = np.squeeze(matConfProb[1, ...])
        matConfBGProb = np.squeeze(matConfProb[0, ...])
        volProb = sitk.GetImageFromArray(matConfFGProb)
        sitk.WriteImage(
            volProb, opt.prefixPredictedFN + 'confProb_sub{:02d}'.format(ind) +
            '.nii.gz')
        volProb = sitk.GetImageFromArray(matConfBGProb)
        sitk.WriteImage(
            volProb, opt.prefixPredictedFN +
            'confProb1_sub{:02d}'.format(ind) + '.nii.gz')
        volOut = sitk.GetImageFromArray(matConfLabel)
        sitk.WriteImage(
            volOut, opt.prefixPredictedFN + 'confLabel_sub{:02d}'.format(ind) +
            '.nii.gz')
Ejemplo n.º 5
0
def main():
    global opt
    opt = parser.parse_args()
    print opt

    path_test = '/home/dongnie/warehouse/pelvicSeg/prostateChallenge/data/'
    path_test = '/shenlab/lab_stor5/dongnie/challengeData/data/'
    path_test = '/shenlab/lab_stor5/dongnie/pelvic/'
    if opt.isSegReg:
        negG = ResSegRegNet(opt.in_channels, opt.out_channels, nd=opt.NDim)
    elif opt.isContourLoss:
        netG = ResSegContourNet(opt.in_channels,
                                opt.out_channels,
                                nd=opt.NDim,
                                isRandomConnection=opt.isResidualEnhancement,
                                isSmallDilation=opt.isViewExpansion,
                                isSpatialDropOut=opt.isSpatialDropOut,
                                dropoutRate=opt.dropoutRate)
    else:
        netG = ResSegNet(opt.in_channels,
                         opt.out_channels,
                         nd=opt.NDim,
                         isRandomConnection=opt.isResidualEnhancement,
                         isSmallDilation=opt.isViewExpansion,
                         isSpatialDropOut=opt.isSpatialDropOut,
                         dropoutRate=opt.dropoutRate)
    #netG.apply(weights_init)
    netG.cuda()

    checkpoint = torch.load(opt.modelPath)
    #     netG.load_state_dict(checkpoint["model"].state_dict())
    netG.load_state_dict(checkpoint["model"])
    #     netG.load_state_dict(torch.load(opt.modelPath))

    ids = [1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13]
    #     ids = [45,46,47,48,49]
    ids = [1, 2, 3, 4, 13, 29]
    for ind in ids:
        #         mr_test_itk=sitk.ReadImage(os.path.join(path_test,'Case%d.nii.gz'%ind))
        #         ct_test_itk=sitk.ReadImage(os.path.join(path_test,'Case%d_segmentation.nii.gz'%ind))
        mr_test_itk = sitk.ReadImage(
            os.path.join(path_test, 'img%d_nocrop.nii.gz' % ind))
        ct_test_itk = sitk.ReadImage(
            os.path.join(path_test, 'img%d_label_nie_nocrop.nii.gz' % ind))

        mrnp = sitk.GetArrayFromImage(mr_test_itk)
        mu = np.mean(mrnp)

        ctnp = sitk.GetArrayFromImage(ct_test_itk)

        #for training data in pelvicSeg
        if opt.how2normalize == 1:
            maxV, minV = np.percentile(mrnp, [99, 1])
            print 'maxV,', maxV, ' minV, ', minV
            mrnp = (mrnp - mu) / (maxV - minV)
            print 'unique value: ', np.unique(ctnp)

        #for training data in pelvicSeg
        elif opt.how2normalize == 2:
            maxV, minV = np.percentile(mrnp, [99, 1])
            print 'maxV,', maxV, ' minV, ', minV
            mrnp = (mrnp - mu) / (maxV - minV)
            print 'unique value: ', np.unique(ctnp)

        #for training data in pelvicSegRegH5
        elif opt.how2normalize == 3:
            std = np.std(mrnp)
            mrnp = (mrnp - mu) / std
            print 'maxV,', np.ndarray.max(mrnp), ' minV, ', np.ndarray.min(
                mrnp)

        elif opt.how2normalize == 4:
            maxV, minV = np.percentile(mrnp, [99.2, 1])
            print 'maxV is: ', np.ndarray.max(mrnp)
            mrnp[np.where(mrnp > maxV)] = maxV
            print 'maxV is: ', np.ndarray.max(mrnp)
            mu = np.mean(mrnp)
            std = np.std(mrnp)
            mrnp = (mrnp - mu) / std
            print 'maxV,', np.ndarray.max(mrnp), ' minV, ', np.ndarray.min(
                mrnp)

    #             full image version with average over the overlapping regions
    #             ct_estimated = testOneSubject(mrnp,ctnp,[3,168,112],[1,168,112],[1,8,8],netG,'Segmentor_model_%d.pt'%iter)

    # the attention regions


#         x1=80
#         x2=192
#         y1=35
#         y2=235
#         matFA = mrnp[:,y1:y2,x1:x2] #note, matFA and matFAOut same size
#         matGT = ctnp[:,y1:y2,x1:x2]
        matFA = mrnp
        matGT = ctnp

        if opt.resType == 2:
            matOut, matProb, _ = testOneSubject(matFA,
                                                matGT,
                                                opt.out_channels,
                                                opt.input_sz,
                                                opt.output_sz,
                                                opt.test_step_sz,
                                                netG,
                                                opt.modelPath,
                                                resType=opt.resType,
                                                nd=opt.NDim)
        else:
            matOut, _ = testOneSubject(matFA,
                                       matGT,
                                       opt.out_channels,
                                       opt.input_sz,
                                       opt.output_sz,
                                       opt.test_step_sz,
                                       netG,
                                       opt.modelPath,
                                       resType=opt.resType,
                                       nd=opt.NDim)

        #matOut,_ = testOneSubject(matFA,matGT,opt.out_channels,opt.input_sz, opt.output_sz, opt.test_step_sz,netG,opt.modelPath, nd = opt.NDim)
        ct_estimated = np.zeros([ctnp.shape[0], ctnp.shape[1], ctnp.shape[2]])
        ct_prob = np.zeros(
            [opt.out_channels, ctnp.shape[0], ctnp.shape[1], ctnp.shape[2]])
        #         print 'matOut shape: ',matOut.shape
        #         ct_estimated[:,y1:y2,x1:x2] = matOut
        ct_estimated = matOut
        ct_prob = matProb
        matProb_Bladder = np.squeeze(ct_prob[1, :, :, :])
        matProb_Prostate = np.squeeze(ct_prob[2, :, :, :])
        matProb_Rectum = np.squeeze(ct_prob[3, :, :, :])

        threshold = 0.9
        tmat_prob = np.zeros(matProb_Bladder.shape)
        tmat = np.zeros(matProb_Bladder.shape)

        #for bladder
        inds1 = np.where(matProb_Bladder > threshold)
        tmat_prob[inds1] = matProb_Bladder[inds1]
        tmat[inds1] = 1
        #for prostate
        inds2 = np.where(matProb_Prostate > threshold)
        tmat_prob[inds2] = matProb_Prostate[inds2]
        tmat[inds2] = 2
        #for rectum
        inds3 = np.where(matProb_Rectum > threshold)
        tmat_prob[inds3] = matProb_Rectum[inds3]
        tmat[inds3] = 3

        tmat = denoiseImg_closing(tmat, kernel=np.ones((20, 20, 20)))
        tmat = denoiseImg_isolation(tmat, struct=np.ones((3, 3, 3)))
        diceBladder = dice(tmat, ctnp, 1)
        diceProstate = dice(tmat, ctnp, 2)
        diceRectum = dice(tmat, ctnp, 3)
        print 'sub%d' % ind, 'dice1 = ', diceBladder, ' dice2= ', diceProstate, ' dice3= ', diceRectum

        #         volout = sitk.GetImageFromArray(matProb_Prostate)
        #         sitk.WriteImage(volout,opt.prefixPredictedFN+'probmap_Prostate_sub{:02d}'.format(ind)+'.nii.gz')

        volout = sitk.GetImageFromArray(tmat_prob)
        sitk.WriteImage(
            volout, opt.prefixPredictedFN +
            'threshold_probmap_sub{:02d}'.format(ind) + '.nii.gz')

        volout = sitk.GetImageFromArray(tmat)
        sitk.WriteImage(
            volout, opt.prefixPredictedFN +
            'threshold_segmap_sub{:02d}'.format(ind) + '.nii.gz')

        ct_estimated = np.rint(ct_estimated)
        ct_estimated = denoiseImg_closing(ct_estimated,
                                          kernel=np.ones((20, 20, 20)))
        ct_estimated = denoiseImg_isolation(ct_estimated,
                                            struct=np.ones((3, 3, 3)))

        #         diceProstate = dice(ct_estimated,ctnp,2)
        #         diceRectumm = dice(ct_estimated,ctnp,3)

        #         print 'pred: ',ct_estimated.dtype, ' shape: ',ct_estimated.shape
        #         print 'gt: ',ctnp.dtype,' shape: ',ctnp.shape
        #         print 'sub%d'%ind,'dice1 = ',diceBladder,' dice2= ',diceProstate,' dice3= ',diceRectumm
        diceBladder = dice(ct_estimated, ctnp, 1)
        diceProstate = dice(ct_estimated, ctnp, 2)
        diceRectum = dice(ct_estimated, ctnp, 3)
        print 'sub%d' % ind, 'dice1 = ', diceBladder, ' dice2= ', diceProstate, ' dice3= ', diceRectum
        volout = sitk.GetImageFromArray(ct_estimated)
        sitk.WriteImage(
            volout,
            opt.prefixPredictedFN + 'sub{:02d}'.format(ind) + '.nii.gz')
        volgt = sitk.GetImageFromArray(ctnp)
        sitk.WriteImage(volgt, 'gt_sub{:02d}'.format(ind) + '.nii.gz')
Ejemplo n.º 6
0
def main():    

########################################configs####################################
    global opt, model 
    opt = parser.parse_args()
    print opt
    print 'test my list, opt.input_sz: ',opt.input_sz
    given_weight = torch.FloatTensor([1,15]) #note, weights for each organ
    given_ids = torch.FloatTensor([0,1])
    given_weight = given_weight.cuda()
    given_ids = given_ids.cuda()
    path_test = '/home/dongnie/warehouse/mrs_data'
    path_test = '/shenlab/lab_stor5/dongnie/challengeData/data'
    path_test = '/home/dongnie/warehouse/pelvicSeg/prostateChallenge/data'
#     path_test = '/shenlab/lab_stor5/dongnie/pelvic'
    path_patients_h5 = '/home/dongnie/warehouse/BrainEstimation/brainH5'
    path_patients_h5 = '/home/dongnie/warehouse/pelvicSeg/pelvicH5'
    path_patients_h5 = '/home/dongnie/warehouse/pelvicSeg/pelvicSegRegH5'
    path_patients_h5 = '/home/dongnie/warehouse/pelvicSeg/pelvicSegRegContourBatchH5'
    path_patients_h5 = '/shenlab/lab_stor5/dongnie/challengeData/pelvicSegRegContourBatchH5'
    path_patients_h5 = '/home/dongnie/warehouse/pelvicSeg/prostateChallenge/pelvicSegRegContourBatchH5'
    path_patients_h5 = '/home/dongnie/warehouse/pelvicSeg/prostateChallenge/pelvic3DSegRegContourBatchH5'

#     path_patients_h5 = '/shenlab/lab_stor5/dongnie/pelvic/pelvicSegRegH5'
#     path_patients_h5 = '/home/dongnie/warehouse/pelvicSeg/pelvicSegRegPartH5/' #only contains 1-15
    path_patients_h5_test = '/home/dongnie/warehouse/pelvicSeg/pelvicSegRegContourH5Test'
    path_patients_h5_test = '/shenlab/lab_stor5/dongnie/challengeData/pelvicSegRegContourH5Test'
    path_patients_h5_test = '/home/dongnie/warehouse/pelvicSeg/prostateChallenge/pelvicSegRegContourH5Test'
    path_patients_h5_test = '/home/dongnie/warehouse/pelvicSeg/prostateChallenge/pelvic3DSegRegContourH5Test'
#     path_patients_h5_test ='/shenlab/lab_stor5/dongnie/pelvic/pelvicSegRegH5Test'
########################################configs####################################



    if opt.isSegReg:
        negG = ResSegRegNet(opt.in_channels, opt.out_channels, nd=opt.NDim)
    elif opt.isContourLoss:
        netG = ResSegContourNet(opt.in_channels, opt.out_channels, nd=opt.NDim, isRandomConnection=opt.isResidualEnhancement,isSmallDilation=opt.isViewExpansion, isSpatialDropOut=opt.isSpatialDropOut,dropoutRate=opt.dropoutRate)
    else:
        netG = ResSegNet(opt.in_channels, opt.out_channels, nd=opt.NDim, isRandomConnection=opt.isResidualEnhancement,isSmallDilation=opt.isViewExpansion, isSpatialDropOut=opt.isSpatialDropOut,dropoutRate=opt.dropoutRate)
    #netG.apply(weights_init)
    netG.cuda()
    
    if opt.isAdLoss:
        netD = Discriminator(1, nd=opt.NDim)
        netD.apply(weights_init)
        netD.cuda()
        optimizerD =optim.Adam(netD.parameters(),lr=opt.lr)
    
    params = list(netG.parameters())
    print('len of params is ')
    print(len(params))
    print('size of params is ')
    print(params[0].size())
    
    optimizerG =optim.Adam(netG.parameters(),lr=opt.lr)

    
    
    criterion_MSE = nn.MSELoss()
    
#     criterion_NLL2D = nn.NLLLoss2d(weight=given_weight)
    if opt.NDim==2:
        criterion_CEND = CrossEntropy2d(weight=given_weight)
    elif opt.NDim==3: 
        criterion_CEND = CrossEntropy3d(weight=given_weight)
    
    criterion_BCE2D = CrossEntropy2d()#for contours
    
#     criterion_dice = DiceLoss4Organs(organIDs=[1,2,3], organWeights=[1,1,1])
#     criterion_dice = WeightedDiceLoss4Organs()
    criterion_dice = myWeightedDiceLoss4Organs(organIDs=given_ids, organWeights = given_weight)
    
    criterion_focal = myFocalLoss(4, alpha=given_weight, gamma=2)
    
    criterion = nn.BCELoss()
    criterion = criterion.cuda()
    criterion_dice = criterion_dice.cuda()
    criterion_MSE = criterion_MSE.cuda()
    criterion_CEND = criterion_CEND.cuda()
    criterion_BCE2D = criterion_BCE2D.cuda()
    criterion_focal = criterion_focal.cuda()
    softmax2d = nn.Softmax2d()
#     inputs=Variable(torch.randn(1000,1,32,32)) #here should be tensor instead of variable
#     targets=Variable(torch.randn(1000,10,1,1)) #here should be tensor instead of variable
#     trainset=data_utils.TensorDataset(inputs, targets)
#     trainloader = data_utils.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
#     inputs=torch.randn(1000,1,32,32)
#     targets=torch.LongTensor(1000)
    


#     batch_size = 10
    if opt.NDim == 3:
        data_generator = Generator_3D_patches(path_patients_h5,opt.batchSize,inputKey='dataMR',outputKey='dataSeg')
        data_generator_test = Generator_3D_patches(path_patients_h5_test,opt.batchSize,inputKey='dataMR',outputKey='dataSeg')

    else:
        data_generator_test = Generator_2D_slices(path_patients_h5_test,opt.batchSize,inputKey='dataMR2D',outputKey='dataSeg2D')
        if opt.isSegReg:
            data_generator = Generator_2D_slices_variousKeys(path_patients_h5,opt.batchSize,inputKey='dataMR2D',outputKey='dataSeg2D',regKey1='dataBladder2D',regKey2='dataProstate2D',regKey3='dataRectum2D')
        elif opt.isContourLoss:
            data_generator = Generator_2D_slicesV1(path_patients_h5,opt.batchSize,inputKey='dataMR2D',segKey='dataSeg2D',contourKey='dataContour2D')
        else:
            data_generator = Generator_2D_slices(path_patients_h5,opt.batchSize,inputKey='dataMR2D',outputKey='dataSeg2D')
    
    
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1
            netG.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    
########### We'd better use dataloader to load a lot of data,and we also should train several epoches############### 
    running_loss = 0.0
    start = time.time()
    for iter in range(opt.start_epoch, opt.numofIters+1):
        #print('iter %d'%iter)
        
        if opt.isSegReg:
            inputs,labels, regGT1, regGT2, regGT3 = data_generator.next()
        elif opt.isContourLoss:
            inputs,labels,contours = data_generator.next()
        else:
            inputs,labels = data_generator.next()
            #print inputs.size,labels.size

        labels = np.squeeze(labels)
        labels = labels.astype(int)
        
        if opt.isContourLoss:
            contours = np.squeeze(contours)
            contours = contours.astype(int)
            contours = torch.from_numpy(contours)
            contours = contours.cuda()
            contours = Variable(contours)
        
        inputs = torch.from_numpy(inputs)
        labels = torch.from_numpy(labels)
        inputs = inputs.cuda()
        labels = labels.cuda()
        #we should consider different data to train
        
        #wrap them into Variable
        inputs,labels = Variable(inputs),Variable(labels)
        
        if opt.isAdLoss:
            #zero the parameter gradients
            #netD.zero_grad()
            #forward + backward +optimizer
            if opt.isSegReg:
                outputG, outputReg1, outputReg2, outputReg3 = netG(inputs)
            elif opt.isContourLoss:    
                outputG,_ = netG(inputs)
            else:
                outputG = netG(inputs)
                
                
            if opt.NDim==2:
                outputG = softmax2d(outputG) #batach
            elif opt.NDim==3:
                outputG = F.softmax(outputG, dim=1)
    
    
            outputG = outputG.data.max(1)[1]
            #outputG = torch.squeeze(outputG) #[N,C,W,H]
            labels = labels.unsqueeze(1) #expand the 1st dim
    #         print 'outputG: ',outputG.size(),'labels: ',labels.size()
            outputR = labels.type(torch.FloatTensor).cuda() #output_Real
            outputG = Variable(outputG.type(torch.FloatTensor).cuda())
            outputD_real = netD(outputR)
    #         print 'size outputG: ',outputG.unsqueeze(1).size()
            outputD_fake = netD(outputG.unsqueeze(1))
    
            
            ## (1) update D network: maximize log(D(x)) + log(1 - D(G(z)))
            netD.zero_grad()
            batch_size = inputs.size(0)
            #print(inputs.size())
            #train with real data
    #         real_label = torch.FloatTensor(batch_size)
    #         real_label.data.resize_(batch_size).fill_(1)
            real_label = torch.ones(batch_size,1)
            real_label = real_label.cuda()
            #print(real_label.size())
            real_label = Variable(real_label)
            #print(outputD_real.size())
            loss_real = criterion(outputD_real,real_label)
            loss_real.backward()
            #train with fake data
            fake_label=torch.zeros(batch_size,1)
    #         fake_label = torch.FloatTensor(batch_size)
    #         fake_label.data.resize_(batch_size).fill_(0)
            fake_label = fake_label.cuda()
            fake_label = Variable(fake_label)
            loss_fake = criterion(outputD_fake,fake_label)
            loss_fake.backward()
            
            lossD = loss_real + loss_fake
    
            optimizerD.step()
        
        ## (2) update G network: minimize the L1/L2 loss, maximize the D(G(x))
        
                #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the 
        #angel of equation (note the max and min difference for generator and discriminator)
        if opt.isAdLoss:
            if opt.isSegReg:
                outputG, outputReg1, outputReg2, outputReg3 = netG(inputs)
            elif opt.isContourLoss:
                outputG,_ = netG(inputs)
            else:
                outputG = netG(inputs)
            outputG = outputG.data.max(1)[1]
            outputG = Variable(outputG.type(torch.FloatTensor).cuda())
#         print 'outputG shape, ',outputG.size()

            outputD = netD(outputG.unsqueeze(1))
            outputD = F.sigmoid(outputD)
            averProb = outputD.data.cpu().mean()
#             print 'prob: ',averProb
#             adImportance = computeAttentionWeight(averProb)
            adImportance = computeSampleAttentionWeight(averProb)       
            lossG_D = opt.lambdaAD * criterion(outputD, real_label) #note, for generator, the label for outputG is real
            lossG_D.backward(retain_graph=True)
        
        if opt.isSegReg:
            outputG, outputReg1, outputReg2, outputReg3 = netG(inputs)
        elif opt.isContourLoss: 
            outputG,outputContour = netG(inputs)
        else:
            outputG = netG(inputs) #here I am not sure whether we should use twice or not
        netG.zero_grad()
        
        if opt.isFocalLoss:
            lossG_focal = criterion_focal(outputG,torch.squeeze(labels))
            lossG_focal.backward(retain_graph=True) #compute gradients
        
        if opt.isSoftmaxLoss:
            if opt.isSampleImportanceFromAd:
                lossG_G = (1+adImportance) * criterion_CEND(outputG,torch.squeeze(labels)) 
            else:
                lossG_G = criterion_CEND(outputG,torch.squeeze(labels)) 
                
            lossG_G.backward(retain_graph=True) #compute gradients
        
        if opt.isContourLoss:
            lossG_contour = criterion_BCE2D(outputContour,contours)
            lossG_contour.backward(retain_graph=True)
#         criterion_dice(outputG,torch.squeeze(labels))
#         print 'hahaN'
        if opt.isSegReg:
            lossG_Reg1 = criterion_MSE(outputReg1, regGT1)
            lossG_Reg2 = criterion_MSE(outputReg2, regGT2)
            lossG_Reg3 = criterion_MSE(outputReg3, regGT3)
            lossG_Reg = lossG_Reg1 + lossG_Reg2 + lossG_Reg3
            lossG_Reg.backward()

        if opt.isDiceLoss:
#             print 'isDiceLoss line278'
#             criterion_dice = myWeightedDiceLoss4Organs(organIDs=[0,1,2,3], organWeights=[1,4,8,6])
            if opt.isSampleImportanceFromAd:
                loss_dice = (1+adImportance) * criterion_dice(outputG,torch.squeeze(labels))
            else:
                loss_dice = criterion_dice(outputG,torch.squeeze(labels))
#             loss_dice = myDiceLoss4Organs(outputG,torch.squeeze(labels)) #succeed
#             loss_dice.backward(retain_graph=True) #compute gradients for dice loss
            loss_dice.backward() #compute gradients for dice loss
        
        #lossG_D.backward()

        
        #for other losses, we can define the loss function following the pytorch tutorial
        
        optimizerG.step() #update network parameters
#         print 'gradients of parameters****************************'
#         [x.grad.data for x in netG.parameters()]
#         print x.grad.data[0]
#         print '****************************'
        if opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss and opt.isSegReg and opt.isFocalLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_G+loss_dice.data[0] + lossG_Reg + lossG_focal
        if opt.isDiceLoss and opt.isSoftmaxLoss and opt.isSegReg and opt.isFocalLoss:
            lossG = lossG_G+loss_dice.data[0] + lossG_Reg + lossG_focal
        elif opt.isDiceLoss and opt.isFocalLoss and opt.isAdLoss and opt.isSegReg:
            lossG = opt.lambdaAD * lossG_D + lossG_focal + loss_dice.data[0] + lossG_Reg
        elif opt.isDiceLoss and opt.isFocalLoss and opt.isSegReg:
            lossG = lossG_focal + loss_dice.data[0] + lossG_Reg
        elif opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss and opt.isSegReg:
            lossG = opt.lambdaAD * lossG_D + lossG_G+loss_dice.data[0] + lossG_Reg
        elif opt.isDiceLoss and opt.isSoftmaxLoss and opt.isSegReg:
            lossG =  lossG_G+loss_dice.data[0] + lossG_Reg
        elif opt.isSoftmaxLoss and opt.isAdLoss and opt.isSegReg:
            lossG = opt.lambdaAD * lossG_D + lossG_G + lossG_Reg
        elif opt.isSoftmaxLoss and opt.isSegReg:
            lossG = lossG_G + lossG_Reg
        elif opt.isDiceLoss and opt.isAdLoss and opt.isSegReg:
            lossG = opt.lambdaAD * lossG_D + loss_dice.data[0] + lossG_Reg    
        elif opt.isDiceLoss and opt.isSegReg:
            lossG = loss_dice.data[0] + lossG_Reg    
        elif opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_G + loss_dice.data[0]
        elif opt.isDiceLoss and opt.isSoftmaxLoss:
            lossG = lossG_G + loss_dice.data[0]
        elif opt.isDiceLoss and opt.isFocalLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_focal + loss_dice.data[0]
        elif opt.isDiceLoss and opt.isFocalLoss:
            lossG = lossG_focal + loss_dice.data[0]      
        elif opt.isSoftmaxLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_G
        elif opt.isSoftmaxLoss:
            lossG = lossG_G
        elif opt.isFocalLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + lossG_focal  
        elif opt.isFocalLoss:
            lossG = lossG_focal      
        elif opt.isDiceLoss and opt.isAdLoss:
            lossG = opt.lambdaAD * lossG_D + loss_dice.data[0]
        elif opt.isDiceLoss:
            lossG = loss_dice.data[0]

        #print('loss for generator is %f'%lossG.data[0])
        #print statistics
        running_loss = running_loss + lossG.data[0]
#         print 'running_loss is ',running_loss,' type: ',type(running_loss)
        
#         print type(outputD_fake.cpu().data[0].numpy())
        
        if iter%opt.showTrainLossEvery==0: #print every 2000 mini-batches
            print '************************************************'
            print 'time now is: ' + time.asctime(time.localtime(time.time()))
            if opt.isAdLoss:
                print 'the outputD_real for iter {}'.format(iter), ' is ',outputD_real.cpu().data[0].numpy()[0]
                print 'the outputD_fake for iter {}'.format(iter), ' is ',outputD_fake.cpu().data[0].numpy()[0]
#             print 'running loss is ',running_loss
            print 'average running loss for generator between iter [%d, %d] is: %.3f'%(iter - 100 + 1,iter,running_loss/100)
            if opt.isAdLoss:
                print 'loss for discriminator at iter ',iter, ' is %f'%lossD.data[0]
            print 'total loss for generator at iter ',iter, ' is %f'%lossG.data[0]
            if opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss and opt.isSegReg:
                print 'lossG_D, lossG_G and loss_dice loss_Reg are %.2f, %.2f and %.2f respectively.'%(lossG_D.data[0], lossG_G.data[0], loss_dice.data[0], lossG_Reg.data[0])
            elif opt.isDiceLoss and opt.isSoftmaxLoss and opt.isAdLoss:
                print 'lossG_D, lossG_G and loss_dice are %.2f, %.2f and %.2f respectively.'%(lossG_D.data[0], lossG_G.data[0], loss_dice.data[0])
            elif opt.isDiceLoss and opt.isSoftmaxLoss:
                print 'lossG_G and loss_dice are %.2f and %.2f respectively.'%(lossG_G.data[0], loss_dice.data[0])
            elif opt.isDiceLoss and opt.isFocalLoss and opt.isAdLoss:
                print 'lossG_D, lossG_focal and loss_dice are %.2f, %.2f and %.2f respectively.'%(lossG_D.data[0], lossG_focal.data[0], loss_dice.data[0])    
            elif opt.isSoftmaxLoss and opt.isAdLoss:
                print 'lossG_D and lossG_G are %.2f and %.2f respectively.'%(lossG_D.data[0], lossG_G.data[0])
            elif opt.isFocalLoss and opt.isAdLoss:
                print 'lossG_D and lossG_focal are %.2f and %.2f respectively.'%(lossG_D.data[0], lossG_focal.data[0])    
            elif opt.isDiceLoss and opt.isAdLoss:
                print 'lossG_D and loss_dice are %.2f and %.2f respectively.'%(lossG_D.data[0], loss_dice.data[0])
            
            if opt.isContourLoss:
                print 'lossG_contour is {}'.format(lossG_contour.data[0])

            print 'cost time for iter [%d, %d] is %.2f'%(iter - 100 + 1,iter, time.time()-start)
            print '************************************************'
            running_loss = 0.0
            start = time.time()
        if iter%opt.saveModelEvery==0: #save the model
            torch.save(netG.state_dict(), opt.prefixModelName+'%d.pt'%iter)
            print 'save model: '+opt.prefixModelName+'%d.pt'%iter
        
        if iter%opt.decLREvery==0 and iter>0:
            opt.lr = opt.lr*0.1
            adjust_learning_rate(optimizerG, opt.lr)
            print 'now the learning rate is {}'.format(opt.lr)
        
        if iter%opt.showTestPerformanceEvery==0: #test one subject  
            # to test on the validation dataset in the format of h5 
            inputs,labels = data_generator_test.next()
            labels = np.squeeze(labels)
            labels = labels.astype(int)
            inputs = torch.from_numpy(inputs)
            labels = torch.from_numpy(labels)
            inputs = inputs.cuda()
            labels = labels.cuda()
            inputs,labels = Variable(inputs),Variable(labels)
            if opt.isSegReg:
                outputG, outputReg1, outputReg2, outputReg3 = netG(inputs)
            elif opt.isContourLoss: 
                outputG,_ = netG(inputs)
            else:
                outputG = netG(inputs) #here I am not sure whether we should use twice or not
            
            lossG_G = criterion_CEND(outputG,torch.squeeze(labels))
            loss_dice = criterion_dice(outputG,torch.squeeze(labels))
            print '.......come to validation stage: iter {}'.format(iter),'........'
            print 'lossG_G and loss_dice are %.2f and %.2f respectively.'%(lossG_G.data[0], loss_dice.data[0])
            
            ####release all the unoccupied memory####
            torch.cuda.empty_cache()

            mr_test_itk=sitk.ReadImage(os.path.join(path_test,'Case45.nii.gz'))
            ct_test_itk=sitk.ReadImage(os.path.join(path_test,'Case45_segmentation.nii.gz'))
            
            mrnp=sitk.GetArrayFromImage(mr_test_itk)
            mu=np.mean(mrnp)
    
            ctnp=sitk.GetArrayFromImage(ct_test_itk)
            
            #for training data in pelvicSeg
            if opt.how2normalize == 1:
                maxV, minV=np.percentile(mrnp, [99 ,1])
                print 'maxV,',maxV,' minV, ',minV
                mrnp=(mrnp-mu)/(maxV-minV)
                print 'unique value: ',np.unique(ctnp)

            #for training data in pelvicSeg
            elif opt.how2normalize == 2:
                maxV, minV = np.percentile(mrnp, [99 ,1])
                print 'maxV,',maxV,' minV, ',minV
                mrnp = (mrnp-mu)/(maxV-minV)
                print 'unique value: ',np.unique(ctnp)
            
            #for training data in pelvicSegRegH5
            elif opt.how2normalize== 3:
                std = np.std(mrnp)
                mrnp = (mrnp - mu)/std
                print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp)
                
            elif opt.how2normalize== 4:
                maxV, minV = np.percentile(mrnp, [99.2 ,1])
                print 'maxV is: ',np.ndarray.max(mrnp)
                mrnp[np.where(mrnp>maxV)] = maxV
                print 'maxV is: ',np.ndarray.max(mrnp)
                mu=np.mean(mrnp)
                std = np.std(mrnp)
                mrnp = (mrnp - mu)/std
                print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp)
    
#             full image version with average over the overlapping regions
#             ct_estimated = testOneSubject(mrnp,ctnp,[3,168,112],[1,168,112],[1,8,8],netG,'Segmentor_model_%d.pt'%iter)
            
            # the attention regions
#             x1=80
#             x2=192
#             y1=35
#             y2=235
#             matFA = mrnp[:,y1:y2,x1:x2] #note, matFA and matFAOut same size 
#             matGT = ctnp[:,y1:y2,x1:x2]
            matFA = mrnp
            matGT = ctnp
#                 volFA = sitk.GetImageFromArray(matFA)
#                 sitk.WriteImage(volFA,'volFA'+'.nii.gz')
#                 volGT = sitk.GetImageFromArray(matGT)
#                 sitk.WriteImage(volGT,'volGT'+'.nii.gz')
            
            matOut,_ = testOneSubject(matFA,matGT, opt.out_channels, opt.input_sz, opt.output_sz, opt.test_step_sz, netG,opt.prefixModelName+'%d.pt'%iter, nd=opt.NDim)
            ct_estimated = np.zeros([ctnp.shape[0],ctnp.shape[1],ctnp.shape[2]])
            print 'matOut shape: ',matOut.shape
#             ct_estimated[:,y1:y2,x1:x2] = matOut
            ct_estimated = matOut

            ct_estimated = np.rint(ct_estimated) 
            ct_estimated = denoiseImg_closing(ct_estimated, kernel=np.ones((20,20,20)))   
            ct_estimated = denoiseImg_isolation(ct_estimated, struct=np.ones((3,3,3)))
            diceBladder = dice(ct_estimated,ctnp,1)
#             diceProstate = dice(ct_estimated,ctnp,2)
#             diceRectumm = dice(ct_estimated,ctnp,3)
            
            print 'pred: ',ct_estimated.dtype, ' shape: ',ct_estimated.shape
            print 'gt: ',ctnp.dtype,' shape: ',ct_estimated.shape
            #print 'dice1 = ',diceBladder,' dice2= ',diceProstate,' dice3= ',diceRectumm
            print 'dice1 = ',diceBladder
            volout = sitk.GetImageFromArray(ct_estimated)
            sitk.WriteImage(volout,opt.prefixPredictedFN+'{}'.format(iter)+'.nii.gz')    
#             netG.save_state_dict('Segmentor_model_%d.pt'%iter)
#             netD.save_state_dic('Discriminator_model_%d.pt'%iter)
        
    print('Finished Training')