Example #1
0
    def compileValFunction(self):

        message = 'Compiling the Validation Function'
        self.logger.info(logMessage('+', message))

        startTime = time.time()

        valPrediction = get_output(self.outputLayer, 
                                     deterministic = True,
                                     batch_norm_update_averages=False, 
                                     batch_norm_use_averages=False)
        # TODO. Chack wheather the flatten style of targetvar and output are same.
        self.flattenedTargetVar = T.flatten(self.targetVar)

        valLoss = categorical_crossentropy(valPrediction, self.flattenedTargetVar).mean()
        weightNorm = regularize_network_params(self.outputLayer, lasagne.regularization.l2)
        valLoss += self.weightDecay * weightNorm

        valPredictionLabel = T.argmax(valPrediction, axis = 1)
        valACC = T.mean(T.eq(valPredictionLabel, self.flattenedTargetVar), 
                        dtype = theano.config.floatX)

        valFunc = theano.function([self.inputVar, self.targetVar], 
                                  [valLoss, valACC])
        
        message = 'Compiled the Validation Function, spent {:.2f}s'.format(time.time()- startTime)
        self.logger.info(logMessage('+', message))

        return valFunc
Example #2
0
    def complieTrainFunction(self):
        message = 'Compiling the Training Function'
        self.logger.info(logMessage('+', message))

        startTime = time.time()

        trainPrediction = get_output(self.outputLayer, 
                                     deterministic = False,
                                     batch_norm_update_averages=False, 
                                     batch_norm_use_averages=False)
        # TODO. Chack wheather the flatten style of targetvar and output are same.
        self.flattenedTargetVar = T.flatten(self.targetVar)

        trainLoss = categorical_crossentropy(trainPrediction, self.flattenedTargetVar).mean()
        weightNorm = regularize_network_params(self.outputLayer, lasagne.regularization.l2)
        trainLoss += self.weightDecay * weightNorm

        trainPredictionLabel = T.argmax(trainPrediction, axis = 1)
        trainACC = T.mean(T.eq(trainPredictionLabel, self.flattenedTargetVar), 
                          dtype = theano.config.floatX)
        
        params = get_all_params(self.outputLayer, trainable = True)
        update = self.optimizer(trainLoss, params, learning_rate = self.learningRate)

        trainFunc = theano.function([self.inputVar, self.targetVar], 
                                    [trainLoss, trainACC], 
                                    updates = update)
        
        message = 'Compiled the Training Function, spent {:.2f}s'.format(time.time()- startTime)
        self.logger.info(logMessage('+', message))

        return trainFunc
Example #3
0
    def saveWeights(self, fileName):

        message = 'Save Weights in {}'.format(fileName)
        self.logger.info(logMessage('+', message))
        fileNameWithPath = os.path.join(self.weightsFolder, fileName)
        np.savez(fileNameWithPath, 
                 *get_all_param_values(self.outputLayer))
Example #4
0
    def compileTestFunction(self):

        message = 'Compiling the Test Function'
        self.logger.info(logMessage('+', message))

        startTime = time.time()

        testPrediction = get_output(self.outputLayer, 
                                    deterministic = True,
                                    batch_norm_use_averages=False)
       
        testPredictionLabel = T.argmax(testPrediction, axis = 1)

        testFunc = theano.function([self.inputVar], [testPredictionLabel])
        
        message = 'Compiled the Test Function, spent {:.2f}s'.format(time.time()- startTime)
        self.logger.info(logMessage('+', message))

        return testFunc
Example #5
0
def generateNetwork(configFile):

    logger = logging.getLogger(__name__)

    # Get config infomation
    configInfo = {}
    execfile(configFile, configInfo)

    # Choose the network type
    # ====================================================================================
    # The networksDict should conresponding to the import statement.
    networksDict = {
        'sectorNet': SectorNet,
        'baseNet': BaseNet,
        'dilated3DNet': Dilated3DNet
    }

    networkType = configInfo['networkType']
    networkClass = networksDict[networkType]
    # ====================================================================================

    # Generate the network
    # ====================================================================================
    preTrainedWeights = configInfo['preTrainedWeights']
    if preTrainedWeights == '':
        message = 'We will create a new network'
        logger.info(logMessage(' ', message))
    else:
        message = 'We will use a pre trained network'
        logger.info(logMessage(' ', message))

    message = 'Creating {}'.format(configInfo['networkName'])
    logger.info(logMessage('#', message))
    network = networkClass(configFile)
    assert network.networkType == networkType
    message = 'Created {}'.format(networkType)
    logger.info(logMessage('#', message))
    # =====================================================================================

    return network
