Exemplo n.º 1
0
def eval(args, eval_data_loader, model, result_path, logger):
    model.eval()
    batch_time = AverageMeter()
    dice = AverageMeter()
    end = time.time()
    dice_list = []
    Dice_1 = AverageMeter()
    Dice_2 = AverageMeter()
    Dice_3 = AverageMeter()
    ret_segmentation = []

    for iter, (image, label, imt) in enumerate(eval_data_loader):
        # if iter > 1:
        #     break
        # batchsize = 1 ,so squeeze dim 1
        image = image.squeeze(dim=0)
        label = label.squeeze(dim=0)

        target_seg = label.numpy()
        target_cls = target_seg2target_cls(target_seg)

        with torch.no_grad():
            # batch test for memory reduce
            batch = 8
            pred_seg = torch.zeros(image.shape[0], image.shape[2],
                                   image.shape[3])
            pred_cls = torch.zeros(image.shape[0], 3)
            for i in range(0, image.shape[0], batch):
                start_id = i
                end_id = i + batch
                if end_id > image.shape[0]:
                    end_id = image.shape[0]
                image_batch = image[start_id:end_id, :, :, :]
                image_var = Variable(image_batch).cuda()
                # wangshen model forward
                output_seg, output_cls = model(image_var)
                _, pred_batch = torch.max(output_seg, 1)
                pred_seg[start_id:end_id, :, :] = pred_batch.cpu().data
                pred_cls[start_id:end_id, :] = output_cls.cpu().data

            pred_seg = pred_seg.numpy().astype('uint8')

            if args.vis:
                imt = (imt.squeeze().numpy()).astype('uint8')
                ant = label.numpy().astype('uint8')
                model_name = args.seg_path.split('/')[-3]
                save_dir = osp.join(result_path, 'vis', '%04d' % iter)
                if not exists(save_dir): os.makedirs(save_dir)
                vis_multi_class(imt, ant, pred_seg, save_dir)
                print('save vis, finished!')

            batch_time.update(time.time() - end)
            label_seg = label.numpy().astype('uint8')

            pred_seg = pred_seg.numpy().astype('uint8')

            ret = aic_fundus_lesion_segmentation(label_seg, pred_seg)
            ret_segmentation.append(ret)
            dice_score = compute_single_segment_score(ret)
            dice_list.append(dice_score)
            dice.update(dice_score)
            Dice_1.update(ret[1])
            Dice_2.update(ret[2])
            Dice_3.update(ret[3])

            ground_truth = target_cls.numpy().astype('float32')
            prediction = pred_cls.numpy().astype('float32')  # predict label

            if iter == 0:
                detection_ref_all = ground_truth
                detection_pre_all = prediction
            else:
                detection_ref_all = np.concatenate(
                    (detection_ref_all, ground_truth), axis=0)
                detection_pre_all = np.concatenate(
                    (detection_pre_all, prediction), axis=0)

        end = time.time()
        logger_vis.info(
            'Eval: [{0}/{1}]\t'
            'Dice {dice.val:.3f} ({dice.avg:.3f})\t'
            'Dice_1 {dice_1.val:.3f} ({dice_1.avg:.3f})\t'
            'Dice_2 {dice_2.val:.3f} ({dice_2.avg:.3f})\t'
            'Dice_3 {dice_3.val:.3f} ({dice_3.avg:.3f})\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(
                iter,
                len(eval_data_loader),
                dice=dice,
                dice_1=Dice_1,
                dice_2=Dice_2,
                dice_3=Dice_3,
                batch_time=batch_time))

    final_seg, seg_1, seg_2, seg_3 = compute_segment_score(ret_segmentation)
    print('### Seg ###')
    print('Final Seg Score:{}'.format(final_seg))
    print('Final Seg_1 Score:{}'.format(seg_1))
    print('Final Seg_2 Score:{}'.format(seg_2))
    print('Final Seg_3 Score:{}'.format(seg_3))

    ret_detection = aic_fundus_lesion_classification(
        detection_ref_all,
        detection_pre_all,
        num_samples=len(eval_data_loader) * 128)
    auc = np.array(ret_detection).mean()
    print('AUC :', auc)
    auc_1 = ret_detection[0]
    auc_2 = ret_detection[1]
    auc_3 = ret_detection[2]

    epoch = 0
    logger.append(
        [epoch, final_seg, seg_1, seg_2, seg_3, auc, auc_1, auc_2, auc_3])
