def wrap(i_image=None):
     start_time = time.time_ns()
     rtn = i_fn(i_image)
     end_time = time.time_ns()
     proc_time = (end_time - start_time)/1e6 #Unit miliseconds
     Logs.log('Processing time X = {}'.format(proc_time))
     return rtn
 def seg_predict(self,i_image=None):
     """Note that: The segmentation problem has no overlapped objects (what can happened in detection problem)"""
     assert isinstance(i_image, np.ndarray), 'Got type: {}'.format(type(i_image))
     assert len(i_image.shape) in (2, 3), 'Got shape: {}'.format(i_image.shape)
     if len(i_image.shape) == 2:  # Gray image with shape (height, width)
         i_image = np.expand_dims(i_image, -1)
     else:  # RGB image with shape (height, width, depth)
         assert len(i_image.shape) == 3
     assert i_image.shape[-1] in (1, 3)  # Only gray or color images are accepted
     assert i_image.dtype in (np.uint8,)  # Only accept normal RGB image (0~255)
     """Block extraction"""
     height, width, depth= i_image.shape
     mask = np.zeros(shape=(height, width, 1))
     blocks, masks = self.get_blks(i_image=i_image, i_mask=mask, i_blk_sizes=self.vseg_isize,i_blk_strides=self.vseg_strides)
     num_blk_height = len(blocks)
     num_blk_width  = len(blocks[0])
     blocks = self.forward_block_convert(blocks)
     blocks = np.array(blocks)
     assert len(blocks.shape) == 4, '{}'.format(blocks.shape)  # Shape: (None, blk_height, blk_width,nchannels)
     """Prediction. DONOT Normalize data"""
     preds = self.segnet.predict(i_image=blocks)               # N-by256-by-256-by-1 for example.
     """Scaling preds to be same as the original"""
     spreds = []
     for pred in preds:
         spreds.append(SupFns.scale_mask(i_mask=pred,i_tsize=self.vseg_isize))
     preds      = self.backward_block_convert(spreds,num_blk_height,num_blk_width)
     pred_image = self.join_blks(i_blks=preds, i_steps=self.vseg_strides,i_overlapped_adjust=True)
     Logs.log('seg pred_image shape = {}'.format(pred_image.shape))
     return pred_image #Shape (height, width,1) with gray level from 0 to (self.segnet.vnum_classes -1).
 def cls_predict(self,i_image=None):
     """Note that: clsnet here is used to classify an image block into two classess of with or without objects"""
     """Classification of a single RGB image"""
     assert isinstance(i_image, np.ndarray), 'Got type: {}'.format(type(i_image))
     """Support function for making prediction block image"""
     def make_pred_image(i_pred_label=0.):
         assert isinstance(i_pred_label,(float,int)), 'Got type: {}'.format(type(i_pred_label))
         if i_pred_label>0:
             return np.ones(shape=(self.vcls_isize[0],self.vcls_isize[1],1),dtype=np.int)
         else:
             return np.zeros(shape=(self.vcls_isize[0],self.vcls_isize[1],1),dtype=np.int)
     assert len(i_image.shape) in (2,3), 'Got shape: {}'.format(i_image.shape)
     if len(i_image.shape) == 2: #Gray image with shape (height, width)
         i_image = np.expand_dims(i_image,-1)
     else:#RGB image with shape (height, width, depth)
         assert len(i_image.shape)==3
     assert i_image.shape[-1] in (1,3), 'Got shape: {}'.format(i_image.shape)   #Only gray or color images are accepted
     assert i_image.dtype in (np.uint8,), 'Got dtype: {}'.format(i_image.dtype) #Only accept normal RGB image (0~255)
     """Block extraction"""
     height, width, depth  = i_image.shape
     mask = np.zeros(shape =(height, width, 1))
     blocks, masks  = self.get_blks(i_image=i_image, i_mask=mask, i_blk_sizes=self.vcls_isize,i_blk_strides=self.vcls_strides)
     num_blk_height = len(blocks)
     num_blk_width  = len(blocks[0])
     blocks = self.forward_block_convert(blocks)
     blocks = np.array(blocks)         # Shape: (None, blk_height, blk_width, nchannels)
     assert len(blocks.shape)== 4 ,'Got shape: {}'.format(blocks.shape)      # Shape: (None, blk_height, blk_width, nchannels)
     assert blocks.shape[-1] == depth, 'Got shape: {}'.format(blocks.shape)  # Shape: (None, blk_height, blk_width, nchannels)
     """Preclassify based on block gray level"""
     block_means     = [np.mean(blk) for blk in blocks]
     pre_pred_labels = [x > self.vcls_sgray_level for x in block_means]
     pre_pred_labels = np.array(pre_pred_labels, dtype=np.float)
     """Prediction. DONOT Normalize data"""
     preds  = self.clsnet.predict(i_image=blocks) #Shape: (None, num_classes) where None is number of blocks extracted from image
     """As my design, preds has shape of (None, 2) as indicates the with/without existance of object in block"""
     pred_labels = (preds[:, 1] - preds[:, 0]) > self.vcls_th #self.vcls_th = 0 for conventional case (i. e. argmax)
     assert isinstance(pred_labels,np.ndarray)
     assert len(pred_labels.shape) == 1 #Shape (None, )
     pred_labels = pred_labels.astype(np.float)
     pred_labels = pred_labels * pre_pred_labels
     pred_blocks = [make_pred_image(i_pred_label=i) for i in pred_labels]
     pred_blocks = self.backward_block_convert(i_blocks=pred_blocks,i_height=num_blk_height,i_width=num_blk_width)
     pred_image  = self.join_blks(i_blks=pred_blocks, i_steps=self.vcls_strides) #Binary mask image with value of 0s and 1s
     Logs.log('cls pred_image shape = {}'.format(pred_image.shape))
     return preds, pred_image
 def on_epoch_end(self, epoch, logs=None):
     """
     @param epoch: Current epoch index
     @type epoch: Integer
     @param logs: include acc and loss, and optionally include val_loss (if validation is enabled), and val_acc (if validation and accuracy monitoring are enabled)
     @type logs: Dictionary
     @return:
     @rtype:
     """
     """Save model to disk after every epoch"""
     if self.accuracy <logs['accuracy']:
         Logs.log('Accuracy improved from {} to {}'.format(self.accuracy,logs['accuracy']))
         self.accuracy = logs['accuracy']
         self.model.save(filepath=self.model_path)
     else:
         pass
     summary_str = 'Epoch: {} '.format(epoch, epoch)
     for key in logs.keys():
         summary_str = '{} {} = {:3.6f}'.format(summary_str, key, logs[key])
     summary_str +='\n'
     Logs.log(summary_str)