Example #6
0
    def restoreWeights(self):

        if self.preTrainedWeights == '':
            return

        assert self.preTrainedWeights != ''

        message = 'Load Weights from {}'.format(self.preTrainedWeights)
        self.logger.info(logMessage('+', message))

        with np.load(self.preTrainedWeights) as f:
            savedWeights = [f['arr_{:d}'.format(i)] 
                            for i in range(len(f.files))]

        set_all_param_values(self.outputLayer, savedWeights)
Example #7
0
    def buildBaseNet(self, inputShape=(None, 4, 25, 25, 25), forSummary=False):

        if not forSummary:
            message = 'Building the Architecture of BaseNet'
            self.logger.info(logMessage('+', message))

        baseNet = InputLayer(self.inputShape, self.inputVar)

        if not forSummary:
            message = 'Building the convolution layers'
            self.logger.info(logMessage('-', message))

        kernelShapeListLen = len(self.kernelNumList)

        summary = '\n' + '.' * 130 + '\n'
        summary += '    {:<15} {:<50} {:<29} {:<29}\n'.format(
            'Layer', 'Input shape', 'W shape', 'Output shape')
        summary += '.' * 130 + '\n'

        summary += '{:<3} {:<15} {:<50} {:<29} {:<29}\n'.format(
            1, 'Input', inputShape, '',
            get_output_shape(baseNet, input_shapes=inputShape))

        for i in xrange(kernelShapeListLen - 1):

            kernelShape = self.kernelShapeList[i]
            kernelNum = self.kernelNumList[i]

            conv3D = Conv3DLayer(incoming=baseNet,
                                 num_filters=kernelNum,
                                 filter_size=kernelShape,
                                 W=HeNormal(gain='relu'),
                                 nonlinearity=linear,
                                 name='Conv3D{}'.format(i))

            # Just for summary the fitler shape.
            WShape = conv3D.W.get_value().shape

            summary += '{:<3} {:<15} {:<50} {:<29} {:<29}\n'.format(
                i + 2, 'Conv3D',
                get_output_shape(baseNet, input_shapes=inputShape), WShape,
                get_output_shape(conv3D, input_shapes=inputShape))

            batchNormLayer = BatchNormLayer(conv3D)
            preluLayer = prelu(batchNormLayer)

            concatLayerInputShape = '{:<25}{:<25}'.format(
                get_output_shape(conv3D, input_shapes=inputShape),
                get_output_shape(baseNet, input_shapes=inputShape))

            baseNet = ConcatLayer(
                [preluLayer, baseNet],
                1,
                cropping=['center', 'None', 'center', 'center', 'center'])

            summary += '    {:<15} {:<50} {:<29} {:<29}\n'.format(
                'Concat', concatLayerInputShape, '',
                get_output_shape(baseNet, input_shapes=inputShape))
        if not forSummary:
            message = 'Finish Built the convolution layers'
            self.logger.info(logMessage('-', message))

            message = 'Building the last classfication layers'
            self.logger.info(logMessage('-', message))

        assert self.kernelShapeList[-1] == [1, 1, 1]

        kernelShape = self.kernelShapeList[-1]
        kernelNum = self.kernelNumList[-1]

        conv3D = Conv3DLayer(incoming=baseNet,
                             num_filters=kernelNum,
                             filter_size=kernelShape,
                             W=HeNormal(gain='relu'),
                             nonlinearity=linear,
                             name='Classfication Layer')

        receptiveFieldList = [
            inputShape[idx] -
            get_output_shape(conv3D, input_shapes=inputShape)[idx] + 1
            for idx in xrange(-3, 0)
        ]
        assert receptiveFieldList != []
        receptiveFieldSet = set(receptiveFieldList)
        assert len(receptiveFieldSet) == 1, (receptiveFieldSet, inputShape,
                                             get_output_shape(
                                                 conv3D,
                                                 input_shapes=inputShape))
        self.receptiveField = list(receptiveFieldSet)[0]

        # Just for summary the fitler shape.
        WShape = conv3D.W.get_value().shape

        summary += '{:<3} {:<15} {:<50} {:<29} {:<29}\n'.format(
            kernelShapeListLen + 1, 'Conv3D',
            get_output_shape(baseNet, input_shapes=inputShape), WShape,
            get_output_shape(conv3D, input_shapes=inputShape))

        # The output shape should be (batchSize, numOfClasses, zSize, xSize, ySize).
        # We will reshape it to (batchSize * zSize * xSize * ySize, numOfClasses),
        # because, the softmax in theano can only receive matrix.

        baseNet = DimshuffleLayer(conv3D, (0, 2, 3, 4, 1))
        summary += '    {:<15} {:<50} {:<29} {:<29}\n'.format(
            'Dimshuffle', get_output_shape(conv3D, input_shapes=inputShape),
            '', get_output_shape(baseNet, input_shapes=inputShape))

        batchSize, zSize, xSize, ySize, _ = get_output(baseNet).shape
        reshapeLayerInputShape = get_output_shape(baseNet,
                                                  input_shapes=inputShape)
        baseNet = ReshapeLayer(baseNet,
                               (batchSize * zSize * xSize * ySize, kernelNum))
        summary += '    {:<15} {:<50} {:<29} {:<29}\n'.format(
            'Reshape', reshapeLayerInputShape, '',
            get_output_shape(baseNet, input_shapes=inputShape))

        nonlinearityLayerInputShape = get_output_shape(baseNet,
                                                       input_shapes=inputShape)
        baseNet = NonlinearityLayer(baseNet, softmax)
        summary += '    {:<15} {:<50} {:<29} {:<29}\n'.format(
            'Nonlinearity', nonlinearityLayerInputShape, '',
            get_output_shape(baseNet, input_shapes=inputShape))

        if not forSummary:
            message = 'Finish Built the last classfication layers'
            self.logger.info(logMessage('-', message))

            message = 'The Receptivr Field of BaseNet equal {}'.format(
                self.receptiveField)
            self.logger.info(logMessage('*', message))

            message = 'Finish Building the Architecture of BaseNet'
            self.logger.info(logMessage('+', message))

        summary += '.' * 130 + '\n'
        self._summary = summary

        return baseNet
