示例#1
0
 def test_type_shape(self, input_data, expected_type, expected_shape):
     result = list_data_collate(input_data)
     self.assertIsInstance(result, expected_type)
     if isinstance(result, dict):
         data = result['image']
     else:
         data = result[0]
     self.assertEqual(data.shape, expected_shape)
示例#2
0
    def __call__(self, batch: Any):
        """
        Args:
            batch: batch of data to pad-collate
        """
        # data is either list of dicts or list of lists
        is_list_of_dicts = isinstance(batch[0], dict)
        # loop over items inside of each element in a batch
        for key_or_idx in batch[0].keys() if is_list_of_dicts else range(
                len(batch[0])):
            # calculate max size of each dimension
            max_shapes = []
            for elem in batch:
                if not isinstance(elem[key_or_idx],
                                  (torch.Tensor, np.ndarray)):
                    break
                max_shapes.append(elem[key_or_idx].shape[1:])
            # len > 0 if objects were arrays, else skip as no padding to be done
            if not max_shapes:
                continue
            max_shape = np.array(max_shapes).max(axis=0)
            # If all same size, skip
            if np.all(np.array(max_shapes).min(axis=0) == max_shape):
                continue
            # Do we need to convert output to Tensor?
            output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor)

            # Use `SpatialPadd` or `SpatialPad` to match sizes
            # Default params are central padding, padding with 0's
            # If input is dictionary, use the dictionary version so that the transformation is recorded

            padder = SpatialPad(spatial_size=max_shape,
                                method=self.method,
                                mode=self.mode,
                                **self.np_kwargs)
            transform = padder if not output_to_tensor else Compose(
                [padder, ToTensor()])

            for idx, batch_i in enumerate(batch):
                im = batch_i[key_or_idx]
                orig_size = im.shape[1:]
                padded = transform(batch_i[key_or_idx])
                batch = replace_element(padded, batch, idx, key_or_idx)

                # If we have a dictionary of data, append to list
                if is_list_of_dicts:
                    self.push_transform(batch[idx],
                                        key_or_idx,
                                        orig_size=orig_size)

        # After padding, use default list collator
        return list_data_collate(batch)
示例#3
0
 def test_collate(self, device, dtype):
     numel = 3
     ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(numel)]
     collated = list_data_collate(ims)
     # tensor
     self.assertIsInstance(collated, MetaTensor)
     expected_shape = (numel,) + tuple(ims[0].shape)
     self.assertTupleEqual(tuple(collated.shape), expected_shape)
     for i, im in enumerate(ims):
         self.check(im, ims[i], ids=True)
     # affine
     self.assertIsInstance(collated.affine, torch.Tensor)
     expected_shape = (numel,) + tuple(ims[0].affine.shape)
     self.assertTupleEqual(tuple(collated.affine.shape), expected_shape)
示例#4
0
    def test_default_device(self, data_type):
        device = "cuda" if torch.cuda.is_available() else "cpu:0"
        inputs = data_type(torch.ones((3, 16, 15, 7))).to(device=device)
        inputs = list_data_collate([inputs])  # make a proper batch
        roi_shape = (4, 10, 7)
        sw_batch_size = 10

        def compute(data):
            return data + 1

        result = sliding_window_inference(inputs, roi_shape, sw_batch_size,
                                          compute)
        np.testing.assert_string_equal(inputs.device.type, result.device.type)
        expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1
        np.testing.assert_allclose(result.cpu().numpy(), expected_val)
示例#5
0
    def test_sw_device(self, data_type):
        inputs = data_type(torch.ones((3, 16, 15, 7))).to(device="cpu")
        inputs = list_data_collate([inputs])  # make a proper batch
        roi_shape = (4, 10, 7)
        sw_batch_size = 10

        def compute(data):
            self.assertEqual(data.device.type, "cuda")
            return data + torch.tensor(1, device="cuda")

        result = sliding_window_inference(inputs,
                                          roi_shape,
                                          sw_batch_size,
                                          compute,
                                          sw_device="cuda")
        np.testing.assert_string_equal(inputs.device.type, result.device.type)
        expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1
        np.testing.assert_allclose(result.cpu().numpy(), expected_val)
示例#6
0
文件: batch.py 项目: Nic-Ma/MONAI
    def __call__(self, batch: Any):
        """
        Args:
            batch: batch of data to pad-collate
        """
        # data is either list of dicts or list of lists
        is_list_of_dicts = isinstance(batch[0], dict)
        # loop over items inside of each element in a batch
        batch_item = tuple(batch[0].keys()) if is_list_of_dicts else range(
            len(batch[0]))
        for key_or_idx in batch_item:
            # calculate max size of each dimension
            max_shapes = []
            for elem in batch:
                if not isinstance(elem[key_or_idx],
                                  (torch.Tensor, np.ndarray)):
                    break
                max_shapes.append(elem[key_or_idx].shape[1:])
            # len > 0 if objects were arrays, else skip as no padding to be done
            if not max_shapes:
                continue
            max_shape = np.array(max_shapes).max(axis=0)
            # If all same size, skip
            if np.all(np.array(max_shapes).min(axis=0) == max_shape):
                continue

            # Use `SpatialPad` to match sizes, Default params are central padding, padding with 0's
            padder = SpatialPad(spatial_size=max_shape,
                                method=self.method,
                                mode=self.mode,
                                **self.kwargs)
            for idx, batch_i in enumerate(batch):
                orig_size = batch_i[key_or_idx].shape[1:]
                padded = padder(batch_i[key_or_idx])
                batch = replace_element(padded, batch, idx, key_or_idx)

                # If we have a dictionary of data, append to list
                if is_list_of_dicts:
                    self.push_transform(batch[idx],
                                        key_or_idx,
                                        orig_size=orig_size)

        # After padding, use default list collator
        return list_data_collate(batch)
