Ejemplo n.º 1
0
    def validation_step(self, batch, batch_idx):
        for k, x in batch.items():
            # reshape data for inference
            # 2d: (N=1, num_slices, H, W) -> (num_slices, N=1, H, W)
            # 3d: (N=1, 1, H, W, D) -> (1, N=1, H, W, D)
            batch[k] = x.transpose(0, 1)

        # run inference, compute losses and outputs
        val_losses, step_outputs = self._step(batch)

        # collect data for measuring metrics and validation visualisation
        val_data = batch
        val_data.update(step_outputs)
        if 'source_seg' in batch.keys():
            val_data['warped_source_seg'] = warp(batch['source_seg'],
                                                 val_data['disp_pred'],
                                                 interp_mode='nearest')
        if 'target_original' in batch.keys():
            val_data['target_pred'] = warp(val_data['target_original'],
                                           val_data['disp_pred'])

        # calculate validation metrics
        val_metrics = {k: float(loss.cpu()) for k, loss in val_losses.items()}
        val_metrics.update(
            measure_metrics(val_data, self.hparams.meta.metric_groups))

        # log visualisation figure to Tensorboard
        if batch_idx == 0:
            val_fig = visualise_result(val_data, axis=2)
            self.logger.experiment.add_figure(f'val_fig',
                                              val_fig,
                                              global_step=self.global_step,
                                              close=True)
        return val_metrics
Ejemplo n.º 2
0
def inference(model, dataloader, output_dir, device=torch.device('cpu')):
    for idx, batch in enumerate(tqdm(dataloader)):
        for k, x in batch.items():
            # reshape data for inference
            # 2d: (N=1, num_slices, H, W) -> (num_slices, N=1, H, W)
            # 3d: (N=1, 1, H, W, D) -> (1, N=1, H, W, D)
            batch[k] = x.transpose(0, 1).to(device=device)

        # model inference
        out = model(batch['target'], batch['source'])
        batch['disp_pred'] = out[1] if len(
            out) == 2 else out  # (flow, disp) or disp

        # warp images and segmentation using predicted disp
        batch['warped_source'] = warp(batch['source'], batch['disp_pred'])
        if 'source_seg' in batch.keys():
            batch['warped_source_seg'] = warp(batch['source_seg'],
                                              batch['disp_pred'],
                                              interp_mode='nearest')
        if 'target_original' in batch.keys():
            batch['target_pred'] = warp(batch['target_original'],
                                        batch['disp_pred'])

        # save the outputs
        subj_id = dataloader.dataset.subject_list[idx]
        output_id_dir = setup_dir(output_dir + f'/{subj_id}')
        for k, x in batch.items():
            x = x.detach().cpu().numpy()
            # reshape for saving:
            # 2D: img (N=num_slice, 1, H, W) -> (H, W, N);
            #     disp (N=num_slice, 2, H, W) -> (H, W, N, 2)
            # 3D: img (N=1, 1, H, W, D) -> (H, W, D);
            #     disp (N=1, 3, H, W, D) -> (H, W, D, 3)
            x = np.moveaxis(x, [0, 1], [-2, -1]).squeeze()
            save_nifti(x, path=output_id_dir + f'/{k}.nii.gz')
Ejemplo n.º 3
0
    def _step(self, batch):
        """ Forward pass inference + compute loss """
        tar = batch['target']
        src = batch['source']
        out = self.forward(tar, src)

        if self.hparams.transformation.config.svf:
            # output the flow field and disp field
            flow, disp = out
            warped_src = warp(src, disp)
            losses = self.loss_fn(tar, warped_src, flow)

        else:
            # only output disp field
            disp = out
            warped_src = warp(src, disp)
            losses = self.loss_fn(tar, warped_src, disp)

        step_outputs = {'disp_pred': disp, 'warped_source': warped_src}
        return losses, step_outputs