コード例 #1
0
def infer(args, model, criterion, val_loader,logger,path):
    OtherVal = BinaryIndicatorsMetric()
    val_loss = AverageMeter()
    model.eval()
    with torch.no_grad():
        for step, (input, target,name) in tqdm(enumerate(val_loader)):
            input = input.to(args.device)
            target = target.to(args.device)
            pred = model(input)
            if args.deepsupervision:
                pred=pred[-1].clone()
            # sabe predit mask
            # save the mask
            file_masks=pred.clone()
            file_masks=torch.sigmoid(file_masks).data.cpu().numpy()
            n,c,h,w=file_masks.shape
            assert n==len(file_masks)
            for i in range(len(file_masks)):
                file_index=int(name[i].split('.')[0])
                file_mask=(file_masks[i][0] > 0.5).astype(np.uint8)
                file_mask[file_mask >= 1] = 255
                file_mask=Image.fromarray(file_mask)
                file_mask.save(os.path.join(path,str(file_index)+".png"))

            # compute loss
            pred = pred.view(pred.size(0), -1)
            target = target.view(target.size(0), -1)
            v_loss = criterion(pred, target)
            val_loss.update(v_loss.item(), 1)
            OtherVal.update(labels=target, preds=pred, n=1)
        vmr, vms, vmp, vmf, vmjc, vmd, vmacc = OtherVal.get_avg
        # mvmr, mvms, mvmp, mvmf, mvmjc, mvmd, mvmacc = valuev2
        logger.info("Val_Loss:{:.5f} Acc:{:.5f} Dice:{:.5f} Jc:{:.5f}".format(val_loss.avg, vmacc, vmd, vmjc))
コード例 #2
0
ファイル: retrain_chao.py プロジェクト: lswzjuer/NAS-WDAN
def val(args, model, criterion, val_loader, epoch, logger):
    OtherVal = BinaryIndicatorsMetric()
    tloss_r = AverageMeter()
    model.eval()
    with torch.no_grad():
        for step, (input, target, _) in enumerate(val_loader):
            # n,1,h,w  n,h,w
            input = input.to(args.device)
            target = target.to(args.device)
            preds = model(input)
            if args.deepsupervision:
                assert isinstance(preds, list) or isinstance(preds, tuple)
                tloss = 0
                for index in range(len(preds)):
                    subtloss = criterion(
                        preds[index].view(preds[index].size(0), -1),
                        target.view(target.size(0), -1))
                    tloss += subtloss
                tloss = tloss * preds[-1].size(0)
                tloss /= len(preds)
            else:
                tloss = criterion(preds[-1].view(preds[-1].size(0), -1),
                                  target.view(target.size(0), -1))
                tloss = tloss * preds.size(0)
            tloss_r.update(tloss.item(), 1)
            OtherVal.update(labels=target.view(target.size(0), -1),
                            preds=preds[-1].view(preds[-1].size(0), -1))

    logger.info("Epoch:{} Val   T-loss:{:.5f} ".format(epoch, tloss_r.avg))
    mr, ms, mp, mf, mjc, md, macc = OtherVal.get_avg
    logger.info(" Val     dice:{:.3f} iou:{:.3f} acc:{:.3f}  ".format(
        md, mjc, macc))
    return tloss_r.avg, md
コード例 #3
0
def infer(args, model, criterion, val_loader):
    OtherVal = BinaryIndicatorsMetric()
    DeepOtherVal = BinaryIndicatorsMetric()
    val_loss = AverageMeter()
    model.eval()
    with torch.no_grad():
        for step, (input, target,_) in tqdm(enumerate(val_loader)):
            input = input.to(args.device)
            target = target.to(args.device)
            preds_list = model(input)
            preds_list = [pred.view(pred.size(0), -1) for pred in preds_list]
            target = target.view(target.size(0), -1)
            if args.deepsupervision:
                for i in range(len(preds_list)):
                    if i == 0:
                        v_loss = criterion(preds_list[i], target)
                    v_loss += criterion(preds_list[i], target)
            else:
                v_loss = criterion(preds_list[-1], target)
            val_loss.update(v_loss.item(), 1)
            # get the deepsupervision
            if args.deepsupervision:
                avg_pred = 0
                for pred in preds_list:
                    avg_pred += pred
                avg_pred /= len(preds_list)
                DeepOtherVal.update(labels=target, preds=avg_pred, n=1)
            OtherVal.update(labels=target, preds=preds_list[-1], n=1)
        if args.deepsupervision:
            return val_loss.avg, OtherVal.get_avg, DeepOtherVal.get_avg
        else:
            return val_loss.avg, OtherVal.get_avg