示例#7
0
    def update_meta(rets: Sequence, func, args, kwargs):
        """Update the metadata from the output of `__torch_function__`.
        The output could be a single object, or a sequence of them. Hence, they get
        converted to a sequence if necessary and then processed by iterating across them.

        For each element, if not of type `MetaTensor`, then nothing to do
        """
        out = []
        metas = None
        for idx, ret in enumerate(rets):
            # if not `MetaTensor`, nothing to do.
            if not isinstance(ret, MetaTensor):
                pass
            # if not tracking, convert to `torch.Tensor`.
            elif not (get_track_meta() or get_track_transforms()):
                ret = ret.as_tensor()
            # else, handle the `MetaTensor` metadata.
            else:
                meta_args = MetaObj.flatten_meta_objs(
                    list(args) + list(kwargs.values()))
                ret._copy_meta(meta_args)

                # If we have a batch of data, then we need to be careful if a slice of
                # the data is returned. Depending on how the data are indexed, we return
                # some or all of the metadata, and the return object may or may not be a
                # batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
                if ret.is_batch:
                    # only decollate metadata once
                    if metas is None:
                        metas = decollate_batch(ret.meta)
                    # if indexing e.g., `batch[0]`
                    if func == torch.Tensor.__getitem__:
                        idx = args[1]
                        if isinstance(idx, Sequence):
                            idx = idx[0]
                        # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
                        # first element will be `slice(None, None, None)` and `Ellipsis`,
                        # respectively. Don't need to do anything with the metadata.
                        if idx not in (slice(None, None, None), Ellipsis):
                            meta = metas[idx]
                            # if using e.g., `batch[0:2]`, then `is_batch` should still be
                            # `True`. Also re-collate the remaining elements.
                            if isinstance(meta, list) and len(meta) > 1:
                                ret.meta = list_data_collate(meta)
                            # if using e.g., `batch[0]` or `batch[0, 1]`, then return single
                            # element from batch, and set `is_batch` to `False`.
                            else:
                                ret.meta = meta
                                ret.is_batch = False
                    # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
                    # But we only want to split the batch if the `unbind` is along the 0th
                    # dimension.
                    elif func == torch.Tensor.unbind:
                        if len(args) > 1:
                            dim = args[1]
                        elif "dim" in kwargs:
                            dim = kwargs["dim"]
                        else:
                            dim = 0
                        if dim == 0:
                            ret.meta = metas[idx]
                            ret.is_batch = False

                ret.affine = ret.affine.to(ret.device)
            out.append(ret)
        # if the input was a tuple, then return it as a tuple
        return tuple(out) if isinstance(rets, tuple) else out
示例#8
0
    def test_indexing(self):
        """
        Check the metadata is returned in the expected format depending on whether
        the input `MetaTensor` is a batch of data or not.
        """
        ims = [self.get_im()[0] for _ in range(5)]
        data = list_data_collate(ims)

        # check that when using non-batch data, metadata is copied wholly when indexing
        # or iterating across data.
        im = ims[0]
        self.check_meta(im[0], im)
        self.check_meta(next(iter(im)), im)

        self.assertEqual(im[None].shape, (1, 1, 10, 8))
        self.assertEqual(data[None].shape, (1, 5, 1, 10, 8))

        # index
        d = data[0]
        self.check(d, ims[0], ids=False)

        # iter
        d = next(iter(data))
        self.check(d, ims[0], ids=False)

        # complex indexing

        # `is_batch==True`, should have subset of image and metadata.
        d = data[1:3]
        self.check(d, list_data_collate(ims[1:3]), ids=False)

        # is_batch==True, should have subset of image and same metadata as `[1:3]`.
        d = data[1:3, 0]
        self.check(d, list_data_collate([i[0] for i in ims[1:3]]), ids=False)

        # `is_batch==False`, should have first metadata and subset of first image.
        d = data[0, 0]
        self.check(d, ims[0][0], ids=False)
        self.assertEqual(d.applied_operations, ims[0][0].applied_operations)

        # `is_batch==True`, should have all metadata and subset of all images.
        d = data[:, 0]
        self.check(d, list_data_collate([i[0] for i in ims]), ids=False)

        # `is_batch==True`, should have all metadata and subset of all images.
        d = data[..., -1]
        self.check(d, list_data_collate([i[..., -1] for i in ims]), ids=False)

        # `is_batch==False`, tuple split along batch dim. Should have individual
        # metadata.
        d = data.unbind(0)
        self.assertIsInstance(d, tuple)
        self.assertEqual(len(d), len(ims))
        for _d, _im in zip(d, ims):
            self.check(_d, _im, ids=False)

        # `is_batch==False`, tuple split along batch dim. Should have individual
        # metadata.
        d = data.unbind(dim=0)
        self.assertIsInstance(d, tuple)
        self.assertEqual(len(d), len(ims))
        for _d, _im in zip(d, ims):
            self.check(_d, _im, ids=False)
            self.assertEqual(_d.applied_operations, _im.applied_operations)

        # `is_batch==True`, tuple split along non-batch dim. Should have all metadata.
        d = data.unbind(-1)
        self.assertIsInstance(d, tuple)
        self.assertEqual(len(d), ims[0].shape[-1])
        for _d in d:
            self.check_meta(_d, data)

        # `is_batch==True`, tuple split along non-batch dim. Should have all metadata.
        d = data.unbind(dim=-1)
        self.assertIsInstance(d, tuple)
        self.assertEqual(len(d), ims[0].shape[-1])
        for _d in d:
            self.check_meta(_d, data)