Beispiel #5
0
    StackedCnnUNets.vcls_isize = (
        64, 64
    )  # Size of block for clsnet                                     (*)
    StackedCnnUNets.vseg_isize = (
        128, 128
    )  # Size ò block for segnet                                      (*)
    StackedCnnUNets.vcls_strides = (
        32, 32
    )  # Stride for taking blocks for clsnet                          (*)
    StackedCnnUNets.vseg_strides = (
        64, 64
    )  # Stride for taking blocks for segnet                          (*)
    StackedCnnUNets.vcls_sgray_level = 30  # Threshold for removing dark blocks                           (Fixed)
    StackedCnnUNets.vcls_object_size = 10  # Threshold for deciding blocks with/without gnd object        (Fixed)
    StackedCnnUNets.vseg_object_size = 10  # Threshold for deciding blocks with/without gnd object        (Fixed)
    StackedCnnUNets.vcls_th = 0  # Threshold for making decision (extension)                    (Fixed)
    StackedCnnUNets.vcls_lsm_factor = 0  # Label smoothing factor. A number from 0 to 1                 (Fixed)
    StackedCnnUNets.vseg_lsm_factor = 0  # Label smoothing factor. A number from 0 to 1                 (Fixed)
    StackedCnnUNets.vdebug = False  # Debug flag                                                   (Fixed)
    trainer = StackedCnnUNets()  #(Fixed)
    """Get sample dataset"""
    ich_dber = ICH_DB(i_tsize=(512, 512), i_num_folds=num_folds)
    train_db, val_db = ich_dber.call(i_fold_index=fold_index)
    trainer.train(i_train_db=train_db, i_val_db=val_db)  #(Fixed)
    """2D Evaluation"""
    #trainer.eval(i_db=val_db)                                               #(Fixed)
    """3D Evaluation"""
    val_db = ich_dber.get_val_patient(i_fold_index=fold_index)
    trainer.eval3d(i_db=val_db)
    Logs.move_log(i_dst_path=ckpts)
