Esempio n. 1
0
def val(args,eval_data_loader, model):
    model.eval()
    batch_time = AverageMeter()
    dice = AverageMeter()
    end = time.time()
    dice_list = []
    Dice_1 = AverageMeter()
    Dice_2 = AverageMeter()
    Dice_3 = AverageMeter()
    Dice_4 = AverageMeter()
    Dice_5 = AverageMeter()
    Dice_6 = AverageMeter()
    Dice_7 = AverageMeter()
    Dice_8 = AverageMeter()
    Dice_9 = AverageMeter()
    ret_segmentation = []
    
    for iter, (image, label, _) in enumerate(eval_data_loader):
        # batchsize = 1 ,so squeeze dim 1
        image = image.squeeze(dim=0)
        label = label.squeeze(dim=0)

        target_seg = label.numpy()
        #target_cls = target_seg2target_cls(target_seg)
        
        with torch.no_grad():
            # batch test for memory reduce
            batch = 1
            pred_seg = torch.zeros(image.shape[0],image.shape[2],image.shape[3])
            pred_cls = torch.zeros(image.shape[0],9)
            for i in range(0,image.shape[0],batch):
                start_id = i
                end_id = i + batch
                if end_id > image.shape[0]:
                    end_id = image.shape[0]
                image_batch = image[start_id:end_id,:,:,:]
                image_var = Variable(image_batch).cuda()
                # model forward
                output_seg = model(image_var)
                _, pred_batch = torch.max(output_seg, 1)
                pred_seg[start_id:end_id,:,:] = pred_batch.cpu().data
                #pred_cls[start_id:end_id,:] = output_cls.cpu().data
            # merice dice for seg
            pred_seg = pred_seg.numpy().astype('uint8') 
            batch_time.update(time.time() - end)
            label_seg = label.numpy().astype('uint8')
            ret = aic_fundus_lesion_segmentation(label_seg,pred_seg)
            ret_segmentation.append(ret)
            dice_score = compute_single_segment_score(ret)
            dice_list.append(dice_score)
            dice.update(dice_score)
            Dice_1.update(ret[1])
            Dice_2.update(ret[2])
            Dice_3.update(ret[3])
            Dice_4.update(ret[4])
            Dice_5.update(ret[5])
            Dice_6.update(ret[6])
            Dice_7.update(ret[7])
            Dice_8.update(ret[8])
            Dice_9.update(ret[9])
            # metric auc for cls
            #ground_truth = target_cls.numpy().astype('float32')
            #prediction = pred_cls.numpy().astype('float32')
            #if iter == 0:
             #   detection_ref_all = ground_truth
             #   detection_pre_all = prediction
            #else:
             #   detection_ref_all = np.concatenate((detection_ref_all, ground_truth), axis=0)
             #   detection_pre_all = np.concatenate((detection_pre_all, prediction), axis=0)
            
        end = time.time()
        logger_vis.info('Eval: [{0}/{1}]\t'
                    'Dice {dice.val:.3f} ({dice.avg:.3f})\t'
                    'Dice_1 {dice_1.val:.3f} ({dice_1.avg:.3f})\t'
                    'Dice_2 {dice_2.val:.3f} ({dice_2.avg:.3f})\t'
                    'Dice_3 {dice_3.val:.3f} ({dice_3.avg:.3f})\t'
                    'Dice_4 {dice_4.val:.6f} ({dice_4.avg:.4f})\t'
                    'Dice_5 {dice_5.val:.6f} ({dice_5.avg:.4f})\t'
                    'Dice_6 {dice_6.val:.6f} ({dice_6.avg:.4f})\t'
                    'Dice_7 {dice_7.val:.6f} ({dice_7.avg:.4f})\t'
                    'Dice_8 {dice_8.val:.6f} ({dice_8.avg:.4f})\t'
                    'Dice_9 {dice_9.val:.6f} ({dice_9.avg:.4f})\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    .format(iter, len(eval_data_loader), dice = dice,dice_1 = Dice_1,dice_2 = Dice_2,dice_3 = Dice_3,dice_4=Dice_4,dice_5=Dice_5,dice_6=Dice_6,dice_7=Dice_7,dice_8=Dice_8,dice_9=Dice_9,batch_time=batch_time))

    # compute average dice for seg
    final_seg,seg_1,seg_2,seg_3,seg_4,seg_5,seg_6,seg_7,seg_8,seg_9, = compute_segment_score(ret_segmentation)
    print('### Seg ###')
    print('Final Seg Score:{}'.format(final_seg))
    print('Final Seg_1 Score:{}'.format(seg_1))
    print('Final Seg_2 Score:{}'.format(seg_2))
    print('Final Seg_3 Score:{}'.format(seg_3))
    print('Final Seg_4 Score:{}'.format(seg_4))
    print('Final Seg_5 Score:{}'.format(seg_5))
    print('Final Seg_6 Score:{}'.format(seg_6))
    print('Final Seg_7 Score:{}'.format(seg_7))
    print('Final Seg_8 Score:{}'.format(seg_8))
    print('Final Seg_9 Score:{}'.format(seg_9))
    # compute average auc for cls

    #ret_detection = aic_fundus_lesion_classification(detection_ref_all, detection_pre_all, num_samples=len(eval_data_loader)*56) ###############
    #print(ret_detection)
    #auc = np.array(ret_detection).mean()
    #print('AUC :',auc)
    #auc_1 = ret_detection[0]
    #auc_2 = ret_detection[1]
    #auc_3 = ret_detection[2]
    #auc_4 = ret_detection[3]
    #auc_5 = ret_detection[4]
    #auc_6 = ret_detection[5]
    #auc_7 = ret_detection[6]
    #auc_8 = ret_detection[7]
    #auc_9 = ret_detection[8]
    return final_seg,seg_1,seg_2,seg_3,seg_4,seg_5,seg_6,seg_7,seg_8,seg_9,dice_list
def val(args,eval_data_loader, model):
    model.eval()
    batch_time = AverageMeter()
    dice = AverageMeter()
    end = time.time()
    dice_list = []
    Dice_1 = AverageMeter()
    Dice_2 = AverageMeter()
    Dice_3 = AverageMeter()
    ret_segmentation = []
    
    for iter, (image, label, _) in enumerate(eval_data_loader):
        # batchsize = 1 ,so squeeze dim 1
        image = image.squeeze(dim=0)
        label = label.squeeze(dim=0)

        target_seg = label.numpy()
        target_cls = target_seg2target_cls(target_seg)
        
        with torch.no_grad():
            # batch test for memory reduce
            batch = 16
            pred_seg = torch.zeros(image.shape[0],image.shape[2],image.shape[3])
            pred_cls = torch.zeros(image.shape[0],3)
            for i in range(0,image.shape[0],batch):
                start_id = i
                end_id = i + batch
                if end_id > image.shape[0]:
                    end_id = image.shape[0]
                image_batch = image[start_id:end_id,:,:,:]
                image_var = Variable(image_batch).cuda()
                # model forward
                output_seg,output_cls = model(image_var)
                _, pred_batch = torch.max(output_seg, 1)
                pred_seg[start_id:end_id,:,:] = pred_batch.cpu().data
                pred_cls[start_id:end_id,:] = output_cls.cpu().data
            # merice dice for seg
            pred_seg = pred_seg.numpy().astype('uint8') 
            batch_time.update(time.time() - end)
            label_seg = label.numpy().astype('uint8')
            ret = aic_fundus_lesion_segmentation(label_seg,pred_seg)
            ret_segmentation.append(ret)
            dice_score = compute_single_segment_score(ret)
            dice_list.append(dice_score)
            dice.update(dice_score)
            Dice_1.update(ret[1])
            Dice_2.update(ret[2])
            Dice_3.update(ret[3])
            # metric auc for cls
            ground_truth = target_cls.numpy().astype('float32')
            prediction = pred_cls.numpy().astype('float32') 
            if iter == 0:
                detection_ref_all = ground_truth
                detection_pre_all = prediction
            else:
                detection_ref_all = np.concatenate((detection_ref_all, ground_truth), axis=0)
                detection_pre_all = np.concatenate((detection_pre_all, prediction), axis=0)
            
        end = time.time()
        logger_vis.info('Eval: [{0}/{1}]\t'
                    'Dice {dice.val:.3f} ({dice.avg:.3f})\t'
                    'Dice_1 {dice_1.val:.3f} ({dice_1.avg:.3f})\t'
                    'Dice_2 {dice_2.val:.3f} ({dice_2.avg:.3f})\t'
                    'Dice_3 {dice_3.val:.3f} ({dice_3.avg:.3f})\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    .format(iter, len(eval_data_loader), dice = dice,dice_1 = Dice_1,dice_2 = Dice_2,dice_3 = Dice_3,batch_time=batch_time))

    # compute average dice for seg
    final_seg,seg_1,seg_2,seg_3 = compute_segment_score(ret_segmentation)
    print('### Seg ###')
    print('Final Seg Score:{}'.format(final_seg))
    print('Final Seg_1 Score:{}'.format(seg_1))
    print('Final Seg_2 Score:{}'.format(seg_2))
    print('Final Seg_3 Score:{}'.format(seg_3))
    # compute average auc for cls
    ret_detection = aic_fundus_lesion_classification(detection_ref_all, detection_pre_all, num_samples=len(eval_data_loader)*128)
    auc = np.array(ret_detection).mean()
    print('AUC :',auc)
    auc_1 = ret_detection[0]
    auc_2 = ret_detection[1]
    auc_3 = ret_detection[2]

    return final_seg,seg_1,seg_2,seg_3,dice_list,auc,auc_1,auc_2,auc_3
Esempio n. 3
0
def eval(args, eval_data_loader, model, result_path, logger):
    model.eval()
    batch_time = AverageMeter()
    dice = AverageMeter()
    end = time.time()
    dice_list = []
    Dice_1 = AverageMeter()
    Dice_2 = AverageMeter()
    Dice_3 = AverageMeter()
    ret_segmentation = []

    for iter, (image, label, imt) in enumerate(eval_data_loader):
        # if iter > 1:
        #     break
        # batchsize = 1 ,so squeeze dim 1
        image = image.squeeze(dim=0)
        label = label.squeeze(dim=0)

        target_seg = label.numpy()
        target_cls = target_seg2target_cls(target_seg)

        with torch.no_grad():
            # batch test for memory reduce
            batch = 8
            pred_seg = torch.zeros(image.shape[0], image.shape[2],
                                   image.shape[3])
            pred_cls = torch.zeros(image.shape[0], 3)
            for i in range(0, image.shape[0], batch):
                start_id = i
                end_id = i + batch
                if end_id > image.shape[0]:
                    end_id = image.shape[0]
                image_batch = image[start_id:end_id, :, :, :]
                image_var = Variable(image_batch).cuda()
                # wangshen model forward
                output_seg, output_cls = model(image_var)
                _, pred_batch = torch.max(output_seg, 1)
                pred_seg[start_id:end_id, :, :] = pred_batch.cpu().data
                pred_cls[start_id:end_id, :] = output_cls.cpu().data

            pred_seg = pred_seg.numpy().astype('uint8')

            if args.vis:
                imt = (imt.squeeze().numpy()).astype('uint8')
                ant = label.numpy().astype('uint8')
                model_name = args.seg_path.split('/')[-3]
                save_dir = osp.join(result_path, 'vis', '%04d' % iter)
                if not exists(save_dir): os.makedirs(save_dir)
                vis_multi_class(imt, ant, pred_seg, save_dir)
                print('save vis, finished!')

            batch_time.update(time.time() - end)
            label_seg = label.numpy().astype('uint8')

            pred_seg = pred_seg.numpy().astype('uint8')

            ret = aic_fundus_lesion_segmentation(label_seg, pred_seg)
            ret_segmentation.append(ret)
            dice_score = compute_single_segment_score(ret)
            dice_list.append(dice_score)
            dice.update(dice_score)
            Dice_1.update(ret[1])
            Dice_2.update(ret[2])
            Dice_3.update(ret[3])

            ground_truth = target_cls.numpy().astype('float32')
            prediction = pred_cls.numpy().astype('float32')  # predict label

            if iter == 0:
                detection_ref_all = ground_truth
                detection_pre_all = prediction
            else:
                detection_ref_all = np.concatenate(
                    (detection_ref_all, ground_truth), axis=0)
                detection_pre_all = np.concatenate(
                    (detection_pre_all, prediction), axis=0)

        end = time.time()
        logger_vis.info(
            'Eval: [{0}/{1}]\t'
            'Dice {dice.val:.3f} ({dice.avg:.3f})\t'
            'Dice_1 {dice_1.val:.3f} ({dice_1.avg:.3f})\t'
            'Dice_2 {dice_2.val:.3f} ({dice_2.avg:.3f})\t'
            'Dice_3 {dice_3.val:.3f} ({dice_3.avg:.3f})\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(
                iter,
                len(eval_data_loader),
                dice=dice,
                dice_1=Dice_1,
                dice_2=Dice_2,
                dice_3=Dice_3,
                batch_time=batch_time))

    final_seg, seg_1, seg_2, seg_3 = compute_segment_score(ret_segmentation)
    print('### Seg ###')
    print('Final Seg Score:{}'.format(final_seg))
    print('Final Seg_1 Score:{}'.format(seg_1))
    print('Final Seg_2 Score:{}'.format(seg_2))
    print('Final Seg_3 Score:{}'.format(seg_3))

    ret_detection = aic_fundus_lesion_classification(
        detection_ref_all,
        detection_pre_all,
        num_samples=len(eval_data_loader) * 128)
    auc = np.array(ret_detection).mean()
    print('AUC :', auc)
    auc_1 = ret_detection[0]
    auc_2 = ret_detection[1]
    auc_3 = ret_detection[2]

    epoch = 0
    logger.append(
        [epoch, final_seg, seg_1, seg_2, seg_3, auc, auc_1, auc_2, auc_3])
Esempio n. 4
0
def eval(args, eval_data_loader, model, result_path, logger):
    model.eval()
    batch_time = AverageMeter()
    dice = AverageMeter()
    end = time.time()
    dice_list = []
    Dice_1 = AverageMeter()
    Dice_2 = AverageMeter()
    Dice_3 = AverageMeter()
    Dice_4 = AverageMeter()
    Dice_5 = AverageMeter()
    Dice_6 = AverageMeter()
    Dice_7 = AverageMeter()
    Dice_8 = AverageMeter()
    Dice_9 = AverageMeter()
    ret_segmentation = []

    for iter, (image, label, imt) in enumerate(eval_data_loader):
        # if iter > 1:
        #     break
        # batchsize = 1 ,so squeeze dim 1
        image = image.squeeze(dim=0)
        label = label.squeeze(dim=0)

        target_seg = label.numpy()
        #target_cls = target_seg2target_cls(target_seg)

        with torch.no_grad():
            # batch test for memory reduce
            batch = 1
            pred_seg = torch.zeros(image.shape[0], image.shape[2],
                                   image.shape[3])
            #pred_cls = torch.zeros(image.shape[0],10)
            for i in range(0, image.shape[0], batch):
                start_id = i
                end_id = i + batch
                if end_id > image.shape[0]:
                    end_id = image.shape[0]
                image_batch = image[start_id:end_id, :, :, :]
                image_var = Variable(image_batch).cuda()
                # wangshen model forward
                output_seg = model(image_var)
                #print(output_seg.shape)
                #print(torch.max(output_seg, 1))
                _, pred_batch = torch.max(output_seg, 1)

                pred_seg[start_id:end_id, :, :] = pred_batch.cpu().data

                #pred_cls[start_id:end_id,:] = output_cls.cpu().data

            pred_seg = pred_seg.numpy().astype('uint8')

            if args.vis:
                imt = (imt.squeeze().numpy()).astype('uint8')
                ant = label.numpy().astype('uint8')
                model_name = args.seg_path.split('/')[-3]
                save_dir = osp.join(result_path, 'vis', '%04d' % iter)
                if not exists(save_dir): os.makedirs(save_dir)
                vis_multi_class(imt, ant, pred_seg, save_dir)
                print('save vis, finished!')

            batch_time.update(time.time() - end)
            label_seg = label.numpy().astype('uint8')

            ret = aic_fundus_lesion_segmentation(label_seg, pred_seg)
            ret_segmentation.append(ret)
            dice_score = compute_single_segment_score(ret)
            dice_list.append(dice_score)
            dice.update(dice_score)
            Dice_1.update(ret[1])
            Dice_2.update(ret[2])
            Dice_3.update(ret[3])
            Dice_4.update(ret[4])
            Dice_5.update(ret[5])
            Dice_6.update(ret[6])
            Dice_7.update(ret[7])
            Dice_8.update(ret[8])
            Dice_9.update(ret[9])

            #ground_truth = target_cls.numpy().astype('float32')
            #prediction = pred_cls.numpy().astype('float32') # predict label

            #if iter == 0:
            #    detection_ref_all = ground_truth
            #    detection_pre_all = prediction
            #else:
            #    detection_ref_all = np.concatenate((detection_ref_all, ground_truth), axis=0)
            #    detection_pre_all = np.concatenate((detection_pre_all, prediction), axis=0)

        end = time.time()
        logger_vis.info(
            'Eval: [{0}/{1}]\t'
            'Dice {dice.val:.3f} ({dice.avg:.3f})\t'
            'Dice_1 {dice_1.val:.6f} ({dice_1.avg:.4f})\t'
            'Dice_2 {dice_2.val:.6f} ({dice_2.avg:.4f})\t'
            'Dice_3 {dice_3.val:.6f} ({dice_3.avg:.4f})\t'
            'Dice_4 {dice_4.val:.6f} ({dice_4.avg:.4f})\t'
            'Dice_5 {dice_5.val:.6f} ({dice_5.avg:.4f})\t'
            'Dice_6 {dice_6.val:.6f} ({dice_6.avg:.4f})\t'
            'Dice_7 {dice_7.val:.6f} ({dice_7.avg:.4f})\t'
            'Dice_8 {dice_8.val:.6f} ({dice_8.avg:.4f})\t'
            'Dice_9 {dice_9.val:.6f} ({dice_9.avg:.4f})\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(
                iter,
                len(eval_data_loader),
                dice=dice,
                dice_1=Dice_1,
                dice_2=Dice_2,
                dice_3=Dice_3,
                dice_4=Dice_4,
                dice_5=Dice_5,
                dice_6=Dice_6,
                dice_7=Dice_7,
                dice_8=Dice_8,
                dice_9=Dice_9,
                batch_time=batch_time))

    final_seg, seg_1, seg_2, seg_3, seg_4, seg_5, seg_6, seg_7, seg_8, seg_9 = compute_segment_score(
        ret_segmentation)
    print('### Seg ###')
    print('Final Seg Score:{}'.format(final_seg))
    print('Final Seg_1 Score:{}'.format(seg_1))
    print('Final Seg_2 Score:{}'.format(seg_2))
    print('Final Seg_3 Score:{}'.format(seg_3))
    print('Final Seg_4 Score:{}'.format(seg_4))
    print('Final Seg_5 Score:{}'.format(seg_5))
    print('Final Seg_6 Score:{}'.format(seg_6))
    print('Final Seg_7 Score:{}'.format(seg_7))
    print('Final Seg_8 Score:{}'.format(seg_8))
    print('Final Seg_9 Score:{}'.format(seg_9))

    #ret_detection = aic_fundus_lesion_classification( detection_ref_all, detection_pre_all, num_samples=len(eval_data_loader)*42)   #######val number
    #auc = np.array(ret_detection).mean()
    #print('AUC :',auc)
    #auc_1 = ret_detection[0]
    #auc_2 = ret_detection[1]
    #auc_3 = ret_detection[2]
    #auc_4 = ret_detection[3]
    #auc_5 = ret_detection[4]
    #auc_6 = ret_detection[5]
    #auc_7 = ret_detection[6]
    #auc_8 = ret_detection[7]
    #auc_9 = ret_detection[8]
    epoch = 0
    logger.append([
        epoch, final_seg, seg_1, seg_2, seg_3, seg_4, seg_5, seg_6, seg_7,
        seg_8, seg_9
    ])