Esempio n. 1
0
def eval_net(net, loader, device):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    mask_type = torch.float32
    n_val = len(loader)  # the number of batch
    tot = 0
    tot_lv = 0
    tot_myo = 0
    tot_rv = 0
    with tqdm(total=n_val, desc='Validation round', unit='batch',
              leave=False) as pbar:
        for imgs, true_masks in loader:
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device, dtype=mask_type)

            with torch.no_grad():
                reco, z_out, mu_tilde, a_out, mask_pred, mu, logvar = net(
                    imgs, true_masks, 'test')

            pred = mask_pred
            pred = (pred > 0.5).float()
            tot += dice_coeff(pred[:, 0:3, :, :], true_masks[:, 0:3, :, :],
                              device).item()
            tot_lv += dice_coeff(pred[:, 0, :, :], true_masks[:, 0, :, :],
                                 device).item()
            tot_myo += dice_coeff(pred[:, 1, :, :], true_masks[:, 1, :, :],
                                  device).item()
            tot_rv += dice_coeff(pred[:, 2, :, :], true_masks[:, 2, :, :],
                                 device).item()
            pbar.update()

    net.train()
    return tot / n_val, tot_lv / n_val, tot_myo / n_val, tot_rv / n_val
Esempio n. 2
0
def eval_net(net, loader, device):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    centerlines_type = torch.float32 if net.n_classes == 1 else torch.long
    points_type = torch.float32 if net.n_classes == 1 else torch.long
    n_val = len(loader)  # the number of batch
    tot = 0

    with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
        for batch in loader:
            imgs, true_centerlines, true_points = batch['image'], batch['centerlines'], batch['points']
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_centerlines = true_centerlines.to(device=device, dtype=centerlines_type)
            true_points = true_points.to(device=device, dtype=points_type)

            with torch.no_grad():
                centerlines_pred, points_pred = net(imgs)

            if net.n_classes > 1:
                LOSS = F.cross_entropy(centerlines_pred, true_centerlines).item() + F.cross_entropy(points_pred, true_points).item()
                tot += LOSS
            else:
                centerlines_pred = torch.sigmoid(centerlines_pred)
                points_pred = torch.sigmoid(points_pred)
                centerlines_pred = (centerlines_pred > 0.5).float()
                points_pred = (points_pred > 0.5).float()
                LOSS = dice_coeff(centerlines_pred, true_centerlines).item() + dice_coeff(points_pred, true_points).item()
                tot += LOSS
            pbar.update()

    net.train()
    return tot / n_val
Esempio n. 3
0
def eval_net(net, loader, device):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()

    # yuankai change the mask_type to float32
    # if net.n_classes == 1:
    #     mask_type = torch.float32
    # else:
    #     mask_type = torch.long
    mask_type = torch.float32

    n_val = len(loader)  # the number of batch
    tot = 0

    # with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
    for batch in loader:
        imgs, true_masks = batch['image'], batch['mask']
        imgs = imgs.to(device=device, dtype=torch.float32)
        true_masks = true_masks.to(device=device, dtype=mask_type)

        with torch.no_grad():
            mask_pred = net(imgs)

        if net.n_classes > 1:
            pred = mask_pred.max(dim=1)[1]
            pred = (pred).float()
            tot += dice_coeff(pred, true_masks).item()
            # tot += F.cross_entropy(mask_pred, true_masks).item()
        else:
            pred = torch.sigmoid(mask_pred)
            pred = (pred > 0.5).float()
            tot += dice_coeff(pred, true_masks).item()
        # pbar.update()

    return tot / n_val
