Exemplo n.º 1
0
    def _finalize(self, _engine: Engine) -> None:
        """
        All gather classification results from ranks and save to CSV file.

        Args:
            _engine: Ignite Engine, unused argument.
        """
        ws = idist.get_world_size()
        if self.save_rank >= ws:
            raise ValueError(
                "target save rank is greater than the distributed group size.")

        outputs = torch.stack(self._outputs, dim=0)
        filenames = self._filenames
        if ws > 1:
            outputs = evenly_divisible_all_gather(outputs, concat=True)
            filenames = string_list_all_gather(filenames)

        if len(filenames) == 0:
            meta_dict = None
        else:
            if len(filenames) != len(outputs):
                warnings.warn(
                    f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}."
                )
            meta_dict = {Key.FILENAME_OR_OBJ: filenames}

        # save to CSV file only in the expected rank
        if idist.get_rank() == self.save_rank:
            saver = self.saver or CSVSaver(output_dir=self.output_dir,
                                           filename=self.filename,
                                           overwrite=self.overwrite,
                                           delimiter=self.delimiter)
            saver.save_batch(outputs, meta_dict)
            saver.finalize()
Exemplo n.º 2
0
    def _finalize(self, engine: Engine) -> None:
        """
        All gather classification results from ranks and save to CSV file.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        """
        ws = idist.get_world_size()
        if self.save_rank >= ws:
            raise ValueError(
                "target save rank is greater than the distributed group size."
            )

        outputs = torch.stack(self._outputs, dim=0)
        filenames = self._filenames
        if ws > 1:
            outputs = evenly_divisible_all_gather(outputs, concat=True)
            filenames = string_list_all_gather(filenames)

        if len(filenames) == 0:
            meta_dict = None
        else:
            if len(filenames) != len(outputs):
                warnings.warn(
                    f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}."
                )
            meta_dict = {Key.FILENAME_OR_OBJ: filenames}

        # save to CSV file only in the expected rank
        if idist.get_rank() == self.save_rank:
            # print('Output:', type(outputs), len(outputs), type(outputs[0]), len(outputs[0]))
            # print('Labels:', type(self._labels), len(self._labels), type(self._labels[0]), len(self._labels[0]))
            # print('Meta:', type(meta_dict[Key.FILENAME_OR_OBJ]), len(meta_dict[Key.FILENAME_OR_OBJ]))
            self.saver.save_batch(outputs, self._labels, meta_dict)
            self.saver.finalize()
    def __call__(self, engine: Engine) -> None:
        """
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        """
        ws = idist.get_world_size()
        if self.save_rank >= ws:
            raise ValueError(
                "target save rank is greater than the distributed group size.")

        # all gather file names across ranks
        _images = string_list_all_gather(
            strings=self._filenames) if ws > 1 else self._filenames

        # only save metrics to file in specified rank
        if idist.get_rank() == self.save_rank:
            _metrics = {}
            if self.metrics is not None and len(engine.state.metrics) > 0:
                _metrics = {
                    k: v
                    for k, v in engine.state.metrics.items()
                    if k in self.metrics or "*" in self.metrics
                }
            _metric_details = {}
            if hasattr(engine.state, "metric_details"):
                details = engine.state.metric_details  # type: ignore
                if self.metric_details is not None and len(details) > 0:
                    for k, v in details.items():
                        if k in self.metric_details or "*" in self.metric_details:
                            _metric_details[k] = v

            write_metrics_reports(
                save_dir=self.save_dir,
                images=None if len(_images) == 0 else _images,
                metrics=_metrics,
                metric_details=_metric_details,
                summary_ops=self.summary_ops,
                deli=self.deli,
                output_type=self.output_type,
            )
Exemplo n.º 4
0
def compute(args):
    # generate synthetic data for the example
    if args.local_rank == 0 and not os.path.exists(args.dir):
        # create 16 random pred, label paris for evaluation
        print(
            f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # if have multiple nodes, set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            pred, label = create_test_image_3d(128,
                                               128,
                                               128,
                                               num_seg_classes=1,
                                               channel_dim=-1,
                                               noise_max=0.5)
            n = nib.Nifti1Image(pred, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"pred{i:d}.nii.gz"))
            n = nib.Nifti1Image(label, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"label{i:d}.nii.gz"))

    # initialize the distributed evaluation process, change to NCCL backend if computing on GPU
    dist.init_process_group(backend="gloo", init_method="env://")

    preds = sorted(glob(os.path.join(args.dir, "pred*.nii.gz")))
    labels = sorted(glob(os.path.join(args.dir, "label*.nii.gz")))
    datalist = [{
        "pred": pred,
        "label": label
    } for pred, label in zip(preds, labels)]

    # split data for every subprocess, for example, 16 processes compute in parallel
    data_part = partition_dataset(
        data=datalist,
        num_partitions=dist.get_world_size(),
        shuffle=False,
        even_divisible=False,
    )[dist.get_rank()]

    # define transforms for predictions and labels
    transforms = Compose([
        LoadImaged(keys=["pred", "label"]),
        EnsureChannelFirstd(keys=["pred", "label"]),
        ScaleIntensityd(keys="pred"),
        EnsureTyped(keys=["pred", "label"]),
        AsDiscreted(keys="pred", threshold=0.5),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    data_part = [transforms(item) for item in data_part]

    # compute metrics for current process
    metric = DiceMetric(include_background=True,
                        reduction="mean",
                        get_not_nans=False)
    metric(y_pred=[i["pred"] for i in data_part],
           y=[i["label"] for i in data_part])
    filenames = [
        item["pred_meta_dict"]["filename_or_obj"] for item in data_part
    ]
    # all-gather results from all the processes and reduce for final result
    result = metric.aggregate().item()
    filenames = string_list_all_gather(strings=filenames)

    if args.local_rank == 0:
        print("mean dice: ", result)
        # generate metrics reports at: output/mean_dice_raw.csv, output/mean_dice_summary.csv, output/metrics.csv
        write_metrics_reports(
            save_dir="./output",
            images=filenames,
            metrics={"mean_dice": result},
            metric_details={"mean_dice": metric.get_buffer()},
            summary_ops="*",
        )

    metric.reset()

    dist.destroy_process_group()