示例#9
0
    def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
        """
        Update the metadata from the output of `MetaTensor.__torch_function__`.

        The output of `torch.Tensor.__torch_function__` could be a single object or a
        sequence of them. Hence, in `MetaTensor.__torch_function__` we convert them to a
        list of not already, and then we loop across each element, processing metadata
        as necessary. For each element, if not of type `MetaTensor`, then nothing to do.

        Args:
            rets: the output from `torch.Tensor.__torch_function__`, which has been
                converted to a list in `MetaTensor.__torch_function__` if it wasn't
                already a `Sequence`.
            func: the torch function that was applied. Examples might be `torch.squeeze`
                or `torch.Tensor.__add__`. We need this since the metadata need to be
                treated differently if a batch of data is considered. For example,
                slicing (`torch.Tensor.__getitem__`) the ith element of the 0th
                dimension of a batch of data should return a ith tensor with the ith
                metadata.
            args: positional arguments that were passed to `func`.
            kwargs: keyword arguments that were passed to `func`.

        Returns:
            A sequence with the same number of elements as `rets`. For each element, if
            the input type was not `MetaTensor`, then no modifications will have been
            made. If global parameters have been set to false (e.g.,
            `not get_track_meta()`), then any `MetaTensor` will be converted to
            `torch.Tensor`. Else, metadata will be propagated as necessary (see
            :py:func:`MetaTensor._copy_meta`).
        """
        out = []
        metas = None
        is_batch = any(
            x.is_batch
            for x in MetaObj.flatten_meta_objs(args, kwargs.values())
            if hasattr(x, "is_batch"))
        for idx, ret in enumerate(rets):
            # if not `MetaTensor`, nothing to do.
            if not isinstance(ret, MetaTensor):
                pass
            # if not tracking, convert to `torch.Tensor`.
            elif not get_track_meta():
                ret = ret.as_tensor()
            # else, handle the `MetaTensor` metadata.
            else:
                meta_args = MetaObj.flatten_meta_objs(args, kwargs.values())
                ret.is_batch = is_batch
                ret.copy_meta_from(meta_args, copy_attr=not is_batch)
                # the following is not implemented but the network arch may run into this case:
                # if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args):
                #     raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.")

                # If we have a batch of data, then we need to be careful if a slice of
                # the data is returned. Depending on how the data are indexed, we return
                # some or all of the metadata, and the return object may or may not be a
                # batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
                if is_batch:
                    # if indexing e.g., `batch[0]`
                    if func == torch.Tensor.__getitem__:
                        batch_idx = args[1]
                        if isinstance(batch_idx, Sequence):
                            batch_idx = batch_idx[0]
                        # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
                        # first element will be `slice(None, None, None)` and `Ellipsis`,
                        # respectively. Don't need to do anything with the metadata.
                        if batch_idx not in (slice(None, None, None), Ellipsis,
                                             None) and idx == 0:
                            ret_meta = decollate_batch(args[0],
                                                       detach=False)[batch_idx]
                            if isinstance(ret_meta,
                                          list):  # e.g. batch[0:2], re-collate
                                ret_meta = list_data_collate(ret_meta)
                            else:  # e.g. `batch[0]` or `batch[0, 1]`, batch index is an integer
                                ret_meta.is_batch = False
                            ret.__dict__ = ret_meta.__dict__.copy()
                    # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
                    # But we only want to split the batch if the `unbind` is along the 0th
                    # dimension.
                    elif func == torch.Tensor.unbind:
                        if len(args) > 1:
                            dim = args[1]
                        elif "dim" in kwargs:
                            dim = kwargs["dim"]
                        else:
                            dim = 0
                        if dim == 0:
                            if metas is None:
                                metas = decollate_batch(args[0], detach=False)
                            ret.__dict__ = metas[idx].__dict__.copy()
                            ret.is_batch = False

            out.append(ret)
        # if the input was a tuple, then return it as a tuple
        return tuple(out) if isinstance(rets, tuple) else out