コード例 #4
0
def infer(args, model, criterion, val_loader, logger, path):
    OtherVal_v1 = RefugeIndicatorsMetricBinary()
    OtherVal_v2 = BinaryIndicatorsMetric()
    tloss_r = AverageMeter()
    model.eval()

    with torch.no_grad():
        for step, (input, target, name) in tqdm(enumerate(val_loader)):
            # input: n,1,h,w   n,1,h,w   name:n list
            input = input.to(args.device)
            target = target.to(args.device)
            pred = model(input)
            if args.deepsupervision:
                pred = pred[-1].clone()

            # save the mask
            file_masks = pred.clone()
            file_masks = torch.sigmoid(file_masks).data.cpu().numpy()
            n, c, h, w = file_masks.shape
            assert n == len(file_masks)
            for i in range(len(file_masks)):
                subdir = name[0][i]
                file_index = name[1][i]
                if not os.path.exists(os.path.join(path, subdir)):
                    os.mkdir(os.path.join(path, subdir))
                file_mask = (file_masks[i][0] > 0.5).astype(np.uint8)
                file_mask[file_mask >= 1] = 255
                file_mask = Image.fromarray(file_mask)
                file_mask.save(os.path.join(path, subdir, file_index + ".png"))

            # batchsize=8
            OtherVal_v1.update(pred, target)
            #preds_list = [pred.view(pred.size(0), -1) for pred in preds_list]
            target = target.view(target.size(0), -1)
            v_loss = criterion(pred.view(pred.size(0), -1), target)
            v_loss = v_loss * n
            tloss_r.update(v_loss.item(), 1)
            OtherVal_v2.update(labels=target,
                               preds=pred.view(pred.size(0), -1))

        per1 = OtherVal_v1.avg()
        mr, ms, mp, mf, mjc, md, macc = OtherVal_v2.get_avg
        logger.info(" V1     dice:{:.3f} iou:{:.3f} acc:{:.3f}  ".format(
            per1[0], per1[1], per1[2]))
        logger.info(" V2     dice:{:.3f} iou:{:.3f} acc:{:.3f}  ".format(
            md, mjc, macc))
        return tloss_r.avg, per1
コード例 #5
0
ファイル: l1filter_pruning.py プロジェクト: lswzjuer/NAS-WDAN
def infer(args, model, criterion, val_loader,logger):
    OtherVal = BinaryIndicatorsMetric()
    val_loss = AverageMeter()
    model.eval()
    with torch.no_grad():
        for step, (input, target,name) in tqdm(enumerate(val_loader)):
            input = input.to(args.device)
            target = target.to(args.device)
            pred = model(input)
            if args.deepsupervision:
                pred = pred[-1].clone()
            pred = pred.view(pred.size(0), -1)
            target = target.view(target.size(0), -1)
            v_loss = criterion(pred, target)
            val_loss.update(v_loss.item(), 1)
            OtherVal.update(labels=target, preds=pred, n=1)
        vmr, vms, vmp, vmf, vmjc, vmd, vmacc = OtherVal.get_avg
        # mvmr, mvms, mvmp, mvmf, mvmjc, mvmd, mvmacc = valuev2
        logger.info("Val_Loss:{:.5f} Acc:{:.5f} Dice:{:.5f} Jc:{:.5f}".format(val_loss.avg, vmacc, vmd, vmjc))
        return vmr, vms, vmp, vmf, vmjc, vmd, vmacc,val_loss.avg
コード例 #6
0
def infer(args, model, criterion, val_loader, logger, path):
    OtherVal = BinaryIndicatorsMetric()
    val_loss = AverageMeter()
    model.eval()
    with torch.no_grad():
        for step, (input, target, name) in tqdm(enumerate(val_loader)):
            input = input.to(args.device)
            target = target.to(args.device)
            preds_list = model(input)

            # save images  n,c,h,w
            file_masks = preds_list[-1].clone()
            file_masks = torch.sigmoid(file_masks).data.cpu().numpy()
            n, c, h, w = file_masks.shape
            assert n == len(file_masks)
            for i in range(len(file_masks)):
                file_name = 'ISIC_' + name[i] + '_segmentation.png'
                file_mask = (file_masks[i][0] > 0.5).astype(np.uint8)
                file_mask[file_mask >= 1] = 255
                file_mask = Image.fromarray(file_mask)
                file_mask.save(os.path.join(path, file_name))

            preds_list = [pred.view(pred.size(0), -1) for pred in preds_list]
            target = target.view(target.size(0), -1)
            v_loss = 0
            if args.deepsupervision:
                for pred in preds_list:
                    subloss = criterion(pred, target)
                    v_loss += subloss
            else:
                v_loss = criterion(preds_list[-1], target)
            val_loss.update(v_loss.item(), 1)
            OtherVal.update(labels=target, preds=preds_list[-1], n=1)

            if step > 1:
                break

        vmr, vms, vmp, vmf, vmjc, vmd, vmacc = OtherVal.get_avg
        # mvmr, mvms, mvmp, mvmf, mvmjc, mvmd, mvmacc = valuev2
        logger.info("Val_Loss:{:.5f} Acc:{:.5f} Dice:{:.5f} Jc:{:.5f}".format(
            val_loss.avg, vmacc, vmd, vmjc))
