Exemplo n.º 1
 def cal_dice(thres_seg, size_seg, thres_after = 0.2):
     ipos = 0
     dice = 0.0
     for pred, true_rle in zip(preds, trues):
         # post process
         true = rle2mask(true_rle, self.args.width, self.args.height)
         pred = post_process_segment(pred, thres_seg, size_seg, thres_after)
         ipos += 1
         dice += dice_metric(true, pred)
     return dice/len(preds)
Exemplo n.º 2
        def cal_dice(thres_seg, size_seg, thres_after, thres_oth=-float('inf'), size_oth=0):
            ipos = 0
            dice = 0.0
            if self.args.use_weight:
                nnormal = self.nweight
                nnormal = len(self.dataloader.dataset)

            for pred, other, true_rle in zip(preds, others, trues):
                # post process
                true = rle2mask(true_rle, self.args.width, self.args.height)
                pred = post_process_single(pred, other, thres_seg, size_seg, thres_after, thres_oth, size_oth)
                if self.args.use_weight:
                    dice += dice_metric(true, pred)*self.weight[ipos]
                    dice += dice_metric(true, pred)
                ipos += 1

            return dice/nnormal
Exemplo n.º 3
def evaluate_batch(data, outputs, args, threshold=0.5):
    if args.output == 0:
        masks = data[1].detach().cpu().numpy()
        pred_masks = (torch.sigmoid(outputs).detach().cpu().numpy() >
        # print(masks.shape, pred_masks.shape)
        return dice_metric(masks, pred_masks), 0.0
    elif args.output == 1:
        masks = data[1].detach().cpu().numpy()
        labels = data[2].detach().cpu().numpy()
        pred_masks = (torch.sigmoid(outputs[0]).detach().cpu().numpy() >
        pred_labels = outputs[1].detach().cpu().numpy()
        return dice_metric(masks, pred_masks), np.sum(
            np.sqrt((pred_labels - labels)**2))
    elif args.output == 2:  # classification
        masks = data[1].detach().cpu().numpy()
        labels = data[2].detach().cpu().numpy()
        pred_masks = (torch.sigmoid(outputs[0]).detach().cpu().numpy() >
        pred_labels = (torch.sigmoid(outputs[1]).detach().cpu().numpy() >
        return dice_metric(masks, pred_masks), np.sum(
            (pred_labels == labels).astype(int))
Exemplo n.º 4
    def predict_dataloader(self, to_rle = False, fnames = None):
        if self.dicPara is None:

        if to_rle and fnames is None:
            raise ValueError('File names are not given.')
        # evaluate the net
        dice_total, dice_pos, dice_neg = Meter(), Meter(), Meter()
        dicSubmit = {'ImageId_ClassId':[], 'EncodedPixels':[]}
        preds = []
        ipos = 0
        def area_ratio(mask):
            return mask.sum()/self.args.height/self.args.width
        with torch.no_grad():
            for data in tqdm(self.dataloader):
                # load the data
                images, labels = data[0].to(self.device), data[1].to(self.device)
                images = images.permute(0, 3, 1, 2)

                output_masks, output_labels = self.predict_flip_batch(images)
                for output_mask, output_label, label_raw in zip(output_masks, output_labels, labels):
                    # using simple threshold and output the result
                    output_thres = post_process_segment(output_mask, thres_seg = self.dicPara['thres_seg'], \
                                                             size_seg = self.dicPara['size_seg'], \
                                                             thres_after = self.dicPara['thres_after'])

                    # calculate the dice if it is not a test dataloader
                    if not self.isTest:
                        dice = dice_metric(label_raw.detach().cpu().numpy(), output_thres)
                        if label_raw.sum() > 0:
        # print information
        print('Parameters: ',','.join(['"{:s}":{:.4f}'.format(key, val) for key,val in self.dicPara.items()]))
        print('Dice total {:.3f}'.format(dice_total.avg()))
        print('Positive Data {:d}, {:.3f}'.format(dice_pos.num, dice_pos.avg()))
        print('Negative Data {:d}, {:.3f}'.format(dice_neg.num, dice_neg.avg())) 
        return dice_total.avg()
Exemplo n.º 5
 def dice(self, logit, truth, threshold=0.5):
     prob = F.sigmoid(logit)
     dice = dice_metric(prob, truth, threshold=threshold, is_average=True)
     return dice
Exemplo n.º 6
    def predict_dataloader(self, to_rle = False, fnames = None):
        if self.dicPara is None:

        if to_rle and fnames is None:
            raise ValueError('File names are not given.')
        # evaluate the net
        dicPred = dict()
        for classid in range(self.args.category):
            dicPred['Class '+str(classid+1)] = []
            dicPred['Dice '+str(classid+1)] = []
            dicPred['True '+str(classid+1)] = []
        dicSubmit = {'ImageId_ClassId':[], 'EncodedPixels':[]}
        dice, preds = 0.0, []
        diceW = 0.0
        ipos = 0
        def area_ratio(mask):
            return mask.sum()/self.args.height/self.args.width
        with torch.no_grad():
            for data in tqdm(self.dataloader):
                # load the data
                images, labels = data[0].to(self.device), data[1].to(self.device)
                images = images.permute(0, 3, 1, 2)

                output_masks, output_labels = self.predict_flip_batch(images)
                for output_mask, output_label, label_raw in zip(output_masks, output_labels, labels):
                    # using simple threshold and output the result
                    output_thres = post_process(output_mask, output_label, self.dicPara)
                    # transfer into the rles
                    # record the predicted labels
                    for category in range(self.args.category):
                        # to rle if required
                        if to_rle:
                            fname = fnames[ipos]
                            fname_short = fname.split('/')[-1]+'_{:d}'.format(category+1)
                            rle = mask2rle(output_thres[:,:,category])
                        dicPred['Class {:d}'.format(category+1)].append(area_ratio(output_thres[:,:,category]))
                        if not self.isTest:
                            dice_cat = dice_metric(label_raw[:,:,category].detach().cpu().numpy(), output_thres[:,:,category])
                            dicPred['Dice {:d}'.format(category+1)].append(dice_cat)
                            dicPred['True {:d}'.format(category+1)].append(area_ratio(label_raw[:,:,category].detach().cpu().numpy()))
                            # add to the final dice
                            # print(self.weight.shape, ipos)
                            dice  += dice_cat
                            diceW += dice_cat*self.weight[ipos]                            
                    ipos += 1
        keys = [key for key in dicPred.keys()]
        for key in keys:
            if len(dicPred[key]) == 0:
                dicPred.pop(key, None)

        # regularize result
        diceW =  diceW/self.nweight/self.args.category
        dice  =  dice/len(self.dataloader.dataset)/self.args.category
        print("Weighted Dice {:.4f}\t Unweighted Dice {:.4f}".format(diceW, dice))

        return dice, dicPred, dicSubmit