def val(args,eval_data_loader, model):
    model.eval()
    batch_time = AverageMeter()
    dice = AverageMeter()
    end = time.time()
    dice_list = []
    Dice_1 = AverageMeter()
    Dice_2 = AverageMeter()
    Dice_3 = AverageMeter()
    ret_segmentation = []
    
    for iter, (image, label, _) in enumerate(eval_data_loader):
        # batchsize = 1 ,so squeeze dim 1
        image = image.squeeze(dim=0)
        label = label.squeeze(dim=0)

        target_seg = label.numpy()
        target_cls = target_seg2target_cls(target_seg)
        
        with torch.no_grad():
            # batch test for memory reduce
            batch = 16
            pred_seg = torch.zeros(image.shape[0],image.shape[2],image.shape[3])
            pred_cls = torch.zeros(image.shape[0],3)
            for i in range(0,image.shape[0],batch):
                start_id = i
                end_id = i + batch
                if end_id > image.shape[0]:
                    end_id = image.shape[0]
                image_batch = image[start_id:end_id,:,:,:]
                image_var = Variable(image_batch).cuda()
                # model forward
                output_seg,output_cls = model(image_var)
                _, pred_batch = torch.max(output_seg, 1)
                pred_seg[start_id:end_id,:,:] = pred_batch.cpu().data
                pred_cls[start_id:end_id,:] = output_cls.cpu().data
            # merice dice for seg
            pred_seg = pred_seg.numpy().astype('uint8') 
            batch_time.update(time.time() - end)
            label_seg = label.numpy().astype('uint8')
            ret = aic_fundus_lesion_segmentation(label_seg,pred_seg)
            ret_segmentation.append(ret)
            dice_score = compute_single_segment_score(ret)
            dice_list.append(dice_score)
            dice.update(dice_score)
            Dice_1.update(ret[1])
            Dice_2.update(ret[2])
            Dice_3.update(ret[3])
            # metric auc for cls
            ground_truth = target_cls.numpy().astype('float32')
            prediction = pred_cls.numpy().astype('float32') 
            if iter == 0:
                detection_ref_all = ground_truth
                detection_pre_all = prediction
            else:
                detection_ref_all = np.concatenate((detection_ref_all, ground_truth), axis=0)
                detection_pre_all = np.concatenate((detection_pre_all, prediction), axis=0)
            
        end = time.time()
        logger_vis.info('Eval: [{0}/{1}]\t'
                    'Dice {dice.val:.3f} ({dice.avg:.3f})\t'
                    'Dice_1 {dice_1.val:.3f} ({dice_1.avg:.3f})\t'
                    'Dice_2 {dice_2.val:.3f} ({dice_2.avg:.3f})\t'
                    'Dice_3 {dice_3.val:.3f} ({dice_3.avg:.3f})\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    .format(iter, len(eval_data_loader), dice = dice,dice_1 = Dice_1,dice_2 = Dice_2,dice_3 = Dice_3,batch_time=batch_time))

    # compute average dice for seg
    final_seg,seg_1,seg_2,seg_3 = compute_segment_score(ret_segmentation)
    print('### Seg ###')
    print('Final Seg Score:{}'.format(final_seg))
    print('Final Seg_1 Score:{}'.format(seg_1))
    print('Final Seg_2 Score:{}'.format(seg_2))
    print('Final Seg_3 Score:{}'.format(seg_3))
    # compute average auc for cls
    ret_detection = aic_fundus_lesion_classification(detection_ref_all, detection_pre_all, num_samples=len(eval_data_loader)*128)
    auc = np.array(ret_detection).mean()
    print('AUC :',auc)
    auc_1 = ret_detection[0]
    auc_2 = ret_detection[1]
    auc_3 = ret_detection[2]

    return final_seg,seg_1,seg_2,seg_3,dice_list,auc,auc_1,auc_2,auc_3