def validation_step(self, batch, batch_idx): for k, x in batch.items(): # reshape data for inference # 2d: (N=1, num_slices, H, W) -> (num_slices, N=1, H, W) # 3d: (N=1, 1, H, W, D) -> (1, N=1, H, W, D) batch[k] = x.transpose(0, 1) # run inference, compute losses and outputs val_losses, step_outputs = self._step(batch) # collect data for measuring metrics and validation visualisation val_data = batch val_data.update(step_outputs) if 'source_seg' in batch.keys(): val_data['warped_source_seg'] = warp(batch['source_seg'], val_data['disp_pred'], interp_mode='nearest') if 'target_original' in batch.keys(): val_data['target_pred'] = warp(val_data['target_original'], val_data['disp_pred']) # calculate validation metrics val_metrics = {k: float(loss.cpu()) for k, loss in val_losses.items()} val_metrics.update( measure_metrics(val_data, self.hparams.meta.metric_groups)) # log visualisation figure to Tensorboard if batch_idx == 0: val_fig = visualise_result(val_data, axis=2) self.logger.experiment.add_figure(f'val_fig', val_fig, global_step=self.global_step, close=True) return val_metrics
def inference(model, dataloader, output_dir, device=torch.device('cpu')): for idx, batch in enumerate(tqdm(dataloader)): for k, x in batch.items(): # reshape data for inference # 2d: (N=1, num_slices, H, W) -> (num_slices, N=1, H, W) # 3d: (N=1, 1, H, W, D) -> (1, N=1, H, W, D) batch[k] = x.transpose(0, 1).to(device=device) # model inference out = model(batch['target'], batch['source']) batch['disp_pred'] = out[1] if len( out) == 2 else out # (flow, disp) or disp # warp images and segmentation using predicted disp batch['warped_source'] = warp(batch['source'], batch['disp_pred']) if 'source_seg' in batch.keys(): batch['warped_source_seg'] = warp(batch['source_seg'], batch['disp_pred'], interp_mode='nearest') if 'target_original' in batch.keys(): batch['target_pred'] = warp(batch['target_original'], batch['disp_pred']) # save the outputs subj_id = dataloader.dataset.subject_list[idx] output_id_dir = setup_dir(output_dir + f'/{subj_id}') for k, x in batch.items(): x = x.detach().cpu().numpy() # reshape for saving: # 2D: img (N=num_slice, 1, H, W) -> (H, W, N); # disp (N=num_slice, 2, H, W) -> (H, W, N, 2) # 3D: img (N=1, 1, H, W, D) -> (H, W, D); # disp (N=1, 3, H, W, D) -> (H, W, D, 3) x = np.moveaxis(x, [0, 1], [-2, -1]).squeeze() save_nifti(x, path=output_id_dir + f'/{k}.nii.gz')
def _step(self, batch): """ Forward pass inference + compute loss """ tar = batch['target'] src = batch['source'] out = self.forward(tar, src) if self.hparams.transformation.config.svf: # output the flow field and disp field flow, disp = out warped_src = warp(src, disp) losses = self.loss_fn(tar, warped_src, flow) else: # only output disp field disp = out warped_src = warp(src, disp) losses = self.loss_fn(tar, warped_src, disp) step_outputs = {'disp_pred': disp, 'warped_source': warped_src} return losses, step_outputs