Esempio n. 1
0
def main(data_path):
    from dataset.transform import MedicalTransform
    from torch.utils.data import DataLoader, SequentialSampler
    from utils.vis import imshow

    root = Path(data_path)
    transform = MedicalTransform(output_size=512,
                                 roi_error_range=15,
                                 use_roi=True)
    transform.eval()
    dataset = KiTS19(root,
                     stack_num=3,
                     spec_classes=[0, 1, 2],
                     img_size=(512, 512),
                     use_roi=True,
                     roi_file='roi.json',
                     roi_error_range=5,
                     train_transform=transform,
                     valid_transform=None,
                     test_transform=None)

    subset = dataset.train_dataset
    sampler = SequentialSampler(subset)
    data_loader = DataLoader(subset, batch_size=1, sampler=sampler)

    for batch_idx, data in enumerate(data_loader):
        data = dataset.vis_transform(data)
        imgs, labels = data['image'], data['label']
        imshow(title='KiTS19', imgs=(imgs[0][1], labels[0]))
def training(net, dataset, criterion, optimizer, scheduler, epoch, batch_size,
             num_workers, vis_intvl, logger):
    sampler = RandomSampler(dataset.train_dataset)

    train_loader = DataLoader(dataset.train_dataset,
                              batch_size=batch_size,
                              sampler=sampler,
                              num_workers=num_workers,
                              pin_memory=True)

    tbar = tqdm(train_loader, ascii=True, desc='train', dynamic_ncols=True)
    for batch_idx, data in enumerate(tbar):
        imgs, labels = data['image'].cuda(), data['label'].cuda()
        outputs = net(imgs)

        losses = {}
        for key, up_outputs in outputs.items():
            b, c, h, w = up_outputs.shape
            up_labels = torch.unsqueeze(labels.float(), dim=1)
            up_labels = F.interpolate(up_labels, size=(h, w), mode='bilinear')
            up_labels = torch.squeeze(up_labels, dim=1).long()

            up_labels_onehot = class2one_hot(up_labels, 3)
            up_outputs = F.softmax(up_outputs, dim=1)
            up_loss = criterion(up_outputs, up_labels_onehot)
            losses[key] = up_loss

        loss = sum(losses.values())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if vis_intvl > 0 and batch_idx % vis_intvl == 0:
            data['predict'] = outputs['output']
            data = dataset.vis_transform(data)
            imgs, labels, predicts = data['image'], data['label'], data[
                'predict']
            imshow(title='Train',
                   imgs=(imgs[0, dataset.img_channels // 2], labels[0],
                         predicts[0]),
                   shape=(1, 3),
                   subtitle=('image', 'label', 'predict'))

        losses['total'] = loss
        for k in losses.keys():
            losses[k] = losses[k].item()
        tbar.set_postfix(losses)

    scheduler.step(loss.item())

    for k, v in losses.items():
        logger.add_scalar(f'loss/{k}', v, epoch)

    return loss.item()
def training(net, dataset, criterion, optimizer, scheduler, epoch, batch_size,
             num_workers, vis_intvl, logger):
    sampler = RandomSampler(dataset.train_dataset)

    train_loader = DataLoader(dataset.train_dataset,
                              batch_size=batch_size,
                              sampler=sampler,
                              num_workers=num_workers,
                              pin_memory=True)

    tbar = tqdm(train_loader, ascii=True, desc='train', dynamic_ncols=True)
    for batch_idx, data in enumerate(tbar):
        imgs, labels = data['image'].cuda(), data['label'].cuda()
        outputs = net(imgs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if vis_intvl > 0 and batch_idx % vis_intvl == 0:
            data['predict'] = outputs
            data = dataset.vis_transform(data)
            imgs, labels, predicts = data['image'], data['label'], data[
                'predict']
            imshow(title='Train',
                   imgs=(imgs[0, dataset.img_channels // 2], labels[0],
                         predicts[0]),
                   shape=(1, 3),
                   subtitle=('image', 'label', 'predict'))

        tbar.set_postfix(loss=f'{loss.item():.5f}')

    scheduler.step(loss.item())

    logger.add_scalar('loss', loss.item(), epoch)
    return loss.item()
def evaluation(net, dataset, epoch, batch_size, num_workers, vis_intvl, logger,
               type):
    type = type.lower()
    if type == 'train':
        subset = dataset.train_dataset
        case_slice_indices = dataset.train_case_slice_indices
    elif type == 'valid':
        subset = dataset.valid_dataset
        case_slice_indices = dataset.valid_case_slice_indices

    sampler = SequentialSampler(subset)
    data_loader = DataLoader(subset,
                             batch_size=batch_size,
                             sampler=sampler,
                             num_workers=num_workers,
                             pin_memory=True)
    evaluator = Evaluator(dataset.num_classes)

    case = 0
    vol_label = []
    vol_output = []

    with tqdm(total=len(case_slice_indices) - 1,
              ascii=True,
              desc=f'eval/{type:5}',
              dynamic_ncols=True) as pbar:
        for batch_idx, data in enumerate(data_loader):
            imgs, labels, idx = data['image'].cuda(
            ), data['label'], data['index']

            outputs = net(imgs)
            outputs = outputs.argmax(dim=1)

            labels = labels.cpu().detach().numpy()
            outputs = outputs.cpu().detach().numpy()
            idx = idx.numpy()

            vol_label.append(labels)
            vol_output.append(outputs)

            while case < len(case_slice_indices) - 1 and idx[
                    -1] >= case_slice_indices[case + 1] - 1:
                vol_output = np.concatenate(vol_output, axis=0)
                vol_label = np.concatenate(vol_label, axis=0)

                vol_num_slice = case_slice_indices[
                    case + 1] - case_slice_indices[case]
                evaluator.add(vol_output[:vol_num_slice],
                              vol_label[:vol_num_slice])

                vol_output = [vol_output[vol_num_slice:]]
                vol_label = [vol_label[vol_num_slice:]]
                case += 1
                pbar.update(1)

            if vis_intvl > 0 and batch_idx % vis_intvl == 0:
                data['predict'] = outputs
                data = dataset.vis_transform(data)
                imgs, labels, predicts = data['image'], data['label'], data[
                    'predict']
                imshow(title=f'eval/{type:5}',
                       imgs=(imgs[0, dataset.img_channels // 2], labels[0],
                             predicts[0]),
                       shape=(1, 3),
                       subtitle=('image', 'label', 'predict'))
Esempio n. 5
0
def main(batch_size, num_gpu, img_size, data_path, resume, output_path, vis_intvl, num_workers):
    data_path = Path(data_path)
    output_path = Path(output_path)
    if not output_path.exists():
        output_path.mkdir(parents=True)
    
    roi_error_range = 15
    transform = MedicalTransform(output_size=img_size, roi_error_range=roi_error_range, use_roi=True)
    
    dataset = KiTS19(data_path, stack_num=3, spec_classes=[0, 1, 2], img_size=img_size,
                     use_roi=True, roi_file='roi.json', roi_error_range=5, test_transform=transform)
    
    net = DenseUNet(in_ch=dataset.img_channels, out_ch=dataset.num_classes)
    
    if resume:
        data = {'net': net}
        cp_file = Path(resume)
        cp.load_params(data, cp_file, device='cpu')
    
    gpu_ids = [i for i in range(num_gpu)]
    
    print(f'{" Start evaluation ":-^40s}\n')
    msg = f'Net: {net.__class__.__name__}\n' + \
          f'Dataset: {dataset.__class__.__name__}\n' + \
          f'Batch size: {batch_size}\n' + \
          f'Device: cuda{str(gpu_ids)}\n'
    print(msg)
    
    torch.cuda.empty_cache()
    
    net = torch.nn.DataParallel(net, device_ids=gpu_ids).cuda()
    
    net.eval()
    torch.set_grad_enabled(False)
    transform.eval()
    
    subset = dataset.test_dataset
    case_slice_indices = dataset.test_case_slice_indices
    
    sampler = SequentialSampler(subset)
    data_loader = DataLoader(subset, batch_size=batch_size, sampler=sampler,
                             num_workers=num_workers, pin_memory=True)
    
    case = 0
    vol_output = []
    
    with tqdm(total=len(case_slice_indices) - 1, ascii=True, desc=f'eval/test', dynamic_ncols=True) as pbar:
        for batch_idx, data in enumerate(data_loader):
            imgs, idx = data['image'].cuda(), data['index']
            
            outputs = net(imgs)
            predicts = outputs['output']
            predicts = predicts.argmax(dim=1)
            
            predicts = predicts.cpu().detach().numpy()
            idx = idx.numpy()
            
            vol_output.append(predicts)
            
            while case < len(case_slice_indices) - 1 and idx[-1] >= case_slice_indices[case + 1] - 1:
                vol_output = np.concatenate(vol_output, axis=0)
                vol_num_slice = case_slice_indices[case + 1] - case_slice_indices[case]
                
                roi = dataset.get_roi(case, type='test')
                vol = vol_output[:vol_num_slice]
                vol_ = reverse_transform(vol, roi, dataset, transform)
                vol_ = vol_.astype(np.uint8)
                
                case_id = dataset.case_idx_to_case_id(case, type='test')
                affine = np.load(data_path / f'case_{case_id:05d}' / 'affine.npy')
                vol_nii = nib.Nifti1Image(vol_, affine)
                vol_nii_filename = output_path / f'prediction_{case_id:05d}.nii.gz'
                vol_nii.to_filename(str(vol_nii_filename))
                
                vol_output = [vol_output[vol_num_slice:]]
                case += 1
                pbar.update(1)
            
            if vis_intvl > 0 and batch_idx % vis_intvl == 0:
                data['predict'] = predicts
                data = dataset.vis_transform(data)
                imgs, predicts = data['image'], data['predict']
                imshow(title=f'eval/test', imgs=(imgs[0, 1], predicts[0]), shape=(1, 2),
                       subtitle=('image', 'predict'))
Esempio n. 6
0
def get_roi_from_resunet(batch_size, num_gpu, img_size, data_path, resume, roi_file, vis_intvl, num_workers):
    with open(roi_file, 'r') as f:
        rois = json.load(f)
    
    data_path = Path(data_path)
    
    transform = MedicalTransform(output_size=img_size, roi_error_range=15, use_roi=False)
    
    dataset = KiTS19(data_path, stack_num=5, spec_classes=[0, 1, 1], img_size=img_size,
                     use_roi=False, train_transform=transform, valid_transform=transform)
    
    net = ResUNet(in_ch=dataset.img_channels, out_ch=dataset.num_classes, base_ch=64)
    
    if resume:
        data = {'net': net}
        cp_file = Path(resume)
        cp.load_params(data, cp_file, device='cpu')
    
    gpu_ids = [i for i in range(num_gpu)]
    
    torch.cuda.empty_cache()
    
    net = torch.nn.DataParallel(net, device_ids=gpu_ids).cuda()
    
    net.eval()
    torch.set_grad_enabled(False)
    transform.eval()
    
    subset = dataset.test_dataset
    case_slice_indices = dataset.test_case_slice_indices
    
    sampler = SequentialSampler(subset)
    data_loader = DataLoader(subset, batch_size=batch_size, sampler=sampler,
                             num_workers=num_workers, pin_memory=True)
    
    case = 0
    vol_output = []
    
    with tqdm(total=len(case_slice_indices) - 1, ascii=True, desc=f'eval/test', dynamic_ncols=True) as pbar:
        for batch_idx, data in enumerate(data_loader):
            imgs, idx = data['image'].cuda(), data['index']
            
            predicts = net(imgs)
            predicts = predicts.argmax(dim=1)
            
            predicts = predicts.cpu().detach().numpy()
            idx = idx.numpy()
            
            vol_output.append(predicts)
            
            while case < len(case_slice_indices) - 1 and idx[-1] >= case_slice_indices[case + 1] - 1:
                vol_output = np.concatenate(vol_output, axis=0)
                vol_num_slice = case_slice_indices[case + 1] - case_slice_indices[case]
                
                vol = vol_output[:vol_num_slice]
                kidney = calc(vol, idx=1)
                case_roi = {'kidney': kidney}
                case_id = dataset.case_idx_to_case_id(case, 'test')
                rois[f'case_{case_id:05d}'].update(case_roi)
                with open(roi_file, 'w') as f:
                    json.dump(rois, f, indent=4, separators=(',', ': '))
                
                vol_output = [vol_output[vol_num_slice:]]
                case += 1
                pbar.update(1)
            
            if vis_intvl > 0 and batch_idx % vis_intvl == 0:
                data['predict'] = predicts
                data = dataset.vis_transform(data)
                imgs, predicts = data['image'], data['predict']
                imshow(title=f'eval/test', imgs=(imgs[0, 1], predicts[0]), shape=(1, 2),
                       subtitle=('image', 'predict'))
        # overlay = np.zeros_like(label)
        # overlay[np.logical_and(label_boundary, pred_boundary)] = 255
        
        img_boindary = img.copy()
        img_boindary[pred_boundary == 255] = [255, 0, 0]
        img_boindary[label_boundary == 255] = [0, 255, 0]
        # img_boindary[overlay == 255] = [0, 0, 255]
        
        imgs.append(img_boindary)
    
    return imgs


if __name__ == '__main__':
    from utils.vis import imshow
    
    root = Path('data')
    cases = sorted([d for d in root.iterdir() if d.is_dir()])
    for case in cases:
        img_path = case / 'imaging'
        seg_path = case / 'segmentation'
        num_slice = len(list(img_path.glob('*.npy')))
        for i in range(num_slice // 2, num_slice // 3 * 2):
            img_file = img_path / f'{i:03d}.npy'
            seg_file = seg_path / f'{i:03d}.npy'
            img = np.load(str(img_file))
            seg = np.load(str(seg_file))
            vis_img = vis_boundary(img, seg, seg, 3)
            imshow('vis1', vis_img[0])
            imshow('vis2', vis_img[1])
Esempio n. 8
0
        if posmask.any():
            negmask = ~posmask
            res[:, c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask
    
    return res


if __name__ == '__main__':
    a = np.load('../data/case_00000/segmentation/296.npy')
    b = np_class2one_hot(a, 3)
    c = one_hot2dist(b)[0]
    d = b[0].copy()
    d[0, ...] = 0
    d[1, ...] = 0
    d[2, ...] = 1
    
    multipled = np.einsum('cwh,cwh->cwh', d, c)
    
    loss = multipled.mean()
    
    d = (c + 437) / 3.44 / 255
    from utils.vis import imshow
    
    # imshow('b', b[0].transpose((1, 2, 0)) * 128, (1, 1))
    # imshow('c', c.transpose((1, 2, 0)), (1, 1))
    imshow('c0', c[0], (1, 1))
    imshow('c1', c[1], (1, 1))
    imshow('c2', c[2], (1, 1))
    # imshow('d', d.transpose((1, 2, 0)), (1, 1))
    ...