コード例 #7
0
def infer(args, model, val_queue, criterion):
    OtherVal = BinaryIndicatorsMetric()
    val_loss = AverageMeter()
    model.eval()
    with torch.no_grad():
        for step, (input, target, _) in tqdm(enumerate(val_queue)):
            input = input.to(args.device)
            target = target.to(args.device)
            preds = model(input)
            preds = [pred.view(pred.size(0), -1) for pred in preds]
            target = target.view(target.size(0), -1)
            if args.deepsupervision:
                for i in range(len(preds)):
                    if i == 0:
                        loss = criterion(preds[i], target)
                    loss += criterion(preds[i], target)
            else:
                loss = criterion(preds[-1], target)
            val_loss.update(loss.item(), 1)
            OtherVal.update(labels=target, preds=preds[-1], n=1)
        return val_loss.avg, OtherVal.get_avg
コード例 #8
0
def val(args,model,criterion,val_loader,epoch,logger):
    OtherVal=BinaryIndicatorsMetric()
    val_loss=AverageMeter()
    model.eval()
    with torch.no_grad():
        for step, (input, target,_) in enumerate(val_loader):
            input = input.to(args.device)
            target = target.to(args.device)
            preds =model(input)
            if args.deepsupervision:
                assert isinstance(preds, list) or isinstance(preds, tuple)
                preds = [pred.view(pred.size(0), -1) for pred in preds]
                target = target.view(target.size(0), -1)
                for index in range(len(preds)):
                    if index == 0:
                        loss = criterion(preds[index], target)
                    loss += criterion(preds[index], target)
                loss /= len(preds)
            else:
                preds = preds.view(preds.size(0), -1)
                target = target.view(target.size(0), -1)
                loss = criterion(preds, target)
            val_loss.update(loss.item(),1)
            if args.deepsupervision:
                OtherVal.update(labels=target, preds=preds[-1], n=1)
            else:
                OtherVal.update(labels=target,preds=preds,n=1)

    # update best and recoder early stop
    mr, ms, mp, mf, mjc, md, macc = OtherVal.get_avg
    mean_loss=val_loss.avg
    logger.info("Epoch:{}  Val Loss:{:.3f} ".format(epoch,mean_loss))
    logger.info("Acc:{:.3f} Rec:{:.3f} Spe:{:.3f} Pre:{:.3f} F1:{:.3f} Jc:{:.3f} Dice:{:.3f}".format(macc,mr, ms, mp, mf, mjc, md))
    return mr, ms, mp, mf, mjc, md,macc,mean_loss
