def infer(args, model, criterion, val_loader): OtherVal = BinaryIndicatorsMetric() DeepOtherVal = BinaryIndicatorsMetric() val_loss = AverageMeter() model.eval() with torch.no_grad(): for step, (input, target,_) in tqdm(enumerate(val_loader)): input = input.to(args.device) target = target.to(args.device) preds_list = model(input) preds_list = [pred.view(pred.size(0), -1) for pred in preds_list] target = target.view(target.size(0), -1) if args.deepsupervision: for i in range(len(preds_list)): if i == 0: v_loss = criterion(preds_list[i], target) v_loss += criterion(preds_list[i], target) else: v_loss = criterion(preds_list[-1], target) val_loss.update(v_loss.item(), 1) # get the deepsupervision if args.deepsupervision: avg_pred = 0 for pred in preds_list: avg_pred += pred avg_pred /= len(preds_list) DeepOtherVal.update(labels=target, preds=avg_pred, n=1) OtherVal.update(labels=target, preds=preds_list[-1], n=1) if args.deepsupervision: return val_loss.avg, OtherVal.get_avg, DeepOtherVal.get_avg else: return val_loss.avg, OtherVal.get_avg
def val(args,model,criterion,val_loader,epoch,logger): OtherVal=BinaryIndicatorsMetric() val_loss=AverageMeter() model.eval() with torch.no_grad(): for step, (input, target,_) in enumerate(val_loader): input = input.to(args.device) target = target.to(args.device) preds =model(input) if args.deepsupervision: assert isinstance(preds, list) or isinstance(preds, tuple) preds = [pred.view(pred.size(0), -1) for pred in preds] target = target.view(target.size(0), -1) for index in range(len(preds)): if index == 0: loss = criterion(preds[index], target) loss += criterion(preds[index], target) loss /= len(preds) else: preds = preds.view(preds.size(0), -1) target = target.view(target.size(0), -1) loss = criterion(preds, target) val_loss.update(loss.item(),1) if args.deepsupervision: OtherVal.update(labels=target, preds=preds[-1], n=1) else: OtherVal.update(labels=target,preds=preds,n=1) # update best and recoder early stop mr, ms, mp, mf, mjc, md, macc = OtherVal.get_avg mean_loss=val_loss.avg logger.info("Epoch:{} Val Loss:{:.3f} ".format(epoch,mean_loss)) logger.info("Acc:{:.3f} Rec:{:.3f} Spe:{:.3f} Pre:{:.3f} F1:{:.3f} Jc:{:.3f} Dice:{:.3f}".format(macc,mr, ms, mp, mf, mjc, md)) return mr, ms, mp, mf, mjc, md,macc,mean_loss
def train(args, model, criterion, train_loader, optimizer, epoch, logger): OtherTrain = BinaryIndicatorsMetric() model.train() train_loss = AverageMeter() for step, (input, target, _) in enumerate(train_loader): input = input.to(args.device) target = target.to(args.device) # input is B C H W target is B,1,H,W preds: B,1,H,W preds = model(input) if args.deepsupervision: assert isinstance(preds, list) or isinstance(preds, tuple) preds = [pred.view(pred.size(0), -1) for pred in preds] target = target.view(target.size(0), -1) for index in range(len(preds)): if index == 0: loss = criterion(preds[index], target) loss += criterion(preds[index], target) loss /= len(preds) else: preds = preds.view(preds.size(0), -1) target = target.view(target.size(0), -1) loss = criterion(preds, target) train_loss.update(loss.item(), 1) optimizer.zero_grad() loss.backward() optimizer.step() # get all the indicators if args.deepsupervision: OtherTrain.update(labels=target, preds=preds[-1], n=1) else: OtherTrain.update(labels=target, preds=preds, n=1) logger.info("Epoch:{} Train Loss:{:.3f}".format(epoch, train_loss.avg))
def infer(args, model, criterion, val_loader,logger,path): OtherVal = BinaryIndicatorsMetric() val_loss = AverageMeter() model.eval() with torch.no_grad(): for step, (input, target,name) in tqdm(enumerate(val_loader)): input = input.to(args.device) target = target.to(args.device) pred = model(input) if args.deepsupervision: pred=pred[-1].clone() # sabe predit mask # save the mask file_masks=pred.clone() file_masks=torch.sigmoid(file_masks).data.cpu().numpy() n,c,h,w=file_masks.shape assert n==len(file_masks) for i in range(len(file_masks)): file_index=int(name[i].split('.')[0]) file_mask=(file_masks[i][0] > 0.5).astype(np.uint8) file_mask[file_mask >= 1] = 255 file_mask=Image.fromarray(file_mask) file_mask.save(os.path.join(path,str(file_index)+".png")) # compute loss pred = pred.view(pred.size(0), -1) target = target.view(target.size(0), -1) v_loss = criterion(pred, target) val_loss.update(v_loss.item(), 1) OtherVal.update(labels=target, preds=pred, n=1) vmr, vms, vmp, vmf, vmjc, vmd, vmacc = OtherVal.get_avg # mvmr, mvms, mvmp, mvmf, mvmjc, mvmd, mvmacc = valuev2 logger.info("Val_Loss:{:.5f} Acc:{:.5f} Dice:{:.5f} Jc:{:.5f}".format(val_loss.avg, vmacc, vmd, vmjc))
def val(args, model, criterion, val_loader, epoch, logger): OtherVal = BinaryIndicatorsMetric() tloss_r = AverageMeter() model.eval() with torch.no_grad(): for step, (input, target, _) in enumerate(val_loader): # n,1,h,w n,h,w input = input.to(args.device) target = target.to(args.device) preds = model(input) if args.deepsupervision: assert isinstance(preds, list) or isinstance(preds, tuple) tloss = 0 for index in range(len(preds)): subtloss = criterion( preds[index].view(preds[index].size(0), -1), target.view(target.size(0), -1)) tloss += subtloss tloss = tloss * preds[-1].size(0) tloss /= len(preds) else: tloss = criterion(preds[-1].view(preds[-1].size(0), -1), target.view(target.size(0), -1)) tloss = tloss * preds.size(0) tloss_r.update(tloss.item(), 1) OtherVal.update(labels=target.view(target.size(0), -1), preds=preds[-1].view(preds[-1].size(0), -1)) logger.info("Epoch:{} Val T-loss:{:.5f} ".format(epoch, tloss_r.avg)) mr, ms, mp, mf, mjc, md, macc = OtherVal.get_avg logger.info(" Val dice:{:.3f} iou:{:.3f} acc:{:.3f} ".format( md, mjc, macc)) return tloss_r.avg, md
def train(args, model_alphas, model, criterion, train_loader, optimizer): Train_recoder = BinaryIndicatorsMetric() Deep_Train_recoder = BinaryIndicatorsMetric() loss_recoder = AverageMeter() model.train() for step, (input, target, _) in tqdm(enumerate(train_loader)): input = input.to(args.device) target = target.to(args.device) # input is B C H W target is B,1,H,W preds: B,1,H,W optimizer.zero_grad() # [output1,...] if model_alphas is not None: preds_list = model(model_alphas, input) else: preds_list = model(input) preds_list = [pred.view(pred.size(0), -1) for pred in preds_list] target = target.view(target.size(0), -1) if args.deepsupervision: for i in range(len(preds_list)): if i == 0: w_loss = criterion(preds_list[i], target) w_loss += criterion(preds_list[i], target) else: w_loss = criterion(preds_list[-1], target) w_loss.backward() optimizer.step() loss_recoder.update(w_loss.item(), 1) # get all the indicators #Deep_Train_recoder if args.deepsupervision: avg_pred = 0 for pred in preds_list: avg_pred += pred avg_pred /= len(preds_list) Deep_Train_recoder.update(labels=target, preds=avg_pred, n=1) Train_recoder.update(labels=target, preds=preds_list[-1], n=1) if args.deepsupervision: return loss_recoder.avg, Train_recoder.get_avg, Deep_Train_recoder.get_avg else: return loss_recoder.avg, Train_recoder.get_avg
def infer(args, model, criterion, val_loader, logger, path): OtherVal_v1 = RefugeIndicatorsMetricBinary() OtherVal_v2 = BinaryIndicatorsMetric() tloss_r = AverageMeter() model.eval() with torch.no_grad(): for step, (input, target, name) in tqdm(enumerate(val_loader)): # input: n,1,h,w n,1,h,w name:n list input = input.to(args.device) target = target.to(args.device) pred = model(input) if args.deepsupervision: pred = pred[-1].clone() # save the mask file_masks = pred.clone() file_masks = torch.sigmoid(file_masks).data.cpu().numpy() n, c, h, w = file_masks.shape assert n == len(file_masks) for i in range(len(file_masks)): subdir = name[0][i] file_index = name[1][i] if not os.path.exists(os.path.join(path, subdir)): os.mkdir(os.path.join(path, subdir)) file_mask = (file_masks[i][0] > 0.5).astype(np.uint8) file_mask[file_mask >= 1] = 255 file_mask = Image.fromarray(file_mask) file_mask.save(os.path.join(path, subdir, file_index + ".png")) # batchsize=8 OtherVal_v1.update(pred, target) #preds_list = [pred.view(pred.size(0), -1) for pred in preds_list] target = target.view(target.size(0), -1) v_loss = criterion(pred.view(pred.size(0), -1), target) v_loss = v_loss * n tloss_r.update(v_loss.item(), 1) OtherVal_v2.update(labels=target, preds=pred.view(pred.size(0), -1)) per1 = OtherVal_v1.avg() mr, ms, mp, mf, mjc, md, macc = OtherVal_v2.get_avg logger.info(" V1 dice:{:.3f} iou:{:.3f} acc:{:.3f} ".format( per1[0], per1[1], per1[2])) logger.info(" V2 dice:{:.3f} iou:{:.3f} acc:{:.3f} ".format( md, mjc, macc)) return tloss_r.avg, per1
def infer(args, model, criterion, val_loader,logger): OtherVal = BinaryIndicatorsMetric() val_loss = AverageMeter() model.eval() with torch.no_grad(): for step, (input, target,name) in tqdm(enumerate(val_loader)): input = input.to(args.device) target = target.to(args.device) pred = model(input) if args.deepsupervision: pred = pred[-1].clone() pred = pred.view(pred.size(0), -1) target = target.view(target.size(0), -1) v_loss = criterion(pred, target) val_loss.update(v_loss.item(), 1) OtherVal.update(labels=target, preds=pred, n=1) vmr, vms, vmp, vmf, vmjc, vmd, vmacc = OtherVal.get_avg # mvmr, mvms, mvmp, mvmf, mvmjc, mvmd, mvmacc = valuev2 logger.info("Val_Loss:{:.5f} Acc:{:.5f} Dice:{:.5f} Jc:{:.5f}".format(val_loss.avg, vmacc, vmd, vmjc)) return vmr, vms, vmp, vmf, vmjc, vmd, vmacc,val_loss.avg
def infer(args, model, criterion, val_loader, logger, path): OtherVal = BinaryIndicatorsMetric() val_loss = AverageMeter() model.eval() with torch.no_grad(): for step, (input, target, name) in tqdm(enumerate(val_loader)): input = input.to(args.device) target = target.to(args.device) preds_list = model(input) # save images n,c,h,w file_masks = preds_list[-1].clone() file_masks = torch.sigmoid(file_masks).data.cpu().numpy() n, c, h, w = file_masks.shape assert n == len(file_masks) for i in range(len(file_masks)): file_name = 'ISIC_' + name[i] + '_segmentation.png' file_mask = (file_masks[i][0] > 0.5).astype(np.uint8) file_mask[file_mask >= 1] = 255 file_mask = Image.fromarray(file_mask) file_mask.save(os.path.join(path, file_name)) preds_list = [pred.view(pred.size(0), -1) for pred in preds_list] target = target.view(target.size(0), -1) v_loss = 0 if args.deepsupervision: for pred in preds_list: subloss = criterion(pred, target) v_loss += subloss else: v_loss = criterion(preds_list[-1], target) val_loss.update(v_loss.item(), 1) OtherVal.update(labels=target, preds=preds_list[-1], n=1) if step > 1: break vmr, vms, vmp, vmf, vmjc, vmd, vmacc = OtherVal.get_avg # mvmr, mvms, mvmp, mvmf, mvmjc, mvmd, mvmacc = valuev2 logger.info("Val_Loss:{:.5f} Acc:{:.5f} Dice:{:.5f} Jc:{:.5f}".format( val_loss.avg, vmacc, vmd, vmjc))
def infer(args, model, val_queue, criterion): OtherVal = BinaryIndicatorsMetric() val_loss = AverageMeter() model.eval() with torch.no_grad(): for step, (input, target, _) in tqdm(enumerate(val_queue)): input = input.to(args.device) target = target.to(args.device) preds = model(input) preds = [pred.view(pred.size(0), -1) for pred in preds] target = target.view(target.size(0), -1) if args.deepsupervision: for i in range(len(preds)): if i == 0: loss = criterion(preds[i], target) loss += criterion(preds[i], target) else: loss = criterion(preds[-1], target) val_loss.update(loss.item(), 1) OtherVal.update(labels=target, preds=preds[-1], n=1) return val_loss.avg, OtherVal.get_avg
def train(args, train_queue, val_queue, model, criterion, optimizer_weight, optimizer_arch, train_arch): Train_recoder = BinaryIndicatorsMetric() w_loss_recoder = AverageMeter() a_loss_recoder = AverageMeter() model.train() for step, (input, target, _) in tqdm(enumerate(train_queue)): input = input.to(args.device) target = target.to(args.device) # input is B C H W target is B,1,H,W preds: B,1,H,W optimizer_weight.zero_grad() preds = model(input) assert isinstance(preds, list) preds = [pred.view(pred.size(0), -1) for pred in preds] target = target.view(target.size(0), -1) torch.cuda.empty_cache() if args.deepsupervision: for i in range(len(preds)): if i == 0: w_loss = criterion(preds[i], target) w_loss += criterion(preds[i], target) else: w_loss = criterion(preds[-1], target) w_loss.backward() if args.grad_clip: nn.utils.clip_grad_norm_(model.weight_parameters(), args.grad_clip) optimizer_weight.step() w_loss_recoder.update(w_loss.item(), 1) # get all the indicators Train_recoder.update(labels=target, preds=preds[-1], n=1) # update network arch parameters if train_arch: # In the original implementation of DARTS, it is input_search, target_search = next(iter(valid_queue), which slows down # the training when using PyTorch 0.4 and above. try: input_search, target_search, _ = next(valid_queue_iter) except: valid_queue_iter = iter(val_queue) input_search, target_search, _ = next(valid_queue_iter) input_search = input_search.to(args.device) target_search = target_search.to(args.device) optimizer_arch.zero_grad() archs_preds = model(input_search) archs_preds = [pred.view(pred.size(0), -1) for pred in archs_preds] target_search = target_search.view(target_search.size(0), -1) torch.cuda.empty_cache() if args.deepsupervision: for i in range(len(archs_preds)): if i == 0: a_loss = criterion(archs_preds[i], target_search) a_loss += criterion(archs_preds[i], target_search) else: a_loss = criterion(archs_preds[-1], target_search) a_loss.backward() if args.grad_clip: nn.utils.clip_grad_norm_(model.arch_parameters(), args.grad_clip) optimizer_arch.step() a_loss_recoder.update(a_loss.item(), 1) weight_loss_avg = w_loss_recoder.avg if train_arch: arch_loss_avg = a_loss_recoder.avg else: arch_loss_avg = 0 mr, ms, mp, mf, mjc, md, macc = Train_recoder.get_avg return weight_loss_avg, arch_loss_avg, mr, ms, mp, mf, mjc, md, macc
def train(args, train_queue, val_queue, model, criterion, optimizer_weight, optimizer_arch, train_arch): Train_recoder = BinaryIndicatorsMetric() w_loss_recoder = AverageMeter() a_loss_recoder = AverageMeter() model.train() #import pdb;pdb.set_trace() for step, (input, target, _) in tqdm(enumerate(train_queue)): input = input.to(args.device) target = target.to(args.device) # alpha=1 open mixup alpha=0 close mixup mixup_images, target, perm_target, lam = mixup_data( input, target, alpha=args.alpha, use_cuda=args.use_cuda) # input is B C H W target is B,1,H,W preds: B,1,H,W optimizer_weight.zero_grad() preds = model(mixup_images) assert isinstance(preds, list) preds = [pred.view(pred.size(0), -1) for pred in preds] target = target.view(target.size(0), -1) perm_target = perm_target.view(perm_target.size(0), -1) torch.cuda.empty_cache() if args.deepsupervision: for i in range(len(preds)): if i == 0: target1_loss = criterion(preds[i], target) target1_loss += criterion(preds[i], target) else: target1_loss = criterion(preds[-1], target) if args.deepsupervision: for i in range(len(preds)): if i == 0: target2_loss = criterion(preds[i], perm_target) target2_loss += criterion(preds[i], perm_target) else: target2_loss = criterion(preds[-1], perm_target) w_loss = lam * target1_loss + (1 - lam) * target2_loss w_loss.backward() if args.grad_clip: nn.utils.clip_grad_norm_(model.weight_parameters(), args.grad_clip) optimizer_weight.step() w_loss_recoder.update(w_loss.item(), 1) # get all the indicators if (step + 1) % args.compute_freq == 0: Train_recoder.update(labels=target, preds=preds[-1], n=1) # update network arch parameters if train_arch: # In the original implementation of DARTS, it is input_search, target_search = next(iter(valid_queue), which slows down # the training when using PyTorch 0.4 and above. try: input_search, target_search, _ = next(valid_queue_iter) except: valid_queue_iter = iter(val_queue) input_search, target_search, _ = next(valid_queue_iter) input_search = input_search.to(args.device) target_search = target_search.to(args.device) mixup_input_search, target_search, perm_target_search, alam = mixup_data( input_search, target_search, alpha=args.alpha, use_cuda=args.use_cuda) optimizer_arch.zero_grad() archs_preds = model(mixup_input_search) archs_preds = [pred.view(pred.size(0), -1) for pred in archs_preds] target_search = target_search.view(target_search.size(0), -1) perm_target_search = perm_target_search.view( perm_target_search.size(0), -1) torch.cuda.empty_cache() if args.deepsupervision: for i in range(len(archs_preds)): if i == 0: a_loss1 = criterion(archs_preds[i], target_search) a_loss1 += criterion(archs_preds[i], target_search) else: a_loss1 = criterion(archs_preds[-1], target_search) if args.deepsupervision: for i in range(len(archs_preds)): if i == 0: a_loss2 = criterion(archs_preds[i], perm_target_search) a_loss2 += criterion(archs_preds[i], perm_target_search) else: a_loss2 = criterion(archs_preds[-1], perm_target_search) a_loss = alam * a_loss1 + (1 - alam) * a_loss2 a_loss.backward() if args.grad_clip: nn.utils.clip_grad_norm_(model.arch_parameters(), args.grad_clip) optimizer_arch.step() a_loss_recoder.update(a_loss.item(), 1) weight_loss_avg = w_loss_recoder.avg if train_arch: arch_loss_avg = a_loss_recoder.avg else: arch_loss_avg = 0 mr, ms, mp, mf, mjc, md, macc = Train_recoder.get_avg return weight_loss_avg, arch_loss_avg, mr, ms, mp, mf, mjc, md, macc
def infer(args, model, criterion, val_loader, logger, path): OtherVal8 = BinaryIndicatorsMetric() OtherVal6 = BinaryIndicatorsMetric() OtherVal4 = BinaryIndicatorsMetric() val_loss = AverageMeter() model.eval() with torch.no_grad(): for step, (input, target, name) in tqdm(enumerate(val_loader)): # input: n,1,h,w n,1,h,w name:n list input = input.to(args.device) target = target.to(args.device) preds_list = model(input) # save the mask file_masks = preds_list[-1].clone() file_masks = torch.sigmoid(file_masks).data.cpu().numpy() n, c, h, w = file_masks.shape assert n == len(file_masks) for i in range(len(file_masks)): file_index = int(name[i].split('.')[0]) file_mask = (file_masks[i][0] > 0.5).astype(np.uint8) file_mask[file_mask >= 1] = 255 file_mask = Image.fromarray(file_mask) file_mask.save(os.path.join(path, str(file_index) + "_8.png")) file_masks = preds_list[-2].clone() file_masks = torch.sigmoid(file_masks).data.cpu().numpy() n, c, h, w = file_masks.shape assert n == len(file_masks) for i in range(len(file_masks)): file_index = int(name[i].split('.')[0]) file_mask = (file_masks[i][0] > 0.5).astype(np.uint8) file_mask[file_mask >= 1] = 255 file_mask = Image.fromarray(file_mask) file_mask.save(os.path.join(path, str(file_index) + "_6.png")) file_masks = preds_list[-3].clone() file_masks = torch.sigmoid(file_masks).data.cpu().numpy() n, c, h, w = file_masks.shape assert n == len(file_masks) for i in range(len(file_masks)): file_index = int(name[i].split('.')[0]) file_mask = (file_masks[i][0] > 0.5).astype(np.uint8) file_mask[file_mask >= 1] = 255 file_mask = Image.fromarray(file_mask) file_mask.save(os.path.join(path, str(file_index) + "_4.png")) # batchsize=8 preds_list = [pred.view(pred.size(0), -1) for pred in preds_list] target = target.view(target.size(0), -1) v_loss = criterion(preds_list[-1], target) val_loss.update(v_loss.item(), 1) OtherVal8.update(labels=target, preds=preds_list[-1], n=1) OtherVal6.update(labels=target, preds=preds_list[-2], n=1) OtherVal4.update(labels=target, preds=preds_list[-3], n=1) vmr, vms, vmp, vmf, vmjc, vmd, vmacc = OtherVal8.get_avg logger.info( "8:Val_Loss:{:.5f} Acc:{:.5f} Dice:{:.5f} Jc:{:.5f}".format( val_loss.avg, vmacc, vmd, vmjc)) vmr, vms, vmp, vmf, vmjc, vmd, vmacc = OtherVal6.get_avg logger.info( "6:Val_Loss:{:.5f} Acc:{:.5f} Dice:{:.5f} Jc:{:.5f}".format( val_loss.avg, vmacc, vmd, vmjc)) vmr, vms, vmp, vmf, vmjc, vmd, vmacc = OtherVal4.get_avg logger.info( "4:Val_Loss:{:.5f} Acc:{:.5f} Dice:{:.5f} Jc:{:.5f}".format( val_loss.avg, vmacc, vmd, vmjc))
def inference_isic(model1, model2, img_dir, mask_dir): OtherVal = BinaryIndicatorsMetric() model1.eval() model2.eval() filenames = os.listdir(img_dir) data_list = [] gt_list = [] img_ids = [] filenames = sorted(filenames, key=lambda x: int(x.split('_')[-1][:-len('.jpg')])) for filename in filenames: ext = os.path.splitext(filename)[-1] if ext == '.jpg': filename = filename.split('_')[-1][:-len('.jpg')] img_ids.append(filename) data_list.append('ISIC_' + filename + '.jpg') gt_list.append('ISIC_' + filename + '_segmentation.png') assert (len(data_list) == len(gt_list)) data_list = [os.path.join(img_dir, i) for i in data_list] gt_list = [os.path.join(mask_dir, i) for i in gt_list] hard_filenames = [ '18', '24', '26', '31', '42', '49' '56', '62', '73', '81', '91', '97', '113', '153', '184', '16071', '246', '288', '311', '319', '324', '358', '387', '393', '395', '499', '504', '520', '529', 531, 547, 549, 1140, 1148, 1152, 1184, 1442, 2829, 3346, 4115, 5555, 6612, 6914, 7557, 8913, 9873, 9875, 9934, 10093, 11107, 11110, 11168, 11349, 12090, 12136, 12149, 12167, 12187, 12212, 12216, 12290, 12329, 12512, 12516, 12713, 12773, 12876, 12999, 13000, 13010, 13063, 13120, 13164, 13227, 13242, 13393, 13493, 13516, 13518, 13549, 13709, 13813, 13832, 13988, 14132, 14189, 14221, 14639, 14693, 14912, 15102, 15176, 15237, 15330, 155417, 15443, 16068 ] better_filenames = [ '16', '63', '75', '101', '105', '131', '148', '164', '184', '198', '252', '276', '330', '397', '433', '458', '476', '480', 1119, 1212, 1262, 1306, 1374, 3346, 6671, 9504, 9895, 9992, 10041, 10044, 10175, 10183, 10213, 10382, 10452, 10456, 11079, 11130, 11159, 12318, 12495, 12897, 12961, 13146, 13340, 13371, 13411, 13807, 13910, 13918, 14090, 14693, 14697, 14850, 14898, 14904, 15062, 15166, 15207, 15483, 15563, ] easy_filenames = [ '34', '39', '52', '57', '117', '164', '165', '182', '207', '213', '222', '225', '232' ] dataset_wrong_case = [ 9800, 9934, 9951, 10021, 10361, 10584, 11227, 13310, 13600, 13673, 13680, 15132, 15152, 15251, 16036, ] all_bad_case = [ 10320, 10361, 10445, 10457, 10477, 11081, 11084, 11121, 12369, 12484, 12726, 12740, 12768, 12786, 12789, 12876, 12877, 13120, 13310, 13393, 13552, 13832, 13975, 14222, 14328, 14372, 14385, 14434, 14454, 14480, 14503, 14506, 14580, 14628, 14786, 14931, 14932, 14963, 14982, 14985, 15020, 15021, 15062, 15309, 15537, 15947, 15966, 15969, 15983, 156008, 16034, 16037, 16058, 16068, ] for i in range(len(data_list)): file_name = img_ids[i] print("Filename:{}".format(file_name)) img, flipimg, original_img, mask = isic_transform( data_list[i], gt_list[i]) output = model1(img) #flip_output=model1(flipimg) #output=torch.sigmoid(output).data.cpu().numpy()[0,0,:,:] output2 = model2(img) nas_output = output[-1].clone() nas_output = nas_output.view(nas_output.size(0), -1) target = torch.from_numpy(np.asarray(mask)) target = target.unsqueeze(0).unsqueeze(0) target = target.view(target.size(0), -1) OtherVal.update(labels=target, preds=nas_output, n=1) # OtherVal.update(labels=target, preds=outputs_original[0].view(outputs_original[0].size(0), -1), n=1) # # output2=torch.sigmoid(output2[-1]).data.cpu().numpy()[0,0,:,:] # # flip_output=torch.sigmoid(flip_output).data.cpu().numpy()[0,0,:,:] # # flip_output=np.fliplr(flip_output) # # 可视化 # oimage=np.asarray(original_img).astype(np.uint8) # mask=np.asarray(mask).astype(np.uint8) # output=(output>0.5).astype(np.uint8) # output2=(output2>0.5).astype(np.uint8) # # flip_output=(flip_output>0.5).astype(np.uint8) # mask[mask>=1]=255 # output[output>=1]=255 # output2[output2>=1]=255 # # flip_output[flip_output>=1]=255 # # #rgb # #img[..., 2] = np.where(mask == 1, 255, img[..., 2]) # contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # cv2.drawContours(oimage, contours, -1, (0, 0, 255), 1,lineType=cv2.LINE_AA) # # output_contours, _ = cv2.findContours(output, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # cv2.drawContours(oimage, output_contours, -1, (0, 255,0), 1,lineType=cv2.LINE_AA) # # output_contours, _ = cv2.findContours(output2, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # cv2.drawContours(oimage, output_contours, -1, (255, 0,0), 1,lineType=cv2.LINE_AA) # # show_images(oimage.copy(),mask.copy(),output.copy(),output2.copy(),file_name) # cv2.imwrite(os.path.join(image_save_dir,'{}.png'.format(filename)),oimage) # flip_output_contours, _ = cv2.findContours(flip_output, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # cv2.drawContours(oimage, flip_output_contours, -1, (255, 0,0), 1,lineType=cv2.LINE_AA) # oimage=oimage[:,:,[2,1,0]] # cv2.imread(os.path.join()) # cv2.imshow('original_mask',oimage) # cv2.imshow('mask',mask) # cv2.imshow('output',output) # cv2.imshow('output2',output2) # # cv2.imshow('flip_output',flip_output) # cv2.waitKey() # # [75,101] # # 容易 [78,107,129] value = OtherVal.get_avg mr, ms, mp, mf, mjc, md, macc = value print("Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".format(macc, md, mjc))
def inference_isic(models_list, img_dir, mask_dir): OtherVal = BinaryIndicatorsMetric() for model in models_list: model.eval() filenames = os.listdir(img_dir) data_list = [] gt_list = [] img_ids = [] for filename in filenames: data_list.append(filename) gt_list.append(filename) img_ids.append(filename) assert os.path.splitext(filename)[-1] == '.tif' assert (len(data_list) == len(gt_list)) data_list = [os.path.join(img_dir, i) for i in data_list] gt_list = [os.path.join(mask_dir, i) for i in gt_list] hard_filenames = [] better_filenames = [] easy_filenames = [] dataset_wrong_case = [] all_bad_case = [] model_name_list = [ 'unet', 'unet++', 'multires_unet', 'attention_unet_v1', 'nas_search' ] for i in range(len(data_list)): file_name = img_ids[i].split('.')[0] print("Filename:{}".format(file_name)) img, original_img, mask = isic_transform(data_list[i], gt_list[i]) outputs_original = [model(img) for model in models_list] nas_output = outputs_original[-1][-1].clone() nas_output = nas_output.view(nas_output.size(0), -1) target = torch.from_numpy(np.asarray(mask)) target = target.unsqueeze(0).unsqueeze(0) target = target.view(target.size(0), -1) # OtherVal.update(labels=target, preds=nas_output, n=1) OtherVal.update(labels=target, preds=outputs_original[0].view( outputs_original[0].size(0), -1), n=1) # outputs=[] # for index,output in enumerate(outputs_original): # if isinstance(output,list): # print("Index:{} is nas search mmodel !".format(index)) # outputs.append(torch.sigmoid(output[-1]).data.cpu().numpy()[0,0,:,:]) # else: # outputs.append(torch.sigmoid(output).data.cpu().numpy()[0,0,:,:]) # outputs=[(output>0.5).astype(np.uint8) for output in outputs] # for output in outputs: # output[output>=1]=255 # #flip_output=model1(flipimg) # # flip_output=torch.sigmoid(flip_output).data.cpu().numpy()[0,0,:,:] # # flip_output=np.fliplr(flip_output) # # 可视化 # oimage=np.asarray(original_img).astype(np.uint8) # mask=np.asarray(mask).astype(np.uint8) # # flip_output=(flip_output>0.5).astype(np.uint8) # mask[mask>=1]=255 # # # flip_output[flip_output>=1]=255 # #img[..., 2] = np.where(mask == 1, 255, img[..., 2]) # unet=outputs[0] # # unetpp=outputs[1] # # multires_unet=outputs[2] # # attention_unet_v1=outputs[3] # nas_search_output=outputs[-1] # # # rgb # contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # cv2.drawContours(oimage, contours, -1, (0, 0,255), 1,lineType=cv2.LINE_AA) # output_contours, _ = cv2.findContours(unet, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # cv2.drawContours(oimage, output_contours, -1, (0, 255,0), 1,lineType=cv2.LINE_AA) # # # output_contours, _ = cv2.findContours(unetpp, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # # cv2.drawContours(oimage, output_contours, -1, (0, 0,255), 1,lineType=cv2.LINE_AA) # # # output_contours, _ = cv2.findContours(multires_unet, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # # cv2.drawContours(oimage, output_contours, -1, (0, 255,255), 1,lineType=cv2.LINE_AA) # # # # # # output_contours, _ = cv2.findContours(attention_unet_v1, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # # cv2.drawContours(oimage, output_contours, -1, (255, 0,255), 1,lineType=cv2.LINE_AA) # # # output_contours, _ = cv2.findContours(nas_search_output, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # cv2.drawContours(oimage, output_contours, -1, (255, 0,0), 1,lineType=cv2.LINE_AA) # # # #show_images(oimage.copy(),mask.copy(),unet.copy(),unetpp.copy(),multires_unet.copy(),attention_unet_v1.copy(),file_name) # show_images(oimage.copy(), mask.copy(), unet.copy(),nas_search_output.copy(), file_name) value = OtherVal.get_avg mr, ms, mp, mf, mjc, md, macc = value print("Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".format(macc, md, mjc))