def val_dataloader(self): return get_data_loader( dataset_name=self.config.dataset.name, modalities=self.config.dataset.modalities, root_dir_paths=self.config.dataset.root_dir_paths, augmentation_type='none', use_shuffle=False, batch_size=self.config.dataset.batch_size, num_workers=self.config.dataset.num_workers, )
def train_dataloader(self): data_loader = get_data_loader( dataset_name=self.config.dataset.name, modalities=self.config.dataset.modalities, root_dir_paths=self.config.dataset.root_dir_paths, augmentation_type=self.config.dataset.augmentation_type, use_shuffle=self.config.dataset.use_shuffle, batch_size=self.config.dataset.batch_size, num_workers=self.config.dataset.num_workers, ) return data_loader
def main(config, needs_save): os.environ['CUDA_VISIBLE_DEVICES'] = config.training.visible_devices seed = check_manual_seed(config.training.seed) print('Using manual seed: {}'.format(seed)) if config.dataset.patient_ids == 'TRAIN_PATIENT_IDS': patient_ids = TRAIN_PATIENT_IDS elif config.dataset.patient_ids == 'TEST_PATIENT_IDS': patient_ids = TEST_PATIENT_IDS else: raise NotImplementedError data_loader = get_data_loader( mode=config.dataset.mode, dataset_name=config.dataset.name, patient_ids=patient_ids, root_dir_path=config.dataset.root_dir_path, use_augmentation=config.dataset.use_augmentation, batch_size=config.dataset.batch_size, num_workers=config.dataset.num_workers, image_size=config.dataset.image_size) E = Encoder(input_dim=config.model.input_dim, z_dim=config.model.z_dim, filters=config.model.enc_filters, activation=config.model.enc_activation).float() D = Decoder(input_dim=config.model.input_dim, z_dim=config.model.z_dim, filters=config.model.dec_filters, activation=config.model.dec_activation, final_activation=config.model.dec_final_activation).float() if config.model.enc_spectral_norm: apply_spectral_norm(E) if config.model.dec_spectral_norm: apply_spectral_norm(D) if config.training.use_cuda: E.cuda() D.cuda() E = nn.DataParallel(E) D = nn.DataParallel(D) if config.model.saved_E: print(config.model.saved_E) E.load_state_dict(torch.load(config.model.saved_E)) if config.model.saved_D: print(config.model.saved_D) D.load_state_dict(torch.load(config.model.saved_D)) print(E) print(D) e_optim = optim.Adam(filter(lambda p: p.requires_grad, E.parameters()), config.optimizer.enc_lr, [0.9, 0.9999]) d_optim = optim.Adam(filter(lambda p: p.requires_grad, D.parameters()), config.optimizer.dec_lr, [0.9, 0.9999]) alpha = config.training.alpha beta = config.training.beta margin = config.training.margin batch_size = config.dataset.batch_size fixed_z = torch.randn(calc_latent_dim(config)) if 'ssim' in config.training.loss: ssim_loss = pytorch_ssim.SSIM(window_size=11) def l_recon(recon: torch.Tensor, target: torch.Tensor): if config.training.loss == 'l2': loss = F.mse_loss(recon, target, reduction='sum') elif config.training.loss == 'l1': loss = F.l1_loss(recon, target, reduction='sum') elif config.training.loss == 'ssim': loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) elif config.training.loss == 'ssim+l1': loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) \ + F.l1_loss(recon, target, reduction='sum') elif config.training.loss == 'ssim+l2': loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) \ + F.mse_loss(recon, target, reduction='sum') else: raise NotImplementedError return beta * loss / batch_size def l_reg(mu: torch.Tensor, log_var: torch.Tensor): loss = -0.5 * torch.sum(1 + log_var - mu**2 - torch.exp(log_var)) return loss / batch_size def update(engine, batch): E.train() D.train() image = norm(batch['image']) if config.training.use_cuda: image = image.cuda(non_blocking=True).float() else: image = image.float() e_optim.zero_grad() d_optim.zero_grad() z, z_mu, z_logvar = E(image) x_r = D(z) l_vae_reg = l_reg(z_mu, z_logvar) l_vae_recon = l_recon(x_r, image) l_vae_total = l_vae_reg + l_vae_recon l_vae_total.backward() e_optim.step() d_optim.step() if config.training.use_cuda: torch.cuda.synchronize() return { 'TotalLoss': l_vae_total.item(), 'EncodeLoss': l_vae_reg.item(), 'ReconLoss': l_vae_recon.item(), } output_dir = get_output_dir_path(config) trainer = Engine(update) timer = Timer(average=True) monitoring_metrics = ['TotalLoss', 'EncodeLoss', 'ReconLoss'] for metric in monitoring_metrics: RunningAverage(alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric)).attach( trainer, metric) pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) @trainer.on(Events.STARTED) def save_config(engine): config_to_save = defaultdict(dict) for key, child in config._asdict().items(): for k, v in child._asdict().items(): config_to_save[key][k] = v config_to_save['seed'] = seed config_to_save['output_dir'] = output_dir print('Training starts by the following configuration: ', config_to_save) if needs_save: save_path = os.path.join(output_dir, 'config.json') with open(save_path, 'w') as f: json.dump(config_to_save, f) @trainer.on(Events.ITERATION_COMPLETED) def show_logs(engine): if (engine.state.iteration - 1) % config.save.log_iter_interval == 0: columns = ['epoch', 'iteration'] + list( engine.state.metrics.keys()) values = [str(engine.state.epoch), str(engine.state.iteration)] \ + [str(value) for value in engine.state.metrics.values()] message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format( epoch=engine.state.epoch, max_epoch=config.training.n_epochs, i=engine.state.iteration, max_i=len(data_loader)) for name, value in zip(columns, values): message += ' | {name}: {value}'.format(name=name, value=value) pbar.log_message(message) @trainer.on(Events.EPOCH_COMPLETED) def save_logs(engine): if needs_save: fname = os.path.join(output_dir, 'logs.tsv') columns = ['epoch', 'iteration'] + list( engine.state.metrics.keys()) values = [str(engine.state.epoch), str(engine.state.iteration)] \ + [str(value) for value in engine.state.metrics.values()] with open(fname, 'a') as f: if f.tell() == 0: print('\t'.join(columns), file=f) print('\t'.join(values), file=f) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format( engine.state.epoch, timer.value())) timer.reset() @trainer.on(Events.EPOCH_COMPLETED) def save_images(engine): if needs_save: if engine.state.epoch % config.save.save_epoch_interval == 0: image = norm(engine.state.batch['image']) with torch.no_grad(): z, _, _ = E(image) x_r = D(z) x_p = D(fixed_z) image = denorm(image).detach().cpu() x_r = denorm(x_r).detach().cpu() x_p = denorm(x_p).detach().cpu() image = image[:config.save.n_save_images, ...] x_r = x_r[:config.save.n_save_images, ...] x_p = x_p[:config.save.n_save_images, ...] save_path = os.path.join( output_dir, 'result_{}.png'.format(engine.state.epoch)) save_image(torch.cat([image, x_r, x_p]).data, save_path) if needs_save: checkpoint_handler = ModelCheckpoint( output_dir, config.save.study_name, save_interval=config.save.save_epoch_interval, n_saved=config.save.n_saved, create_dir=True, ) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'E': E, 'D': D }) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) print('Training starts: [max_epochs] {}, [max_iterations] {}'.format( config.training.n_epochs, config.training.n_epochs * len(data_loader))) trainer.run(data_loader, config.training.n_epochs)
def main(config, needs_save, i): if config.run.visible_devices: os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices assert config.train_dataset.root_dir_path == config.val_dataset.root_dir_path # train_patient_ids, val_patient_ids = divide_patients(config.train_dataset.root_dir_path) train_patient_ids, val_patient_ids = get_cv_splits( config.train_dataset.root_dir_path, i) seed = check_manual_seed() print('Using seed: {}'.format(seed)) class_name_to_index = config.label_to_id._asdict() index_to_class_name = {v: k for k, v in class_name_to_index.items()} train_data_loader = get_data_loader( mode='train', dataset_name=config.train_dataset.dataset_name, root_dir_path=config.train_dataset.root_dir_path, patient_ids=train_patient_ids, batch_size=config.train_dataset.batch_size, num_workers=config.train_dataset.num_workers, volume_size=config.train_dataset.volume_size, ) val_data_loader = get_data_loader( mode='val', dataset_name=config.val_dataset.dataset_name, root_dir_path=config.val_dataset.root_dir_path, patient_ids=val_patient_ids, batch_size=config.val_dataset.batch_size, num_workers=config.val_dataset.num_workers, volume_size=config.val_dataset.volume_size, ) model = ResUNet( input_dim=config.model.input_dim, output_dim=config.model.output_dim, filters=config.model.filters, ) print(model) if config.run.use_cuda: model.cuda() model = nn.DataParallel(model) if config.model.saved_model: print('Loading saved model: {}'.format(config.model.saved_model)) model.load_state_dict(torch.load(config.model.saved_model)) else: print('Initializing weights.') init_weights(model, init_type=config.model.init_type) optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.optimizer.lr, betas=config.optimizer.betas, weight_decay=config.optimizer.weight_decay) dice_loss = SoftDiceLoss() focal_loss = FocalLoss( gamma=config.focal_loss.gamma, alpha=config.focal_loss.alpha, ) active_contour_loss = ActiveContourLoss( weight=config.active_contour_loss.weight, ) dice_coeff = DiceCoefficient( n_classes=config.metric.n_classes, index_to_class_name=index_to_class_name, ) one_hot_encoder = OneHotEncoder( n_classes=config.metric.n_classes, ).forward def train(engine, batch): adjust_learning_rate(optimizer, engine.state.epoch, initial_lr=config.optimizer.lr, n_epochs=config.run.n_epochs, gamma=config.optimizer.gamma) model.train() image = batch['image'] label = batch['label'] if config.run.use_cuda: image = image.cuda(non_blocking=True).float() label = label.cuda(non_blocking=True).long() else: image = image.float() label = label.long() optimizer.zero_grad() output = model(image) target = one_hot_encoder(label)[:, 1:, ...] l_dice = dice_loss(output, target) l_focal = focal_loss(output, target) l_active_contour = active_contour_loss(output, target) l_total = l_dice + l_focal + l_active_contour l_total.backward() optimizer.step() m_dice = dice_coeff.update(output.detach(), label) measures = { 'SoftDiceLoss': l_dice.item(), 'FocalLoss': l_focal.item(), 'ActiveContourLoss': l_active_contour.item(), } measures.update(m_dice) if config.run.use_cuda: torch.cuda.synchronize() return measures def evaluate(engine, batch): model.eval() image = batch['image'] label = batch['label'] if config.run.use_cuda: image = image.cuda(non_blocking=True).float() label = label.cuda(non_blocking=True).long() else: image = image.float() label = label.long() with torch.no_grad(): output = model(image) target = one_hot_encoder(label)[:, 1:, ...] l_dice = dice_loss(output, target) l_focal = focal_loss(output, target) l_active_contour = active_contour_loss(output, target) m_dice = dice_coeff.update(output.detach(), label) measures = { 'SoftDiceLoss': l_dice.item(), 'FocalLoss': l_focal.item(), 'ActiveContourLoss': l_active_contour.item(), } measures.update(m_dice) if config.run.use_cuda: torch.cuda.synchronize() return measures output_dir_path = get_output_dir_path(config, i) trainer = Engine(train) evaluator = Engine(evaluate) timer = Timer(average=True) if needs_save: checkpoint_handler = ModelCheckpoint( output_dir_path, config.save.study_name, save_interval=config.save.save_epoch_interval, n_saved=config.run.n_epochs + 1, create_dir=True, ) monitoring_metrics = ['SoftDiceLoss', 'FocalLoss', 'ActiveContourLoss'] monitoring_metrics += class_name_to_index.keys() for metric in monitoring_metrics: RunningAverage(alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric)).attach( trainer, metric) for metric in monitoring_metrics: RunningAverage(alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric)).attach( evaluator, metric) pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=monitoring_metrics) pbar.attach(evaluator, metric_names=monitoring_metrics) @trainer.on(Events.STARTED) def call_save_config(engine): if needs_save: return save_config(engine, config, seed, output_dir_path) @trainer.on(Events.EPOCH_COMPLETED) def call_save_logs(engine): if needs_save: return save_logs('train', engine, config, output_dir_path) @trainer.on(Events.EPOCH_COMPLETED) def call_print_times(engine): return print_times(engine, config, pbar, timer) @trainer.on(Events.EPOCH_COMPLETED) def run_validation(engine): evaluator.run(val_data_loader, 1) if needs_save: save_logs('val', evaluator, config, output_dir_path) save_images(evaluator, trainer.state.epoch) def save_images(evaluator, epoch): batch = evaluator.state.batch image = batch['image'] label = batch['label'] if config.run.use_cuda: image = image.cuda(non_blocking=True).float() label = label.cuda(non_blocking=True).long() else: image = image.float() label = label.long() with torch.no_grad(): pred = model(image) output = torch.ones_like(label) mask_0 = pred[:, 0, ...] < 0.5 mask_1 = pred[:, 1, ...] < 0.5 mask_2 = pred[:, 2, ...] < 0.5 mask = mask_0 * mask_1 * mask_2 pred = pred.argmax(1) output += pred output[mask] = 0 image = image.detach().cpu().float() label = label.detach().cpu().unsqueeze(1).float() output = output.detach().cpu().unsqueeze(1).float() z_middle = image.shape[-1] // 2 image = image[:, 0, ..., z_middle] label = label[:, 0, ..., z_middle] output = output[:, 0, ..., z_middle] if config.save.image_vmax is not None: vmax = config.save.image_vmax else: vmax = image.max() if config.save.image_vmin is not None: vmin = config.save.image_vmin else: vmin = image.min() image = np.clip(image, vmin, vmax) image -= vmin image /= (vmax - vmin) image *= 255.0 save_path = os.path.join(output_dir_path, 'result_{}.png'.format(epoch)) save_images_via_plt(image, label, output, config.save.n_save_images, config, save_path) if needs_save: trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'model': model, 'optim': optimizer }) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) print('Training starts: [max_epochs] {}, [max_iterations] {}'.format( config.run.n_epochs, config.run.n_epochs * len(train_data_loader))) trainer.run(train_data_loader, config.run.n_epochs)
def inference(config): if config.run.visible_devices: os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices test_patient_ids = os.listdir(config.test_dataset.root_dir_path) seed = check_manual_seed() print('Using seed: {}'.format(seed)) class_name_to_index = config.label_to_id._asdict() index_to_class_name = {v: k for k, v in class_name_to_index.items()} test_data_loader = get_data_loader( mode='test', dataset_name=config.test_dataset.dataset_name, root_dir_path=config.test_dataset.root_dir_path, patient_ids=test_patient_ids, batch_size=config.test_dataset.batch_size, num_workers=config.test_dataset.num_workers, volume_size=config.test_dataset.volume_size, ) model_1 = get_trained_model(config.model_1) model_2 = get_trained_model(config.model_2) model_3 = get_trained_model(config.model_3) model_4 = get_trained_model(config.model_4) model_5 = get_trained_model(config.model_5) model_1.eval() model_2.eval() model_3.eval() model_4.eval() model_5.eval() for batch in tqdm(test_data_loader): image = batch['image'].cuda().float() assert image.size(0) == 1 patient_id = batch['patient_id'][0] nii_path = batch['nii_path'][0] image = F.pad(image, (2, 3, 0, 0, 0, 0, 0, 0, 0, 0), 'constant', 0) output = torch.ones((1, image.shape[2], image.shape[3], image.shape[4])) with torch.no_grad(): pred_1 = model_1(image) pred_2 = model_2(image) pred_3 = model_3(image) pred_4 = model_4(image) pred_5 = model_5(image) pred = (pred_1 + pred_2 + pred_3 + pred_4 + pred_5) / 5.0 mask_0 = pred[:, 0, ...] < 0.5 mask_1 = pred[:, 1, ...] < 0.5 mask_2 = pred[:, 2, ...] < 0.5 mask = mask_0 * mask_1 * mask_2 pred = pred.argmax(1).cpu() output += pred output[mask] = 0 image = image[..., 2:-3] output = output[..., 2:-3] save_dir_path = os.path.join(config.save.save_root_dir, patient_id) os.makedirs(save_dir_path, exist_ok=True) image = image.cpu().numpy()[0, 1, ...] output = output.cpu().numpy()[0, ...].astype(np.int16) nii_image = nib.load(nii_path) nii_output = nib.Nifti1Image(output, affine=nii_image.affine) nib.save(nii_output, os.path.join(os.path.join( save_dir_path, patient_id + '_output.nii.gz')) )