def main(epoch_num, batch_size, lr, num_gpu, img_size, data_path, log_path,
         resume, eval_intvl, cp_intvl, vis_intvl, num_workers):
    data_path = Path(data_path)
    log_path = Path(log_path)
    cp_path = log_path / 'checkpoint'

    if not resume and log_path.exists() and len(list(log_path.glob('*'))) > 0:
        print(f'log path "{str(log_path)}" has old file', file=sys.stderr)
        sys.exit(-1)
    if not cp_path.exists():
        cp_path.mkdir(parents=True)

    transform = MedicalTransform(output_size=img_size,
                                 roi_error_range=15,
                                 use_roi=False)

    dataset = KiTS19(data_path,
                     stack_num=5,
                     spec_classes=[0, 1, 1],
                     img_size=img_size,
                     use_roi=False,
                     train_transform=transform,
                     valid_transform=transform)

    net = ResUNet(in_ch=dataset.img_channels,
                  out_ch=dataset.num_classes,
                  base_ch=64)

    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    start_epoch = 0
    if resume:
        data = {'net': net, 'optimizer': optimizer, 'epoch': 0}
        cp_file = Path(resume)
        cp.load_params(data, cp_file, device='cpu')
        start_epoch = data['epoch'] + 1

    criterion = nn.CrossEntropyLoss()

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.1,
        patience=5,
        verbose=True,
        threshold=0.0001,
        threshold_mode='rel',
        cooldown=0,
        min_lr=0,
        eps=1e-08)

    logger = SummaryWriter(str(log_path))

    gpu_ids = [i for i in range(num_gpu)]

    print(f'{" Start training ":-^40s}\n')
    msg = f'Net: {net.__class__.__name__}\n' + \
          f'Dataset: {dataset.__class__.__name__}\n' + \
          f'Epochs: {epoch_num}\n' + \
          f'Learning rate: {optimizer.param_groups[0]["lr"]}\n' + \
          f'Batch size: {batch_size}\n' + \
          f'Device: cuda{str(gpu_ids)}\n'
    print(msg)

    torch.cuda.empty_cache()

    # to GPU device
    net = torch.nn.DataParallel(net, device_ids=gpu_ids).cuda()
    criterion = criterion.cuda()
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.cuda()

    # start training
    valid_score = 0.0
    best_score = 0.0
    best_epoch = 0

    for epoch in range(start_epoch, epoch_num):
        epoch_str = f' Epoch {epoch + 1}/{epoch_num} '
        print(f'{epoch_str:-^40s}')
        print(f'Learning rate: {optimizer.param_groups[0]["lr"]}')

        net.train()
        torch.set_grad_enabled(True)
        transform.train()
        try:
            loss = training(net, dataset, criterion, optimizer, scheduler,
                            epoch, batch_size, num_workers, vis_intvl, logger)

            if eval_intvl > 0 and (epoch + 1) % eval_intvl == 0:
                net.eval()
                torch.set_grad_enabled(False)
                transform.eval()

                train_score = evaluation(net,
                                         dataset,
                                         epoch,
                                         batch_size,
                                         num_workers,
                                         vis_intvl,
                                         logger,
                                         type='train')
                valid_score = evaluation(net,
                                         dataset,
                                         epoch,
                                         batch_size,
                                         num_workers,
                                         vis_intvl,
                                         logger,
                                         type='valid')

                print(f'Train data score: {train_score:.5f}')
                print(f'Valid data score: {valid_score:.5f}')

            if valid_score > best_score:
                best_score = valid_score
                best_epoch = epoch
                cp_file = cp_path / 'best.pth'
                cp.save(epoch, net.module, optimizer, str(cp_file))
                print('Update best acc!')
                logger.add_scalar('best/epoch', best_epoch + 1, 0)
                logger.add_scalar('best/score', best_score, 0)

            if (epoch + 1) % cp_intvl == 0:
                cp_file = cp_path / f'cp_{epoch + 1:03d}.pth'
                cp.save(epoch, net.module, optimizer, str(cp_file))

            print(f'Best epoch: {best_epoch + 1}')
            print(f'Best score: {best_score:.5f}')

        except KeyboardInterrupt:
            cp_file = cp_path / 'INTERRUPTED.pth'
            cp.save(epoch, net.module, optimizer, str(cp_file))
            return