Esempio n. 4
0
def eval_model(model, eval_loader):
    model.to(device=device)
    model.eval()
    eval_tot = len(eval_loader)
    dice_coeffs_soft = np.zeros(4)
    dice_coeffs_hard = np.zeros(4)
    vis_images = []

    with torch.set_grad_enabled(False):
        batch_id = 0
        for inputs, true_masks in tqdm(eval_loader):
            inputs = inputs.to(device=device, dtype=torch.float)
            true_masks = true_masks.to(device=device, dtype=torch.float)
            bs, _, h, w = inputs.shape
            h_size = (h - 1) // image_size + 1
            w_size = (w - 1) // image_size + 1
            masks_pred = torch.zeros(true_masks.shape).to(dtype=torch.float)
            for i in range(h_size):
                for j in range(w_size):
                    h_max = min(h, (i + 1) * image_size)
                    w_max = min(w, (j + 1) * image_size)
                    inputs_part = inputs[:, :, i * image_size:h_max,
                                         j * image_size:w_max]
                    if net_name == 'unet':
                        masks_pred[:, :, i * image_size:h_max,
                                   j * image_size:w_max] = model(
                                       inputs_part).to("cpu")
                    elif net_name == 'hednet':
                        masks_pred[:, :, i * image_size:h_max,
                                   j * image_size:w_max] = model(
                                       inputs_part)[-1].to("cpu")

            masks_pred_softmax = softmax(masks_pred)
            masks_max, _ = torch.max(masks_pred_softmax, 1)
            masks_soft = masks_pred_softmax[:, 1:-1, :, :]
            np.save(
                os.path.join(logdir, 'mask_soft_' + str(batch_id) + '.npy'),
                masks_soft.numpy())
            np.save(
                os.path.join(logdir, 'mask_true_' + str(batch_id) + '.npy'),
                true_masks[:, 1:-1].cpu().numpy())
            masks_hard = (masks_pred_softmax == masks_max[:, None, :, :]).to(
                dtype=torch.float)[:, 1:-1, :, :]
            dice_coeffs_soft += dice_coeff(masks_soft,
                                           true_masks[:, 1:-1, :, :].to("cpu"))
            dice_coeffs_hard += dice_coeff(masks_hard,
                                           true_masks[:, 1:-1, :, :].to("cpu"))
            images_batch = generate_log_images_full(inputs, true_masks[:,
                                                                       1:-1],
                                                    masks_soft, masks_hard)
            images_batch = images_batch.to("cpu").numpy()
            vis_images.extend(images_batch)
            batch_id += 1
    return dice_coeffs_soft / eval_tot, dice_coeffs_hard / eval_tot, vis_images
Esempio n. 5
0
def eval_net(net, loader, device, pathFile = None, saveImage = False):
    """Evaluation without the densecrf with the dice coefficient"""

    net.eval()

    mask_type = torch.float32

    n_val = len(loader)   
    tot = 0

   
    for batch in loader:
        imgs, true_masks, filename = batch['image'], batch['mask'], batch['filename']
        imgs = imgs.to(device=device, dtype=torch.float32)
        true_masks = true_masks.to(device=device, dtype=mask_type)

        with torch.no_grad():
            mask_pred = net(imgs)

        if net.n_classes > 1:
            pred = mask_pred.max(dim=1)[1]
            pred = (pred).float()
            tot += dice_coeff(pred, true_masks).item()
            # tot += F.cross_entropy(mask_pred, true_masks).item()
        else:
            pred = torch.sigmoid(mask_pred)
            pred = (pred > 0.5).float()
            tot += dice_coeff(pred, true_masks).item()
      
        if saveImage:  

          
            output_seg = mask_pred.max(dim=1)[1].unsqueeze(1)
            output_seg = output_seg.data.cpu().numpy() 


            if not os.path.exists(pathFile):
               
                os.makedirs(pathFile)

            for i in range(output_seg.shape[0]):
        
                output_img = output_seg[i,0,:,:] * 255
               
                filePath = os.path.join(pathFile, filename[i] + ".png")
                Image.fromarray(output_img.astype(np.uint8)).save(filePath)
                
    return tot / n_val
Esempio n. 6
0
def eval_net(net, loader, device, criterion):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    batch_size = loader.batch_size
    n_val = len(loader) * batch_size  # the number of batch
    tot_dice = 0
    tot_loss = 0

    with tqdm(total=n_val, desc='Validation round', unit='img',
              leave=False) as pbar:
        for batch in loader:
            imgs, true_masks = batch['image'], batch['mask']
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device, dtype=torch.float32)

            with torch.no_grad():
                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                tot_loss += loss.item()

                pred = torch.sigmoid(masks_pred)
                pred = (pred > 0.5).float()
                tot_dice += dice_coeff(pred, true_masks).item()
            pbar.update(batch_size)

    net.train()
    return tot_dice / len(loader), tot_loss / len(loader)
