Пример #1
0
    def AugmentDS3D(self,Augentries):
        args=self.arggen(Augentries)
#        trimgs,labimgs,data_path,num
        config['data_path'],config['fext'],num,cr=args[0],args[1],int(args[2]),float(args[3])
        fext=config['fext']
        data_path=config['data_path']    
        fold3D='3D'
        self.Savefiles(data_path,fext,'trlab')
        imhdf5=open_hdf5_file(config['image_hdf5_path'])
        trimgs=np.squeeze(imhdf5.root.data,axis=3)
        imhdf5.close()
        labhdf5=open_hdf5_file(config['label_hdf5_path'])
        labimgs=np.squeeze(labhdf5.root.truth,axis=3)
        labhdf5.close()
        nLabels=np.max(labimgs)
        print('estimated number of lables:',nLabels)
        print('stack shape',trimgs.shape)
        (nd,nr,nc)=trimgs.shape
        print(trimgs.shape)
        patch_size=64
        origin_row = np.random.randint(0, nr-patch_size, num)
        origin_col = np.random.randint(0, nc-patch_size, num)
        origin_dep = np.random.randint(0, nd-patch_size, num)
        origins=np.array((origin_dep,origin_row,origin_col))
        trimgs_patches=extract_3D_patches(trimgs,patch_size,origins)  
        labs_patches=extract_3D_patches(labimgs,patch_size,origins)
        if nLabels>1:
            for i in range(trimgs_patches.shape[0]):
                    trimgs_patches[i],labs_patches[i]= transform_3Dpatch(trimgs_patches[i],labs_patches[i])
        if nLabels==1:
            for i in range(trimgs_patches.shape[0]):
                    trimgs_patches[i],labs_patches[i]= transform_3Dpatch(trimgs_patches[i],labs_patches[i],scale_deviation=0.15)
        print('patches shape',trimgs_patches.shape)
        if fold3D in sorted(os.listdir(os.path.join(data_path,image_p))):
           shutil.rmtree(os.path.join(data_path,image_p,fold3D), ignore_errors=True)
           shutil.rmtree(os.path.join(data_path,label_p,fold3D), ignore_errors=True)        
        os.makedirs(os.path.join(data_path,image_p,fold3D))
        os.makedirs(os.path.join(data_path,label_p,fold3D))
        for i in range(trimgs_patches.shape[0]):
            os.makedirs(os.path.join(data_path,image_p,fold3D,str(i)))
            os.makedirs(os.path.join(data_path,label_p,fold3D,str(i)))
            for k in range(trimgs_patches.shape[1]):
                self.save_tif(data_path,os.path.join(image_p,fold3D,str(i)),'img',trimgs_patches[i,k],k,fext)
                self.save_tif(data_path,os.path.join(label_p,fold3D,str(i)),'img',labs_patches[i,k],k,fext)
        imglabpatches=np.array((trimgs_patches,labs_patches))
        self.Savefiles3D(data_path,imglabpatches,fext)
        
        return