"""=================================================================================================================="""
 def on_train_end(self, logs=None):
     print('Train End: ',logs)
     Logs.log('Finished training our model!')
 def on_train_begin(self, logs=None):
     print('Train Begin: ',logs)
     Logs.log('Start training our model...')
Beispiel #8
0
 def measures(self, i_labels=None, i_preds=None, i_object_index=1):
     assert isinstance(i_labels, (list, tuple))
     assert isinstance(i_preds, (list, tuple))
     assert isinstance(i_object_index, int)
     assert i_object_index > 0  #Dont care background
     assert len(i_labels) == len(i_preds)
     measures = list()
     for index, label in enumerate(i_labels):
         pred = i_preds[index]
         measures.append(
             self.get_measures(i_labels=label,
                               i_preds=pred,
                               i_object_index=i_object_index))
     measures = np.array(measures)
     """Overall Measure (global)"""
     global_measures = list()
     global_measures.append(
         SegMetrics_2D.get_dice(i_TP=self.TPs, i_FP=self.FPs,
                                i_FN=self.FNs))  # Dice
     global_measures.append(
         SegMetrics_2D.get_Jaccard(i_TP=self.TPs,
                                   i_FP=self.FPs,
                                   i_FN=self.FNs))  # Jaccard
     global_measures.append(
         SegMetrics_2D.get_precision(i_TP=self.TPs,
                                     i_FP=self.FPs))  # Precision
     global_measures.append(
         SegMetrics_2D.get_recall(i_TP=self.TPs, i_FN=self.FNs))  # Recall
     global_measures.append(
         SegMetrics_2D.get_recall(i_TP=self.TPs,
                                  i_FN=self.FNs))  # Sensitivity = Recall
     global_measures.append(
         SegMetrics_2D.get_specificity(i_TN=self.TNs,
                                       i_FP=self.FPs))  # Specificity
     global_measures.append(
         SegMetrics_2D.get_accuracy(i_TP=self.TPs,
                                    i_TN=self.TNs,
                                    i_FP=self.FPs,
                                    i_FN=self.FNs))  # Overall Accuracy
     global_measures = np.array(global_measures)
     """2D Performance measurement"""
     """Performance measurement"""
     evaluer = SegMetrics_2D(i_num_classes=i_object_index + 1,
                             i_care_background=False)
     Logs.log('Using entire dataset')
     measures2d, measure_mean2d, measure_std2d = evaluer.eval(
         i_labels=self.labels, i_preds=self.preds, i_object_care=False)
     Logs.log('Measure shape (2D) = {}'.format(measures2d.shape))
     Logs.log('Measure mean  (2D) = {}'.format(measure_mean2d))
     Logs.log('Measure std   (2D) = {}'.format(measure_std2d))
     Logs.log(
         'Using sub dataset that only consider images containing objects')
     measures2d, measure_mean2d, measure_std2d = evaluer.eval(
         i_labels=self.labels, i_preds=self.preds, i_object_care=True)
     Logs.log('Measure shape (2D) = {}'.format(measures2d.shape))
     Logs.log('Measure mean  (2D) = {}'.format(measure_mean2d))
     Logs.log('Measure std   (2D) = {}'.format(measure_std2d))
     self.erase_indicators()
     return measures, SegMetrics_2D.get_mean(
         measures), SegMetrics_2D.get_std(measures), global_measures
 def init_params(self,i_params):
     assert isinstance(i_params, (SysParams,dict))
     Logs.log('-'*100)
     Logs.log('Init stacked cls-seg class global variables...')
     """As my design, all global variables of a class start with 'v' letter """
     global_variables = [var for var in self.__dict__ if var.startswith('v')]  # As my design
     if isinstance(i_params,SysParams):
         params = i_params.__dict__
     else:
         params = i_params.copy()
     for key in params.keys():
         val = params[key]
         if key in global_variables:
             Logs.log('Variable {} was changed from {} to : {}'.format(key,self.__dict__[key],val))
             self.__dict__.update({key: val})
         else:
             assert isinstance(key, str)
             if key.startswith('v'):
                 self.__dict__.update({key: val})
                 Logs.log('Init {} as : {}'.format(key, val))
             else:
                 pass
     """Custom parameter adjustment"""
     Logs.log('Set-up parameters for segmentation nets:')
     Logs.log_cls_params(self)
     self.__init__()
     return True
    def prepare_data(self, i_db=None, i_cls_flag=True, i_train_flag=True,i_save_path=None):
        assert isinstance(i_db,(list, tuple))  # List of (image, mask) pair => Used for segnet cases
        assert isinstance(i_cls_flag,bool)
        assert isinstance(i_train_flag,bool)
        if i_cls_flag:
            blk_size    = self.vcls_isize
            blk_strides = self.vcls_strides
            object_size = self.vcls_object_size
            threshold   = self.vcls_sgray_level
        else:
            blk_size    = self.vseg_isize
            blk_strides = self.vseg_strides
            object_size = self.vseg_object_size
            threshold   = 0
        positive_blks,negative_blks = [],[]
        tfrecord_size = 100000
        TFRecordDB.lossy = False
        tfwriter = TFRecordDB()
        for index, element in enumerate(i_db):
            image, mask = element                 # As my design
            assert isinstance(image, np.ndarray)  # Image
            assert isinstance(mask, np.ndarray)   # Mask for segmentation
            if len(image.shape)==2: # Gray image with shape (height, width)
                image = np.expand_dims(image,-1)
            else:
                assert len(image.shape)==3
            assert len(mask.shape) in (2, 3)
            if len(mask.shape)==2:
                mask = np.expand_dims(mask,-1)
            else:
                assert len(mask.shape)==3
                assert mask.shape[-1]==1 # Gray image as the meaning of mask
            """Start extracting blocks"""
            blocks, blk_masks = self.get_blks(i_image=image,i_mask=mask,i_blk_sizes=blk_size,i_blk_strides=blk_strides)
            """Flatten blocks. As my design of get and joint blocks funs"""
            blocks    = self.forward_block_convert(blocks)
            blk_masks = self.forward_block_convert(blk_masks)
            #mask_size = int(np.sum((mask>0).astype(np.int)))*0.75
            for blk_ind,blk in enumerate(blocks):
                blk_mask = blk_masks[blk_ind]
                assert isinstance(blk_mask,np.ndarray)
                if i_cls_flag:#Taking blocks for clsnets
                    if i_train_flag:
                        if np.average(blk) >= threshold:
                            if np.sum(blk_mask)>=object_size:#Count number of object pixels
                                positive_blks.append(blk)
                            elif 0<np.sum(blk_mask)<object_size:
                                pass
                            else:
                                negative_blks.append(blk)
                        else:
                            pass
                    else:#Testing => Just keep original
                        if np.sum(blk_mask)>0:
                            positive_blks.append(blk)
                        else:
                            negative_blks.append(blk)
                else:#Taking blocks for segnets
                    if i_train_flag:
                        if np.sum(blk_mask)>=object_size: #Only taking blocks with objects
                            positive_blks.append(blk)        #Image
                            negative_blks.append(blk_mask)   #Mask
                        else:
                            pass
                    else:#Testing => Just keep original
                        if np.sum(blk_mask)>0:
                            positive_blks.append(blk)       # Image
                            negative_blks.append(blk_mask)  # Mask
                        else:
                            pass
            """Complement blocks"""
            if i_train_flag:
                min_object_size      = object_size
                obj_blocks,obj_masks = self.get_obj_blks(i_image=image,i_mask=mask,i_blk_sizes=blk_size,i_object_size=min_object_size)
                for obj_index, obj_blk in enumerate(obj_blocks):
                    positive_blks.append(obj_blk)
                    if i_cls_flag:
                        pass
                    else:
                        negative_blks.append(obj_masks[obj_index])
                print('Images: {} => Additional blocks = {} vs {} => Sizes = P: {} and N: {} --- {}'.format(index, len(obj_blocks), len(obj_masks), len(positive_blks), len(negative_blks), object_size))
            else:
                pass
        """Save data to TFRecordDB"""
        blocks,labels = [],[]
        if i_cls_flag:#Classification network
            num_positive_blocks = len(positive_blks)
            num_negative_blocks = len(negative_blks)
            num_samples         = max(num_negative_blocks,num_positive_blocks)
            Logs.log('Num Pos = {}, Num Neg = {}, Num Samples = {}'.format(num_positive_blocks, num_negative_blocks,num_samples))
            if i_train_flag:
                min_samples = min(num_negative_blocks,num_positive_blocks)
                ori_neg_indices = [i for i in range(num_negative_blocks)]
                ori_pos_indices = [i for i in range(num_positive_blocks)]
                neg_indices,pos_indices = [],[]
                if min_samples==num_negative_blocks:
                    ratio   = int(num_positive_blocks/num_negative_blocks)
                    remains = num_positive_blocks - ratio*num_negative_blocks
                    remains_indices = [i for i in range(remains)]
                    for ratio_index in range(ratio):
                        neg_indices += ori_neg_indices
                    neg_indices += remains_indices
                    pos_indices = ori_pos_indices
                    print(ratio,len(neg_indices),len(pos_indices))
                else:
                    ratio = int(num_negative_blocks / num_positive_blocks)
                    remains = num_negative_blocks - ratio * num_positive_blocks
                    remains_indices = [i for i in range(remains)]
                    for ratio_index in range(ratio):
                        pos_indices += ori_pos_indices
                    pos_indices += remains_indices
                    neg_indices = ori_neg_indices
                combine_indices = list(zip(neg_indices,pos_indices))
                for combine_item in combine_indices:
                    neg_index = combine_item[0]
                    pos_index = combine_item[1]
                    blocks.append(positive_blks[pos_index])
                    labels.append(1)
                    blocks.append(negative_blks[neg_index])
                    labels.append(0)
            else:
                blocks = positive_blks + negative_blks
                labels = [1 for _ in range(num_positive_blocks)] + [0 for _ in range(num_negative_blocks)]
            num_positives = np.sum(labels)
            num_negatives = len(labels) - num_positives
            num_blocks    = num_negatives + num_positives
            log_path = os.path.split(i_save_path)[0]
            log_path = os.path.join(log_path, 'statistics.txt')
            with open(log_path, 'a+') as file:
                file.writelines('Statistics for classification db\n')
                file.writelines('Num Positive Blocks: {}/{} ~= {}(%)\n'.format(num_positives, num_blocks,num_positives * 100 / num_blocks))
                file.writelines('Num Negative Blocks: {}/{} ~= {}(%)\n'.format(num_negatives, num_blocks,num_negatives * 100 / num_blocks))
                file.writelines('-' * 100)
                file.writelines('\n')
        else:#Segmentation network
            blocks = positive_blks #List of unit8 images
            labels = negative_blks #List of uint8 images
            """Writing statistics"""
            num_blocks       = len(blocks)
            total_sizes      = num_blocks * blk_size[0] * blk_size[1]

            #labels_bin       = np.array(labels)
            #labels_bin       = (labels_bin>0).astype(np.int)
            #object_sizes     = np.sum(labels_bin)

            object_sizes     = np.sum([np.sum(label) for label in labels])
            non_object_sizes = total_sizes - object_sizes
            object_sizes     = np.array(object_sizes,dtype=np.int64)
            non_object_sizes = np.array(non_object_sizes,dtype=np.int64)
            total_sizes      = np.array(total_sizes,dtype=np.int64)
            log_path = os.path.split(i_save_path)[0]
            log_path = os.path.join(log_path, 'statistics.txt')
            with open(log_path, 'a+') as file:
                file.writelines('-' * 100 + '\n')
                file.writelines('Statistics for segmentation db\n')
                file.writelines('Num blocks = {}\n'.format(num_blocks))
                file.writelines('Ratio Object    = {}/{} ~ {}(%)\n'.format(object_sizes, total_sizes, (100 * object_sizes) / total_sizes))
                file.writelines('Ratio NonObject = {}/{} ~ {}(%)\n'.format(non_object_sizes, total_sizes,(100 * non_object_sizes) / total_sizes))
                file.writelines('\n')
        """"Write to tfrecords"""
        db_fields = {'image': [], 'label': []}
        write_data = list(zip(blocks, labels))
        tfwriter.write(i_n_records=write_data, i_size=tfrecord_size, i_fields=db_fields, i_save_file=i_save_path)
        dataset = TFRecordDB.read(i_save_path,i_original=True)#Set i_original to True to return dictionary
        return dataset
 def eval3d(self,i_db=None):
     """As my design, i_db is in format ((images,masks),...,(images,masks))"""
     assert isinstance(i_db,(list,tuple))
     num_patients = len(i_db)
     Logs.log('Number of patients = {}'.format(num_patients))
     preds,labels = [],[]
     image_index = 0
     for index, patient in enumerate(i_db):
         p_preds,p_labels = [],[]
         images, masks = patient
         num_images = images.shape[-1]
         for p_index in range(num_images):
             p_image= images[:,:,p_index]
             p_mask = masks[:,:,p_index]
             start_time = time.time_ns()
             pred_image = self.predict(i_image=p_image)
             end_time   = time.time_ns()
             proc_time  = (end_time-start_time)/1e6 #Unit is ms
             Logs.log('Processing time = {}'.format(proc_time))
             p_preds.append(pred_image)
             p_labels.append(p_mask)
             Logs.log('Image index = {} with pred_shape = {} and label_shape = {}'.format(image_index,pred_image.shape,p_mask.shape))
             image_index += 1
         p_preds = np.array(p_preds)
         p_labels= np.array(p_labels)
         preds.append(p_preds)
         labels.append(p_labels)
     """Measurement"""
     """Performance measurement"""
     evaluer = SegMetrics_3D()
     measures, measure_mean, measure_std, global_measures = evaluer.measures(i_labels=labels,i_preds=preds,i_object_index=1)
     Logs.log('Measure shape (3D) = {}'.format(measures.shape))
     Logs.log_matrix(i_str='Details',i_matrix=measures)
     Logs.log('Measure mean  (3D) = {}'.format(SupFns.round_matrix(measure_mean)))
     Logs.log('Measure std   (3D) = {}'.format(SupFns.round_matrix(measure_std)))
     Logs.log('Measuge Global(3D) = {}'.format(SupFns.round_matrix(global_measures)))
     return labels, preds
 def eval(self,i_db=None):
     assert isinstance(i_db,(list,tuple)),'Got type: {}'.format(type(i_db))
     labels, preds = [],[]
     for index, element in enumerate(i_db):
         print('Evaluating index = {}'.format(index))
         image, mask = element  # As my design
         pred_image = self.predict(i_image=image)
         preds.append(pred_image)
         labels.append(mask)
     """Performance measurement"""
     evaluer = SegMetrics_2D(i_num_classes=self.segnet.vnum_classes, i_care_background=self.segnet.vcare_background)
     Logs.log('Using entire dataset')
     measures, measure_mean, measure_std = evaluer.eval(i_labels=labels, i_preds=preds, i_object_care=False)
     Logs.log('Measure shape = {}'.format(measures.shape))
     Logs.log('Measure mean  = {}'.format(measure_mean))
     Logs.log('Measure std   = {}'.format(measure_std))
     Logs.log('Using sub dataset that only consider images containing objects')
     measures, measure_mean, measure_std = evaluer.eval(i_labels=labels, i_preds=preds, i_object_care=True)
     Logs.log('Measure shape = {}'.format(measures.shape))
     Logs.log('Measure mean  = {}'.format(measure_mean))
     Logs.log('Measure std   = {}'.format(measure_std))
     return labels,preds