Ejemplo n.º 1
0
def named_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size.
    If named tensor adds batch to the first dimension."""
    elem = batch[0]
    elem_type = type(elem)

    if isinstance(elem, torch.Tensor):
        out = None
        # TODO: Named tensor: once stack supports named tensors, drop the rename.
        names = elem.names
        elem = elem.rename(None)
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.rename(None).numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        out_batch = torch.stack([_.rename(None) for _ in batch], 0, out=out)
        if any(names):
            new_names = tuple(['batch'] + list(names))
            out_batch = out_batch.refine_names(*new_names)
        return out_batch

    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(
                    default_collate_err_msg_format.format(elem.dtype))

            return named_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: named_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(named_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        transposed = zip(*batch)
        return [named_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))
Ejemplo n.º 2
0
    def __call__(self, filename):
        """
        Args:
            filename (str, list, tuple, file): path file or file-like object or a list of files.
        """
        filename = ensure_tuple(filename)
        img_array = list()
        compatible_meta = dict()
        for name in filename:
            img = nib.load(name)
            img = correct_nifti_header_if_necessary(img)
            header = dict(img.header)
            header["filename_or_obj"] = name
            header["affine"] = img.affine
            header["original_affine"] = img.affine.copy()
            header["as_closest_canonical"] = self.as_closest_canonical
            ndim = img.header["dim"][0]
            spatial_rank = min(ndim, 3)
            header["spatial_shape"] = img.header["dim"][1:spatial_rank + 1]

            if self.as_closest_canonical:
                img = nib.as_closest_canonical(img)
                header["affine"] = img.affine

            img_array.append(np.array(img.get_fdata(dtype=self.dtype)))
            img.uncache()

            if self.image_only:
                continue

            if not compatible_meta:
                for meta_key in header:
                    meta_datum = header[meta_key]
                    # pytype: disable=attribute-error
                    if (type(meta_datum).__name__ == "ndarray"
                            and np_str_obj_array_pattern.search(
                                meta_datum.dtype.str) is not None):
                        continue
                    # pytype: enable=attribute-error
                    compatible_meta[meta_key] = meta_datum
            else:
                assert np.allclose(
                    header["affine"], compatible_meta["affine"]
                ), "affine data of all images should be same."

        img_array = np.stack(img_array,
                             axis=0) if len(img_array) > 1 else img_array[0]
        if self.image_only:
            return img_array
        return img_array, compatible_meta
Ejemplo n.º 3
0
def collate_fn(batch, ignore_lists=True, pad_before_stack=True, cat_tensors=False):
    r"""Puts each data field into a tensor with outer dimension batch size"""
    # taken from pytorch source code - pytorch data collater, but does not do anything with lists (avoids zip behavior)
    f = partial(collate_fn, ignore_lists=ignore_lists, pad_before_stack=pad_before_stack, cat_tensors=cat_tensors)
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        if cat_tensors:
            return torch.cat(batch, 0, out=out)
        else:
            if pad_before_stack:
                return pad_sequence(batch, batch_first=True)
            else:
                return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return f([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: f([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(f(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        if ignore_lists:
            return batch
        else:
            transposed = zip(*batch)
            return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))
Ejemplo n.º 4
0
def dict_collate_fn(batch):
    # To keep each image's label as separate dictionaries, default pytorch behaviour will stack each key
    # Only modified one line of the pytorch 1.6.0 default collate function

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(
                    default_collate_err_msg_format.format(elem.dtype))

            return dict_collate_fn([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return batch  # !Only modified this line
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(dict_collate_fn(samples)
                           for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError(
                'each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [dict_collate_fn(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))
def collate_any(batch, key=None):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(
                    default_collate_err_msg_format.format(elem.dtype))

            return collate_any([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {
            key: collate_any([d[key] for d in batch], key=key)
            for key in elem
        }
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(collate_any(samples) for samples in zip(*batch)))
    elif key is not None:
        # we are not at top level, so NO TRANSPOSITION!!!
        # unchanged batches for example for keypoints, object_infos
        return batch
    elif isinstance(elem, container_abcs.Sequence):
        assert key is None, "You are transposing something at no-top level, which is WRONG"
        transposed = zip(*batch)
        return [collate_any(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))
Ejemplo n.º 6
0
def default_collate1(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(
                    default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples)
                           for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        if elem_type.__name__ == 'tuple':
            targets = []
            imgs = []
            for sample in batch:
                imgs.append(sample[0])
                targets.append(torch.FloatTensor(
                    sample[1]))  # torch.FloatTensor(sample[1])
            return torch.stack(imgs, 0), targets
        else:
            transposed = zip(*batch)
            return [default_collate(samples) for samples in transposed]
    raise TypeError(default_collate_err_msg_format.format(elem_type))
Ejemplo n.º 7
0
    def __getitem__(self, index):
        meta_data = None
        if self.image_only:
            img = load_nifti(
                self.image_files[index],
                as_closest_canonical=self.as_closest_canonical,
                image_only=self.image_only,
                dtype=self.dtype,
            )
        else:
            img, meta_data = load_nifti(
                self.image_files[index],
                as_closest_canonical=self.as_closest_canonical,
                image_only=self.image_only,
                dtype=self.dtype,
            )
        target = None
        if self.seg_files is not None:
            target = load_nifti(self.seg_files[index])
        elif self.labels is not None:
            target = self.labels[index]

        seed = np.random.randint(2147483647)

        if self.transform is not None:
            if isinstance(self.transform, Randomizable):
                self.transform.set_random_state(seed=seed)
            img = self.transform(img)

        if self.seg_transform is not None:
            if isinstance(self.seg_transform, Randomizable):
                self.seg_transform.set_random_state(seed=seed)
            target = self.seg_transform(target)

        if self.image_only or meta_data is None:
            return img, target

        compatible_meta = {}
        for meta_key in meta_data:
            meta_datum = meta_data[meta_key]
            if (
                type(meta_datum).__name__ == "ndarray"
                and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None
            ):
                continue
            compatible_meta[meta_key] = meta_datum
        return img, target, compatible_meta
Ejemplo n.º 8
0
 def _collate(self, batch: Sequence[Any]) -> Any:
     elem = batch[0]
     elem_type = type(elem)
     if isinstance(elem, Tensor):
         out = None
         if torch.utils.data.get_worker_info() is not None:  # type: ignore
             # If we're in a background process, concatenate directly into a
             # shared memory tensor to avoid an extra copy
             numel = sum(x.numel() for x in batch)
             storage = elem.storage()._new_shared(numel)
             out = elem.new(storage).resize_(len(batch), *list(elem.size()))
         ndims = elem.dim()
         if (ndims > 0) and ((ndims % 2) == 0):
             return torch.cat(batch, dim=0, out=out)  # type: ignore
         return torch.stack(batch, dim=0, out=out)  # type: ignore
     elif (
         elem_type.__module__ == "numpy"
         and elem_type.__name__ != "str_"
         and elem_type.__name__ != "string_"
     ):
         elem = batch[0]
         if elem_type.__name__ == "ndarray":
             # array of string classes and object
             if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                 raise TypeError(default_collate_err_msg_format.format(elem.dtype))
             return self._collate([torch.as_tensor(b) for b in batch])
     elif isinstance(elem, float):
         return torch.tensor(batch, dtype=torch.float64)
     elif isinstance(elem, int):
         return torch.tensor(batch)
     elif isinstance(elem, string_classes):
         return batch
     elif isinstance(elem, Mapping):
         return {key: self._collate([d[key] for d in batch]) for key in elem}
     elif isinstance(elem, tuple) and hasattr(elem, "_fields"):  # namedtuple
         return elem_type(**(self._collate(samples) for samples in zip(*batch)))
     elif is_dataclass(elem):  # dataclass
         return elem_type(
             **{
                 field.name: self._collate([getattr(d, field.name) for d in batch])
                 for field in fields(elem)
             }
         )
     elif isinstance(elem, (tuple, list)):
         transposed = zip(*batch)
         return [self._collate(samples) for samples in transposed]
     raise TypeError(default_collate_err_msg_format.format(elem_type))
Ejemplo n.º 9
0
def _convert_to_tensor(data: Any) -> Any:
    """
    Maps all kind of collections and numbers to tensors.

    Args:
        data: the data to convert to tensor

    Return:
        the converted data
    """
    if isinstance(data, numbers.Number):
        return torch.tensor([data])
    # is not array of object
    elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None:
        return torch.from_numpy(data)
    elif isinstance(data, torch.Tensor):
        return data

    raise TypeError(f"The given type ('{type(data).__name__}') cannot be converted to a tensor!")
Ejemplo n.º 10
0
    def __getitem__(self, index):
        meta_data = None
        if self.image_only:
            img = load_nifti(self.image_files[index],
                             as_closest_canonical=self.as_closest_canonical,
                             image_only=self.image_only,
                             dtype=self.dtype)
        else:
            img, meta_data = load_nifti(
                self.image_files[index],
                as_closest_canonical=self.as_closest_canonical,
                image_only=self.image_only,
                dtype=self.dtype)
        seg = load_nifti(self.seg_files[index])

        # https://github.com/pytorch/vision/issues/9#issuecomment-304224800
        seed = np.random.randint(2147483647)

        if self.transform is not None:
            np.random.seed(seed)
            img = self.transform(img)
            random_sync_test = np.random.randint(2147483647)

        if self.seg_transform is not None:
            np.random.seed(
                seed
            )  # ensure randomized transforms roll the same values for segmentations as images
            seg = self.seg_transform(seg)
            seg_seed = np.random.randint(2147483647)
            assert (random_sync_test == seg_seed)

        if self.image_only or meta_data is None:
            return img, seg

        compatible_meta = {}
        for meta_key in meta_data:
            meta_datum = meta_data[meta_key]
            if type(meta_datum).__name__ == 'ndarray' \
                    and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None:
                continue
            compatible_meta[meta_key] = meta_datum
        return img, seg, compatible_meta
Ejemplo n.º 11
0
def _copy_compatible_dict(from_dict: Dict, to_dict: Dict):
    if not isinstance(to_dict, dict):
        raise ValueError(f"to_dict must be a Dict, got {type(to_dict)}.")
    if not to_dict:
        for key in from_dict:
            datum = from_dict[key]
            if isinstance(datum, np.ndarray) and np_str_obj_array_pattern.search(datum.dtype.str) is not None:
                continue
            to_dict[key] = datum
    else:
        affine_key, shape_key = "affine", "spatial_shape"
        if affine_key in from_dict and not np.allclose(from_dict[affine_key], to_dict[affine_key]):
            raise RuntimeError(
                "affine matrix of all images should be the same for channel-wise concatenation. "
                f"Got {from_dict[affine_key]} and {to_dict[affine_key]}."
            )
        if shape_key in from_dict and not np.allclose(from_dict[shape_key], to_dict[shape_key]):
            raise RuntimeError(
                "spatial_shape of all images should be the same for channel-wise concatenation. "
                f"Got {from_dict[shape_key]} and {to_dict[shape_key]}."
            )
Ejemplo n.º 12
0
def adaptive_collate(batch: Any) -> Any:
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        ndims = elem.dim()
        if ndims > 0 and ndims % 2 == 0:
            return torch.cat(batch, dim=0, out=out)
        else:
            return torch.stack(batch, dim=0, out=out)
    elif (elem_type.__module__ == "numpy" and elem_type.__name__ != "str_"
          and elem_type.__name__ != "string_"):
        elem = batch[0]
        if elem_type.__name__ == "ndarray":
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(
                    default_collate_err_msg_format.format(elem.dtype))
            return adaptive_collate([torch.as_tensor(b) for b in batch])
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, tuple) and hasattr(elem, "_fields"):  # namedtuple
        return elem_type(*(adaptive_collate(samples)
                           for samples in zip(*batch)))
    elif isinstance(elem, (tuple, list)):
        transposed = zip(*batch)
        return [adaptive_collate(samples) for samples in transposed]
    raise TypeError(default_collate_err_msg_format.format(elem_type))