Example #8
0
def trainNetwork(network, configFile):

    logger = logging.getLogger(__name__)

    message = 'Training {}'.format(network.networkType)
    logger.info(logMessage('#', message))

    # Get config infomation
    configInfo = {}
    execfile(configFile, configInfo)

    # Network information
    # ==============================================================================
    # Just for rebuild the network than we can get the network summary conresponding
    # the trainSampleSize
    # Read network information
    trainSampleSize = configInfo['trainSampleSize']
    networkType = network.networkType
    receptiveField = network.receptiveField
    networkSummary = network.summary(trainSampleSize)
    # ------------------------------------------------------------------------------
    # Logger network summary
    message = 'Network Summary'
    logger.info(logMessage('*', message))
    logger.info(networkSummary)
    logger.info(logMessage('-', '-'))
    tableRowList = []
    tableRowList.append(['-', '-'])
    tableRowList.append(['Network Type', networkType])
    tableRowList.append(['Receptive Field', receptiveField])
    tableRowList.append(['-', '-'])

    logger.info(logTable(tableRowList))
    logger.info(logMessage('*', '*'))
    # =============================================================================

    # Training and validation data information
    # =============================================================================
    # Read training and validation data information
    imageFolder = configInfo['imageFolder']
    imageGrades = configInfo['imageGrades']
    numOfPatients = configInfo['numOfPatients']
    modals = configInfo['modals']
    useROI = configInfo['useROI']
    normType = configInfo['normType']
    weightMapType = configInfo['weightMapType']
    # ----------------------------------------------------------------------------
    # Logger training and validation data information
    message = 'Training and Validation Data Summary'
    logger.info(logMessage('*', message))

    tableRowList = []
    tableRowList.append(['-', '-'])
    tableRowList.append(['Image Folder', imageFolder])
    tableRowList.append(['Image Grades', imageGrades])
    tableRowList.append(['Number of Patients', numOfPatients])
    tableRowList.append(['Modals', modals])
    tableRowList.append(['Use ROI', useROI])
    tableRowList.append(['Normalization Type', normType])
    tableRowList.append(['Weight Map Type', weightMapType])
    tableRowList.append(['-', '-'])

    logger.info(logTable(tableRowList))
    logger.info(logMessage('*', '*'))
    # ===========================================================================

    # Training and validation setting infomation
    # ===========================================================================
    # Read training and validation setting information
    trainValRatio = configInfo['trainValRatio']
    memoryThreshold = configInfo['memoryThreshold']
    usePoolToSample = configInfo['usePoolToSample']
    numOfEpochs = configInfo['numOfEpochs']
    numOfSubEpochs = configInfo['numOfSubEpochs']
    batchSize = configInfo['batchSize']
    trainSampleSize = configInfo['trainSampleSize']
    valSampleSize = configInfo['valSampleSize']
    numOfTrainSamplesPerSubEpoch = configInfo['numOfTrainSamplesPerSubEpoch']
    numOfValSamplesPerSubEpoch = int(
        float(numOfTrainSamplesPerSubEpoch) / trainValRatio)
    weightsFolder = configInfo['weightsFolder']
    assert batchSize < numOfTrainSamplesPerSubEpoch
    assert batchSize < numOfValSamplesPerSubEpoch
    # ---------------------------------------------------------------------------
    # Logger training and validation setting infomation
    message = 'Training and Validation setting Summary'
    logger.info(logMessage('*', message))

    tableRowList = []
    tableRowList.append(['-', '-'])
    tableRowList.append(['Training / validation', trainValRatio])
    tableRowList.append(
        ['Memory Threshold for Subepoch', '{}G'.format(memoryThreshold)])
    tableRowList.append(
        ['Wheather Use MultiProcess to Sample', usePoolToSample])
    tableRowList.append(['Number of Epochs', numOfEpochs])
    tableRowList.append(['Number of Subepochs', numOfSubEpochs])
    tableRowList.append(['Batch Size', batchSize])
    tableRowList.append(['Training Samples Size', trainSampleSize])
    tableRowList.append(['Validation Samples Size', valSampleSize])
    tableRowList.append([
        'Number of Training Samples for Subepoch', numOfTrainSamplesPerSubEpoch
    ])
    tableRowList.append([
        'Number of Validation Samples for Subepoch', numOfValSamplesPerSubEpoch
    ])
    tableRowList.append(
        ['Folder to Store Weights During Training', weightsFolder])
    tableRowList.append(['-', '-'])

    logger.info(logTable(tableRowList))
    logger.info(logMessage('*', '*'))
    # ===========================================================================

    # Prepare patients file dir list
    # ===========================================================================
    # Read patients dir list
    patientsDirList = []
    gradeDirList = [os.path.join(imageFolder, grade) for grade in imageGrades]
    for gradeDir in gradeDirList:
        patientsDirList += [
            os.path.join(gradeDir, patient) for patient in os.listdir(gradeDir)
        ]

    # Make sure there are no same elements in the patientsDirList
    assert len(patientsDirList) == len(set(patientsDirList))
    random.shuffle(patientsDirList)
    patientsDirList = patientsDirList[:numOfPatients]
    # ---------------------------------------------------------------------------
    # Divide patients dir in two part according to trainValRatio
    patsDirForValList = patientsDirList[::trainValRatio + 1]
    patsDirForTrainList = [
        patsDir for patsDir in patientsDirList
        if patsDir not in patsDirForValList
    ]
    assert len(patsDirForValList) + len(patsDirForTrainList) == len(
        patientsDirList)

    message = 'We get {} patients files. '.format(len(patientsDirList))
    message += 'Randomly choose {} for training '.format(
        len(patsDirForTrainList))
    message += 'and {} for validation'.format(len(patsDirForValList))
    logger.info(logMessage(' ', message))
    # ---------------------------------------------------------------------------
    # Prepare training patients dir list for each subepoch, because the
    # training data may be so large and need to be split for each subepoch
    numOfModals = len(modals)
    memoryNeededPerPatData = 0.035 * numOfModals + int(useROI) * 0.020
    # memortThreshold should large than a single patient need memory
    assert memoryThreshold > memoryNeededPerPatData

    maxPatNumPerSubEpoch = math.floor(memoryThreshold / memoryNeededPerPatData)
    maxPatNumPerSubEpoch = int(maxPatNumPerSubEpoch)
    patDirPerSubEpochDict = {}
    # If len(patsDirForTrainList) < maxPatNumPerSubEpoch,
    # each will use same patients dir
    for subEpIdx in xrange(numOfSubEpochs):
        chosenPatsDir = patsDirForTrainList[:maxPatNumPerSubEpoch]
        patDirPerSubEpochDict[subEpIdx] = chosenPatsDir
        random.shuffle(patsDirForTrainList)
    # ===========================================================================

    # Prepare a table to record and show training and validation results
    # ===========================================================================
    trainTRowList = []
    trainTRowList.append(['-', '-', '-', '-', '-', '-'])
    trainTRowList.append([
        'EPOCH', 'SUBEPOCH', 'Train Time', 'Train Loss', 'Train ACC',
        'Sampling Time'
    ])
    trainTRowList.append(['-', '-', '-', '-', '-', '-'])
    # ***************************************************************************
    valTRowList = []
    valTRowList.append(
        ['-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-'])
    valTRowList.append([
        'EPOCH',
        'SUBEPOCH',
        'Val Time',
        'CT Dice',
        'CT Sens',
        'CT Spec',
        'Core Dice',
        'Core Sens',
        'Core Spec',
        'Eh Dice',
        'Eh Sens',
        'Eh Spec',
    ])
    valTRowList.append(
        ['-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-'])
    # ===========================================================================

    # Prepare folder to store network weights during training
    # ===========================================================================
    storeTime = time.strftime('%y-%m-%d_%H:%M:%S')
    weightsDir = os.path.join(weightsFolder, str(storeTime))
    os.mkdir(weightsDir)
    # ===========================================================================

    # Train and Val
    # ##################################################################################################
    for epIdx in xrange(numOfEpochs):

        message = 'EPOCH: {}/{}'.format(epIdx + 1, numOfEpochs)
        logger.info(logMessage('+', message))

        # Initial some epoch recordor
        # ==============================================================================================
        trainEpLoss = 0
        trainEpACC = 0
        trainEpBatchNum = 0
        # **********************************************************************************************
        valEpCTDice = 0
        valEpCTSens = 0
        valEpCTSpec = 0

        valEpCoreDice = 0
        valEpCoreSens = 0
        valEpCoreSpec = 0

        valEpEnDice = 0
        valEpEnSens = 0
        valEpEnSpec = 0

        valEpPatsNum = 0
        # **********************************************************************************************
        epTrainSampleTime = 0
        epTrainTime = 0
        epValTime = 0
        # ==============================================================================================

        for subEpIdx in xrange(numOfSubEpochs):

            message = 'SUBEPOCH: {}/{}'.format(subEpIdx + 1, numOfSubEpochs)
            logger.info(logMessage('-', message))

            # Training
            # ==========================================================================================
            # Train sample training data
            trainSampleTime = time.time()
            message = 'Sampling Training Data'
            logger.info(logMessage('-', message))
            trainSampleAndLabelList = getSamplesForSubEpoch(
                numOfTrainSamplesPerSubEpoch, patDirPerSubEpochDict[subEpIdx],
                useROI, modals, normType, trainSampleSize, receptiveField,
                weightMapType, usePoolToSample)

            trainSamplesList, trainLabelsList = trainSampleAndLabelList
            trainSampleTime = time.time() - trainSampleTime
            epTrainSampleTime += trainSampleTime
            # -----------------------------------------------------------------------------------------
            # Prepare for train batch loop
            trainBatchIdxList = [
                trainBatchIdx for trainBatchIdx in xrange(
                    0, numOfTrainSamplesPerSubEpoch, batchSize)
            ]

            # For the last batch not to be too small.
            trainBatchIdxList[-1] = numOfTrainSamplesPerSubEpoch
            assert len(trainBatchIdxList) > 1
            trainBatchNum = len(trainBatchIdxList[:-1])
            # -----------------------------------------------------------------------------------------
            # Training batch loop
            trainSubEpLoss = 0
            trainSubEpACC = 0
            trainSubEpBatchNum = 0
            trainTime = time.time()
            message = 'Training'
            logger.info(logMessage(':', message))
            # ........................................................................................
            for trainBatchIdx in xrange(trainBatchNum):
                # Just for clear.
                trainStartIdx = trainBatchIdxList[trainBatchIdx]
                trainEndIdx = trainBatchIdxList[trainBatchIdx + 1]

                trainSamplesBatch = trainSamplesList[trainStartIdx:trainEndIdx]
                trainSamplesBatch = np.asarray(trainSamplesBatch,
                                               dtype=theano.config.floatX)

                trainLabelsBatch = trainLabelsList[trainStartIdx:trainEndIdx]
                trainLabelsBatch = np.asarray(trainLabelsBatch, dtype='int32')

                trainBatchLoss, trainBatchAcc = network.trainFunction(
                    trainSamplesBatch, trainLabelsBatch)
                # Record subepoch training results.
                trainSubEpLoss += trainBatchLoss
                trainSubEpACC += trainBatchAcc
                trainSubEpBatchNum += 1
            trainTime = time.time() - trainTime
            epTrainTime += trainTime
            # ........................................................................................
            # Release source
            del trainSamplesList[:], trainLabelsList[:]
            del trainSamplesList, trainLabelsList
            del trainSamplesBatch, trainLabelsBatch
            gc.collect()
            # ========================================================================================

            # Validation
            # ========================================================================================
            valSubEpCTDice = 0
            valSubEpCTSens = 0
            valSubEpCTSpec = 0

            valSubEpCoreDice = 0
            valSubEpCoreSens = 0
            valSubEpCoreSpec = 0

            valSubEpEhDice = 0
            valSubEpEhSens = 0
            valSubEpEhSpec = 0

            valSubEpPatsNum = 0
            valTime = time.time()
            message = 'Validation'
            logger.info(logMessage(':', message))
            for patIdx, patientDir in enumerate(patsDirForValList):
                logger.info('Val {}/{} patient'.format(patIdx + 1,
                                                       len(patsDirForValList)))
                segmentResult, segmentResultMask, gTArray = segmentWholeBrain(
                    network, patientDir, useROI, modals, normType,
                    valSampleSize, receptiveField, False, batchSize)
                assert gTArray != []
                cTDice, cTSens, cTSpeci = voxleWiseMetrics(
                    segmentResult, gTArray, [1, 2, 3, 4])
                coreDice, cTSens, cTSpec = voxleWiseMetrics(
                    segmentResult, gTArray, [1, 3, 4])
                ehDice, ehSens, ehSpec = voxleWiseMetrics(
                    segmentResult, gTArray, [4])
                valSubEpCTDice += cTDice
                valSubEpCTSens += cTSens
                valSubEpCTSpec += cTSpeci

                valSubEpCoreDice += coreDice
                valSubEpCoreSens += cTSens
                valSubEpCoreSpec += cTSpec

                valSubEpEhDice += ehDice
                valSubEpEhSens += ehSens
                valSubEpEhSpec += ehSpec

                del segmentResult, segmentResultMask, gTArray
                gc.collect()

            valSubEpPatsNum = len(patientDir)
            valTime = time.time() - valTime
            epValTime += valTime
            # =========================================================================================

            # Record epooch results and compute subepoch results
            # =========================================================================================
            # Record epooch training results and compute training subepoch results
            # -----------------------------------------------------------------------------------------
            # Record training epoch results
            trainEpLoss += trainSubEpLoss
            trainEpACC += trainSubEpACC
            trainEpBatchNum += trainSubEpBatchNum
            # -----------------------------------------------------------------------------------------
            # Compute training subepoch results
            trainSubEpLoss /= trainSubEpBatchNum
            trainSubEpACC /= trainSubEpBatchNum
            # *****************************************************************************************
            # Record epooch validation results and computer validation subepoch results
            # -----------------------------------------------------------------------------------------
            # Record validation epoch results
            valEpCTDice += valSubEpCTDice
            valEpCTSens += valSubEpCTSens
            valEpCTSpec += valSubEpCTSpec

            valEpCoreDice += valSubEpCoreDice
            valEpCoreSens += valSubEpCoreSens
            valEpCoreSpec += valSubEpCoreSpec

            valEpEnDice += valSubEpEhDice
            valEpEnSens += valSubEpEhSens
            valEpEnSpec += valSubEpEhSpec

            valEpPatsNum += valSubEpPatsNum
            # -----------------------------------------------------------------------------------------
            # Compute validation subepoch results
            valSubEpCTDice /= valSubEpPatsNum
            valSubEpCTSens /= valSubEpPatsNum
            valSubEpCTSpec /= valSubEpPatsNum

            valSubEpCoreDice /= valSubEpPatsNum
            valSubEpCoreSens /= valSubEpPatsNum
            valSubEpCoreSpec /= valSubEpPatsNum

            valSubEpEhDice /= valSubEpPatsNum
            valSubEpEhSens /= valSubEpPatsNum
            valSubEpEhSpec /= valSubEpPatsNum
            # =========================================================================================

            # Recording for subEpoch row of table
            # =========================================================================================
            indexColumn = epIdx + 1 if subEpIdx == 0 else ''
            # Recording for subEpoch row of training table
            # -----------------------------------------------------------------------------------------
            trainSampleTime = '{:.3}'.format(trainSampleTime)
            trainTime = '{:.3}'.format(trainTime)
            trainSubEpLoss = '{:.6f}'.format(trainSubEpLoss)
            trainSubEpACC = '{:.6f}'.format(trainSubEpACC)

            trainTRowList.append([
                indexColumn, subEpIdx + 1, trainTime, trainSubEpLoss,
                trainSubEpACC, trainSampleTime
            ])
            # *****************************************************************************************
            # Recording for subepoch row of validation table
            valTime = '{:.3}'.format(valTime)
            valSubEpCTDice = '{:.4f}'.format(valSubEpCTDice)
            valSubEpCTSens = '{:.4f}'.format(valSubEpCTSens)
            valSubEpCTSpec = '{:.4f}'.format(valSubEpCTSpec)

            valSubEpCoreDice = '{:.4f}'.format(valSubEpCoreDice)
            valSubEpCoreSens = '{:.4f}'.format(valSubEpCoreSens)
            valSubEpCoreSpec = '{:.4f}'.format(valSubEpCoreSpec)

            valSubEpEhDice = '{:.4f}'.format(valSubEpEhDice)
            valSubEpEhSens = '{:.4f}'.format(valSubEpEhSens)
            valSubEpEhSpec = '{:.4f}'.format(valSubEpEhSpec)

            valTRowList.append([
                indexColumn, subEpIdx + 1, valTime, valSubEpCTDice,
                valSubEpCTSens, valSubEpCTSpec, valSubEpCoreDice,
                valSubEpCoreSens, valSubEpCoreSpec, valSubEpEhDice,
                valSubEpEhSens, valSubEpEhSpec
            ])
            # =========================================================================================

            # Subepoch logger
            # =========================================================================================
            message = 'Subepoch: {}/{} '.format(subEpIdx + 1, numOfSubEpochs)
            message += ' Train Loss: {}, '.format(trainSubEpLoss)
            message += ' Train ACC: {}'.format(trainSubEpACC)
            logger.info(logMessage('-', message))
            message = 'Subepoch: {}/{} '.format(subEpIdx + 1, numOfSubEpochs)
            message += ' Val Core Dice: {}, '.format(valSubEpCoreDice)
            message += ' Val Core Sens: {}'.format(valSubEpCoreSens)
            message += ' Val Core Spec: {}'.format(valSubEpCoreSpec)
            logger.info(logMessage('-', message))
            # =========================================================================================

        # Compute epoch results
        # =============================================================================================
        # Compute epoch training results
        trainEpLoss /= trainEpBatchNum
        trainEpACC /= trainEpBatchNum
        # *********************************************************************************************
        # Compute epoch validation results
        valEpCTDice /= valEpPatsNum
        valEpCTSens /= valEpPatsNum
        valEpCTSpec /= valEpPatsNum

        valEpCoreDice /= valEpPatsNum
        valEpCoreSens /= valEpPatsNum
        valEpCoreSpec /= valEpPatsNum

        valEpEnDice /= valEpPatsNum
        valEpEnSens /= valEpPatsNum
        valEpEnSpec /= valEpPatsNum
        # =============================================================================================

        # Recording for subEpoch row of table
        # =============================================================================================
        epTrainSampleTime = '{:.3}'.format(epTrainSampleTime)
        epTrainTime = '{:.3}'.format(epTrainTime)
        epValTime = '{:.3}'.format(epValTime)
        # ---------------------------------------------------------------------------------------------
        trainEpLoss = '{:.6f}'.format(trainEpLoss)
        trainEpACC = '{:.6f}'.format(trainEpACC)
        # ---------------------------------------------------------------------------------------------
        trainTRowList.append(['-', '-', '-', '-', '-', '-'])
        trainTRowList.append(
            ['', '', epTrainSampleTime, trainEpLoss, trainEpACC, epTrainTime])
        trainTRowList.append(['-', '-', '-', '-', '-', '-'])
        # *********************************************************************************************
        valEpCTDice = '{:.4f}'.format(valEpCTDice)
        valEpCTSens = '{:.4f}'.format(valEpCTSens)
        valEpCTSpec = '{:.4f}'.format(valEpCTSpec)

        valEpCoreDice = '{:.4f}'.format(valEpCoreDice)
        valEpCoreSens = '{:.4f}'.format(valEpCoreSens)
        valEpCoreSpec = '{:.4f}'.format(valEpCoreSpec)

        valEpEnDice = '{:.4f}'.format(valEpEnDice)
        valEpEnSens = '{:.4f}'.format(valEpEnSens)
        valEpEnSpec = '{:.4f}'.format(valEpEnSpec)
        # ---------------------------------------------------------------------------------------------
        valTRowList.append(
            ['-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-'])
        valTRowList.append([
            '', '', epValTime, valEpCTDice, valEpCTSens, valEpCTSpec,
            valEpCoreDice, valEpCoreSens, valEpCoreSpec, valEpEnDice,
            valEpEnSens, valEpEnSpec
        ])
        valTRowList.append(
            ['-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-'])
        # =============================================================================================

        # Epoch logger
        # =============================================================================================
        message = 'Epoch: {}/{} '.format(epIdx + 1, numOfEpochs)
        message += ' Train Loss: {}, '.format(trainEpLoss)
        message += ' Train ACC: {}'.format(trainEpACC)
        logger.info(logMessage('+', message))
        message = 'Epoch: {}/{} '.format(epIdx + 1, numOfEpochs)
        message += ' Val Core Dice: {}, '.format(valEpCoreDice)
        message += ' Val Core Sens: {}'.format(valEpEnSens)
        message += ' Val Core Spec: {}'.format(valEpEnSpec)
        logger.info(logMessage('+', message))
        # =============================================================================================

        # Store network weights
        # =============================================================================================
        weightsFileName = '{}_{}'.format(epIdx, subEpIdx)
        weightsFileNameWithPath = os.path.join(weightsDir, weightsFileName)
        network.saveWeights(weightsFileNameWithPath)
        # =============================================================================================

        # Reset learning rate
        # =============================================================================================
        oldLeraningRate = network.learningRate.get_value()
        newLearningRate = oldLeraningRate * network.learningRateDecay
        network.learningRate.set_value(newLearningRate)
        message = 'Reset Learning Rate, From {} to {}'.format(
            oldLeraningRate, newLearningRate)
        logger.info(logMessage('~', message))
        # =============================================================================================

    # #################################################################################################

    # Logger table
    # =================================================================================================
    message = 'The Training Results'
    logger.info(logMessage('=', message))
    logger.info(logTable(trainTRowList))
    logger.info(logMessage('=', '*'))
    # *************************************************************************************************
    message = 'The Validation Results'
    logger.info(logMessage('=', message))
    logger.info(logTable(valTRowList))
    logger.info(logMessage('=', '*'))
    message = 'End Training Loops'
    logger.info(logMessage('#', message))
    # =================================================================================================

    return trainTRowList, valTRowList