Пример #2
0
def get_roi_from_resunet(batch_size, num_gpu, img_size, data_path, resume, roi_file, vis_intvl, num_workers):
    with open(roi_file, 'r') as f:
        rois = json.load(f)
    
    data_path = Path(data_path)
    
    transform = MedicalTransform(output_size=img_size, roi_error_range=15, use_roi=False)
    
    dataset = KiTS19(data_path, stack_num=5, spec_classes=[0, 1, 1], img_size=img_size,
                     use_roi=False, train_transform=transform, valid_transform=transform)
    
    net = ResUNet(in_ch=dataset.img_channels, out_ch=dataset.num_classes, base_ch=64)
    
    if resume:
        data = {'net': net}
        cp_file = Path(resume)
        cp.load_params(data, cp_file, device='cpu')
    
    gpu_ids = [i for i in range(num_gpu)]
    
    torch.cuda.empty_cache()
    
    net = torch.nn.DataParallel(net, device_ids=gpu_ids).cuda()
    
    net.eval()
    torch.set_grad_enabled(False)
    transform.eval()
    
    subset = dataset.test_dataset
    case_slice_indices = dataset.test_case_slice_indices
    
    sampler = SequentialSampler(subset)
    data_loader = DataLoader(subset, batch_size=batch_size, sampler=sampler,
                             num_workers=num_workers, pin_memory=True)
    
    case = 0
    vol_output = []
    
    with tqdm(total=len(case_slice_indices) - 1, ascii=True, desc=f'eval/test', dynamic_ncols=True) as pbar:
        for batch_idx, data in enumerate(data_loader):
            imgs, idx = data['image'].cuda(), data['index']
            
            predicts = net(imgs)
            predicts = predicts.argmax(dim=1)
            
            predicts = predicts.cpu().detach().numpy()
            idx = idx.numpy()
            
            vol_output.append(predicts)
            
            while case < len(case_slice_indices) - 1 and idx[-1] >= case_slice_indices[case + 1] - 1:
                vol_output = np.concatenate(vol_output, axis=0)
                vol_num_slice = case_slice_indices[case + 1] - case_slice_indices[case]
                
                vol = vol_output[:vol_num_slice]
                kidney = calc(vol, idx=1)
                case_roi = {'kidney': kidney}
                case_id = dataset.case_idx_to_case_id(case, 'test')
                rois[f'case_{case_id:05d}'].update(case_roi)
                with open(roi_file, 'w') as f:
                    json.dump(rois, f, indent=4, separators=(',', ': '))
                
                vol_output = [vol_output[vol_num_slice:]]
                case += 1
                pbar.update(1)
            
            if vis_intvl > 0 and batch_idx % vis_intvl == 0:
                data['predict'] = predicts
                data = dataset.vis_transform(data)
                imgs, predicts = data['image'], data['predict']
                imshow(title=f'eval/test', imgs=(imgs[0, 1], predicts[0]), shape=(1, 2),
                       subtitle=('image', 'predict'))