Esempio n. 7
0
def eval_net(net, loader, device, n_classes):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    mask_type = torch.float32 if n_classes == 1 else torch.long
    n_val = len(loader)  # the number of batch
    tot = 0

    sum_num = 0
    with tqdm(total=n_val, desc='Validation round', unit='batch',
              leave=False) as pbar:
        for batch in loader:
            imgs, true_masks = batch['image'], batch['mask']
            true_biMasks = (true_masks > 0).int()
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_biMasks = true_biMasks.to(device=device, dtype=mask_type)

            with torch.no_grad():
                if CROP_FLAG:
                    mask_pred = crop_tile(net, imgs)
                else:
                    mask_pred = net(imgs)

            if n_classes > 1:
                tot += F.cross_entropy(mask_pred, true_biMasks).item()
            else:
                pred = torch.sigmoid(mask_pred)
                pred = (pred > 0.5).float()
                tot += dice_coeff(pred, true_biMasks).item() * imgs.shape[0]
                sum_num += imgs.shape[0]
                # tot += nn.BCEWithLogitsLoss(reduction='mean')(mask_pred, true_biMasks).item()
            pbar.update()

    net.train()
    return tot / sum_num
Esempio n. 8
0
def eval_net(net, validation_loader, gpu=False):
    """Evaluation without the densecrf with the dice coefficient"""
    total_loss = 0
    num = 0
    for batch_index, (id, z, image, true_mask) in enumerate(validation_loader, 0):

        # image = image.unsqueeze(0)
        # true_mask = true_mask.unsqueeze(0)

        if gpu:
            image = image.cuda()
            true_mask = true_mask.cuda()

        # why do you do [0]

        # masks_pred = net(image, z)
        masks_pred = net(image)

        masks_probs = torch.sigmoid(masks_pred)
        masks_probs_flat = masks_probs.view(-1)
        # threshole transform from probability to solid mask
        masks_probs_flat = (masks_probs_flat > 0.5).float()

        true_mask_flat = true_mask.view(-1)

        total_loss += dice_coeff(masks_probs_flat, true_mask_flat).item()
        num=num+1
    return total_loss / (num+0.1e-10)
Esempio n. 9
0
def eval_net(args, net, loader, device):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    n_val = len(loader)  # the number of batch
    tot = 0
    with tqdm(total=n_val,
              desc='Validation' if not args.test else 'Testing',
              unit='img',
              leave=False) as pbar:
        for batch in loader:
            img, mask, name = batch['image'], batch['mask'], batch['name']
            img = img.to(device=device, dtype=torch.float32)
            mask = mask.to(device=device, dtype=torch.float32)
            with torch.no_grad():
                mask_pred = net(img)
                mask_save = torch.sigmoid(mask_pred).squeeze(0).squeeze(
                    0).cpu().detach().numpy()
                #
                if not args.test:
                    Image.fromarray(mask_save/np.max(mask_save) *255)\
                        .convert('RGB').save(os.path.join('./records/valid/segmentation',name[0]+'.png'))
                    pred = torch.sigmoid(mask_pred)
                    pred = (pred > 0.1).float()
                    tot += dice_coeff(pred, mask, args).item()
                else:
                    Image.fromarray(mask_save /np.max(mask_save)*255)\
                        .convert('RGB').save(os.path.join('./records/test/segmentation',name[0]+'.png'))
            pbar.update()
    return tot / n_val
Esempio n. 10
0
def eval_net(net, loader, device):
    n_val = len(loader)
    net.eval()
    dice = 0
    with tqdm(total=n_val, desc='Evaluation round', unit='batch',
              leave=False) as pbar:
        for batch in loader:
            img, mask = batch['img'][0], batch['mask'][0]
            img = img.to(device=device, dtype=torch.float32)  #(1,3,256,256)
            mask = mask.to(device=device, dtype=torch.float32)  #(1,1,256,256)
            #print(img.shape)
            #print(mask.shape)

            with torch.no_grad():
                mask_pred = net(img.cuda())
            #print(mask_pred.shape) (1,1,256,256)

            pred = torch.sigmoid(mask_pred)
            pred = (pred > 0.5).float()
            if torch.sum(mask) == 0:
                n_val -= 1
            else:
                dice += dice_coeff(pred, mask.cuda()).item()
            pbar.update()

    print(n_val)
    print(f'Dice: {dice/n_val}')
    return (n_val, dice / n_val)