Example #9
0
def testNetwork(network, configFile):

    logger = logging.getLogger(__name__)

    message = 'Testing {}'.format(network.networkType)
    logger.info(logMessage('#', message))

    # Get config information
    # =================================================================================================
    configInfo = {}
    execfile(configFile, configInfo)
    # =================================================================================================

    # Network summary
    # =================================================================================================
    # Read network summary
    testSampleSize = configInfo['testSampleSize']
    networkType = network.networkType
    receptiveField = network.receptiveField
    networkSummary = network.summary(testSampleSize)
    # -------------------------------------------------------------------------------------------------
    # Logger network summary
    message = 'Network Summary'
    logger.info(logMessage('*', message))
    logger.info(networkSummary)

    tableRowList = []
    tableRowList.append(['-', '-'])
    tableRowList.append(['Network Type', networkType])
    tableRowList.append(['Receptive Field', receptiveField])
    tableRowList.append(['-', '-'])
    logger.info(logTable(tableRowList))
    logger.info(logMessage('*', '*'))
    # =================================================================================================

    # Test data summary
    # =================================================================================================
    message = 'Test Data Summary'
    logger.info(logMessage('*', message))

    testImageFolder = configInfo['testImageFolder']
    useROITest = configInfo['useROITest']
    modals = configInfo['modals']
    normType = configInfo['normType']
    useTestData = configInfo['useTestData']
    numOfPatients = len(os.listdir(testImageFolder))
    # -------------------------------------------------------------------------------------------------
    # Logger test data summary
    tableRowList = []
    tableRowList.append(['Test Image Folder', testImageFolder])
    tableRowList.append(['Number of Patients', numOfPatients])
    tableRowList.append(['Use ROI To Test Network', useROITest])
    tableRowList.append(['Modals', modals])
    tableRowList.append(['Normalization Type in Test Process', normType])
    tableRowList.append(['Using Test Data', useTestData])

    logger.info(logTable(tableRowList))
    logger.info(logMessage('*', '*'))
    # =================================================================================================

    # Test setting summary
    # =================================================================================================
    message = 'Test Setting Summary'
    logger.info(logMessage('*', message))
    testSampleSize = configInfo['testSampleSize']
    batchSize = configInfo['batchSize']
    outputFolder = configInfo['outputFolder']
    # -------------------------------------------------------------------------------------------------
    # Logger test setting summary
    tableRowList = []
    tableRowList.append(['Test Samples Size', testSampleSize])
    tableRowList.append(['Test Batch Size', batchSize])
    tableRowList.append(['Folder to Store Test Results', outputFolder])

    logger.info(logTable(tableRowList))
    logger.info(logMessage('*', '*'))
    # =================================================================================================

    # Prepare output folder
    # ==========================================================
    storeTime = time.strftime('%y-%m-%d_%H:%M:%S')
    outputDir = os.path.join(outputFolder, str(storeTime))
    os.mkdir(outputDir)
    # =================================================================================================

    # Test
    # =================================================================================================
    for patient in os.listdir(testImageFolder):

        patientDir = os.path.join(testImageFolder, patient)
        # ---------------------------------------------------------------------------------------------
        # Sample test data
        # For short statement.
        segmentResult, segmentResultMask, gTArray = segmentWholeBrain(
            network, patientDir, useROITest, modals, normType, testSampleSize,
            receptiveField, True, batchSize)

        assert gTArray == []
        # ---------------------------------------------------------------------------------------------
        # Save segment results for each patient
        np.save(segmentResultNameWithPath + 'result', segmentResult)
        np.save(segmentResultNameWithPath + 'resultMask', segmentResultMask)
        message = 'Saved results of {}'.format(patient)
        logger.info(logMessage('-', message))