Пример #3
0
def main(batch_size, num_gpu, img_size, data_path, resume, output_path, vis_intvl, num_workers):
    data_path = Path(data_path)
    output_path = Path(output_path)
    if not output_path.exists():
        output_path.mkdir(parents=True)
    
    roi_error_range = 15
    transform = MedicalTransform(output_size=img_size, roi_error_range=roi_error_range, use_roi=True)
    
    dataset = KiTS19(data_path, stack_num=3, spec_classes=[0, 1, 2], img_size=img_size,
                     use_roi=True, roi_file='roi.json', roi_error_range=5, test_transform=transform)
    
    net = DenseUNet(in_ch=dataset.img_channels, out_ch=dataset.num_classes)
    
    if resume:
        data = {'net': net}
        cp_file = Path(resume)
        cp.load_params(data, cp_file, device='cpu')
    
    gpu_ids = [i for i in range(num_gpu)]
    
    print(f'{" Start evaluation ":-^40s}\n')
    msg = f'Net: {net.__class__.__name__}\n' + \
          f'Dataset: {dataset.__class__.__name__}\n' + \
          f'Batch size: {batch_size}\n' + \
          f'Device: cuda{str(gpu_ids)}\n'
    print(msg)
    
    torch.cuda.empty_cache()
    
    net = torch.nn.DataParallel(net, device_ids=gpu_ids).cuda()
    
    net.eval()
    torch.set_grad_enabled(False)
    transform.eval()
    
    subset = dataset.test_dataset
    case_slice_indices = dataset.test_case_slice_indices
    
    sampler = SequentialSampler(subset)
    data_loader = DataLoader(subset, batch_size=batch_size, sampler=sampler,
                             num_workers=num_workers, pin_memory=True)
    
    case = 0
    vol_output = []
    
    with tqdm(total=len(case_slice_indices) - 1, ascii=True, desc=f'eval/test', dynamic_ncols=True) as pbar:
        for batch_idx, data in enumerate(data_loader):
            imgs, idx = data['image'].cuda(), data['index']
            
            outputs = net(imgs)
            predicts = outputs['output']
            predicts = predicts.argmax(dim=1)
            
            predicts = predicts.cpu().detach().numpy()
            idx = idx.numpy()
            
            vol_output.append(predicts)
            
            while case < len(case_slice_indices) - 1 and idx[-1] >= case_slice_indices[case + 1] - 1:
                vol_output = np.concatenate(vol_output, axis=0)
                vol_num_slice = case_slice_indices[case + 1] - case_slice_indices[case]
                
                roi = dataset.get_roi(case, type='test')
                vol = vol_output[:vol_num_slice]
                vol_ = reverse_transform(vol, roi, dataset, transform)
                vol_ = vol_.astype(np.uint8)
                
                case_id = dataset.case_idx_to_case_id(case, type='test')
                affine = np.load(data_path / f'case_{case_id:05d}' / 'affine.npy')
                vol_nii = nib.Nifti1Image(vol_, affine)
                vol_nii_filename = output_path / f'prediction_{case_id:05d}.nii.gz'
                vol_nii.to_filename(str(vol_nii_filename))
                
                vol_output = [vol_output[vol_num_slice:]]
                case += 1
                pbar.update(1)
            
            if vis_intvl > 0 and batch_idx % vis_intvl == 0:
                data['predict'] = predicts
                data = dataset.vis_transform(data)
                imgs, predicts = data['image'], data['predict']
                imshow(title=f'eval/test', imgs=(imgs[0, 1], predicts[0]), shape=(1, 2),
                       subtitle=('image', 'predict'))
Пример #4
0
def conversion_all(data, output):
    data = Path(data)
    output = Path(output)

    cases = sorted([d for d in data.iterdir() if d.is_dir()])
    pool = mp.Pool()
    pool.map(conversion, zip(cases, [output] * len(cases)))
    pool.close()
    pool.join()


def conversion(data):
    case, output = data
    vol_nii = nib.load(str(case / 'imaging.nii.gz'))
    vol = vol_nii.get_data()
    vol = KiTS19.normalize(vol)

    imaging_dir = output / case.name / 'imaging'
    if not imaging_dir.exists():
        imaging_dir.mkdir(parents=True)
    if len(list(imaging_dir.glob('*.npy'))) != vol.shape[0]:
        for i in range(vol.shape[0]):
            np.save(str(imaging_dir / f'{i:03}.npy'), vol[i])

    segmentation_file = case / 'segmentation.nii.gz'
    if segmentation_file.exists():
        seg = nib.load(str(case / 'segmentation.nii.gz')).get_data()
        segmentation_dir = output / case.name / 'segmentation'
        if not segmentation_dir.exists():
            segmentation_dir.mkdir(parents=True)
        if len(list(segmentation_dir.glob('*.npy'))) != seg.shape[0]: