Exemplo n.º 1
0
 def eval(self, *metrics: TorchLoss, datasets: List[TorchDataset],
          atlas) -> Tuple[TorchLoss]:
     """Evaluate on validation set with training loss function if none provided"""
     val_data_loader = MultiDataLoader(*datasets,
                                       batch_size=self.config.batch_size,
                                       sampler_cls="sequential",
                                       num_workers=self.config.num_workers,
                                       pin_memory=self.config.pin_memory)
     self.network.eval()
     if not isinstance(metrics, Iterable) and isinstance(
             metrics, TorchLoss):
         metrics = [metrics]
     for metric in metrics:
         metric.reset()
     for idx, (inputs, outputs) in enumerate(val_data_loader):
         inputs = prepare_tensors(inputs, self.config.gpu,
                                  self.config.device)
         preds = self.network(inputs, atlas)
         for metric in metrics:
             preds = prepare_tensors(preds, self.config.gpu,
                                     self.config.device)
             outputs = prepare_tensors(outputs, self.config.gpu,
                                       self.config.device)
             metric.cumulate(preds, outputs)
     return metrics
Exemplo n.º 2
0
    def inference(self, epoch: int):
        output_dir = self.config.experiment_dir.joinpath("inference").joinpath("CP_{}".format(epoch))
        for val in self.validation_sets:
            indices = np.random.choice(len(val.image_paths), self.config.n_inference)
            for idx in indices:
                image_path = val.image_paths[idx]
                label_path = val.label_paths[idx]
                output_dir.joinpath(val.name, image_path.parent.stem).mkdir(exist_ok=True, parents=True)
                image = val.get_image_tensor_from_index(idx)
                label = val.get_label_tensor_from_index(idx)
                # template = val.get_label_tensor_from_index(np.random.randint(0, len(val)))
                template = val.template
                template = torch.from_numpy(template).float()
                template = torch.unsqueeze(template, 0)
                template = prepare_tensors(template, self.config.gpu, self.config.device)

                # template_image = val.template_image
                # template_image = np.expand_dims(template_image, 0)
                # template_image = torch.from_numpy(template_image).float()
                # template_image = torch.unsqueeze(template_image, 0)
                # template_image = prepare_tensors(template_image, self.config.gpu, self.config.device)

                image = torch.unsqueeze(image, 0)
                image = prepare_tensors(image, self.config.gpu, self.config.device)
                self.logger.info("Inferencing for {} dataset, image {}.".format(val.name, idx))

                self.inference_func(
                    (image, template), label, image_path, self.network, output_dir.joinpath(val.name, image_path.parent.stem)
                )
        for val in self.extra_validation_sets:
            indices = np.random.choice(len(val.image_paths), self.config.n_inference)
            for idx in indices:
                image_path = val.image_paths[idx]
                label_path = val.label_paths[idx]
                output_dir.joinpath(val.name, image_path.parent.stem).mkdir(exist_ok=True, parents=True)
                image = val.get_image_tensor_from_index(idx)
                label = val.get_label_tensor_from_index(idx)
                # template = val.get_label_tensor_from_index(np.random.randint(0, len(val)))
                template = val.template
                template = torch.from_numpy(template).float()
                template = torch.unsqueeze(template, 0)
                template = prepare_tensors(template, self.config.gpu, self.config.device)

                # template_image = val.template_image
                # template_image = np.expand_dims(template_image, 0)
                # template_image = torch.from_numpy(template_image).float()
                # template_image = torch.unsqueeze(template_image, 0)
                # template_image = prepare_tensors(template_image, self.config.gpu, self.config.device)

                image = torch.unsqueeze(image, 0)
                image = prepare_tensors(image, self.config.gpu, self.config.device)
                self.logger.info("Inferencing for {} dataset, image {}.".format(val.name, idx))

                self.inference_func(
                    (image, template), label, image_path, self.network,
                    output_dir.joinpath(val.name, image_path.parent.stem),
                )
Exemplo n.º 3
0
    def update_atlas(self,
                     train_data_loader,
                     atlas_label,
                     atlas,
                     eta: float = 0.01):
        pbar = tqdm(enumerate(train_data_loader))
        warped_labels = None
        warped_images = None
        n = 0
        for idx, (inputs, outputs) in pbar:
            inputs = prepare_tensors(inputs, self.config.gpu,
                                     self.config.device)
            predicted = self.network(inputs,
                                     atlas_label)  # pass updated atlas in
            warped_label = torch.sum(predicted[-4],
                                     dim=0).cpu().detach().numpy()
            image = np.squeeze(torch.sum(predicted[-5],
                                         dim=0).cpu().detach().numpy(),
                               axis=0)
            n += predicted[-5].shape[0]
            if warped_images is None:
                warped_images = np.zeros(
                    (image.shape[0], image.shape[1], image.shape[2]))
            if warped_labels is None:
                warped_labels = np.zeros(
                    (warped_label.shape[0], warped_label.shape[1],
                     warped_label.shape[2], warped_label.shape[3]))
            warped_images += image
            warped_labels += warped_label
        print(n)

        mean_image = warped_images / n
        mean_label = warped_labels / n
        atlas.update(mean_image, mean_label, eta=eta)
        return atlas
Exemplo n.º 4
0
def main():
    args = parse_args()
    model_path = Path(args.model_path)
    if args.network_conf_path is not None:
        train_conf = ConfigFactory.parse_file(str(Path(
            args.network_conf_path)))
        conf_path = Path(args.network_conf_path)
    else:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    data_config = DataConfig.from_conf(conf_path)
    dataset_config = DatasetConfig.from_conf(
        name=data_config.training_datasets[0],
        mount_prefix=data_config.mount_prefix,
        mode=data_config.data_mode)
    input_path = Path(args.input_dir).joinpath(
        dataset_config.image_label_format.image.format(phase=args.phase))
    output_dir = Path(args.output_dir)
    checkpoint = torch.load(str(model_path),
                            map_location=torch.device(args.device))

    get_conf(train_conf, group="network", key="experiment_dir")
    network = UNet(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
    )
    network.load_state_dict(checkpoint)
    network.cuda(device=args.device)
    # image = nib.load(str(input_path)).get_data()
    # if image.ndim == 4:
    #     image = np.squeeze(image, axis=-1).astype(np.int16)
    # image = image.astype(np.int16)
    # image = np.transpose(image, (2, 0, 1))
    dataset = Torch2DSegmentationDataset(
        name=dataset_config.name,
        image_paths=[input_path],
        label_paths=[
            input_path.parent.joinpath(
                dataset_config.image_label_format.label.format(
                    phase=args.phase))
        ],
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"),
        is_3d=True,
    )

    image = dataset.get_image_tensor_from_index(0)
    image = torch.unsqueeze(image, 0)
    image = prepare_tensors(image, True, args.device)

    label = dataset.get_label_tensor_from_index(0)
    inference(
        image=image,
        label=label,
        image_path=input_path,
        network=network,
        output_dir=output_dir,
    )
Exemplo n.º 5
0
    def forward(self, inputs, atlas):
        image, label = inputs
        img_down1, img_down2, img_down3, img_down4, img_bridge = self.image_encoder(
            image)
        pred_maps = self.seg_decoder(img_down1, img_down2, img_down3,
                                     img_down4, img_bridge)
        pred_maps = torch.sigmoid(pred_maps)

        if image.shape[0] == self.batch_size:
            batch_atlas = self.batch_atlas
            batch_affine_added = self.batch_affine_added
        else:
            batch_atlas = torch.stack([atlas for _ in range(image.shape[0])],
                                      dim=0)
            batch_affine_added = prepare_tensors(
                torch.stack([
                    torch.from_numpy(np.array([[0, 0, 0, 1]]))
                    for _ in range(image.shape[0])
                ],
                            dim=0),
                self.gpu,
                self.device,
            )

        temp_down1, temp_down2, temp_down3, temp_down4, temp_bridge = self.template_encoder(
            batch_atlas)
        affine_params = self.affine_regressor(img_bridge, temp_bridge)
        affine_warped_template = self.affine_transformer(
            batch_atlas, affine_params)

        temp_down1, temp_down2, temp_down3, temp_down4, temp_bridge = self.template_encoder(
            affine_warped_template)
        warped_template, preint_flow, pos_flow, neg_flow = self.decoder_vxm(
            affine_warped_template, None, img_down1, img_down2, img_down3,
            img_down4, img_bridge, temp_down1, temp_down2, temp_down3,
            temp_down4, temp_bridge)

        # inverse transform of label to template space
        inverse_affine_par = inverse_affine_params(affine_params,
                                                   batch_affine_added)
        affine_warped_label = self.affine_transformer(label,
                                                      inverse_affine_par)
        warped_label = self.decoder_vxm.transformer(affine_warped_label,
                                                    neg_flow)

        # inverse transform of image to template space
        affine_warped_image = self.affine_transformer(image,
                                                      inverse_affine_par)
        warped_image = self.decoder_vxm.transformer(affine_warped_image,
                                                    neg_flow)

        return affine_warped_template, warped_template, pred_maps, preint_flow, warped_image, warped_label, \
               affine_warped_label, batch_atlas, affine_warped_image
Exemplo n.º 6
0
 def update_atlas(self, train_data_loader, atlas_label, atlas):
     pbar = tqdm(enumerate(train_data_loader))
     warped_labels = []
     warped_images = []
     for idx, (inputs, outputs) in pbar:
         inputs = prepare_tensors(inputs, self.config.gpu,
                                  self.config.device)
         predicted = self.network(inputs,
                                  atlas_label)  # pass updated atlas in
         warped_label = predicted[-3].cpu().detach().numpy()
         image = predicted[-4].cpu().detach().numpy()
         warped_labels.append(warped_label)
         warped_images.append(np.squeeze(image, axis=1))
     atlas.update(warped_images, warped_labels)
Exemplo n.º 7
0
    def forward(self, inputs, atlas):
        image, label = inputs

        if image.shape[0] == self.batch_size:
            batch_atlas = self.batch_atlas
            batch_affine_added = self.batch_affine_added
        else:
            batch_atlas = torch.stack([atlas for _ in range(image.shape[0])],
                                      dim=0)
            batch_affine_added = prepare_tensors(
                torch.stack([
                    torch.from_numpy(np.array([[0, 0, 0, 1]]))
                    for _ in range(image.shape[0])
                ],
                            dim=0),
                self.gpu,
                self.device,
            )

        img_down1, img_down2, img_down3, img_down4, img_bridge = self.seg_encoder(
            image)
        pred_maps_logits = self.seg_decoder(img_down1, img_down2, img_down3,
                                            img_down4, img_bridge)
        # pred_maps = torch.sigmoid(pred_maps)
        flow, affine_params = self.stn(pred_maps_logits, batch_atlas)
        # inverse transform of label to template space
        inverse_affine_par = inverse_affine_params(affine_params,
                                                   batch_affine_added)

        affine_warped_atlas = self.affine_transformer(batch_atlas,
                                                      affine_params)
        affine_warped_label = self.affine_transformer(label,
                                                      inverse_affine_par)

        warped_atlas, warped_label, preint_flow, pos_flow, neg_flow = self.ffd_transformer(
            affine_warped_atlas,
            affine_warped_label,
            flow,
            bidir=True,
            resize=False)
        # inverse transform of image to template space
        affine_warped_image = self.affine_transformer(image,
                                                      inverse_affine_par)
        warped_image = self.ffd_transformer.transformer(
            affine_warped_image, neg_flow)

        return affine_warped_atlas, warped_atlas, pred_maps_logits, preint_flow, warped_image, warped_label, \
               affine_warped_label, batch_atlas, affine_warped_image
Exemplo n.º 8
0
    def run(self, image: np.ndarray) -> np.ndarray:
        image = torch.from_numpy(image).float()
        image = torch.unsqueeze(image, 0)
        image = prepare_tensors(image, True, self.device)
        predicted = self.model(image)
        predicted = torch.sigmoid(predicted)
        # print("sigmoid", torch.mean(predicted).item(), torch.max(predicted).item())
        predicted = (predicted > 0.5).float()
        # print("0.5", torch.mean(predicted).item(), torch.max(predicted).item())
        predicted = predicted.cpu().detach().numpy()
        predicted = np.squeeze(predicted, axis=0)
        # map back to original size
        final_predicted = np.zeros((image.shape[2], image.shape[3], image.shape[4]))
        # print(predicted.shape, final_predicted.shape)

        for i in range(predicted.shape[0]):
            final_predicted[predicted[i, :, :, :] > 0.5] = i + 1
        # image = nim.get_data()
        final_predicted = np.transpose(final_predicted, [1, 2, 0])
        return final_predicted
Exemplo n.º 9
0
 def __init__(self, in_channels, n_classes, n_slices, feature_size,
              n_filters, batch_norm: bool, group_norm, bidir, int_downsize,
              batch_size, gpu, device):
     super().__init__()
     self.image_encoder = Encoder(in_channels,
                                  n_filters,
                                  batch_norm=batch_norm,
                                  group_norm=group_norm)
     self.template_encoder = Encoder(n_classes,
                                     n_filters,
                                     batch_norm=batch_norm,
                                     group_norm=group_norm)
     self.affine_regressor = AffineRegressor(group_norm=group_norm)
     self.affine_transformer = AffineSpatialTransformer(size=(n_slices,
                                                              feature_size,
                                                              feature_size),
                                                        mode="bilinear")
     self.decoder_vxm = DecoderVxmDense(inshape=(n_slices, feature_size,
                                                 feature_size),
                                        n_filters=n_filters,
                                        batch_norm=batch_norm,
                                        group_norm=group_norm,
                                        int_downsize=int_downsize,
                                        bidir=False)
     self.seg_decoder = SegDecoder(n_filters, batch_norm, group_norm,
                                   n_classes)
     self.batch_size = batch_size
     self.batch_atlas = None
     self.gpu = gpu
     self.device = device
     self.batch_affine_added = prepare_tensors(
         torch.stack([
             torch.from_numpy(np.array([[0, 0, 0, 1]]))
             for _ in range(self.batch_size)
         ],
                     dim=0),
         self.gpu,
         self.device,
     )
Exemplo n.º 10
0
    def __init__(self, in_channels, n_classes, n_slices, feature_size,
                 n_filters, batch_norm: bool, group_norm, bidir, int_downsize,
                 batch_size, gpu, device):
        super().__init__()
        # ITN
        self.seg_encoder = Encoder(in_channels,
                                   n_filters,
                                   batch_norm=batch_norm,
                                   group_norm=group_norm)
        self.seg_decoder = SegDecoder(n_filters, batch_norm, group_norm,
                                      n_classes)

        # STN
        self.stn = STN(num_filters=n_filters, group_norm=group_norm)
        self.batch_size = batch_size
        self.batch_atlas = None
        self.gpu = gpu
        self.device = device
        # configure transformer
        self.ffd_transformer = FFDTransformer(inshape=(n_slices, feature_size,
                                                       feature_size),
                                              int_downsize=2,
                                              bidir=bidir,
                                              mode="bilinear")
        self.affine_transformer = AffineSpatialTransformer(size=(n_slices,
                                                                 feature_size,
                                                                 feature_size),
                                                           mode="bilinear")
        self.batch_affine_added = prepare_tensors(
            torch.stack([
                torch.from_numpy(np.array([[0, 0, 0, 1]]))
                for _ in range(self.batch_size)
            ],
                        dim=0),
            self.gpu,
            self.device,
        )
Exemplo n.º 11
0
def inference(image: np.ndarray, label: torch.Tensor, image_path: Path,
              network: torch.nn.Module, output_dir: Path, gpu, device):
    import math
    original_image = image
    Z, X, Y = image.shape
    n_slices = 96
    X2, Y2 = int(math.ceil(X / 32.0)) * 32, int(math.ceil(Y / 32.0)) * 32
    x_pre, y_pre, z_pre = int((X2 - X) / 2), int((Y2 - Y) / 2), int(
        (Z - n_slices) / 2)
    x_post, y_post, z_post = (X2 - X) - x_pre, (Y2 - Y) - y_pre, (
        Z - n_slices) - z_pre
    z1, z2 = int(Z / 2) - int(n_slices / 2), int(Z / 2) + int(n_slices / 2)
    z1_, z2_ = max(z1, 0), min(z2, Z)
    image = image[z1_:z2_, :, :]
    image = np.pad(image,
                   ((z1_ - z1, z2 - z2_), (x_pre, x_post), (y_pre, y_post)),
                   'constant')
    image = np.expand_dims(image, 0)
    image = torch.from_numpy(image).float()
    image = torch.unsqueeze(image, 0)
    image = prepare_tensors(image, gpu, device)

    predicted = network(image)
    predicted = torch.sigmoid(predicted)
    # print("sigmoid", torch.mean(predicted).item(), torch.max(predicted).item())
    predicted = (predicted > 0.5).float()
    # print("0.5", torch.mean(predicted).item(), torch.max(predicted).item())
    predicted = predicted.cpu().detach().numpy()

    nim = nib.load(str(image_path))
    # Transpose and crop the segmentation to recover the original size
    predicted = np.squeeze(predicted, axis=0)
    # print(predicted.shape)
    predicted = np.pad(predicted, ((0, 0), (z1_, Z - z2_), (0, 0), (0, 0)),
                       "constant")
    predicted = predicted[:, z1_ - z1:z1_ - z1 + Z, x_pre:x_pre + X,
                          y_pre:y_pre + Y]

    # map back to original size
    final_predicted = np.zeros(
        (original_image.shape[0], original_image.shape[1],
         original_image.shape[2]))
    # print(predicted.shape, final_predicted.shape)
    for i in range(predicted.shape[0]):
        # print(a.shape)
        final_predicted[predicted[i, :, :, :] > 0.5] = i + 1
    # image = nim.get_data()
    final_predicted = np.transpose(final_predicted, [1, 2, 0])

    # print(predicted.shape, final_predicted.shape)
    # final_predicted = np.resize(final_predicted, (image.shape[0], image.shape[1], image.shape[2]))

    # print(predicted.shape, final_predicted.shape, np.max(final_predicted), np.mean(final_predicted),
    #       np.min(final_predicted))
    # if Z < 64:
    #     pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, z1_ - z1:z1_ - z1 + Z]
    # else:
    #     pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, :]
    #     pred_segt = np.pad(pred_segt, ((0, 0), (0, 0), (z_pre, z_post)), 'constant')

    nim2 = nib.Nifti1Image(final_predicted, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/seg.nii.gz'.format(str(output_dir)))

    final_image = np.transpose(original_image, [1, 2, 0])
    # print(final_image.shape)
    nim2 = nib.Nifti1Image(final_image, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    nib.save(nim2, '{0}/image.nii.gz'.format(str(output_dir)))
    # shutil.copy(str(input_path), str(output_dir.joinpath("image.nii.gz")))

    final_label = np.zeros((label.shape[1], label.shape[2], label.shape[3]))
    label = label.cpu().detach().numpy()
    for i in range(label.shape[0]):
        final_label[label[i, :, :, :] == 1.0] = i + 1

    final_label = np.transpose(final_label, [1, 2, 0])
    # print(final_label.shape)
    nim2 = nib.Nifti1Image(final_label, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/label.nii.gz'.format(str(output_dir)))
Exemplo n.º 12
0
def main():
    args = parse_args()
    model_path = Path(args.model_path)
    if args.network_conf_path is not None:
        train_conf = ConfigFactory.parse_file(str(Path(
            args.network_conf_path)))
        conf_path = Path(args.network_conf_path)
    else:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    data_config = DataConfig.from_conf(conf_path)
    dataset_config = DatasetConfig.from_conf(
        name=data_config.dataset_names[0],
        mount_prefix=data_config.mount_prefix,
        mode=data_config.data_mode)
    input_path = Path(args.input_dir).joinpath(
        dataset_config.image_label_format.image.format(phase=args.phase))
    output_dir = Path(args.output_dir)
    checkpoint = torch.load(str(model_path),
                            map_location=torch.device(args.device))

    get_conf(train_conf, group="network", key="experiment_dir")
    network = FCN2DSegmentationModel(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
        up_conv_filter=get_conf(train_conf,
                                group="network",
                                key="up_conv_filter"),
        final_conv_filter=get_conf(train_conf,
                                   group="network",
                                   key="final_conv_filter"),
        feature_size=get_conf(train_conf, group="network", key="feature_size"))
    network.load_state_dict(checkpoint)
    network.cuda(device=args.device)
    # image = nib.load(str(input_path)).get_data()
    # if image.ndim == 4:
    #     image = np.squeeze(image, axis=-1).astype(np.int16)
    # image = image.astype(np.int16)
    # image = np.transpose(image, (2, 0, 1))
    image = Torch2DSegmentationDataset.read_image(
        input_path,
        get_conf(train_conf, group="network", key="feature_size"),
        get_conf(train_conf, group="network", key="in_channels"),
        crop=args.crop_image,
    )
    image = np.expand_dims(image, 0)
    image = torch.from_numpy(image).float()
    image = prepare_tensors(image, gpu=True, device=args.device)
    predicted = network(image)
    predicted = torch.sigmoid(predicted)
    print("sigmoid", torch.mean(predicted).item(), torch.max(predicted).item())
    predicted = (predicted > 0.5).float()
    print("0.5", torch.mean(predicted).item(), torch.max(predicted).item())
    predicted = predicted.cpu().detach().numpy()

    nim = nib.load(str(input_path))
    # Transpose and crop the segmentation to recover the original size
    predicted = np.squeeze(predicted, axis=0)
    print(predicted.shape)

    # map back to original size
    final_predicted = np.zeros(
        (image.shape[1], image.shape[2], image.shape[3]))
    print(predicted.shape, final_predicted.shape)

    for i in range(predicted.shape[0]):
        a = predicted[i, :, :, :] > 0.5
        print(a.shape)
        final_predicted[predicted[i, :, :, :] > 0.5] = i + 1
    # image = nim.get_data()
    final_predicted = np.transpose(final_predicted, [1, 2, 0])
    print(predicted.shape, final_predicted.shape)
    # final_predicted = np.resize(final_predicted, (image.shape[0], image.shape[1], image.shape[2]))

    print(predicted.shape, final_predicted.shape, np.max(final_predicted),
          np.mean(final_predicted), np.min(final_predicted))
    # if Z < 64:
    #     pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, z1_ - z1:z1_ - z1 + Z]
    # else:
    #     pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, :]
    #     pred_segt = np.pad(pred_segt, ((0, 0), (0, 0), (z_pre, z_post)), 'constant')

    nim2 = nib.Nifti1Image(final_predicted, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/seg.nii.gz'.format(str(output_dir)))

    final_image = image.cpu().detach().numpy()
    final_image = np.squeeze(final_image, 0)
    final_image = np.transpose(final_image, [1, 2, 0])
    print(final_image.shape)
    nim2 = nib.Nifti1Image(final_image, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    nib.save(nim2, '{0}/image.nii.gz'.format(str(output_dir)))
    # shutil.copy(str(input_path), str(output_dir.joinpath("image.nii.gz")))

    label = Torch2DSegmentationDataset.read_label(
        input_path.parent.joinpath(
            dataset_config.image_label_format.label.format(phase=args.phase)),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="in_channels"),
        crop=args.crop_image,
    )
    final_label = np.zeros((image.shape[1], image.shape[2], image.shape[3]))
    for i in range(label.shape[0]):
        final_label[label[i, :, :, :] == 1.0] = i + 1

    final_label = np.transpose(final_label, [1, 2, 0])
    print(final_label.shape)
    nim2 = nib.Nifti1Image(final_label, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/label.nii.gz'.format(str(output_dir)))
Exemplo n.º 13
0
    def train(self):
        self.network.train()
        train_data_loader = MultiDataLoader(
            *self.training_sets,
            batch_size=self.config.batch_size,
            sampler_cls="random",
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory)
        set = False
        atlas = Atlas.from_data_loader(train_data_loader)
        atlas.save(output_dir=self.config.experiment_dir.joinpath(
            "atlas").joinpath("init"))
        for epoch in range(self.config.num_epochs):
            self.network.train()
            self.logger.info("{}: starting epoch {}/{}".format(
                datetime.now(), epoch, self.config.num_epochs))
            self.loss.reset()
            # if epoch > 10 and not set:
            #     self.optimizer.param_groups[0]['lr'] /= 10
            #     print("-------------Learning rate: {}-------------".format(self.optimizer.param_groups[0]['lr']))
            #     set = True

            pbar = tqdm(enumerate(train_data_loader))
            atlas_label = prepare_tensors(torch.from_numpy(atlas.label()),
                                          self.config.gpu, self.config.device)
            self.network.update_batch_atlas(atlas_label)
            for idx, (inputs, outputs) in pbar:
                inputs = prepare_tensors(inputs, self.config.gpu,
                                         self.config.device)
                outputs = prepare_tensors(outputs, self.config.gpu,
                                          self.config.device)
                predicted = self.network(inputs,
                                         atlas_label)  # pass updated atlas in
                loss = self.loss.cumulate(predicted, outputs)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                pbar.set_description("{:.2f} --- {}".format(
                    (idx + 1) / len(train_data_loader),
                    self.loss.description()))
            # update atlas
            self.network.eval()
            self.update_atlas(train_data_loader, atlas_label, atlas)
            atlas.save(output_dir=self.config.experiment_dir.joinpath(
                "atlas").joinpath("epoch_{}".format(epoch)))
            self.logger.info("Epoch finished !")
            val_metrics = self.eval(self.loss.new(),
                                    *self.other_validation_metrics,
                                    datasets=self.validation_sets,
                                    atlas=atlas_label)
            self.logger.info("Validation loss: {}".format(
                val_metrics[0].description()))
            if val_metrics[1:]:
                self.logger.info("Other metrics on validation set.")
                for metric in val_metrics[1:]:
                    self.logger.info("{}".format(metric.description()))
            self.tensor_board.add_scalar(
                "loss/training/{}".format(self.loss.document()),
                self.loss.log(), epoch)
            self.tensor_board.add_scalar(
                "loss/validation/loss_{}".format(val_metrics[0].document()),
                val_metrics[0].log(), epoch)
            for metric in val_metrics[1:]:
                self.tensor_board.add_scalar(
                    "other_metrics/validation/{}".format(metric.document()),
                    metric.avg(), epoch)

            # eval extra validation sets
            if self.extra_validation_sets:
                for val in self.extra_validation_sets:
                    val_metrics = self.eval(self.loss.new(),
                                            *self.other_validation_metrics,
                                            datasets=[val],
                                            atlas=atlas_label)
                    self.logger.info(
                        "Extra Validation loss on dataset {}: {}".format(
                            val.name, val_metrics[0].description()))
                    if val_metrics[1:]:
                        self.logger.info(
                            "Other metrics on extra validation set.")
                        for metric in val_metrics[1:]:
                            self.logger.info("{}".format(metric.description()))
                    self.tensor_board.add_scalar(
                        "loss/extra_validation_{}/loss_{}".format(
                            val.name, val_metrics[0].document()),
                        val_metrics[0].log(), epoch)
                    for metric in val_metrics[1:]:
                        self.tensor_board.add_scalar(
                            "other_metrics/extra_validation_{}/{}".format(
                                val.name, metric.document()), metric.avg(),
                            epoch)

            checkpoint_dir = self.config.experiment_dir.joinpath("checkpoints")
            checkpoint_dir.mkdir(parents=True, exist_ok=True)
            output_path = checkpoint_dir.joinpath("CP_{}.pth".format(epoch))
            torch.save(self.network.state_dict(), str(output_path))
            self.logger.info("Checkpoint {} saved at {}!".format(
                epoch, str(output_path)))
            if self.inference_func is not None:
                self.inference(epoch, atlas_label)
Exemplo n.º 14
0
def inference(image: np.ndarray, label: np.ndarray, network: torch.nn.Module, output_dir: Path,
              atlas, output_size, gpu, device):
    np_cropped_image, np_cropped_label = central_crop_with_padding(image, label, output_size=output_size)

    cropped_label = torch.from_numpy(np_cropped_label).float()
    cropped_image = np.expand_dims(np_cropped_image, 0)
    cropped_image = torch.from_numpy(cropped_image).float()

    cropped_image = torch.unsqueeze(cropped_image, 0)
    cropped_label = torch.unsqueeze(cropped_label, 0)
    cropped_image = prepare_tensors(cropped_image, gpu, device)
    cropped_label = prepare_tensors(cropped_label, gpu, device)

    affine_warped_template, warped_template, pred_maps, flow, warped_image, warped_label, affine_warped_label, \
    batch_atlas, affine_warped_image = network((cropped_image, cropped_label), atlas)
    for prefix, predicted in zip(
            ["warped_template", "affine_warped_template", "pred_maps"],
            [warped_template, affine_warped_template, pred_maps]
    ):

        # predicted = torch.sigmoid(predicted)
        # print("sigmoid", torch.mean(predicted).item(), torch.max(predicted).item())
        predicted = (predicted > 0.5).float()
        # print("0.5", torch.mean(predicted).item(), torch.max(predicted).item())
        predicted = predicted.cpu().detach().numpy()

        # nim = nib.load(str(image_path))
        # Transpose and crop the segmentation to recover the original size
        predicted = np.squeeze(predicted, axis=0)
        # print(predicted.shape)
        # predicted = np.pad(predicted, ((0, 0), (z1_, Z - z2_), (0, 0), (0, 0)), "constant")
        # predicted = predicted[:, z1_ - z1:z1_ - z1 + Z, x_pre:x_pre + X, y_pre:y_pre + Y]

        # map back to original size
        # final_predicted = np.zeros((image.shape[2], image.shape[3], image.shape[4]))
        __, predicted = central_crop_with_padding(None, predicted, image.shape)
        final_predicted = np.zeros((image.shape[0], image.shape[1], image.shape[2]))

        # print(predicted.shape, final_predicted.shape)
        for i in range(predicted.shape[0]):
            # print(a.shape)
            final_predicted[predicted[i, :, :, :] > 0.5] = i + 1
        # image = nim.get_data()
        final_predicted = np.transpose(final_predicted, [1, 2, 0])

        # print(predicted.shape, final_predicted.shape)
        # final_predicted = np.resize(final_predicted, (image.shape[0], image.shape[1], image.shape[2]))

        # print(predicted.shape, final_predicted.shape, np.max(final_predicted), np.mean(final_predicted),
        #       np.min(final_predicted))
        # if Z < 64:
        #     pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, z1_ - z1:z1_ - z1 + Z]
        # else:
        #     pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, :]
        #     pred_segt = np.pad(pred_segt, ((0, 0), (0, 0), (z_pre, z_post)), 'constant')

        # nim2 = nib.Nifti1Image(final_predicted, nim.affine)
        # nim2.header['pixdim'] = nim.header['pixdim']
        nim2 = nib.Nifti1Image(final_predicted, None)

        output_dir.mkdir(parents=True, exist_ok=True)
        nib.save(nim2, '{}/{}_seg.nii.gz'.format(str(output_dir), prefix))
    # final_image = image.cpu().detach().numpy()
    # final_image = np.squeeze(final_image)
    # final_image, __ = central_crop_with_padding(final_image, None, image.shape)

    final_image = np.transpose(image, [1, 2, 0])
    # print(final_image.shape)
    # nim2 = nib.Nifti1Image(final_image, nim.affine)
    # nim2.header['pixdim'] = nim.header['pixdim']
    nim2 = nib.Nifti1Image(final_image, None)

    nib.save(nim2, '{0}/image.nii.gz'.format(str(output_dir)))
    # shutil.copy(str(input_path), str(output_dir.joinpath("image.nii.gz")))
    final_cropped_image = np.transpose(np_cropped_image, [1, 2, 0])
    nim2 = nib.Nifti1Image(final_cropped_image, None)
    nib.save(nim2, '{0}/cropped_image.nii.gz'.format(str(output_dir)))

    final_cropped_image, final_cropped_label = central_crop_with_padding(np_cropped_image, np_cropped_label, image.shape)
    final_cropped_image = np.transpose(final_cropped_image, [1, 2, 0])
    nim2 = nib.Nifti1Image(final_cropped_image, None)
    nib.save(nim2, '{0}/padded_cropped_image.nii.gz'.format(str(output_dir)))

    warped_image = warped_image.cpu().detach().numpy()
    warped_image = np.squeeze(warped_image, axis=0)
    warped_image = np.squeeze(warped_image, axis=0)
    warped_image, __ = central_crop_with_padding(warped_image, None, image.shape)
    final_warped_image = np.transpose(warped_image, [1, 2, 0])
    nim2 = nib.Nifti1Image(final_warped_image, None)
    nib.save(nim2, '{0}/padded_warped_image.nii.gz'.format(str(output_dir)))

    affine_warped_image = affine_warped_image.cpu().detach().numpy()
    affine_warped_image = np.squeeze(affine_warped_image, axis=0)
    affine_warped_image = np.squeeze(affine_warped_image, axis=0)
    affine_warped_image, __ = central_crop_with_padding(affine_warped_image, None, image.shape)
    final_warped_image = np.transpose(affine_warped_image, [1, 2, 0])
    nim2 = nib.Nifti1Image(final_warped_image, None)
    nib.save(nim2, '{0}/padded_affine_warped_image.nii.gz'.format(str(output_dir)))

    affine_warped_label = affine_warped_label.cpu().detach().numpy()
    affine_warped_label = np.squeeze(affine_warped_label, axis=0)
    affine_warped_label = affine_warped_label > 0.5
    __, affine_warped_label = central_crop_with_padding(None, affine_warped_label, image.shape)

    final_label = np.zeros((affine_warped_label.shape[1], affine_warped_label.shape[2], affine_warped_label.shape[3]))
    # label = label.cpu().detach().numpy()
    # label = np.squeeze(label, 0)

    for i in range(affine_warped_label.shape[0]):
        final_label[affine_warped_label[i, :, :, :] == 1.0] = i + 1

    final_label = np.transpose(final_label, [1, 2, 0])
    # print(final_label.shape)
    # nim2 = nib.Nifti1Image(final_label, nim.affine)
    # nim2.header['pixdim'] = nim.header['pixdim']
    nim2 = nib.Nifti1Image(final_label, None)
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/padded_affine_warped_label.nii.gz'.format(str(output_dir)))


    warped_label = warped_label.cpu().detach().numpy()
    warped_label = np.squeeze(warped_label, axis=0)
    warped_label = warped_label > 0.5
    __, warped_label = central_crop_with_padding(None, warped_label, image.shape)

    final_label = np.zeros((warped_label.shape[1], warped_label.shape[2], warped_label.shape[3]))
    # label = label.cpu().detach().numpy()
    # label = np.squeeze(label, 0)

    for i in range(warped_label.shape[0]):
        final_label[warped_label[i, :, :, :] == 1.0] = i + 1

    final_label = np.transpose(final_label, [1, 2, 0])
    # print(final_label.shape)
    # nim2 = nib.Nifti1Image(final_label, nim.affine)
    # nim2.header['pixdim'] = nim.header['pixdim']
    nim2 = nib.Nifti1Image(final_label, None)
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/padded_warped_label.nii.gz'.format(str(output_dir)))


    final_label = np.zeros((np_cropped_label.shape[1], np_cropped_label.shape[2], np_cropped_label.shape[3]))
    # label = label.cpu().detach().numpy()
    # label = np.squeeze(label, 0)

    for i in range(np_cropped_label.shape[0]):
        final_label[np_cropped_label[i, :, :, :] == 1.0] = i + 1

    final_label = np.transpose(final_label, [1, 2, 0])
    # print(final_label.shape)
    # nim2 = nib.Nifti1Image(final_label, nim.affine)
    # nim2.header['pixdim'] = nim.header['pixdim']
    nim2 = nib.Nifti1Image(final_label, None)
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/cropped_label.nii.gz'.format(str(output_dir)))

    final_label = np.zeros((label.shape[1], label.shape[2], label.shape[3]))
    # label = label.cpu().detach().numpy()
    # label = np.squeeze(label, 0)
    for i in range(label.shape[0]):
        final_label[label[i, :, :, :] == 1.0] = i + 1

    final_label = np.transpose(final_label, [1, 2, 0])
    # print(final_label.shape)
    # nim2 = nib.Nifti1Image(final_label, nim.affine)
    # nim2.header['pixdim'] = nim.header['pixdim']
    nim2 = nib.Nifti1Image(final_label, None)

    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/label.nii.gz'.format(str(output_dir)))