def main(epoch_num, batch_size, lr, num_gpu, img_size, data_path, log_path, resume, eval_intvl, cp_intvl, vis_intvl, num_workers): data_path = Path(data_path) log_path = Path(log_path) cp_path = log_path / 'checkpoint' if not resume and log_path.exists() and len(list(log_path.glob('*'))) > 0: print(f'log path "{str(log_path)}" has old file', file=sys.stderr) sys.exit(-1) if not cp_path.exists(): cp_path.mkdir(parents=True) transform = MedicalTransform(output_size=img_size, roi_error_range=15, use_roi=False) dataset = KiTS19(data_path, stack_num=5, spec_classes=[0, 1, 1], img_size=img_size, use_roi=False, train_transform=transform, valid_transform=transform) net = ResUNet(in_ch=dataset.img_channels, out_ch=dataset.num_classes, base_ch=64) optimizer = torch.optim.Adam(net.parameters(), lr=lr) start_epoch = 0 if resume: data = {'net': net, 'optimizer': optimizer, 'epoch': 0} cp_file = Path(resume) cp.load_params(data, cp_file, device='cpu') start_epoch = data['epoch'] + 1 criterion = nn.CrossEntropyLoss() scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.1, patience=5, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08) logger = SummaryWriter(str(log_path)) gpu_ids = [i for i in range(num_gpu)] print(f'{" Start training ":-^40s}\n') msg = f'Net: {net.__class__.__name__}\n' + \ f'Dataset: {dataset.__class__.__name__}\n' + \ f'Epochs: {epoch_num}\n' + \ f'Learning rate: {optimizer.param_groups[0]["lr"]}\n' + \ f'Batch size: {batch_size}\n' + \ f'Device: cuda{str(gpu_ids)}\n' print(msg) torch.cuda.empty_cache() # to GPU device net = torch.nn.DataParallel(net, device_ids=gpu_ids).cuda() criterion = criterion.cuda() for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() # start training valid_score = 0.0 best_score = 0.0 best_epoch = 0 for epoch in range(start_epoch, epoch_num): epoch_str = f' Epoch {epoch + 1}/{epoch_num} ' print(f'{epoch_str:-^40s}') print(f'Learning rate: {optimizer.param_groups[0]["lr"]}') net.train() torch.set_grad_enabled(True) transform.train() try: loss = training(net, dataset, criterion, optimizer, scheduler, epoch, batch_size, num_workers, vis_intvl, logger) if eval_intvl > 0 and (epoch + 1) % eval_intvl == 0: net.eval() torch.set_grad_enabled(False) transform.eval() train_score = evaluation(net, dataset, epoch, batch_size, num_workers, vis_intvl, logger, type='train') valid_score = evaluation(net, dataset, epoch, batch_size, num_workers, vis_intvl, logger, type='valid') print(f'Train data score: {train_score:.5f}') print(f'Valid data score: {valid_score:.5f}') if valid_score > best_score: best_score = valid_score best_epoch = epoch cp_file = cp_path / 'best.pth' cp.save(epoch, net.module, optimizer, str(cp_file)) print('Update best acc!') logger.add_scalar('best/epoch', best_epoch + 1, 0) logger.add_scalar('best/score', best_score, 0) if (epoch + 1) % cp_intvl == 0: cp_file = cp_path / f'cp_{epoch + 1:03d}.pth' cp.save(epoch, net.module, optimizer, str(cp_file)) print(f'Best epoch: {best_epoch + 1}') print(f'Best score: {best_score:.5f}') except KeyboardInterrupt: cp_file = cp_path / 'INTERRUPTED.pth' cp.save(epoch, net.module, optimizer, str(cp_file)) return
def get_roi_from_resunet(batch_size, num_gpu, img_size, data_path, resume, roi_file, vis_intvl, num_workers): with open(roi_file, 'r') as f: rois = json.load(f) data_path = Path(data_path) transform = MedicalTransform(output_size=img_size, roi_error_range=15, use_roi=False) dataset = KiTS19(data_path, stack_num=5, spec_classes=[0, 1, 1], img_size=img_size, use_roi=False, train_transform=transform, valid_transform=transform) net = ResUNet(in_ch=dataset.img_channels, out_ch=dataset.num_classes, base_ch=64) if resume: data = {'net': net} cp_file = Path(resume) cp.load_params(data, cp_file, device='cpu') gpu_ids = [i for i in range(num_gpu)] torch.cuda.empty_cache() net = torch.nn.DataParallel(net, device_ids=gpu_ids).cuda() net.eval() torch.set_grad_enabled(False) transform.eval() subset = dataset.test_dataset case_slice_indices = dataset.test_case_slice_indices sampler = SequentialSampler(subset) data_loader = DataLoader(subset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=True) case = 0 vol_output = [] with tqdm(total=len(case_slice_indices) - 1, ascii=True, desc=f'eval/test', dynamic_ncols=True) as pbar: for batch_idx, data in enumerate(data_loader): imgs, idx = data['image'].cuda(), data['index'] predicts = net(imgs) predicts = predicts.argmax(dim=1) predicts = predicts.cpu().detach().numpy() idx = idx.numpy() vol_output.append(predicts) while case < len(case_slice_indices) - 1 and idx[-1] >= case_slice_indices[case + 1] - 1: vol_output = np.concatenate(vol_output, axis=0) vol_num_slice = case_slice_indices[case + 1] - case_slice_indices[case] vol = vol_output[:vol_num_slice] kidney = calc(vol, idx=1) case_roi = {'kidney': kidney} case_id = dataset.case_idx_to_case_id(case, 'test') rois[f'case_{case_id:05d}'].update(case_roi) with open(roi_file, 'w') as f: json.dump(rois, f, indent=4, separators=(',', ': ')) vol_output = [vol_output[vol_num_slice:]] case += 1 pbar.update(1) if vis_intvl > 0 and batch_idx % vis_intvl == 0: data['predict'] = predicts data = dataset.vis_transform(data) imgs, predicts = data['image'], data['predict'] imshow(title=f'eval/test', imgs=(imgs[0, 1], predicts[0]), shape=(1, 2), subtitle=('image', 'predict'))
def main(batch_size, num_gpu, img_size, data_path, resume, output_path, vis_intvl, num_workers): data_path = Path(data_path) output_path = Path(output_path) if not output_path.exists(): output_path.mkdir(parents=True) roi_error_range = 15 transform = MedicalTransform(output_size=img_size, roi_error_range=roi_error_range, use_roi=True) dataset = KiTS19(data_path, stack_num=3, spec_classes=[0, 1, 2], img_size=img_size, use_roi=True, roi_file='roi.json', roi_error_range=5, test_transform=transform) net = DenseUNet(in_ch=dataset.img_channels, out_ch=dataset.num_classes) if resume: data = {'net': net} cp_file = Path(resume) cp.load_params(data, cp_file, device='cpu') gpu_ids = [i for i in range(num_gpu)] print(f'{" Start evaluation ":-^40s}\n') msg = f'Net: {net.__class__.__name__}\n' + \ f'Dataset: {dataset.__class__.__name__}\n' + \ f'Batch size: {batch_size}\n' + \ f'Device: cuda{str(gpu_ids)}\n' print(msg) torch.cuda.empty_cache() net = torch.nn.DataParallel(net, device_ids=gpu_ids).cuda() net.eval() torch.set_grad_enabled(False) transform.eval() subset = dataset.test_dataset case_slice_indices = dataset.test_case_slice_indices sampler = SequentialSampler(subset) data_loader = DataLoader(subset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=True) case = 0 vol_output = [] with tqdm(total=len(case_slice_indices) - 1, ascii=True, desc=f'eval/test', dynamic_ncols=True) as pbar: for batch_idx, data in enumerate(data_loader): imgs, idx = data['image'].cuda(), data['index'] outputs = net(imgs) predicts = outputs['output'] predicts = predicts.argmax(dim=1) predicts = predicts.cpu().detach().numpy() idx = idx.numpy() vol_output.append(predicts) while case < len(case_slice_indices) - 1 and idx[-1] >= case_slice_indices[case + 1] - 1: vol_output = np.concatenate(vol_output, axis=0) vol_num_slice = case_slice_indices[case + 1] - case_slice_indices[case] roi = dataset.get_roi(case, type='test') vol = vol_output[:vol_num_slice] vol_ = reverse_transform(vol, roi, dataset, transform) vol_ = vol_.astype(np.uint8) case_id = dataset.case_idx_to_case_id(case, type='test') affine = np.load(data_path / f'case_{case_id:05d}' / 'affine.npy') vol_nii = nib.Nifti1Image(vol_, affine) vol_nii_filename = output_path / f'prediction_{case_id:05d}.nii.gz' vol_nii.to_filename(str(vol_nii_filename)) vol_output = [vol_output[vol_num_slice:]] case += 1 pbar.update(1) if vis_intvl > 0 and batch_idx % vis_intvl == 0: data['predict'] = predicts data = dataset.vis_transform(data) imgs, predicts = data['image'], data['predict'] imshow(title=f'eval/test', imgs=(imgs[0, 1], predicts[0]), shape=(1, 2), subtitle=('image', 'predict'))
def conversion_all(data, output): data = Path(data) output = Path(output) cases = sorted([d for d in data.iterdir() if d.is_dir()]) pool = mp.Pool() pool.map(conversion, zip(cases, [output] * len(cases))) pool.close() pool.join() def conversion(data): case, output = data vol_nii = nib.load(str(case / 'imaging.nii.gz')) vol = vol_nii.get_data() vol = KiTS19.normalize(vol) imaging_dir = output / case.name / 'imaging' if not imaging_dir.exists(): imaging_dir.mkdir(parents=True) if len(list(imaging_dir.glob('*.npy'))) != vol.shape[0]: for i in range(vol.shape[0]): np.save(str(imaging_dir / f'{i:03}.npy'), vol[i]) segmentation_file = case / 'segmentation.nii.gz' if segmentation_file.exists(): seg = nib.load(str(case / 'segmentation.nii.gz')).get_data() segmentation_dir = output / case.name / 'segmentation' if not segmentation_dir.exists(): segmentation_dir.mkdir(parents=True) if len(list(segmentation_dir.glob('*.npy'))) != seg.shape[0]: