def train3D(self): model = self.createModel3D([128,128,48]) print('-'*30) print('Loading training data...') print('-'*30) train_images_path, train_labels_path = self.arrangeDataPath(self.root_folder, self.image_folder,self.mask_folder) hist={};hist['acc']=[];hist['loss']=[] for epochs in range(self.numEpochs): print('epochs: ', epochs) acc=0;loss=0 for each in os.listdir(train_images_path): print('case: ', each) trainImages, trainLabels,affine = self.load3DtrainingData(train_images_path,train_labels_path, each) trainImages = interp3D(trainImages,[0.25,0.25,1],cval=-1024) trainLabels = interp3D(trainLabels,[0.25,0.25,1],cval=0) trainImages,trainLabels = arrange3Ddata(trainImages,trainLabels,48,self.dtype) [numImgs,img_rows,img_cols,img_dep,ch] = trainImages.shape print('training image shape:',trainImages.shape) trainLabels=trainLabels.reshape((numImgs,img_rows*img_cols*img_dep)) if self.wType=='slice': wImg = self.sliceBasedWeighting3D(trainLabels) else: # wImg = self.volumeBasedWeighting(trainLabels) wImg=np.ones(trainLabels.shape) trainLabels = np_utils.to_categorical(trainLabels, self.nb_classes) trainLabels = trainLabels.reshape((numImgs,img_rows*img_cols*img_dep,1,self.nb_classes)) trainLabels = trainLabels.astype(self.dtype) print('-'*30) print('Training model...') print('-'*30) history=model.fit(trainImages, trainLabels, batch_size=self.bs, epochs=1, verbose=1,sample_weight=wImg) acc = acc+history.history['acc'][0] loss=loss+history.history['loss'][0] if ((epochs>0) and ((epochs+1)%25)==0): model.save_weights(os.path.join(self.save_folder,str(epochs+1)+'_'+self.checkWeightFileName)) # model.save_weights(os.path.join(self.save_folder,str(epochs+1)+'_'+self.checkWeightFileName)) hist['acc'].append(acc/len(os.listdir(train_images_path))) hist['loss'].append(loss/len(os.listdir(train_images_path))) np.save(self.save_folder+'history.npy',hist) return
def Predict3D(self, weights): test_images_path, test_labels_path = self.arrangeDataPath( self.root_folder, self.image_folder, self.mask_folder) print('-' * 30) print('Loading saved weights...') print('-' * 30) model = self.createModel3D([128, 128, 48]) model.load_weights(os.path.join(self.save_folder, weights)) from datetime import datetime startTime = datetime.now() for each in os.listdir(test_images_path): print('case: ', each) startTime = datetime.now() testImages, otestLabels, affine = self.load3DtrainingData( test_images_path, test_labels_path, each) oNumImgs = testImages.shape[2] testImages = interp3D(testImages, [0.25, 0.25, 1], cval=-1024) testLabels = interp3D(otestLabels, [0.25, 0.25, 1], cval=0) testImages, testLabels = arrange3Ddata(testImages, testLabels, 48, self.dtype) [numImgs, img_rows, img_cols, img_dep, ch] = testImages.shape print('training image shape:', testImages.shape) testLabels = testLabels.reshape( (numImgs, img_rows * img_cols * img_dep)) testLabels = np_utils.to_categorical(testLabels, self.nb_classes) testLabels = testLabels.reshape( (numImgs, img_rows * img_cols * img_dep, 1, self.nb_classes)) testLabels = testLabels.astype(self.dtype) predImage = model.predict(testImages, batch_size=1, verbose=1) print('-' * 30) print('Predicting masks on test data...') print('-' * 30) #for comuting metrics of hyper dense class predImage = predImage.reshape( (numImgs, img_rows, img_cols, img_dep, self.nb_classes))[:, :, :, :, self.sC - 1:self.sC] print(predImage.shape) predImage = predImage[0, :, :, :, 0] predImage = interp3D(predImage, [4, 4, 1], cval=0) predImage = (predImage > 0.5).astype(self.dtype) print('test labels shape: ', predImage.shape) print(datetime.now() - startTime) saveFolder = os.path.join(self.save_folder, self.pred_folder) testLabels = testLabels.reshape( (numImgs, img_rows, img_cols, img_dep, self.nb_classes))[0, :, :, :, self.sC - 1:self.sC] if self.testLabelFlag: self.computeTestMetrics(otestLabels, predImage) if self.testMetricFlag: self.saveTestMetrics(saveFolder, otestLabels, predImage, each) if self.savePredMask: predImage = predImage.astype('uint8') predImage = nb.Nifti1Image( predImage.reshape(512, 512, img_dep), affine) nb.save(predImage, saveFolder + '/' + each) return