def eval(args, eval_data_loader, model, result_path, logger): model.eval() batch_time = AverageMeter() dice = AverageMeter() end = time.time() dice_list = [] Dice_1 = AverageMeter() Dice_2 = AverageMeter() Dice_3 = AverageMeter() ret_segmentation = [] for iter, (image, label, imt) in enumerate(eval_data_loader): # if iter > 1: # break # batchsize = 1 ,so squeeze dim 1 image = image.squeeze(dim=0) label = label.squeeze(dim=0) target_seg = label.numpy() target_cls = target_seg2target_cls(target_seg) with torch.no_grad(): # batch test for memory reduce batch = 8 pred_seg = torch.zeros(image.shape[0], image.shape[2], image.shape[3]) pred_cls = torch.zeros(image.shape[0], 3) for i in range(0, image.shape[0], batch): start_id = i end_id = i + batch if end_id > image.shape[0]: end_id = image.shape[0] image_batch = image[start_id:end_id, :, :, :] image_var = Variable(image_batch).cuda() # wangshen model forward output_seg, output_cls = model(image_var) _, pred_batch = torch.max(output_seg, 1) pred_seg[start_id:end_id, :, :] = pred_batch.cpu().data pred_cls[start_id:end_id, :] = output_cls.cpu().data pred_seg = pred_seg.numpy().astype('uint8') if args.vis: imt = (imt.squeeze().numpy()).astype('uint8') ant = label.numpy().astype('uint8') model_name = args.seg_path.split('/')[-3] save_dir = osp.join(result_path, 'vis', '%04d' % iter) if not exists(save_dir): os.makedirs(save_dir) vis_multi_class(imt, ant, pred_seg, save_dir) print('save vis, finished!') batch_time.update(time.time() - end) label_seg = label.numpy().astype('uint8') pred_seg = pred_seg.numpy().astype('uint8') ret = aic_fundus_lesion_segmentation(label_seg, pred_seg) ret_segmentation.append(ret) dice_score = compute_single_segment_score(ret) dice_list.append(dice_score) dice.update(dice_score) Dice_1.update(ret[1]) Dice_2.update(ret[2]) Dice_3.update(ret[3]) ground_truth = target_cls.numpy().astype('float32') prediction = pred_cls.numpy().astype('float32') # predict label if iter == 0: detection_ref_all = ground_truth detection_pre_all = prediction else: detection_ref_all = np.concatenate( (detection_ref_all, ground_truth), axis=0) detection_pre_all = np.concatenate( (detection_pre_all, prediction), axis=0) end = time.time() logger_vis.info( 'Eval: [{0}/{1}]\t' 'Dice {dice.val:.3f} ({dice.avg:.3f})\t' 'Dice_1 {dice_1.val:.3f} ({dice_1.avg:.3f})\t' 'Dice_2 {dice_2.val:.3f} ({dice_2.avg:.3f})\t' 'Dice_3 {dice_3.val:.3f} ({dice_3.avg:.3f})\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format( iter, len(eval_data_loader), dice=dice, dice_1=Dice_1, dice_2=Dice_2, dice_3=Dice_3, batch_time=batch_time)) final_seg, seg_1, seg_2, seg_3 = compute_segment_score(ret_segmentation) print('### Seg ###') print('Final Seg Score:{}'.format(final_seg)) print('Final Seg_1 Score:{}'.format(seg_1)) print('Final Seg_2 Score:{}'.format(seg_2)) print('Final Seg_3 Score:{}'.format(seg_3)) ret_detection = aic_fundus_lesion_classification( detection_ref_all, detection_pre_all, num_samples=len(eval_data_loader) * 128) auc = np.array(ret_detection).mean() print('AUC :', auc) auc_1 = ret_detection[0] auc_2 = ret_detection[1] auc_3 = ret_detection[2] epoch = 0 logger.append( [epoch, final_seg, seg_1, seg_2, seg_3, auc, auc_1, auc_2, auc_3])
def val(args,eval_data_loader, model): model.eval() batch_time = AverageMeter() dice = AverageMeter() end = time.time() dice_list = [] Dice_1 = AverageMeter() Dice_2 = AverageMeter() Dice_3 = AverageMeter() ret_segmentation = [] for iter, (image, label, _) in enumerate(eval_data_loader): # batchsize = 1 ,so squeeze dim 1 image = image.squeeze(dim=0) label = label.squeeze(dim=0) target_seg = label.numpy() target_cls = target_seg2target_cls(target_seg) with torch.no_grad(): # batch test for memory reduce batch = 16 pred_seg = torch.zeros(image.shape[0],image.shape[2],image.shape[3]) pred_cls = torch.zeros(image.shape[0],3) for i in range(0,image.shape[0],batch): start_id = i end_id = i + batch if end_id > image.shape[0]: end_id = image.shape[0] image_batch = image[start_id:end_id,:,:,:] image_var = Variable(image_batch).cuda() # model forward output_seg,output_cls = model(image_var) _, pred_batch = torch.max(output_seg, 1) pred_seg[start_id:end_id,:,:] = pred_batch.cpu().data pred_cls[start_id:end_id,:] = output_cls.cpu().data # merice dice for seg pred_seg = pred_seg.numpy().astype('uint8') batch_time.update(time.time() - end) label_seg = label.numpy().astype('uint8') ret = aic_fundus_lesion_segmentation(label_seg,pred_seg) ret_segmentation.append(ret) dice_score = compute_single_segment_score(ret) dice_list.append(dice_score) dice.update(dice_score) Dice_1.update(ret[1]) Dice_2.update(ret[2]) Dice_3.update(ret[3]) # metric auc for cls ground_truth = target_cls.numpy().astype('float32') prediction = pred_cls.numpy().astype('float32') if iter == 0: detection_ref_all = ground_truth detection_pre_all = prediction else: detection_ref_all = np.concatenate((detection_ref_all, ground_truth), axis=0) detection_pre_all = np.concatenate((detection_pre_all, prediction), axis=0) end = time.time() logger_vis.info('Eval: [{0}/{1}]\t' 'Dice {dice.val:.3f} ({dice.avg:.3f})\t' 'Dice_1 {dice_1.val:.3f} ({dice_1.avg:.3f})\t' 'Dice_2 {dice_2.val:.3f} ({dice_2.avg:.3f})\t' 'Dice_3 {dice_3.val:.3f} ({dice_3.avg:.3f})\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' .format(iter, len(eval_data_loader), dice = dice,dice_1 = Dice_1,dice_2 = Dice_2,dice_3 = Dice_3,batch_time=batch_time)) # compute average dice for seg final_seg,seg_1,seg_2,seg_3 = compute_segment_score(ret_segmentation) print('### Seg ###') print('Final Seg Score:{}'.format(final_seg)) print('Final Seg_1 Score:{}'.format(seg_1)) print('Final Seg_2 Score:{}'.format(seg_2)) print('Final Seg_3 Score:{}'.format(seg_3)) # compute average auc for cls ret_detection = aic_fundus_lesion_classification(detection_ref_all, detection_pre_all, num_samples=len(eval_data_loader)*128) auc = np.array(ret_detection).mean() print('AUC :',auc) auc_1 = ret_detection[0] auc_2 = ret_detection[1] auc_3 = ret_detection[2] return final_seg,seg_1,seg_2,seg_3,dice_list,auc,auc_1,auc_2,auc_3