Esempio n. 11
0
def eval_net_AP_Power(net, loader, device):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    mask_type = torch.float32 if net.n_classes == 1 else torch.long
    n_val = len(loader)  # the number of batch
    totAP = 0
    totPower = 0
    mseLoss = nn.MSELoss()

    with tqdm(total=n_val, desc='Validation round', unit='batch',
              leave=False) as pbar:
        for batch in loader:
            imgs, true_masks, true_powers = \
                batch['image'], batch['mask'], batch['power']
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device, dtype=mask_type)
            true_powers = true_powers.to(device=device, dtype=torch.long)
            with torch.no_grad():
                mask_pred = net(imgs)

            if "AP" in mask_pred:

                if net.n_classes > 1:
                    totAP += F.cross_entropy(mask_pred["AP"],
                                             true_masks).item()
                else:
                    pred = torch.sigmoid(mask_pred["AP"])
                    pred = (pred > 0.5).float()
                    totAP += dice_coeff(pred, true_masks).item()
            if "power" in mask_pred:
                totPower += mseLoss(mask_pred["power"], true_powers)
            pbar.update()

    return totAP / n_val, totPower / n_val
Esempio n. 12
0
def eval_net(net, dataset, dir_img, dir_mask, args):
    net.eval()
    if args.gpu:
        net.to(args.device)

    total = 0
    val = get_imgs_and_masks(dataset['val'], dir_img, dir_mask)
    for i, b in enumerate(val):
        img = np.array(b[0]).astype(np.float32)
        mask = np.array(b[1]).astype(np.float32)

        img = torch.from_numpy(img)[None, None, :, :]
        mask = torch.from_numpy(mask).unsqueeze(0)

        if args.gpu:
            img = img.to(args.device)
            mask = mask.to(args.device)
        mask_pred = net(img)
        mask_pred = (mask_pred > 0.5).float()  # 得到预测的分割图

        total += dice_coeff(mask_pred, mask, args.device).cpu().item()
    current_score = total / (i + 1)
    global best_score
    print('current score is %f' % current_score)
    print('best score is %f' % best_score)
    if current_score > best_score:
        best_score = current_score
        print('current best score is {}'.format(best_score))
        if args.save_cp:
            print('saving checkpoint')
            mkdir_p('checkpoint')
            torch.save(net.state_dict(), './checkpoint/unet.pth')

    return best_score
Esempio n. 13
0
def eval_net(net, dataset, lendata, gpu=False, batch_size=8, is_loss=False):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    tot = 0
    criterion = nn.BCELoss()
    with torch.no_grad(), tqdm(total=lendata) as progress_bar:
        for i, b in enumerate(batch(dataset, batch_size)):
            imgs = np.array([i[0] for i in b]).astype(np.float32)
            true_masks = np.array([i[1] for i in b])

            imgs = torch.from_numpy(imgs)
            true_masks = torch.from_numpy(true_masks)

            if gpu:
                imgs = imgs.cuda()
                true_masks = true_masks.cuda()

            masks_pred = net(imgs)[0]

            if is_loss:
                loss = criterion(masks_pred, true_masks)
                tot += loss.item()
                progress_bar.update(batch_size)
                progress_bar.set_postfix(BCE=loss.item())
            else:
                masks_pred = (masks_pred > 0.5).float()
                dice = dice_coeff(masks_pred, true_masks).item()
                tot += dice
                progress_bar.update(batch_size)
                progress_bar.set_postfix(DICE=dice)

    value = tot / i
    return value
Esempio n. 14
0
def eval_net(net, dataset, gpu=False):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    tot = 0
    i_out = 0
    for i, b in enumerate(dataset):
        img = b[0]
        true_mask = b[1]

        img = torch.from_numpy(img).unsqueeze(0)
        true_mask = torch.from_numpy(true_mask).unsqueeze(0)
        true_mask = np.transpose(true_mask, (0, 3, 1, 2))
        true_mask = true_mask.float()

        if gpu:
            img = img.cuda()
            true_mask = true_mask.cuda()

        mask_pred = net(img)[0]
        mask_pred = (mask_pred > 0.5).float()
        mask_pred = mask_pred.unsqueeze(0)

        tot += dice_coeff(mask_pred, true_mask).item()
        i_out = i
    return tot / (i_out + 1)
Esempio n. 15
0
def brats2018_test_inference(args, model, test_dataset):
    """ Returns the test accuracy and loss.
    """
    raise NotImplementedError
    model.eval()
    tot = 0

    with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar:
        for batch in loader:
            imgs = batch['image']
            true_masks = batch['mask']

            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device, dtype=torch.float32)

            mask_pred = net(imgs)

            for true_mask, pred in zip(true_masks, mask_pred):
                pred = (pred > 0.5).float()
                if net.n_classes > 1:
                    tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
                else:
                    tot += dice_coeff(pred, true_mask.squeeze(dim=1)).item()
            pbar.update(imgs.shape[0])

    return tot / n_val
Esempio n. 16
0
def eval_net(net, loader, device):
    """Evaluation without the densecrf with the dice coefficient"""
    net.module.eval()
    # mask_type = torch.float32 if net.n_classes <= 2 else torch.long
    n_val = len(loader)  # the number of batch
    tot = 0
    criterion = nn.BCEWithLogitsLoss()

    with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
        for batch in loader:
            imgs, true_masks = batch['image'], batch['mask']
            imgs = imgs.to(device=device).float()
            true_masks = true_masks.to(device=device).float()

            with torch.no_grad():
                mask_pred = net(imgs)

            if net.module.n_classes > 2:
                tot += criterion(mask_pred, true_masks).item()
            else:
                pred = torch.sigmoid(mask_pred)
                pred = (pred > 0.5).float()
                tot += dice_coeff(pred, true_masks).item()
            pbar.update()

    net.module.train()
    return tot / n_val
Esempio n. 17
0
def eval_net(net, loader, device, threshold=0.5):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    mask_type = torch.float32 if net.n_classes == 1 else torch.long
    n_val = len(loader)  # the number of batch
    tot = 0
    tot_precision = 0
    tot_recall = 0

    with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
        for batch in loader:
            imgs, true_masks = batch['image'], batch['mask']
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device, dtype=mask_type)

            with torch.no_grad():
                mask_pred = net(imgs)

            if net.n_classes > 1:
                tot += F.cross_entropy(mask_pred, true_masks).item()
            else:
                pred = torch.sigmoid(mask_pred)
                pred = (pred > threshold).float()
                mask_widen = None
                if len(batch) > 2:
                    mask_widen = batch['mask_widen']
                    mask_widen = mask_widen.to(device=device, dtype=mask_type)
                dic, prec, recall = dice_coeff(pred, true_masks, mask_widen,True)
                tot += dic.item()
                tot_precision += prec.item()
                tot_recall += recall.item()
            pbar.update()

    net.train()
    return tot / n_val, tot_precision / n_val, tot_recall / n_val
Esempio n. 18
0
def eval_net(loader, device, dict, color_map):

    # simple dice script for computing the values;
    # dataloader has also been modified;
    mask_type = torch.float32
    n_val = len(loader)
    tot = 0

    for batch in loader:
        imgs, true_masks, name = batch['image'], batch['mask'], batch['name']
        imgs = imgs.to(device=device, dtype=torch.float32)
        true_masks = true_masks.to(device=device, dtype=mask_type)

        # getting each image.
        tot += dice_coeff(imgs, true_masks).item()

        for x in range(len(imgs)):

            # calculating individual dice scores:
            # test

            test = DiceCoeff().forward(imgs[x], true_masks[x]).item()
            # yeet = dice_coeff(imgs[x], true_masks[x]).item()
            # print(test)

            # checking if not in keys:
            dict.append([test])

    print("total:: " + str(tot / n_val))
    return tot / n_val
Esempio n. 19
0
def test(args, ckpt_file):
    testdataset = BasicDataset(args["TESTIMAGEDATA_DIR"],
                               args["TESTLABEL_DIRECTORY"], img_scale)
    val_loader = DataLoader(testdataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)
    net = UNet(n_channels=3, n_classes=1, bilinear=True)
    net.to(device=device)
    net.load_state_dict(torch.load(os.path.join(args["EXPT_DIR"] + ckpt_file)))
    net.eval()
    with tqdm(total=n_val, desc='Validation round', unit='batch',
              leave=False) as pbar:
        for batch in val_loader:
            imgs, true_masks = batch['image'], batch['mask']
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device, dtype=mask_type)

            with torch.no_grad():
                mask_pred = net(imgs)

            if net.n_classes > 1:
                tot += F.cross_entropy(mask_pred, true_masks).item()
            else:
                pred_sig = torch.sigmoid(mask_pred)
                pred = (pred_sig > 0.5).float()
                tot += dice_coeff(pred, true_masks).item()
            pbar.update()

    return {"predictions": pred, "labels": true_masks}
Esempio n. 20
0
def eval_net(net, loader, device, n_val):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    tot = 0

    with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar:
        for batch in loader:
            imgs = batch['image']
            true_masks = batch['mask']

            imgs = imgs.to(device=device, dtype=torch.float32)
            mask_type = torch.float32 if net.n_classes == 1 else torch.long
            true_masks = true_masks.to(device=device, dtype=mask_type)

            mask_pred = net(imgs)

            for true_mask, pred in zip(true_masks, mask_pred):
                pred = (pred > 0.5).float()
                if net.n_classes > 1:
                    tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
                else:
                    tot += dice_coeff(pred, true_mask.squeeze(dim=1)).item()
            pbar.update(imgs.shape[0])

    return tot / n_val
Esempio n. 21
0
def eval_net(net, loader, device, verbose=False):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    mask_type = torch.float32 if net.n_classes == 1 else torch.long
    n_val = len(loader)  # the number of batch
    tot = 0
    imgs, mask_pred = None, None

    print ('\nStarting validation...\n')

    for batch in loader:
        imgs, true_masks = batch['image'], batch['mask']
        imgs = imgs.to(device=device, dtype=torch.float32)
        true_masks = true_masks.to(device=device, dtype=mask_type)

        with torch.no_grad():
            mask_pred = net(imgs)

        if net.n_classes > 1:
            tot += F.cross_entropy(mask_pred, true_masks).item()
        else:
            pred = torch.sigmoid(mask_pred)
            pred = (pred > 0.5).float()
            tot += dice_coeff(pred, true_masks).item()

    net.train()

    result = {'val_score': tot/n_val}
    if verbose:
        result['imgs'] = imgs.cpu()
        result['preds'] = mask_pred.cpu()

    return result
Esempio n. 22
0
def eval_model(model, eval_loader, criterion):
    model.eval()
    eval_tot = len(eval_loader)
    dice_coeffs = np.zeros(1)
    eval_loss_ce = 0.

    with torch.set_grad_enabled(False):
        for inputs, true_masks in eval_loader:
            inputs = inputs.to(device=device, dtype=torch.float)
            true_masks = true_masks.to(device=device, dtype=torch.float)
            if net_name == 'unet':
                masks_pred = model(inputs)
            elif net_name == 'hednet':
                masks_pred = model(inputs)[-1]

            masks_pred_transpose = masks_pred.permute(0, 2, 3, 1)
            masks_pred_flat = masks_pred_transpose.reshape(
                -1, masks_pred_transpose.shape[-1])
            true_masks_indices = torch.argmax(true_masks, 1)
            true_masks_flat = true_masks_indices.reshape(-1)
            loss_ce = criterion(masks_pred_flat, true_masks_flat.long())
            eval_loss_ce += loss_ce

            masks_pred_softmax = softmax(masks_pred)
            dice_coeffs += dice_coeff(masks_pred_softmax[:, 1:, :, :],
                                      true_masks[:, 1:, :, :])
        return dice_coeffs / eval_tot, eval_loss_ce / eval_tot
Esempio n. 23
0
def eval_net(net, loader, device):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    mask_type = torch.float32 if net.n_classes == 1 else torch.long
    n_val = len(loader)  # the number of batch
    tot = 0

    with tqdm(total=n_val, desc='Validation round', unit='batch',
              leave=False) as pbar:
        for batch in loader:
            imgs, true_masks = batch['image'], batch['mask']
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device, dtype=mask_type)

            with torch.no_grad():
                mask_pred = net(imgs)

            if net.n_classes > 2:
                tot += F.cross_entropy(mask_pred, true_masks).item()
            else:
                # pred = torch.sigmoid(mask_pred)
                pred = (mask_pred[:, 0] > 0.5).float()
                tot += dice_coeff(pred, true_masks[:, 0].float()).item()
            pbar.update()

    net.train()
    return tot / n_val
