def compileValFunction(self): message = 'Compiling the Validation Function' self.logger.info(logMessage('+', message)) startTime = time.time() valPrediction = get_output(self.outputLayer, deterministic = True, batch_norm_update_averages=False, batch_norm_use_averages=False) # TODO. Chack wheather the flatten style of targetvar and output are same. self.flattenedTargetVar = T.flatten(self.targetVar) valLoss = categorical_crossentropy(valPrediction, self.flattenedTargetVar).mean() weightNorm = regularize_network_params(self.outputLayer, lasagne.regularization.l2) valLoss += self.weightDecay * weightNorm valPredictionLabel = T.argmax(valPrediction, axis = 1) valACC = T.mean(T.eq(valPredictionLabel, self.flattenedTargetVar), dtype = theano.config.floatX) valFunc = theano.function([self.inputVar, self.targetVar], [valLoss, valACC]) message = 'Compiled the Validation Function, spent {:.2f}s'.format(time.time()- startTime) self.logger.info(logMessage('+', message)) return valFunc
def complieTrainFunction(self): message = 'Compiling the Training Function' self.logger.info(logMessage('+', message)) startTime = time.time() trainPrediction = get_output(self.outputLayer, deterministic = False, batch_norm_update_averages=False, batch_norm_use_averages=False) # TODO. Chack wheather the flatten style of targetvar and output are same. self.flattenedTargetVar = T.flatten(self.targetVar) trainLoss = categorical_crossentropy(trainPrediction, self.flattenedTargetVar).mean() weightNorm = regularize_network_params(self.outputLayer, lasagne.regularization.l2) trainLoss += self.weightDecay * weightNorm trainPredictionLabel = T.argmax(trainPrediction, axis = 1) trainACC = T.mean(T.eq(trainPredictionLabel, self.flattenedTargetVar), dtype = theano.config.floatX) params = get_all_params(self.outputLayer, trainable = True) update = self.optimizer(trainLoss, params, learning_rate = self.learningRate) trainFunc = theano.function([self.inputVar, self.targetVar], [trainLoss, trainACC], updates = update) message = 'Compiled the Training Function, spent {:.2f}s'.format(time.time()- startTime) self.logger.info(logMessage('+', message)) return trainFunc
def saveWeights(self, fileName): message = 'Save Weights in {}'.format(fileName) self.logger.info(logMessage('+', message)) fileNameWithPath = os.path.join(self.weightsFolder, fileName) np.savez(fileNameWithPath, *get_all_param_values(self.outputLayer))
def compileTestFunction(self): message = 'Compiling the Test Function' self.logger.info(logMessage('+', message)) startTime = time.time() testPrediction = get_output(self.outputLayer, deterministic = True, batch_norm_use_averages=False) testPredictionLabel = T.argmax(testPrediction, axis = 1) testFunc = theano.function([self.inputVar], [testPredictionLabel]) message = 'Compiled the Test Function, spent {:.2f}s'.format(time.time()- startTime) self.logger.info(logMessage('+', message)) return testFunc
def generateNetwork(configFile): logger = logging.getLogger(__name__) # Get config infomation configInfo = {} execfile(configFile, configInfo) # Choose the network type # ==================================================================================== # The networksDict should conresponding to the import statement. networksDict = { 'sectorNet': SectorNet, 'baseNet': BaseNet, 'dilated3DNet': Dilated3DNet } networkType = configInfo['networkType'] networkClass = networksDict[networkType] # ==================================================================================== # Generate the network # ==================================================================================== preTrainedWeights = configInfo['preTrainedWeights'] if preTrainedWeights == '': message = 'We will create a new network' logger.info(logMessage(' ', message)) else: message = 'We will use a pre trained network' logger.info(logMessage(' ', message)) message = 'Creating {}'.format(configInfo['networkName']) logger.info(logMessage('#', message)) network = networkClass(configFile) assert network.networkType == networkType message = 'Created {}'.format(networkType) logger.info(logMessage('#', message)) # ===================================================================================== return network
def restoreWeights(self): if self.preTrainedWeights == '': return assert self.preTrainedWeights != '' message = 'Load Weights from {}'.format(self.preTrainedWeights) self.logger.info(logMessage('+', message)) with np.load(self.preTrainedWeights) as f: savedWeights = [f['arr_{:d}'.format(i)] for i in range(len(f.files))] set_all_param_values(self.outputLayer, savedWeights)
def buildBaseNet(self, inputShape=(None, 4, 25, 25, 25), forSummary=False): if not forSummary: message = 'Building the Architecture of BaseNet' self.logger.info(logMessage('+', message)) baseNet = InputLayer(self.inputShape, self.inputVar) if not forSummary: message = 'Building the convolution layers' self.logger.info(logMessage('-', message)) kernelShapeListLen = len(self.kernelNumList) summary = '\n' + '.' * 130 + '\n' summary += ' {:<15} {:<50} {:<29} {:<29}\n'.format( 'Layer', 'Input shape', 'W shape', 'Output shape') summary += '.' * 130 + '\n' summary += '{:<3} {:<15} {:<50} {:<29} {:<29}\n'.format( 1, 'Input', inputShape, '', get_output_shape(baseNet, input_shapes=inputShape)) for i in xrange(kernelShapeListLen - 1): kernelShape = self.kernelShapeList[i] kernelNum = self.kernelNumList[i] conv3D = Conv3DLayer(incoming=baseNet, num_filters=kernelNum, filter_size=kernelShape, W=HeNormal(gain='relu'), nonlinearity=linear, name='Conv3D{}'.format(i)) # Just for summary the fitler shape. WShape = conv3D.W.get_value().shape summary += '{:<3} {:<15} {:<50} {:<29} {:<29}\n'.format( i + 2, 'Conv3D', get_output_shape(baseNet, input_shapes=inputShape), WShape, get_output_shape(conv3D, input_shapes=inputShape)) batchNormLayer = BatchNormLayer(conv3D) preluLayer = prelu(batchNormLayer) concatLayerInputShape = '{:<25}{:<25}'.format( get_output_shape(conv3D, input_shapes=inputShape), get_output_shape(baseNet, input_shapes=inputShape)) baseNet = ConcatLayer( [preluLayer, baseNet], 1, cropping=['center', 'None', 'center', 'center', 'center']) summary += ' {:<15} {:<50} {:<29} {:<29}\n'.format( 'Concat', concatLayerInputShape, '', get_output_shape(baseNet, input_shapes=inputShape)) if not forSummary: message = 'Finish Built the convolution layers' self.logger.info(logMessage('-', message)) message = 'Building the last classfication layers' self.logger.info(logMessage('-', message)) assert self.kernelShapeList[-1] == [1, 1, 1] kernelShape = self.kernelShapeList[-1] kernelNum = self.kernelNumList[-1] conv3D = Conv3DLayer(incoming=baseNet, num_filters=kernelNum, filter_size=kernelShape, W=HeNormal(gain='relu'), nonlinearity=linear, name='Classfication Layer') receptiveFieldList = [ inputShape[idx] - get_output_shape(conv3D, input_shapes=inputShape)[idx] + 1 for idx in xrange(-3, 0) ] assert receptiveFieldList != [] receptiveFieldSet = set(receptiveFieldList) assert len(receptiveFieldSet) == 1, (receptiveFieldSet, inputShape, get_output_shape( conv3D, input_shapes=inputShape)) self.receptiveField = list(receptiveFieldSet)[0] # Just for summary the fitler shape. WShape = conv3D.W.get_value().shape summary += '{:<3} {:<15} {:<50} {:<29} {:<29}\n'.format( kernelShapeListLen + 1, 'Conv3D', get_output_shape(baseNet, input_shapes=inputShape), WShape, get_output_shape(conv3D, input_shapes=inputShape)) # The output shape should be (batchSize, numOfClasses, zSize, xSize, ySize). # We will reshape it to (batchSize * zSize * xSize * ySize, numOfClasses), # because, the softmax in theano can only receive matrix. baseNet = DimshuffleLayer(conv3D, (0, 2, 3, 4, 1)) summary += ' {:<15} {:<50} {:<29} {:<29}\n'.format( 'Dimshuffle', get_output_shape(conv3D, input_shapes=inputShape), '', get_output_shape(baseNet, input_shapes=inputShape)) batchSize, zSize, xSize, ySize, _ = get_output(baseNet).shape reshapeLayerInputShape = get_output_shape(baseNet, input_shapes=inputShape) baseNet = ReshapeLayer(baseNet, (batchSize * zSize * xSize * ySize, kernelNum)) summary += ' {:<15} {:<50} {:<29} {:<29}\n'.format( 'Reshape', reshapeLayerInputShape, '', get_output_shape(baseNet, input_shapes=inputShape)) nonlinearityLayerInputShape = get_output_shape(baseNet, input_shapes=inputShape) baseNet = NonlinearityLayer(baseNet, softmax) summary += ' {:<15} {:<50} {:<29} {:<29}\n'.format( 'Nonlinearity', nonlinearityLayerInputShape, '', get_output_shape(baseNet, input_shapes=inputShape)) if not forSummary: message = 'Finish Built the last classfication layers' self.logger.info(logMessage('-', message)) message = 'The Receptivr Field of BaseNet equal {}'.format( self.receptiveField) self.logger.info(logMessage('*', message)) message = 'Finish Building the Architecture of BaseNet' self.logger.info(logMessage('+', message)) summary += '.' * 130 + '\n' self._summary = summary return baseNet
def trainNetwork(network, configFile): logger = logging.getLogger(__name__) message = 'Training {}'.format(network.networkType) logger.info(logMessage('#', message)) # Get config infomation configInfo = {} execfile(configFile, configInfo) # Network information # ============================================================================== # Just for rebuild the network than we can get the network summary conresponding # the trainSampleSize # Read network information trainSampleSize = configInfo['trainSampleSize'] networkType = network.networkType receptiveField = network.receptiveField networkSummary = network.summary(trainSampleSize) # ------------------------------------------------------------------------------ # Logger network summary message = 'Network Summary' logger.info(logMessage('*', message)) logger.info(networkSummary) logger.info(logMessage('-', '-')) tableRowList = [] tableRowList.append(['-', '-']) tableRowList.append(['Network Type', networkType]) tableRowList.append(['Receptive Field', receptiveField]) tableRowList.append(['-', '-']) logger.info(logTable(tableRowList)) logger.info(logMessage('*', '*')) # ============================================================================= # Training and validation data information # ============================================================================= # Read training and validation data information imageFolder = configInfo['imageFolder'] imageGrades = configInfo['imageGrades'] numOfPatients = configInfo['numOfPatients'] modals = configInfo['modals'] useROI = configInfo['useROI'] normType = configInfo['normType'] weightMapType = configInfo['weightMapType'] # ---------------------------------------------------------------------------- # Logger training and validation data information message = 'Training and Validation Data Summary' logger.info(logMessage('*', message)) tableRowList = [] tableRowList.append(['-', '-']) tableRowList.append(['Image Folder', imageFolder]) tableRowList.append(['Image Grades', imageGrades]) tableRowList.append(['Number of Patients', numOfPatients]) tableRowList.append(['Modals', modals]) tableRowList.append(['Use ROI', useROI]) tableRowList.append(['Normalization Type', normType]) tableRowList.append(['Weight Map Type', weightMapType]) tableRowList.append(['-', '-']) logger.info(logTable(tableRowList)) logger.info(logMessage('*', '*')) # =========================================================================== # Training and validation setting infomation # =========================================================================== # Read training and validation setting information trainValRatio = configInfo['trainValRatio'] memoryThreshold = configInfo['memoryThreshold'] usePoolToSample = configInfo['usePoolToSample'] numOfEpochs = configInfo['numOfEpochs'] numOfSubEpochs = configInfo['numOfSubEpochs'] batchSize = configInfo['batchSize'] trainSampleSize = configInfo['trainSampleSize'] valSampleSize = configInfo['valSampleSize'] numOfTrainSamplesPerSubEpoch = configInfo['numOfTrainSamplesPerSubEpoch'] numOfValSamplesPerSubEpoch = int( float(numOfTrainSamplesPerSubEpoch) / trainValRatio) weightsFolder = configInfo['weightsFolder'] assert batchSize < numOfTrainSamplesPerSubEpoch assert batchSize < numOfValSamplesPerSubEpoch # --------------------------------------------------------------------------- # Logger training and validation setting infomation message = 'Training and Validation setting Summary' logger.info(logMessage('*', message)) tableRowList = [] tableRowList.append(['-', '-']) tableRowList.append(['Training / validation', trainValRatio]) tableRowList.append( ['Memory Threshold for Subepoch', '{}G'.format(memoryThreshold)]) tableRowList.append( ['Wheather Use MultiProcess to Sample', usePoolToSample]) tableRowList.append(['Number of Epochs', numOfEpochs]) tableRowList.append(['Number of Subepochs', numOfSubEpochs]) tableRowList.append(['Batch Size', batchSize]) tableRowList.append(['Training Samples Size', trainSampleSize]) tableRowList.append(['Validation Samples Size', valSampleSize]) tableRowList.append([ 'Number of Training Samples for Subepoch', numOfTrainSamplesPerSubEpoch ]) tableRowList.append([ 'Number of Validation Samples for Subepoch', numOfValSamplesPerSubEpoch ]) tableRowList.append( ['Folder to Store Weights During Training', weightsFolder]) tableRowList.append(['-', '-']) logger.info(logTable(tableRowList)) logger.info(logMessage('*', '*')) # =========================================================================== # Prepare patients file dir list # =========================================================================== # Read patients dir list patientsDirList = [] gradeDirList = [os.path.join(imageFolder, grade) for grade in imageGrades] for gradeDir in gradeDirList: patientsDirList += [ os.path.join(gradeDir, patient) for patient in os.listdir(gradeDir) ] # Make sure there are no same elements in the patientsDirList assert len(patientsDirList) == len(set(patientsDirList)) random.shuffle(patientsDirList) patientsDirList = patientsDirList[:numOfPatients] # --------------------------------------------------------------------------- # Divide patients dir in two part according to trainValRatio patsDirForValList = patientsDirList[::trainValRatio + 1] patsDirForTrainList = [ patsDir for patsDir in patientsDirList if patsDir not in patsDirForValList ] assert len(patsDirForValList) + len(patsDirForTrainList) == len( patientsDirList) message = 'We get {} patients files. '.format(len(patientsDirList)) message += 'Randomly choose {} for training '.format( len(patsDirForTrainList)) message += 'and {} for validation'.format(len(patsDirForValList)) logger.info(logMessage(' ', message)) # --------------------------------------------------------------------------- # Prepare training patients dir list for each subepoch, because the # training data may be so large and need to be split for each subepoch numOfModals = len(modals) memoryNeededPerPatData = 0.035 * numOfModals + int(useROI) * 0.020 # memortThreshold should large than a single patient need memory assert memoryThreshold > memoryNeededPerPatData maxPatNumPerSubEpoch = math.floor(memoryThreshold / memoryNeededPerPatData) maxPatNumPerSubEpoch = int(maxPatNumPerSubEpoch) patDirPerSubEpochDict = {} # If len(patsDirForTrainList) < maxPatNumPerSubEpoch, # each will use same patients dir for subEpIdx in xrange(numOfSubEpochs): chosenPatsDir = patsDirForTrainList[:maxPatNumPerSubEpoch] patDirPerSubEpochDict[subEpIdx] = chosenPatsDir random.shuffle(patsDirForTrainList) # =========================================================================== # Prepare a table to record and show training and validation results # =========================================================================== trainTRowList = [] trainTRowList.append(['-', '-', '-', '-', '-', '-']) trainTRowList.append([ 'EPOCH', 'SUBEPOCH', 'Train Time', 'Train Loss', 'Train ACC', 'Sampling Time' ]) trainTRowList.append(['-', '-', '-', '-', '-', '-']) # *************************************************************************** valTRowList = [] valTRowList.append( ['-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-']) valTRowList.append([ 'EPOCH', 'SUBEPOCH', 'Val Time', 'CT Dice', 'CT Sens', 'CT Spec', 'Core Dice', 'Core Sens', 'Core Spec', 'Eh Dice', 'Eh Sens', 'Eh Spec', ]) valTRowList.append( ['-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-']) # =========================================================================== # Prepare folder to store network weights during training # =========================================================================== storeTime = time.strftime('%y-%m-%d_%H:%M:%S') weightsDir = os.path.join(weightsFolder, str(storeTime)) os.mkdir(weightsDir) # =========================================================================== # Train and Val # ################################################################################################## for epIdx in xrange(numOfEpochs): message = 'EPOCH: {}/{}'.format(epIdx + 1, numOfEpochs) logger.info(logMessage('+', message)) # Initial some epoch recordor # ============================================================================================== trainEpLoss = 0 trainEpACC = 0 trainEpBatchNum = 0 # ********************************************************************************************** valEpCTDice = 0 valEpCTSens = 0 valEpCTSpec = 0 valEpCoreDice = 0 valEpCoreSens = 0 valEpCoreSpec = 0 valEpEnDice = 0 valEpEnSens = 0 valEpEnSpec = 0 valEpPatsNum = 0 # ********************************************************************************************** epTrainSampleTime = 0 epTrainTime = 0 epValTime = 0 # ============================================================================================== for subEpIdx in xrange(numOfSubEpochs): message = 'SUBEPOCH: {}/{}'.format(subEpIdx + 1, numOfSubEpochs) logger.info(logMessage('-', message)) # Training # ========================================================================================== # Train sample training data trainSampleTime = time.time() message = 'Sampling Training Data' logger.info(logMessage('-', message)) trainSampleAndLabelList = getSamplesForSubEpoch( numOfTrainSamplesPerSubEpoch, patDirPerSubEpochDict[subEpIdx], useROI, modals, normType, trainSampleSize, receptiveField, weightMapType, usePoolToSample) trainSamplesList, trainLabelsList = trainSampleAndLabelList trainSampleTime = time.time() - trainSampleTime epTrainSampleTime += trainSampleTime # ----------------------------------------------------------------------------------------- # Prepare for train batch loop trainBatchIdxList = [ trainBatchIdx for trainBatchIdx in xrange( 0, numOfTrainSamplesPerSubEpoch, batchSize) ] # For the last batch not to be too small. trainBatchIdxList[-1] = numOfTrainSamplesPerSubEpoch assert len(trainBatchIdxList) > 1 trainBatchNum = len(trainBatchIdxList[:-1]) # ----------------------------------------------------------------------------------------- # Training batch loop trainSubEpLoss = 0 trainSubEpACC = 0 trainSubEpBatchNum = 0 trainTime = time.time() message = 'Training' logger.info(logMessage(':', message)) # ........................................................................................ for trainBatchIdx in xrange(trainBatchNum): # Just for clear. trainStartIdx = trainBatchIdxList[trainBatchIdx] trainEndIdx = trainBatchIdxList[trainBatchIdx + 1] trainSamplesBatch = trainSamplesList[trainStartIdx:trainEndIdx] trainSamplesBatch = np.asarray(trainSamplesBatch, dtype=theano.config.floatX) trainLabelsBatch = trainLabelsList[trainStartIdx:trainEndIdx] trainLabelsBatch = np.asarray(trainLabelsBatch, dtype='int32') trainBatchLoss, trainBatchAcc = network.trainFunction( trainSamplesBatch, trainLabelsBatch) # Record subepoch training results. trainSubEpLoss += trainBatchLoss trainSubEpACC += trainBatchAcc trainSubEpBatchNum += 1 trainTime = time.time() - trainTime epTrainTime += trainTime # ........................................................................................ # Release source del trainSamplesList[:], trainLabelsList[:] del trainSamplesList, trainLabelsList del trainSamplesBatch, trainLabelsBatch gc.collect() # ======================================================================================== # Validation # ======================================================================================== valSubEpCTDice = 0 valSubEpCTSens = 0 valSubEpCTSpec = 0 valSubEpCoreDice = 0 valSubEpCoreSens = 0 valSubEpCoreSpec = 0 valSubEpEhDice = 0 valSubEpEhSens = 0 valSubEpEhSpec = 0 valSubEpPatsNum = 0 valTime = time.time() message = 'Validation' logger.info(logMessage(':', message)) for patIdx, patientDir in enumerate(patsDirForValList): logger.info('Val {}/{} patient'.format(patIdx + 1, len(patsDirForValList))) segmentResult, segmentResultMask, gTArray = segmentWholeBrain( network, patientDir, useROI, modals, normType, valSampleSize, receptiveField, False, batchSize) assert gTArray != [] cTDice, cTSens, cTSpeci = voxleWiseMetrics( segmentResult, gTArray, [1, 2, 3, 4]) coreDice, cTSens, cTSpec = voxleWiseMetrics( segmentResult, gTArray, [1, 3, 4]) ehDice, ehSens, ehSpec = voxleWiseMetrics( segmentResult, gTArray, [4]) valSubEpCTDice += cTDice valSubEpCTSens += cTSens valSubEpCTSpec += cTSpeci valSubEpCoreDice += coreDice valSubEpCoreSens += cTSens valSubEpCoreSpec += cTSpec valSubEpEhDice += ehDice valSubEpEhSens += ehSens valSubEpEhSpec += ehSpec del segmentResult, segmentResultMask, gTArray gc.collect() valSubEpPatsNum = len(patientDir) valTime = time.time() - valTime epValTime += valTime # ========================================================================================= # Record epooch results and compute subepoch results # ========================================================================================= # Record epooch training results and compute training subepoch results # ----------------------------------------------------------------------------------------- # Record training epoch results trainEpLoss += trainSubEpLoss trainEpACC += trainSubEpACC trainEpBatchNum += trainSubEpBatchNum # ----------------------------------------------------------------------------------------- # Compute training subepoch results trainSubEpLoss /= trainSubEpBatchNum trainSubEpACC /= trainSubEpBatchNum # ***************************************************************************************** # Record epooch validation results and computer validation subepoch results # ----------------------------------------------------------------------------------------- # Record validation epoch results valEpCTDice += valSubEpCTDice valEpCTSens += valSubEpCTSens valEpCTSpec += valSubEpCTSpec valEpCoreDice += valSubEpCoreDice valEpCoreSens += valSubEpCoreSens valEpCoreSpec += valSubEpCoreSpec valEpEnDice += valSubEpEhDice valEpEnSens += valSubEpEhSens valEpEnSpec += valSubEpEhSpec valEpPatsNum += valSubEpPatsNum # ----------------------------------------------------------------------------------------- # Compute validation subepoch results valSubEpCTDice /= valSubEpPatsNum valSubEpCTSens /= valSubEpPatsNum valSubEpCTSpec /= valSubEpPatsNum valSubEpCoreDice /= valSubEpPatsNum valSubEpCoreSens /= valSubEpPatsNum valSubEpCoreSpec /= valSubEpPatsNum valSubEpEhDice /= valSubEpPatsNum valSubEpEhSens /= valSubEpPatsNum valSubEpEhSpec /= valSubEpPatsNum # ========================================================================================= # Recording for subEpoch row of table # ========================================================================================= indexColumn = epIdx + 1 if subEpIdx == 0 else '' # Recording for subEpoch row of training table # ----------------------------------------------------------------------------------------- trainSampleTime = '{:.3}'.format(trainSampleTime) trainTime = '{:.3}'.format(trainTime) trainSubEpLoss = '{:.6f}'.format(trainSubEpLoss) trainSubEpACC = '{:.6f}'.format(trainSubEpACC) trainTRowList.append([ indexColumn, subEpIdx + 1, trainTime, trainSubEpLoss, trainSubEpACC, trainSampleTime ]) # ***************************************************************************************** # Recording for subepoch row of validation table valTime = '{:.3}'.format(valTime) valSubEpCTDice = '{:.4f}'.format(valSubEpCTDice) valSubEpCTSens = '{:.4f}'.format(valSubEpCTSens) valSubEpCTSpec = '{:.4f}'.format(valSubEpCTSpec) valSubEpCoreDice = '{:.4f}'.format(valSubEpCoreDice) valSubEpCoreSens = '{:.4f}'.format(valSubEpCoreSens) valSubEpCoreSpec = '{:.4f}'.format(valSubEpCoreSpec) valSubEpEhDice = '{:.4f}'.format(valSubEpEhDice) valSubEpEhSens = '{:.4f}'.format(valSubEpEhSens) valSubEpEhSpec = '{:.4f}'.format(valSubEpEhSpec) valTRowList.append([ indexColumn, subEpIdx + 1, valTime, valSubEpCTDice, valSubEpCTSens, valSubEpCTSpec, valSubEpCoreDice, valSubEpCoreSens, valSubEpCoreSpec, valSubEpEhDice, valSubEpEhSens, valSubEpEhSpec ]) # ========================================================================================= # Subepoch logger # ========================================================================================= message = 'Subepoch: {}/{} '.format(subEpIdx + 1, numOfSubEpochs) message += ' Train Loss: {}, '.format(trainSubEpLoss) message += ' Train ACC: {}'.format(trainSubEpACC) logger.info(logMessage('-', message)) message = 'Subepoch: {}/{} '.format(subEpIdx + 1, numOfSubEpochs) message += ' Val Core Dice: {}, '.format(valSubEpCoreDice) message += ' Val Core Sens: {}'.format(valSubEpCoreSens) message += ' Val Core Spec: {}'.format(valSubEpCoreSpec) logger.info(logMessage('-', message)) # ========================================================================================= # Compute epoch results # ============================================================================================= # Compute epoch training results trainEpLoss /= trainEpBatchNum trainEpACC /= trainEpBatchNum # ********************************************************************************************* # Compute epoch validation results valEpCTDice /= valEpPatsNum valEpCTSens /= valEpPatsNum valEpCTSpec /= valEpPatsNum valEpCoreDice /= valEpPatsNum valEpCoreSens /= valEpPatsNum valEpCoreSpec /= valEpPatsNum valEpEnDice /= valEpPatsNum valEpEnSens /= valEpPatsNum valEpEnSpec /= valEpPatsNum # ============================================================================================= # Recording for subEpoch row of table # ============================================================================================= epTrainSampleTime = '{:.3}'.format(epTrainSampleTime) epTrainTime = '{:.3}'.format(epTrainTime) epValTime = '{:.3}'.format(epValTime) # --------------------------------------------------------------------------------------------- trainEpLoss = '{:.6f}'.format(trainEpLoss) trainEpACC = '{:.6f}'.format(trainEpACC) # --------------------------------------------------------------------------------------------- trainTRowList.append(['-', '-', '-', '-', '-', '-']) trainTRowList.append( ['', '', epTrainSampleTime, trainEpLoss, trainEpACC, epTrainTime]) trainTRowList.append(['-', '-', '-', '-', '-', '-']) # ********************************************************************************************* valEpCTDice = '{:.4f}'.format(valEpCTDice) valEpCTSens = '{:.4f}'.format(valEpCTSens) valEpCTSpec = '{:.4f}'.format(valEpCTSpec) valEpCoreDice = '{:.4f}'.format(valEpCoreDice) valEpCoreSens = '{:.4f}'.format(valEpCoreSens) valEpCoreSpec = '{:.4f}'.format(valEpCoreSpec) valEpEnDice = '{:.4f}'.format(valEpEnDice) valEpEnSens = '{:.4f}'.format(valEpEnSens) valEpEnSpec = '{:.4f}'.format(valEpEnSpec) # --------------------------------------------------------------------------------------------- valTRowList.append( ['-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-']) valTRowList.append([ '', '', epValTime, valEpCTDice, valEpCTSens, valEpCTSpec, valEpCoreDice, valEpCoreSens, valEpCoreSpec, valEpEnDice, valEpEnSens, valEpEnSpec ]) valTRowList.append( ['-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-']) # ============================================================================================= # Epoch logger # ============================================================================================= message = 'Epoch: {}/{} '.format(epIdx + 1, numOfEpochs) message += ' Train Loss: {}, '.format(trainEpLoss) message += ' Train ACC: {}'.format(trainEpACC) logger.info(logMessage('+', message)) message = 'Epoch: {}/{} '.format(epIdx + 1, numOfEpochs) message += ' Val Core Dice: {}, '.format(valEpCoreDice) message += ' Val Core Sens: {}'.format(valEpEnSens) message += ' Val Core Spec: {}'.format(valEpEnSpec) logger.info(logMessage('+', message)) # ============================================================================================= # Store network weights # ============================================================================================= weightsFileName = '{}_{}'.format(epIdx, subEpIdx) weightsFileNameWithPath = os.path.join(weightsDir, weightsFileName) network.saveWeights(weightsFileNameWithPath) # ============================================================================================= # Reset learning rate # ============================================================================================= oldLeraningRate = network.learningRate.get_value() newLearningRate = oldLeraningRate * network.learningRateDecay network.learningRate.set_value(newLearningRate) message = 'Reset Learning Rate, From {} to {}'.format( oldLeraningRate, newLearningRate) logger.info(logMessage('~', message)) # ============================================================================================= # ################################################################################################# # Logger table # ================================================================================================= message = 'The Training Results' logger.info(logMessage('=', message)) logger.info(logTable(trainTRowList)) logger.info(logMessage('=', '*')) # ************************************************************************************************* message = 'The Validation Results' logger.info(logMessage('=', message)) logger.info(logTable(valTRowList)) logger.info(logMessage('=', '*')) message = 'End Training Loops' logger.info(logMessage('#', message)) # ================================================================================================= return trainTRowList, valTRowList
def testNetwork(network, configFile): logger = logging.getLogger(__name__) message = 'Testing {}'.format(network.networkType) logger.info(logMessage('#', message)) # Get config information # ================================================================================================= configInfo = {} execfile(configFile, configInfo) # ================================================================================================= # Network summary # ================================================================================================= # Read network summary testSampleSize = configInfo['testSampleSize'] networkType = network.networkType receptiveField = network.receptiveField networkSummary = network.summary(testSampleSize) # ------------------------------------------------------------------------------------------------- # Logger network summary message = 'Network Summary' logger.info(logMessage('*', message)) logger.info(networkSummary) tableRowList = [] tableRowList.append(['-', '-']) tableRowList.append(['Network Type', networkType]) tableRowList.append(['Receptive Field', receptiveField]) tableRowList.append(['-', '-']) logger.info(logTable(tableRowList)) logger.info(logMessage('*', '*')) # ================================================================================================= # Test data summary # ================================================================================================= message = 'Test Data Summary' logger.info(logMessage('*', message)) testImageFolder = configInfo['testImageFolder'] useROITest = configInfo['useROITest'] modals = configInfo['modals'] normType = configInfo['normType'] useTestData = configInfo['useTestData'] numOfPatients = len(os.listdir(testImageFolder)) # ------------------------------------------------------------------------------------------------- # Logger test data summary tableRowList = [] tableRowList.append(['Test Image Folder', testImageFolder]) tableRowList.append(['Number of Patients', numOfPatients]) tableRowList.append(['Use ROI To Test Network', useROITest]) tableRowList.append(['Modals', modals]) tableRowList.append(['Normalization Type in Test Process', normType]) tableRowList.append(['Using Test Data', useTestData]) logger.info(logTable(tableRowList)) logger.info(logMessage('*', '*')) # ================================================================================================= # Test setting summary # ================================================================================================= message = 'Test Setting Summary' logger.info(logMessage('*', message)) testSampleSize = configInfo['testSampleSize'] batchSize = configInfo['batchSize'] outputFolder = configInfo['outputFolder'] # ------------------------------------------------------------------------------------------------- # Logger test setting summary tableRowList = [] tableRowList.append(['Test Samples Size', testSampleSize]) tableRowList.append(['Test Batch Size', batchSize]) tableRowList.append(['Folder to Store Test Results', outputFolder]) logger.info(logTable(tableRowList)) logger.info(logMessage('*', '*')) # ================================================================================================= # Prepare output folder # ========================================================== storeTime = time.strftime('%y-%m-%d_%H:%M:%S') outputDir = os.path.join(outputFolder, str(storeTime)) os.mkdir(outputDir) # ================================================================================================= # Test # ================================================================================================= for patient in os.listdir(testImageFolder): patientDir = os.path.join(testImageFolder, patient) # --------------------------------------------------------------------------------------------- # Sample test data # For short statement. segmentResult, segmentResultMask, gTArray = segmentWholeBrain( network, patientDir, useROITest, modals, normType, testSampleSize, receptiveField, True, batchSize) assert gTArray == [] # --------------------------------------------------------------------------------------------- # Save segment results for each patient np.save(segmentResultNameWithPath + 'result', segmentResult) np.save(segmentResultNameWithPath + 'resultMask', segmentResultMask) message = 'Saved results of {}'.format(patient) logger.info(logMessage('-', message))