def Predict3D(self,weights): test_images_path, test_labels_path = self.arrangeDataPath(self.root_folder, self.image_folder,self.mask_folder) print('-'*30) print('Loading saved weights...') print('-'*30) model = self.createModel3D([128,128,48]) model.load_weights(os.path.join(self.save_folder,weights)) from datetime import datetime startTime = datetime.now() for each in os.listdir(test_images_path): print('case: ', each); if each.startswith('.'): continue startTime = datetime.now() testImages, otestLabels,affine = self.load3DtestData(test_images_path,test_labels_path, each) oNumImgs=testImages.shape[2] testImages = interp3D(testImages,[0.25,0.25,1],cval=-1024) testImages = arrange3DtestImage(testImages,48,self.dtype) [numImgs,img_rows,img_cols,img_dep,ch] = testImages.shape print('training image shape:',testImages.shape) if self.testLabelFlag: testLabels = interp3D(otestLabels,[0.25,0.25,1],cval=0) testLabels=testLabels.reshape((numImgs,img_rows*img_cols*img_dep)) testLabels = np_utils.to_categorical(testLabels, self.nb_classes) testLabels = testLabels.reshape((numImgs,img_rows*img_cols*img_dep,1,self.nb_classes)) testLabels = testLabels.astype(self.dtype) testLabels = arrange3DtestLabel(testLabels,48,self.dtype) testLabels=testLabels.reshape((numImgs,img_rows,img_cols,img_dep,self.nb_classes))[0,:,:,:,self.sC-1:self.sC] predImage = model.predict(testImages, batch_size=1, verbose=1) print('-'*30) print('Predicting masks on test data...') print('-'*30) #for comuting metrics of hyper dense class predImage = predImage.reshape((numImgs,img_rows,img_cols,img_dep,self.nb_classes))[:,:,:,:,self.sC-1:self.sC] print(predImage.shape) predImage = predImage[0,:,:,:,0] predImage = interp3D(predImage,[4,4,1],cval=0) predImage = (predImage>0.5).astype(self.dtype) print('test labels shape: ', predImage.shape) print(datetime.now() - startTime) saveFolder = os.path.join(self.save_folder) if self.testLabelFlag: self.computeTestMetrics(otestLabels,predImage) if self.testMetricFlag: self.saveTestMetrics(saveFolder,otestLabels,predImage,each) if self.savePredMask: predImage = predImage.astype('uint8') predImage = nb.Nifti1Image(predImage.reshape(512,512,img_dep), affine) nb.save(predImage, saveFolder+'/'+each) return
def train3D(self): model = self.createModel3D([128,128,48]) print('-'*30) print('Loading training data...') print('-'*30) train_images_path, train_labels_path = self.arrangeDataPath(self.root_folder, self.image_folder,self.mask_folder) hist={};hist['acc']=[];hist['loss']=[] for epochs in range(self.numEpochs): print('epochs: ', epochs) acc=0;loss=0 for each in os.listdir(train_images_path): print('case: ', each) trainImages, trainLabels,affine = self.load3DtrainingData(train_images_path,train_labels_path, each) trainImages = interp3D(trainImages,[0.25,0.25,1],cval=-1024) trainLabels = interp3D(trainLabels,[0.25,0.25,1],cval=0) trainImages,trainLabels = arrange3Ddata(trainImages,trainLabels,48,self.dtype) [numImgs,img_rows,img_cols,img_dep,ch] = trainImages.shape print('training image shape:',trainImages.shape) trainLabels=trainLabels.reshape((numImgs,img_rows*img_cols*img_dep)) if self.wType=='slice': wImg = self.sliceBasedWeighting3D(trainLabels) else: # wImg = self.volumeBasedWeighting(trainLabels) wImg=np.ones(trainLabels.shape) trainLabels = np_utils.to_categorical(trainLabels, self.nb_classes) trainLabels = trainLabels.reshape((numImgs,img_rows*img_cols*img_dep,1,self.nb_classes)) trainLabels = trainLabels.astype(self.dtype) print('-'*30) print('Training model...') print('-'*30) history=model.fit(trainImages, trainLabels, batch_size=self.bs, epochs=1, verbose=1,sample_weight=wImg) acc = acc+history.history['acc'][0] loss=loss+history.history['loss'][0] if ((epochs>0) and ((epochs+1)%25)==0): model.save_weights(os.path.join(self.save_folder,str(epochs+1)+'_'+self.checkWeightFileName)) # model.save_weights(os.path.join(self.save_folder,str(epochs+1)+'_'+self.checkWeightFileName)) hist['acc'].append(acc/len(os.listdir(train_images_path))) hist['loss'].append(loss/len(os.listdir(train_images_path))) np.save(self.save_folder+'history.npy',hist) return