Пример #2
0
    def Predict3D(self,TMentries):
        args=self.arggen(TMentries)
        data_path,mName,epchs,vspl,nCl,lr,bs=args[0],args[2],int(args[3]),float(args[4]),int(args[5]),float(args[6]),int(args[7])   
        if nCl>1:
            nLab= nCl +1
        else: 
            nLab=nCl
        print('-'*30)
        print('Loading and preprocessing test data...')
        print('-'*30)
        
        test_open=open_hdf5_file(os.path.join(data_path,config['test_npy']))
        imgs_test_tot = np.squeeze(test_open.root.test,axis=3)
        test_open.close()
        img_depth, img_rows, img_cols = imgs_test_tot.shape
        print('-'*30)
        print('Loading saved weights...')
        print('-'*30)
        print(max(glob.glob(os.path.join(data_path,model_p,'*.hdf5')),key=os.path.getctime))
        latest=max(glob.glob(os.path.join(data_path,model_p,'*.hdf5')),key=os.path.getctime)
        pred_dir = os.path.join(data_path,'preds_'+mName+'_'+latest.split('weights',1)[1])

        CPU=False
        GPU=True
        num_cores =  os.cpu_count()-2
        if GPU:
            num_GPU = 1
            num_CPU = 1
        if CPU:
            num_CPU = 1
            num_GPU = 0
        
        configK = tf.ConfigProto(intra_op_parallelism_threads=num_cores,\
                inter_op_parallelism_threads=num_cores, allow_soft_placement=True,\
                device_count = {'CPU' : num_CPU, 'GPU' : num_GPU})
        session = tf.Session(config=configK)
        K.set_session(session)
        if not GPU:
            vol_test = preprocess3D(imgs_test_tot.reshape(1,img_depth, img_rows, img_cols),img_depth, img_rows, img_cols)
            vol_test = vol_test.astype('float32')
            mean = np.mean(vol_test)  # mean for data centering
            std = np.std(vol_test)  # std for data normalization
            vol_test -= mean
            vol_test /= std
            voltmp=vol_test.reshape(img_depth, img_rows, img_cols,1)
            model= getattr(Nmodels3D,mName)(nLab,voltmp.shape,latest)
            print('-'*30)
            print('Predicting labels on test data...')
            print('-'*30)
            imgs_mask_test=model.predict(vol_test, verbose=1)
            if not os.path.exists(pred_dir):
                os.mkdir(pred_dir)
            if nLab>1:
                imgs_mask_test = imgs_mask_test.reshape(( img_depth, img_rows, img_cols, nLab ) )
                vol_mask= imgs_mask_test.argmax( axis=3 ).astype(np.uint8)
            else:
                vol_mask = np.squeeze((imgs_mask_test * 255.).astype(np.uint8),axis=4)
        else:
            p_size=64
            imgs_test_totshft=np.roll(np.roll(np.roll(imgs_test_tot,p_size//2,axis=2),p_size//2,axis=1),p_size//2,axis=0)
            volmask1=self.pred_vol(imgs_test_tot,nLab,mName,p_size,latest)
            volmask2=self.pred_vol(imgs_test_totshft,nLab,mName,p_size,latest)
            volmask2shft=np.roll(np.roll(np.roll(volmask2,-p_size//2,axis=2),-p_size//2,axis=1),-p_size//2,axis=0)
            vol_mask=volmask1+volmask2shft
            vol_mask[vol_mask>0.1]=255
#            vol_mask=volmask2shft
            
        if not os.path.exists(pred_dir):
            os.mkdir(pred_dir)
        for i in range(vol_mask.shape[0]):
            imsave(os.path.join(pred_dir, str(i) + '_pred.tif'), vol_mask[i].astype(np.uint8))
            i+=1
Пример #3
0
    def Predict(self,TMentries):
        args=self.arggen(TMentries)
        data_path,mName,epchs,vspl,nCl,lr,bs=args[0],args[2],int(args[3]),float(args[4]),int(args[5]),float(args[6]),int(args[7])
        if nCl>1:
            nLab= nCl +1
        else: 
            nLab=nCl
        print('-'*30)
        print('Loading and preprocessing test data...')
        print('-'*30)
        
        test_open=open_hdf5_file(os.path.join(data_path,config['test_npy']))
        imgs_test_tot = np.squeeze(test_open.root.test,axis=3)
        test_open.close()
        imgs_id_test_tot=np.load(os.path.join(data_path,test_id_npy))
        img_rows, img_cols = imgs_test_tot[0].shape
        print('-'*30)
        print('Loading saved weights...')
        print('-'*30)
        print(max(glob.glob(os.path.join(data_path,model_p,'*.hdf5')),key=os.path.getctime))
        latest=max(glob.glob(os.path.join(data_path,model_p,'*.hdf5')),key=os.path.getctime)
        imgtmp=imgs_test_tot[0].reshape(img_rows, img_cols,1)
        model= getattr(Nmodels,mName)(nLab,imgtmp.shape,latest,lr)
        i=0
        for i in range(len(imgs_test_tot)):
            print('image',i+1,'of',len(imgs_test_tot))
            imgs_test = preprocess(imgs_test_tot[i].reshape(1,img_rows, img_cols),img_rows, img_cols)
            imgs_id_test=imgs_id_test_tot[i]
            print('processing image',imgs_id_test)
        
            imgs_test = imgs_test.astype('float32')
            mean = np.mean(imgs_test)  # mean for data centering
            std = np.std(imgs_test)  # std for data normalization
            imgs_test -= mean
            imgs_test /= std
        
        #    model.load_weights(os.path.join(data_path,model_p,'weights.h5'))
        #    print(len(imgs_test))
        #    op_shape=model.output_shape
        #    t=list(op_shape)
        #    t[0]=3
        #    op_shape=tuple(t)
        
            print('-'*30)
            print('Predicting labels on test data...')
            print('-'*30)
            imgs_mask_test=model.predict(imgs_test, verbose=1)
        
            print('-' * 30)
            print('Saving predicted labels to files...')
            print('-' * 30)
            pred_dir = os.path.join(data_path,'preds_'+mName+'_'+latest.split('weights',1)[1])
            if not os.path.exists(pred_dir):
                os.mkdir(pred_dir)

            if nLab>1:
                imgs_mask_test = imgs_mask_test.reshape(( img_rows, img_cols , nLab ) )
                imgs_mask_test= imgs_mask_test.argmax( axis=2 ).astype(np.uint8)
#                    imagep = (image * 255.).astype(np.uint8)
            else:
                imgs_mask_test = (imgs_mask_test * 255.).astype(np.uint8)
            imsave(os.path.join(pred_dir, str(imgs_id_test) + '_pred.tif'), imgs_mask_test)
Пример #4
0
    def TrainModel3D(self,TMentries):
        args=self.arggen(TMentries)
        data_path,mName,epchs,vspl,nCl,lr,bs=args[0],args[2],int(args[3]),float(args[4]),int(args[5]),float(args[6]),int(args[7])
        if nCl>1:
            nLab= nCl +1
        else: 
            nLab=nCl
        config['data_path']=data_path
        config['image_path']=os.path.join(data_path,image_p)
        config['label_path']=os.path.join(data_path,label_p)
        config['image_hdf5_path']=os.path.join(config['data_path'],config['trimg_npy'])
        config['label_hdf5_path']=os.path.join(config['data_path'],config['trlab_npy']) 
        print('-'*30)
        print('Loading and preprocessing train data...')
        print('-'*30)
        imgsopen=open_hdf5_file(config['image_hdf5_path'])
        imgs_train = np.squeeze(imgsopen.root.data,axis=4)
        imgsopen.close()
        labopen=open_hdf5_file(config['label_hdf5_path'])
        imgs_mask_train = np.squeeze(labopen.root.truth,axis=4)
        labopen.close()
        (img_depth, img_rows, img_cols) = (imgs_train[0].shape)
        
        imgs_train2 = []
        for patch in imgs_train: 
            patchtmp=preprocess(patch, img_rows, img_cols)
            imgs_train2.append(patchtmp)
        imgs_mask_train2 = []
        for patch in imgs_mask_train:
            patchtmp=preprocess(patch, img_rows, img_cols)
            imgs_mask_train2.append(patchtmp)
        imgs_train = np.array(imgs_train2).astype('float32')
        mean = np.mean(imgs_train)  # mean for data centering
        std = np.std(imgs_train)  # std for data normalization
    
        imgs_train -= mean
        imgs_train /= std
    
#        imgs_mask_train = imgs_mask_train.astype('float32')
#        imgs_mask_train /= 255.  # scale masks to [0, 1]
#        imgs_mask_train *= (nCl)  # generates the labes as integers
        imgs_mask_train = np.array(imgs_mask_train2).astype('uint8')
        print('Size of the training data:',imgs_mask_train.shape)
        if np.max(imgs_mask_train) != nCl:
            print('Warning: the number of classes does not match the intesities of the label images')
        if nLab>1:
            imgs_mask_train = getSegmentationArr(imgs_mask_train , nLab)
#            global imgs_mask_train2
#            imgs_mask_train2 = np.copy(imgs_mask_train)
        else:
            imgs_mask_train[imgs_mask_train > 0.5] = 1
            imgs_mask_train[imgs_mask_train <= 0.5] = 0
        print(imgs_mask_train.shape)
        print('-'*30)
        print('Creating and compiling model...')
        print('-'*30)
        
        
    #    model = get_unet(imgs_train[0].shape)
    #    model = Model4(imgs_train[0].shape)
        if os.path.exists(os.path.join(data_path,model_p)) and len([x for x in os.listdir(os.path.join(data_path,model_p)) if ('.hdf5') in x])>0:
            print('loading weights and compiling the model')
            latest=max(glob.glob(os.path.join(data_path,model_p,'*.hdf5')),key=os.path.getctime)
            model= getattr(Nmodels3D,mName)(nLab,imgs_train[0].shape,latest,lr)
        else:
            if not os.path.exists(os.path.join(data_path,model_p)):
                os.makedirs(os.path.join(data_path,model_p))
            model= getattr(Nmodels3D,mName)(nLab,imgs_train[0].shape,'',lr)
        
    #    model_checkpoint = ModelCheckpoint('weights.h5', monitor='val_loss', save_best_only=True) 
        model_checkpoint = ModelCheckpoint(os.path.join(data_path,model_p,mName+'weights.ep{epoch:02d}-il{loss:.3f}-vl{val_loss:.3f}.hdf5'), monitor='loss',verbose=1, save_best_only=True)
    
        print('-'*30)
        print('Fitting model...')
        print('-'*30)
    #    model.fit(imgs_train, imgs_mask_train, batch_size=34, nb_epoch=20, verbose=1, shuffle=True,
    #              validation_split=0.2,
    #              callbacks=[model_checkpoint])         
        model.fit(imgs_train, imgs_mask_train, batch_size=bs, epochs=epchs, verbose=1,
                  validation_split=vspl, 
                  shuffle=True, 
                  callbacks=[model_checkpoint])
        return
Пример #5
0
    def AugmentDS(self,Augentries):
        datagen = ImageDataGenerator(
#            featurewise_center=False,
#            featurewise_std_normalization=False,
#            samplewise_center=False,
#            samplewise_std_normalization=False,
#            zca_whitening=True,
#            rescale=None,
            rotation_range=3,
            width_shift_range=0.08,
            height_shift_range=0.08,
            shear_range=0.07,
            zoom_range=0.07,
            horizontal_flip=True,
            vertical_flip=True,
            fill_mode='constant',
            cval=0.
            )
        args=self.arggen(Augentries)
#        trimgs,labimgs,data_path,num
        config['data_path'],config['fext'],num,cr=args[0],args[1],int(args[2]),float(args[3])
        fext=config['fext']
        data_path=config['data_path']
        self.Savefiles(data_path,fext,'trlab')
        imhdf5=open_hdf5_file(config['image_hdf5_path'])
        trimgs=np.squeeze(imhdf5.root.data,axis=3)
        imhdf5.close()
        labhdf5=open_hdf5_file(config['label_hdf5_path'])
        labimgs=np.squeeze(labhdf5.root.truth,axis=3)
        labhdf5.close()
        nLabels=np.max(labimgs)
        print('estimated number of lables:',nLabels)
        if nLabels>1:
            labimgstmp=[]
            for i in range(1,nLabels+1):
                labimgstmp.append(np.ma.masked_not_equal(labimgs,i).filled(0)/i)
            labimgstmp=np.array(labimgstmp)
                
        imgshape=trimgs[0].shape
        print(imgshape)
        print('-'*30)
        print('Augmenting train and labels dataset: ',num,'replica per image...')
        print('-'*30)
    #    seed = np.random.randint(10000)
        seed=np.random.randint(10000,size=2*len(trimgs)*num)
        if tmpf in sorted(os.listdir(config['image_path'])):
           shutil.rmtree(os.path.join(config['image_path'],tmpf), ignore_errors=True)
           shutil.rmtree(os.path.join(config['label_path'],tmpf), ignore_errors=True)
        os.makedirs(os.path.join(config['image_path'],tmpf))
        os.makedirs(os.path.join(config['label_path'],tmpf))
        global batchdata
        batchdata=[]
        j=0
        for x in trimgs:
            x[x==0]=1
            x = x.reshape((1,) + x.shape+(1,))
            # the .flow() command below generates batches of randomly transformed images
            # and saves the results to the `preview/` directory
            i = 0
            
            for batch in datagen.flow(x, batch_size=1,seed=seed[j]):
                self.save_tif(data_path,os.path.join(image_p,tmpf),'img',batch[0,:,:,0].astype('uint8'),seed[i+j*2*num],fext)
                i += 1
                if i >= 2*num:
                    break  # otherwise the generator would loop indefinitely
            j +=1

        if nLabels>1:
            for k in range(1,nLabels+1):
                os.makedirs(os.path.join(config['label_path'],tmpf,str(k)))
                j=0
                for y in labimgstmp[k-1]:
                    y = y.reshape((1,) + y.shape+(1,))
                    i = 0
                    for batch in datagen.flow(y, batch_size=1,seed=seed[j]):
                        self.save_tif(data_path,os.path.join(label_p,tmpf,str(k)),'img',batch[0,:,:,0].astype('uint8'),seed[i+j*2*num],fext)
                        batchdata.append(batch[0,:,:,0])
                        i += 1
                        if i >= 2*num:
                            break  # otherwise the generator would loop indefinitely
                    j +=1
            imglist=[f for f in sorted(os.listdir(os.path.join(config['image_path'],tmpf))) if fext in f]
            for n in range(len(imglist)):
                tmp=sum(read_image(os.path.join(config['label_path'],tmpf,str(k),imglist[n]))*k for k in range(1,nLabels+1))
                self.save_tif(data_path,os.path.join(label_p,tmpf),'img',tmp.astype('uint8'),imglist[n].split('.')[0][-4:],fext)
            for k in range(1,nLabels+1):
                shutil.rmtree(os.path.join(config['label_path'],tmpf,str(k)), ignore_errors=True)                
        else:
            j=0
            for y in labimgs:
                y = y.reshape((1,) + y.shape+(1,))
                i = 0
                for batch in datagen.flow(y, batch_size=1,seed=seed[j]):
    
                    self.save_tif(data_path,os.path.join(label_p,tmpf),'img',batch[0,:,:,0].astype('uint8'),seed[i+j*2*num],fext)
                    batchdata.append(batch[0,:,:,0])
                    i += 1
                    if i >= 2*num:
                        break  # otherwise the generator would loop indefinitely
                j +=1
        self.Savefiles(data_path,fext,'trlab',subtask='augtmp')
#        create_train_data(data_path,os.path.join(image_p,tmpf),os.path.join(label_p,tmpf),fext)
        imhdf5=open_hdf5_file(config['image_hdf5_path'])
        tmptr=np.squeeze(imhdf5.root.data,axis=3)
        imhdf5.close()
        labhdf5=open_hdf5_file(config['label_hdf5_path'])
        tmplab=np.squeeze(labhdf5.root.truth,axis=3)
        labhdf5.close()
        print(imgshape,cr)
        lencrop=int(((imgshape[0]*cr)//16)*16),int(((imgshape[1]*cr)//16)*16)
        print(lencrop)
#        delta=imgshape[0]-lencrop[0],imgshape[1]-lencrop[1]
#        print(delta)
        seltr=[]
        sellab=[]
        j=0
        for i,img in enumerate(tmptr):
            tmpres=crop_no_black(tmptr[i],tmplab[i],lencrop)
            if tmpres is not None:
                seltr.append(tmpres[0])
                sellab.append(tmpres[1])
                j += 1
                if j > len(trimgs)*(num+1):
                    break
        seltr=np.array(seltr)
        sellab=np.array(sellab)
        print(seltr.shape,sellab.shape)
        if selfold in sorted(os.listdir(os.path.join(data_path,image_p))):
           shutil.rmtree(os.path.join(data_path,image_p,selfold), ignore_errors=True)
           shutil.rmtree(os.path.join(data_path,label_p,selfold), ignore_errors=True)        
        os.makedirs(os.path.join(data_path,image_p,selfold))
        os.makedirs(os.path.join(data_path,label_p,selfold))
        for i in range(len(seltr)):
            self.save_tif(data_path,os.path.join(image_p,selfold),'img',seltr[i],i,fext)
            self.save_tif(data_path,os.path.join(label_p,selfold),'img',sellab[i],i,fext)
#        create_train_data(data_path,image_p,label_p,fext)
        if tmpf in sorted(os.listdir(os.path.join(data_path,image_p))):
           shutil.rmtree(os.path.join(data_path,image_p,tmpf), ignore_errors=True)
           shutil.rmtree(os.path.join(data_path,label_p,tmpf), ignore_errors=True)     
        self.Savefiles(data_path,fext,'trlab',subtask='augm')
        print('Done')
        return
Пример #6
0
    def Savefiles(self,data_path,fext,task,subtask=None):
        if task =='trlab':
            config['image_path']=os.path.join(data_path,image_p)
            config['label_path']=os.path.join(data_path,label_p)
            outim=config['image_hdf5_path']=os.path.join(config['data_path'],config['trimg_npy'])
            outlab=config['label_hdf5_path']=os.path.join(config['data_path'],config['trlab_npy'])
            if subtask == None:
                if selfold in sorted(os.listdir(config['image_path'])):
                    print('Augmented data found. Saving augmented data instead of original ones')
                    imgs=fetch_data_1dir(os.path.join(config['image_path'],selfold),fext)
                    lbls=fetch_data_1dir(os.path.join(config['label_path'],selfold),fext)
    #                create_train_data(data_path,os.path.join(image_p,selfold),os.path.join(label_p,selfold),fext) 
                if aug3Dfol in sorted(os.listdir(config['image_path'])):
                    print('3D Augmented data found. Saving augmented data instead of original ones')
                    imgs=list()
                    lbls=list()
                    imgstmp=fetch_augdata_subdirs(os.path.join(config['image_path'],aug3Dfol),fext)
                    lblstmp=fetch_augdata_subdirs(os.path.join(config['label_path'],aug3Dfol),fext)
                    for volume in imgstmp:
                        imgs.append(volume)
                    for volume in lblstmp:
                        lbls.append(volume)
                    
                else:
                    if datasubf3D in sorted(os.listdir(config['image_path'])):
                        imgs=fetch_data_subdirs(os.path.join(config['image_path'],datasubf3D),fext)
                        lbls=fetch_data_subdirs(os.path.join(config['label_path'],datasubf3D),fext)                            
                    else:
                        imgs=fetch_data_1dir(config['image_path'],fext)
                        lbls=fetch_data_1dir(config['label_path'],fext)
            if subtask=='augm':
                if datasubf3D in sorted(os.listdir(config['image_path'])): 
                    imgs=list()
                    lbls=list()
                    if aug3Dfol in sorted(os.listdir(config['image_path'])):
                        shutil.rmtree(os.path.join(config['image_path'],aug3Dfol), ignore_errors=True)
                    imgs=fetch_data_subdirs(os.path.join(config['image_path'],datasubf3D),fext)
                    if aug3Dfol in sorted(os.listdir(config['label_path'])):
                        shutil.rmtree(os.path.join(config['label_path'],aug3Dfol), ignore_errors=True)
                    lbls=fetch_data_subdirs(os.path.join(config['label_path'],datasubf3D),fext)                
                else:
                    imgs=fetch_data_1dir(os.path.join(config['image_path'],selfold),fext)
                    lbls=fetch_data_1dir(os.path.join(config['label_path'],selfold),fext)            
            if subtask=='augtmp': 
                    imgs=fetch_data_1dir(os.path.join(config['image_path'],tmpf),fext)
                    lbls=fetch_data_1dir(os.path.join(config['label_path'],tmpf),fext)  
#                create_train_data(data_path,image_p,label_p,fext)
            if datasubf3D in sorted(os.listdir(config['image_path'])):
                print('3D data found')
                config['img_comp']=write_data_to_file(imgs,outim,'data',is3D=True)
            if datasubf3D in sorted(os.listdir(config['label_path'])):
                 
                config['lab_comp']=write_data_to_file(lbls,outlab,'truth',is3D=True)
            else:
                config['img_comp']=write_data_to_file(imgs,outim,'data')
                config['lab_comp']=write_data_to_file(lbls,outlab,'truth')

            print('Training images hdf5 file written to:',outim)
            print('Training labels hdf5 file written to:',outlab)
            imgsopen=open_hdf5_file(config['image_hdf5_path'])
            imgs_train = np.squeeze(imgsopen.root.data,axis=(-1))
            print('hdf image data shape:',imgs_train.shape)
            imgsopen.close()
        
        if task =='test':
            config['test_path']=os.path.join(data_path,test_p)
            outtest=os.path.join(config['data_path'],config['test_npy'])
            test=fetch_data_1dir(config['test_path'],fext)
            config['test_comp']=write_data_to_file(test,outtest,'test')
            print('Test images hdf5 file written to:',outtest)
            imgs_id = np.ndarray((len(test), ), dtype=np.int32)
            i=0
            for image_name in test:
                img_id = int(image_name.split('.')[0][-4:].lstrip('0'))
                imgs_id[i] = img_id
                i += 1
            np.save(os.path.join(data_path,'imgs_id_test.npy'), imgs_id)
Пример #7
0
    def AugmentDS3D(self,Augentries):
        args=self.arggen(Augentries)
#        trimgs,labimgs,data_path,num
        config['data_path'],config['fext'],num,cr=args[0],args[1],int(args[2]),float(args[3])
        fext=config['fext']
        data_path=config['data_path']    
        fold3D=config['3Daugm']
        self.Savefiles(data_path,fext,'trlab',subtask='augm')
        imhdf5=open_hdf5_file(config['image_hdf5_path'])
        trimgs=np.squeeze(imhdf5.root.data,axis=(-1))
        imhdf5.close()
        labhdf5=open_hdf5_file(config['label_hdf5_path'])
        labimgs=np.squeeze(labhdf5.root.truth,axis=(-1))
        labhdf5.close()
        nLabels=np.max(labimgs)
        print('estimated number of lables:',nLabels)
        print('stack shape',trimgs.shape)
        (nvols,nd,nr,nc)=trimgs.shape
        patch_size=64
        origin_row = np.random.randint(0, nr-patch_size, num*nvols)
        origin_col = np.random.randint(0, nc-patch_size, num*nvols)
        origin_dep = np.random.randint(0, nd-patch_size, num*nvols)
        origins=[]
        for i in range(nvols):
            volorigins=np.array((origin_dep[i*num:(i+1)*num],origin_row[i*num:(i+1)*num],origin_col[i*num:(i+1)*num]))
            origins.append(volorigins)
        origins=np.array(origins)
        trimgs_patches=[]
        labs_patches=[]
        for i in range(nvols):
            trimgs_patches.append(extract_3D_patches(trimgs[i],patch_size,origins[i]))
            labs_patches.append(extract_3D_patches(labimgs[i],patch_size,origins[i]))
        trimgs_patches=np.array(trimgs_patches)
        labs_patches=np.array(labs_patches)
        if nLabels>1:
            for i in range(trimgs_patches.shape[0]):
                for j in range(trimgs_patches.shape[1]):
#                    print(i,j)
#                    plt.imshow(trimgs_patches[i,j][0])
#                    plt.show()
                    trimgs_patches[i,j],labs_patches[i,j]= transform_3Dpatch(trimgs_patches[i,j],labs_patches[i,j])
#                    print('Transformed')
#                    plt.imshow(trimgs_patches[i,j][0])
#                    plt.show()
#                    print('-'*30)
        if nLabels==1:
            for i in range(trimgs_patches.shape[0]):
                for j in range(trimgs_patches.shape[1]):
#                    trimgs_patches[i,j],labs_patches[i,j]= transform_3Dpatch(trimgs_patches[i,j],labs_patches[i,j],scale_deviation=0.15)
                    trimgs_patches[i,j],labs_patches[i,j]= transform_3Dpatch(trimgs_patches[i,j],labs_patches[i,j])
        print('final training patch shape',trimgs_patches.shape)
        if fold3D in sorted(os.listdir(os.path.join(data_path,image_p))):
           shutil.rmtree(os.path.join(data_path,image_p,fold3D), ignore_errors=True)
           shutil.rmtree(os.path.join(data_path,label_p,fold3D), ignore_errors=True)        
        os.makedirs(os.path.join(data_path,image_p,fold3D))
        os.makedirs(os.path.join(data_path,label_p,fold3D))
        for i in range(trimgs_patches.shape[0]):
            os.makedirs(os.path.join(data_path,image_p,fold3D,str(i)))
            os.makedirs(os.path.join(data_path,label_p,fold3D,str(i)))
            for k in range(trimgs_patches.shape[1]):
                os.makedirs(os.path.join(data_path,image_p,fold3D,str(i),str(k)))
                os.makedirs(os.path.join(data_path,label_p,fold3D,str(i),str(k)))                
                for j in range(trimgs_patches.shape[2]):
                    self.save_tif(data_path,os.path.join(image_p,fold3D,str(i),str(k)),'img',trimgs_patches[i,k][j],j,fext)
                    self.save_tif(data_path,os.path.join(label_p,fold3D,str(i),str(k)),'img',labs_patches[i,k][j],j,fext)
#        imglabpatches=np.array((trimgs_patches,labs_patches))
        self.Savefiles(data_path,fext,'trlab')
#        self.Savefiles3D(data_path,imglabpatches,fext)
        
        return