def assessment(path1, path2):
    ds = BasicDataset(path1, path2, scale=0.5)
    loader = DataLoader(ds, batch_size=1, shuffle=False, pin_memory=True)
    tp, tn, fn, fp, dice, count = [0, 0, 0, 0, 0, 0]
    for batch in loader:
        count += 1
        true_mask = batch['image']
        pred_mask = batch['mask']
        pred_mask = pred_mask.to(
            device='cpu' if torch.cuda.is_available() else 'cpu',
            dtype=torch.float32)
        true_mask = true_mask.to(
            device='cpu' if torch.cuda.is_available() else 'cpu',
            dtype=torch.float32)
        for GT, pred in zip(true_mask, pred_mask):
            pred = binary(pred).float()
            GT = binary(GT).float()
            tp += ((pred == 1) & (GT == 1)).cpu().sum().numpy()
            tn += ((pred == 0) & (GT == 0)).cpu().sum().numpy()
            fn += ((pred == 0) & (GT == 1)).cpu().sum().numpy()
            fp += ((pred == 1) & (GT == 0)).cpu().sum().numpy()
            p = tp / (tp + fp)
            r = tp / (tp + fn)
            f1 = 2 * r * p / (r + p)
            acc = (tp + tn) / (tp + tn + fp + fn)
            dice += dice_coeff(pred, GT.squeeze(dim=1)).item()
    return p, r, f1, acc, dice / count
    def eval(self, imgs, true_masks, masks_pred):

        # Calculate dice score for each
        if self.net.n_classes == 1:
            dice = [dice_coeff((masks_pred > 0.5).float(), true_masks).item()]
        else:
            dice = []
            probs = F.softmax(masks_pred, dim=1).data
            max_idx = torch.argmax(probs, 1, keepdim=True)
            one_hot = torch.FloatTensor(probs.shape).to(device=self.device)
            one_hot.zero_()
            one_hot.scatter_(1, max_idx, 1)

            for k in range(1, one_hot.shape[1]):
                input = one_hot[:, k, :, :]
                target = (true_masks == k).float().squeeze(1)
                d = dice_coeff(input, target)
                dice.append(d.item())

        # Calculate accuracy, sensitivity and specificity scalars
        """
        probs = F.softmax(masks_pred, dim=1).data
        max_idx = torch.argmax(probs, 0, keepdim=True)
        one_hot = torch.FloatTensor(probs.shape).to(device=device, dtype=mask_type)
        one_hot.zero_()
        one_hot.scatter_(1, max_idx, 1)


        confusion_vector = one_hot / F.one_hot(true_masks.squeeze(1), num_classes=4).permute(0, 3, 1, 2) if net.n_classes > 1 else masks_pred / true_masks
        # print(torch.unique(confusion_vector))
        true_positives = torch.sum(confusion_vector == 1).item()
        false_positives = torch.sum(confusion_vector == float('-inf')).item() + torch.sum(confusion_vector == float('inf')).item()
        true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
        false_negatives = torch.sum(confusion_vector == 0).item()

        if (true_positives + true_negatives + false_negatives + false_positives) > 0:
            accuracy_sum += (true_positives + true_negatives) / (
            true_positives + true_negatives + false_positives + false_negatives)
        if (true_positives + false_negatives) > 0:
            sensitivity_sum = true_positives / (true_positives + false_negatives)
        if (true_negatives + false_positives) > 0:
            specificity_sum = true_negatives / (true_negatives + false_positives)
        """

        return np.array(dice)
Esempio n. 26
0
def test(args, model, device, test_loader, best_dice, epochs):
    model.eval()
    loss = 0
    dice = 0
    all_dice_coeffs = []
    avg_loss = 0
    avg_dice = 0
    count = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            output = output.squeeze()
            target = target.squeeze()
            output_probs = torch.sigmoid(output)
            output_mask = (output_probs > 0.5).float()
            criterion = nn.BCELoss()
            output_probs_flat = output_probs.view(-1)
            target_flat = target.view(-1)
            loss += criterion(output_probs_flat.float(),
                              target_flat.float()).item()
            dice += dice_coeff(output_mask, target.float()).item()
            all_dice_coeffs.append(
                dice_coeff(output_mask, target.float()).item())
            count += 1

    avg_loss = loss / count
    avg_dice = dice / count
    std_dice = np.std(all_dice_coeffs)
    med_dice = np.median(all_dice_coeffs)
    test_loss.append(avg_loss)
    dice_loss.append(avg_dice)
    if avg_dice > best_dice:
        best_dice = avg_dice
        save_model(epochs, model, best_dice, avg_loss)

    print('\nTest set statistics:')
    print('Average loss: {:.4f}'.format(avg_loss))
    print('Dice Average:')
    print(float(avg_dice))
    print('Dice Median:')
    print(float(med_dice))
    print('Dice Standard Deviation:')
    print(float(std_dice))