コード例 #9
0
ファイル: l1filter_pruning.py プロジェクト: lswzjuer/NAS-WDAN
def train(args, model, criterion, train_loader, optimizer, epoch, logger):
    OtherTrain = BinaryIndicatorsMetric()
    model.train()
    train_loss = AverageMeter()
    for step, (input, target, _) in enumerate(train_loader):
        input = input.to(args.device)
        target = target.to(args.device)
        # input is B C H W   target is B,1,H,W  preds: B,1,H,W
        preds = model(input)
        if args.deepsupervision:
            assert isinstance(preds, list) or isinstance(preds, tuple)
            preds = [pred.view(pred.size(0), -1) for pred in preds]
            target = target.view(target.size(0), -1)
            for index in range(len(preds)):
                if index == 0:
                    loss = criterion(preds[index], target)
                loss += criterion(preds[index], target)
            loss /= len(preds)
        else:
            preds = preds.view(preds.size(0), -1)
            target = target.view(target.size(0), -1)
            loss = criterion(preds, target)
        train_loss.update(loss.item(), 1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # get all the indicators
        if args.deepsupervision:
            OtherTrain.update(labels=target, preds=preds[-1], n=1)
        else:
            OtherTrain.update(labels=target, preds=preds, n=1)
    logger.info("Epoch:{} Train Loss:{:.3f}".format(epoch, train_loss.avg))
コード例 #10
0
ファイル: retrain_isic.py プロジェクト: lswzjuer/NAS-WDAN
def train(args, model_alphas, model, criterion, train_loader, optimizer):
    Train_recoder = BinaryIndicatorsMetric()
    Deep_Train_recoder = BinaryIndicatorsMetric()
    loss_recoder = AverageMeter()
    model.train()
    for step, (input, target, _) in tqdm(enumerate(train_loader)):
        input = input.to(args.device)
        target = target.to(args.device)
        # input is B C H W   target is B,1,H,W  preds: B,1,H,W
        optimizer.zero_grad()
        # [output1,...]
        if model_alphas is not None:
            preds_list = model(model_alphas, input)
        else:
            preds_list = model(input)
        preds_list = [pred.view(pred.size(0), -1) for pred in preds_list]
        target = target.view(target.size(0), -1)
        if args.deepsupervision:
            for i in range(len(preds_list)):
                if i == 0:
                    w_loss = criterion(preds_list[i], target)
                w_loss += criterion(preds_list[i], target)
        else:
            w_loss = criterion(preds_list[-1], target)
        w_loss.backward()
        optimizer.step()
        loss_recoder.update(w_loss.item(), 1)
        # get all the indicators
        #Deep_Train_recoder
        if args.deepsupervision:
            avg_pred = 0
            for pred in preds_list:
                avg_pred += pred
            avg_pred /= len(preds_list)
            Deep_Train_recoder.update(labels=target, preds=avg_pred, n=1)
        Train_recoder.update(labels=target, preds=preds_list[-1], n=1)
    if args.deepsupervision:
        return loss_recoder.avg, Train_recoder.get_avg, Deep_Train_recoder.get_avg
    else:
        return loss_recoder.avg, Train_recoder.get_avg
コード例 #11
0
def train(args, train_queue, val_queue, model, criterion, optimizer_weight,
          optimizer_arch, train_arch):
    Train_recoder = BinaryIndicatorsMetric()
    w_loss_recoder = AverageMeter()
    a_loss_recoder = AverageMeter()
    model.train()
    for step, (input, target, _) in tqdm(enumerate(train_queue)):
        input = input.to(args.device)
        target = target.to(args.device)
        # input is B C H W   target is B,1,H,W  preds: B,1,H,W
        optimizer_weight.zero_grad()
        preds = model(input)
        assert isinstance(preds, list)
        preds = [pred.view(pred.size(0), -1) for pred in preds]
        target = target.view(target.size(0), -1)
        torch.cuda.empty_cache()
        if args.deepsupervision:
            for i in range(len(preds)):
                if i == 0:
                    w_loss = criterion(preds[i], target)
                w_loss += criterion(preds[i], target)
        else:
            w_loss = criterion(preds[-1], target)
        w_loss.backward()
        if args.grad_clip:
            nn.utils.clip_grad_norm_(model.weight_parameters(), args.grad_clip)
        optimizer_weight.step()
        w_loss_recoder.update(w_loss.item(), 1)
        # get all the indicators
        Train_recoder.update(labels=target, preds=preds[-1], n=1)
        # update network arch parameters
        if train_arch:
            # In the original implementation of DARTS, it is input_search, target_search = next(iter(valid_queue), which slows down
            # the training when using PyTorch 0.4 and above.
            try:
                input_search, target_search, _ = next(valid_queue_iter)
            except:
                valid_queue_iter = iter(val_queue)
                input_search, target_search, _ = next(valid_queue_iter)
            input_search = input_search.to(args.device)
            target_search = target_search.to(args.device)
            optimizer_arch.zero_grad()
            archs_preds = model(input_search)
            archs_preds = [pred.view(pred.size(0), -1) for pred in archs_preds]
            target_search = target_search.view(target_search.size(0), -1)
            torch.cuda.empty_cache()
            if args.deepsupervision:
                for i in range(len(archs_preds)):
                    if i == 0:
                        a_loss = criterion(archs_preds[i], target_search)
                    a_loss += criterion(archs_preds[i], target_search)
            else:
                a_loss = criterion(archs_preds[-1], target_search)
            a_loss.backward()
            if args.grad_clip:
                nn.utils.clip_grad_norm_(model.arch_parameters(),
                                         args.grad_clip)
            optimizer_arch.step()
            a_loss_recoder.update(a_loss.item(), 1)

    weight_loss_avg = w_loss_recoder.avg
    if train_arch:
        arch_loss_avg = a_loss_recoder.avg
    else:
        arch_loss_avg = 0
    mr, ms, mp, mf, mjc, md, macc = Train_recoder.get_avg
    return weight_loss_avg, arch_loss_avg, mr, ms, mp, mf, mjc, md, macc
コード例 #12
0
def train(args, train_queue, val_queue, model, criterion, optimizer_weight,
          optimizer_arch, train_arch):
    Train_recoder = BinaryIndicatorsMetric()
    w_loss_recoder = AverageMeter()
    a_loss_recoder = AverageMeter()
    model.train()
    #import pdb;pdb.set_trace()
    for step, (input, target, _) in tqdm(enumerate(train_queue)):
        input = input.to(args.device)
        target = target.to(args.device)
        # alpha=1 open mixup   alpha=0 close mixup
        mixup_images, target, perm_target, lam = mixup_data(
            input, target, alpha=args.alpha, use_cuda=args.use_cuda)
        # input is B C H W   target is B,1,H,W  preds: B,1,H,W
        optimizer_weight.zero_grad()
        preds = model(mixup_images)
        assert isinstance(preds, list)
        preds = [pred.view(pred.size(0), -1) for pred in preds]
        target = target.view(target.size(0), -1)
        perm_target = perm_target.view(perm_target.size(0), -1)
        torch.cuda.empty_cache()

        if args.deepsupervision:
            for i in range(len(preds)):
                if i == 0:
                    target1_loss = criterion(preds[i], target)
                target1_loss += criterion(preds[i], target)
        else:
            target1_loss = criterion(preds[-1], target)
        if args.deepsupervision:
            for i in range(len(preds)):
                if i == 0:
                    target2_loss = criterion(preds[i], perm_target)
                target2_loss += criterion(preds[i], perm_target)
        else:
            target2_loss = criterion(preds[-1], perm_target)
        w_loss = lam * target1_loss + (1 - lam) * target2_loss
        w_loss.backward()
        if args.grad_clip:
            nn.utils.clip_grad_norm_(model.weight_parameters(), args.grad_clip)
        optimizer_weight.step()
        w_loss_recoder.update(w_loss.item(), 1)
        # get all the indicators
        if (step + 1) % args.compute_freq == 0:
            Train_recoder.update(labels=target, preds=preds[-1], n=1)

        # update network arch parameters
        if train_arch:
            # In the original implementation of DARTS, it is input_search, target_search = next(iter(valid_queue), which slows down
            # the training when using PyTorch 0.4 and above.
            try:
                input_search, target_search, _ = next(valid_queue_iter)
            except:
                valid_queue_iter = iter(val_queue)
                input_search, target_search, _ = next(valid_queue_iter)
            input_search = input_search.to(args.device)
            target_search = target_search.to(args.device)
            mixup_input_search, target_search, perm_target_search, alam = mixup_data(
                input_search,
                target_search,
                alpha=args.alpha,
                use_cuda=args.use_cuda)
            optimizer_arch.zero_grad()

            archs_preds = model(mixup_input_search)
            archs_preds = [pred.view(pred.size(0), -1) for pred in archs_preds]
            target_search = target_search.view(target_search.size(0), -1)
            perm_target_search = perm_target_search.view(
                perm_target_search.size(0), -1)
            torch.cuda.empty_cache()

            if args.deepsupervision:
                for i in range(len(archs_preds)):
                    if i == 0:
                        a_loss1 = criterion(archs_preds[i], target_search)
                    a_loss1 += criterion(archs_preds[i], target_search)
            else:
                a_loss1 = criterion(archs_preds[-1], target_search)
            if args.deepsupervision:
                for i in range(len(archs_preds)):
                    if i == 0:
                        a_loss2 = criterion(archs_preds[i], perm_target_search)
                    a_loss2 += criterion(archs_preds[i], perm_target_search)
            else:
                a_loss2 = criterion(archs_preds[-1], perm_target_search)
            a_loss = alam * a_loss1 + (1 - alam) * a_loss2
            a_loss.backward()
            if args.grad_clip:
                nn.utils.clip_grad_norm_(model.arch_parameters(),
                                         args.grad_clip)
            optimizer_arch.step()
            a_loss_recoder.update(a_loss.item(), 1)

    weight_loss_avg = w_loss_recoder.avg
    if train_arch:
        arch_loss_avg = a_loss_recoder.avg
    else:
        arch_loss_avg = 0
    mr, ms, mp, mf, mjc, md, macc = Train_recoder.get_avg
    return weight_loss_avg, arch_loss_avg, mr, ms, mp, mf, mjc, md, macc
コード例 #13
0
def infer(args, model, criterion, val_loader, logger, path):
    OtherVal8 = BinaryIndicatorsMetric()
    OtherVal6 = BinaryIndicatorsMetric()
    OtherVal4 = BinaryIndicatorsMetric()
    val_loss = AverageMeter()
    model.eval()
    with torch.no_grad():
        for step, (input, target, name) in tqdm(enumerate(val_loader)):
            # input: n,1,h,w   n,1,h,w   name:n list
            input = input.to(args.device)
            target = target.to(args.device)
            preds_list = model(input)

            # save the mask
            file_masks = preds_list[-1].clone()
            file_masks = torch.sigmoid(file_masks).data.cpu().numpy()
            n, c, h, w = file_masks.shape
            assert n == len(file_masks)
            for i in range(len(file_masks)):
                file_index = int(name[i].split('.')[0])
                file_mask = (file_masks[i][0] > 0.5).astype(np.uint8)
                file_mask[file_mask >= 1] = 255
                file_mask = Image.fromarray(file_mask)
                file_mask.save(os.path.join(path, str(file_index) + "_8.png"))

            file_masks = preds_list[-2].clone()
            file_masks = torch.sigmoid(file_masks).data.cpu().numpy()
            n, c, h, w = file_masks.shape
            assert n == len(file_masks)
            for i in range(len(file_masks)):
                file_index = int(name[i].split('.')[0])
                file_mask = (file_masks[i][0] > 0.5).astype(np.uint8)
                file_mask[file_mask >= 1] = 255
                file_mask = Image.fromarray(file_mask)
                file_mask.save(os.path.join(path, str(file_index) + "_6.png"))

            file_masks = preds_list[-3].clone()
            file_masks = torch.sigmoid(file_masks).data.cpu().numpy()
            n, c, h, w = file_masks.shape
            assert n == len(file_masks)
            for i in range(len(file_masks)):
                file_index = int(name[i].split('.')[0])
                file_mask = (file_masks[i][0] > 0.5).astype(np.uint8)
                file_mask[file_mask >= 1] = 255
                file_mask = Image.fromarray(file_mask)
                file_mask.save(os.path.join(path, str(file_index) + "_4.png"))
            # batchsize=8
            preds_list = [pred.view(pred.size(0), -1) for pred in preds_list]
            target = target.view(target.size(0), -1)
            v_loss = criterion(preds_list[-1], target)
            val_loss.update(v_loss.item(), 1)
            OtherVal8.update(labels=target, preds=preds_list[-1], n=1)
            OtherVal6.update(labels=target, preds=preds_list[-2], n=1)
            OtherVal4.update(labels=target, preds=preds_list[-3], n=1)

        vmr, vms, vmp, vmf, vmjc, vmd, vmacc = OtherVal8.get_avg
        logger.info(
            "8:Val_Loss:{:.5f} Acc:{:.5f} Dice:{:.5f} Jc:{:.5f}".format(
                val_loss.avg, vmacc, vmd, vmjc))
        vmr, vms, vmp, vmf, vmjc, vmd, vmacc = OtherVal6.get_avg
        logger.info(
            "6:Val_Loss:{:.5f} Acc:{:.5f} Dice:{:.5f} Jc:{:.5f}".format(
                val_loss.avg, vmacc, vmd, vmjc))
        vmr, vms, vmp, vmf, vmjc, vmd, vmacc = OtherVal4.get_avg
        logger.info(
            "4:Val_Loss:{:.5f} Acc:{:.5f} Dice:{:.5f} Jc:{:.5f}".format(
                val_loss.avg, vmacc, vmd, vmjc))
コード例 #14
0
def inference_isic(model1, model2, img_dir, mask_dir):
    OtherVal = BinaryIndicatorsMetric()
    model1.eval()
    model2.eval()
    filenames = os.listdir(img_dir)
    data_list = []
    gt_list = []
    img_ids = []
    filenames = sorted(filenames,
                       key=lambda x: int(x.split('_')[-1][:-len('.jpg')]))
    for filename in filenames:
        ext = os.path.splitext(filename)[-1]
        if ext == '.jpg':
            filename = filename.split('_')[-1][:-len('.jpg')]
            img_ids.append(filename)
            data_list.append('ISIC_' + filename + '.jpg')
            gt_list.append('ISIC_' + filename + '_segmentation.png')

    assert (len(data_list) == len(gt_list))
    data_list = [os.path.join(img_dir, i) for i in data_list]
    gt_list = [os.path.join(mask_dir, i) for i in gt_list]

    hard_filenames = [
        '18', '24', '26', '31', '42', '49'
        '56', '62', '73', '81', '91', '97', '113', '153', '184', '16071',
        '246', '288', '311', '319', '324', '358', '387', '393', '395', '499',
        '504', '520', '529', 531, 547, 549, 1140, 1148, 1152, 1184, 1442, 2829,
        3346, 4115, 5555, 6612, 6914, 7557, 8913, 9873, 9875, 9934, 10093,
        11107, 11110, 11168, 11349, 12090, 12136, 12149, 12167, 12187, 12212,
        12216, 12290, 12329, 12512, 12516, 12713, 12773, 12876, 12999, 13000,
        13010, 13063, 13120, 13164, 13227, 13242, 13393, 13493, 13516, 13518,
        13549, 13709, 13813, 13832, 13988, 14132, 14189, 14221, 14639, 14693,
        14912, 15102, 15176, 15237, 15330, 155417, 15443, 16068
    ]

    better_filenames = [
        '16',
        '63',
        '75',
        '101',
        '105',
        '131',
        '148',
        '164',
        '184',
        '198',
        '252',
        '276',
        '330',
        '397',
        '433',
        '458',
        '476',
        '480',
        1119,
        1212,
        1262,
        1306,
        1374,
        3346,
        6671,
        9504,
        9895,
        9992,
        10041,
        10044,
        10175,
        10183,
        10213,
        10382,
        10452,
        10456,
        11079,
        11130,
        11159,
        12318,
        12495,
        12897,
        12961,
        13146,
        13340,
        13371,
        13411,
        13807,
        13910,
        13918,
        14090,
        14693,
        14697,
        14850,
        14898,
        14904,
        15062,
        15166,
        15207,
        15483,
        15563,
    ]

    easy_filenames = [
        '34', '39', '52', '57', '117', '164', '165', '182', '207', '213',
        '222', '225', '232'
    ]

    dataset_wrong_case = [
        9800,
        9934,
        9951,
        10021,
        10361,
        10584,
        11227,
        13310,
        13600,
        13673,
        13680,
        15132,
        15152,
        15251,
        16036,
    ]

    all_bad_case = [
        10320,
        10361,
        10445,
        10457,
        10477,
        11081,
        11084,
        11121,
        12369,
        12484,
        12726,
        12740,
        12768,
        12786,
        12789,
        12876,
        12877,
        13120,
        13310,
        13393,
        13552,
        13832,
        13975,
        14222,
        14328,
        14372,
        14385,
        14434,
        14454,
        14480,
        14503,
        14506,
        14580,
        14628,
        14786,
        14931,
        14932,
        14963,
        14982,
        14985,
        15020,
        15021,
        15062,
        15309,
        15537,
        15947,
        15966,
        15969,
        15983,
        156008,
        16034,
        16037,
        16058,
        16068,
    ]

    for i in range(len(data_list)):
        file_name = img_ids[i]
        print("Filename:{}".format(file_name))
        img, flipimg, original_img, mask = isic_transform(
            data_list[i], gt_list[i])
        output = model1(img)
        #flip_output=model1(flipimg)
        #output=torch.sigmoid(output).data.cpu().numpy()[0,0,:,:]
        output2 = model2(img)

        nas_output = output[-1].clone()
        nas_output = nas_output.view(nas_output.size(0), -1)
        target = torch.from_numpy(np.asarray(mask))
        target = target.unsqueeze(0).unsqueeze(0)
        target = target.view(target.size(0), -1)
        OtherVal.update(labels=target, preds=nas_output, n=1)
        # OtherVal.update(labels=target, preds=outputs_original[0].view(outputs_original[0].size(0), -1), n=1)
        #
        # output2=torch.sigmoid(output2[-1]).data.cpu().numpy()[0,0,:,:]
        # # flip_output=torch.sigmoid(flip_output).data.cpu().numpy()[0,0,:,:]
        # # flip_output=np.fliplr(flip_output)
        # # 可视化
        # oimage=np.asarray(original_img).astype(np.uint8)
        # mask=np.asarray(mask).astype(np.uint8)
        # output=(output>0.5).astype(np.uint8)
        # output2=(output2>0.5).astype(np.uint8)
        # # flip_output=(flip_output>0.5).astype(np.uint8)
        # mask[mask>=1]=255
        # output[output>=1]=255
        # output2[output2>=1]=255
        # # flip_output[flip_output>=1]=255
        #
        # #rgb
        # #img[..., 2] = np.where(mask == 1, 255, img[..., 2])
        # contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # cv2.drawContours(oimage, contours, -1, (0, 0, 255), 1,lineType=cv2.LINE_AA)
        #
        # output_contours, _ = cv2.findContours(output, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # cv2.drawContours(oimage, output_contours, -1, (0, 255,0), 1,lineType=cv2.LINE_AA)
        #
        # output_contours, _ = cv2.findContours(output2, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # cv2.drawContours(oimage, output_contours, -1, (255, 0,0), 1,lineType=cv2.LINE_AA)
        #
        # show_images(oimage.copy(),mask.copy(),output.copy(),output2.copy(),file_name)
        # cv2.imwrite(os.path.join(image_save_dir,'{}.png'.format(filename)),oimage)

        # flip_output_contours, _ = cv2.findContours(flip_output, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # cv2.drawContours(oimage, flip_output_contours, -1, (255, 0,0), 1,lineType=cv2.LINE_AA)
        # oimage=oimage[:,:,[2,1,0]]
        # cv2.imread(os.path.join())
        # cv2.imshow('original_mask',oimage)
        # cv2.imshow('mask',mask)
        # cv2.imshow('output',output)
        # cv2.imshow('output2',output2)
        # # cv2.imshow('flip_output',flip_output)
        # cv2.waitKey()
        # # [75,101]
        # # 容易 [78,107,129]
    value = OtherVal.get_avg
    mr, ms, mp, mf, mjc, md, macc = value
    print("Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".format(macc, md, mjc))
コード例 #15
0
ファイル: inference_cvc.py プロジェクト: lswzjuer/NAS-WDAN
def inference_isic(models_list, img_dir, mask_dir):
    OtherVal = BinaryIndicatorsMetric()
    for model in models_list:
        model.eval()
    filenames = os.listdir(img_dir)
    data_list = []
    gt_list = []
    img_ids = []
    for filename in filenames:
        data_list.append(filename)
        gt_list.append(filename)
        img_ids.append(filename)
        assert os.path.splitext(filename)[-1] == '.tif'

    assert (len(data_list) == len(gt_list))
    data_list = [os.path.join(img_dir, i) for i in data_list]
    gt_list = [os.path.join(mask_dir, i) for i in gt_list]

    hard_filenames = []

    better_filenames = []

    easy_filenames = []

    dataset_wrong_case = []

    all_bad_case = []

    model_name_list = [
        'unet', 'unet++', 'multires_unet', 'attention_unet_v1', 'nas_search'
    ]
    for i in range(len(data_list)):
        file_name = img_ids[i].split('.')[0]
        print("Filename:{}".format(file_name))
        img, original_img, mask = isic_transform(data_list[i], gt_list[i])
        outputs_original = [model(img) for model in models_list]
        nas_output = outputs_original[-1][-1].clone()
        nas_output = nas_output.view(nas_output.size(0), -1)
        target = torch.from_numpy(np.asarray(mask))
        target = target.unsqueeze(0).unsqueeze(0)
        target = target.view(target.size(0), -1)
        # OtherVal.update(labels=target, preds=nas_output, n=1)
        OtherVal.update(labels=target,
                        preds=outputs_original[0].view(
                            outputs_original[0].size(0), -1),
                        n=1)

        # outputs=[]
        # for index,output in enumerate(outputs_original):
        #     if isinstance(output,list):
        #         print("Index:{} is nas search mmodel !".format(index))
        #         outputs.append(torch.sigmoid(output[-1]).data.cpu().numpy()[0,0,:,:])
        #     else:
        #         outputs.append(torch.sigmoid(output).data.cpu().numpy()[0,0,:,:])
        # outputs=[(output>0.5).astype(np.uint8) for output in outputs]
        # for output in outputs:
        #     output[output>=1]=255
        # #flip_output=model1(flipimg)
        # # flip_output=torch.sigmoid(flip_output).data.cpu().numpy()[0,0,:,:]
        # # flip_output=np.fliplr(flip_output)
        # # 可视化
        # oimage=np.asarray(original_img).astype(np.uint8)
        # mask=np.asarray(mask).astype(np.uint8)
        # # flip_output=(flip_output>0.5).astype(np.uint8)
        # mask[mask>=1]=255
        #
        # # flip_output[flip_output>=1]=255
        # #img[..., 2] = np.where(mask == 1, 255, img[..., 2])
        # unet=outputs[0]
        # # unetpp=outputs[1]
        # # multires_unet=outputs[2]
        # # attention_unet_v1=outputs[3]
        # nas_search_output=outputs[-1]
        #
        # # rgb
        # contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # cv2.drawContours(oimage, contours, -1, (0, 0,255), 1,lineType=cv2.LINE_AA)
        # output_contours, _ = cv2.findContours(unet, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # cv2.drawContours(oimage, output_contours, -1, (0, 255,0), 1,lineType=cv2.LINE_AA)
        #
        # # output_contours, _ = cv2.findContours(unetpp, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # # cv2.drawContours(oimage, output_contours, -1, (0, 0,255), 1,lineType=cv2.LINE_AA)
        #
        # # output_contours, _ = cv2.findContours(multires_unet, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # # cv2.drawContours(oimage, output_contours, -1, (0, 255,255), 1,lineType=cv2.LINE_AA)
        # #
        # #
        # # output_contours, _ = cv2.findContours(attention_unet_v1, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # # cv2.drawContours(oimage, output_contours, -1, (255, 0,255), 1,lineType=cv2.LINE_AA)
        #
        #
        # output_contours, _ = cv2.findContours(nas_search_output, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # cv2.drawContours(oimage, output_contours, -1, (255, 0,0), 1,lineType=cv2.LINE_AA)
        #
        # # #show_images(oimage.copy(),mask.copy(),unet.copy(),unetpp.copy(),multires_unet.copy(),attention_unet_v1.copy(),file_name)
        # show_images(oimage.copy(), mask.copy(), unet.copy(),nas_search_output.copy(), file_name)
    value = OtherVal.get_avg
    mr, ms, mp, mf, mjc, md, macc = value
    print("Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".format(macc, md, mjc))