Esempio n. 27
0
def calculate_metrix(x):
    input_img = read_img(x["input_img"]) / 255
    compare_img = read_img(x["compare_img"]) / 255

    dice_coefficient = dice_loss.dice_coeff(input_img, compare_img).item()

    criterion = torch.nn.MSELoss()
    loss = torch.sqrt(criterion(input_img, compare_img))

    return [float(loss.data.cpu().numpy()), dice_coefficient]
Esempio n. 28
0
    def eval(self, imgs, true_masks, masks_pred):

        # Calculate dice score for each
        if self.net.n_classes == 1:
            dice = [dice_coeff((masks_pred > 0.5).float(), true_masks).item()]
        else:
            dice = []
            probs = F.softmax(masks_pred, dim=1).data
            max_idx = torch.argmax(probs, 1, keepdim=True)
            one_hot = torch.FloatTensor(probs.shape).to(device=self.device)
            one_hot.zero_()
            one_hot.scatter_(1, max_idx, 1)

            for k in range(1, one_hot.shape[1]):
                input = one_hot[:, k, :, :]
                target = (true_masks == k).float().squeeze(1)
                d = dice_coeff(input, target)
                dice.append(d.item())

        return np.array(dice)
Esempio n. 29
0
def eval_net(net, loader, device, nickname, epoch=None, writer=None):
    """Evaluation without the densecrf with the dice coefficient"""
    net.cuda()
    net.eval()
    mask_type = torch.float32 if net.n_classes == 1 else torch.long
    n_val = len(loader)  # the number of batch
    criterion = nn.BCEWithLogitsLoss()
    tot_dice = 0
    tot_loss = 0

    sample_idx = 0
    with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
        for batch in loader:
            imgs, true_masks, seri = batch['image'], batch['mask'], batch['seri'][0]
            print(seri)
            imgs = imgs.to(device=device, dtype=torch.float32).cuda()
            true_masks = true_masks.to(device=device, dtype=mask_type).cuda()

            with torch.no_grad():
                mask_pred = net(imgs)

            if net.n_classes > 1:
                pass
            else:
                val_loss = criterion(mask_pred, true_masks).item()
                tot_loss += val_loss
                pred = torch.sigmoid(mask_pred)
                pred = (pred > 0.5).float()
                tot_dice += dice_coeff(pred, true_masks).item()
            pbar.update()

            
            print(os.path.dirname('debug/{}/{}/{}'.format(nickname, epoch, seri)))
            os.makedirs(os.path.dirname('debug/{}/{}/{}'.format(nickname, epoch, seri)), exist_ok=True)
            canvas_out = imgs.detach().cpu().numpy()[0,:,:,:].transpose(1,2,0)*255
            overlap = canvas_out.copy()
            overlap[:,:,1] = pred.detach().cpu().numpy()[0,0,:,:]*255
            overlap[:,:,2] = true_masks.detach().cpu().numpy()[0,0,:,:]*255
            cv2.imwrite('debug/{}/{}/{}'.format(nickname, epoch, seri), np.hstack([canvas_out, overlap]))
            
            if (epoch != None) and (epoch%50==0):
                if sample_idx % (n_val // 10) == 0:
                    try:
                        writer.add_images('Images/Val/Image', imgs, sample_idx)
                        writer.add_images('Images/Val/Mask', true_masks, sample_idx)
                        writer.add_images('Images/Val/Pred', torch.sigmoid(pred) > 0.5, sample_idx)
                    except:
                        print('No writer.')

            sample_idx += 1

    net.train()
    return tot_dice / n_val, tot_loss / n_val, writer
Esempio n. 30
0
def eval_net(net, dataset, gpu=False):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    tot = 0
    for batch_idx, (data, target, _) in enumerate(dataset):

        data, true_mask = data.cuda().float(), target.cuda().long()

        mask_pred = net(data)
        mask_pred = (mask_pred > 0.5).float()
        mask_pred = mask_pred.argmax(dim=1)
        tot += dice_coeff(mask_pred.float(), true_mask.float()).item()
    return tot / (batch_idx + 1)