예제 #1
0
def multiple_samples_collate(batch, fold=False):
    """
    Collate function for repeated augmentation. Each instance in the batch has
    more than one sample.
    Args:
        batch (tuple or list): data batch to collate.
    Returns:
        (tuple): collated data batch.
    """
    inputs, labels, video_idx, time, extra_data = zip(*batch)
    inputs = [item for sublist in inputs for item in sublist]
    labels = [item for sublist in labels for item in sublist]
    video_idx = [item for sublist in video_idx for item in sublist]
    time = [item for sublist in time for item in sublist]

    inputs, labels, video_idx, time, extra_data = (
        default_collate(inputs),
        default_collate(labels),
        default_collate(video_idx),
        default_collate(time),
        default_collate(extra_data),
    )
    if fold:
        return [inputs], labels, video_idx, time, extra_data
    else:
        return inputs, labels, video_idx, time, extra_data
예제 #2
0
def collate(batch):
    return {
        IMAGE: default_collate([d[IMAGE] for d in batch]).float(),
        KEYPOINT_MAP:
        default_collate([d[KEYPOINT_MAP] for d in batch]).float(),
        MASK: default_collate([d[MASK] for d in batch]).float().unsqueeze(0)
    }
예제 #3
0
def collate_fn(batch):
    batch_clips, batch_targets, batch_keys = zip(*batch)

    # batch_keys = [key for key in batch_keys]

    return default_collate(batch_clips), default_collate(
        batch_targets), batch_keys
예제 #4
0
def maskv3_collate_fn(data):
    imgs, masks, targets = zip(*data)

    imgs = default_collate(imgs)
    targets = default_collate(targets)

    return imgs, masks, targets
예제 #5
0
def collate(batch):
    if len(batch) == 0:
        return batch
    elem = batch[0]
    if elem is None:
        return None
    elif isinstance(elem, container_abcs.Sequence):
        if len(
                elem
        ) == 4:  # We assume those are the maps, map points, headings and patch_size
            scene_map, scene_pts, heading_angle, patch_size = zip(*batch)
            if heading_angle[0] is None:
                heading_angle = None
            else:
                heading_angle = torch.Tensor(heading_angle)
            map = scene_map[0].get_cropped_maps_from_scene_map_batch(
                scene_map,
                scene_pts=torch.Tensor(scene_pts),
                patch_size=patch_size[0],
                rotation=heading_angle)
            return map
        if isinstance(elem, (str, int)):
            return default_collate(batch)
        transposed = zip(*batch)
        return [collate(samples) for samples in transposed]
    elif isinstance(elem, container_abcs.Mapping):
        # We have to dill the neighbors structures. Otherwise each tensor is put into
        # shared memory separately -> slow, file pointer overhead
        # we only do this in multiprocessing
        neighbor_dict = {key: [d[key] for d in batch] for key in elem}
        return dill.dumps(neighbor_dict) if torch.utils.data.get_worker_info(
        ) else neighbor_dict
    return default_collate(batch)
예제 #6
0
def detection_collate(batch):
    """
    Collate function for detection task. Concatanate bboxes, labels and
    metadata from different samples in the first dimension instead of
    stacking them to have a batch-size dimension.
    Args:
        batch (tuple or list): data batch to collate.
    Returns:
        (tuple): collated detection data batch.
    """
    inputs, labels, video_idx, extra_data = zip(*batch)
    inputs, video_idx = default_collate(inputs), default_collate(video_idx)
    labels = torch.tensor(np.concatenate(labels, axis=0)).float()

    collated_extra_data = {}
    for key in extra_data[0].keys():
        data = [d[key] for d in extra_data]
        if key == "boxes" or key == "ori_boxes":
            # Append idx info to the bboxes before concatenating them.
            bboxes = [
                np.concatenate(
                    [np.full((data[i].shape[0], 1), float(i)), data[i]],
                    axis=1) for i in range(len(data))
            ]
            bboxes = np.concatenate(bboxes, axis=0)
            collated_extra_data[key] = torch.tensor(bboxes).float()
        elif key == "metadata":
            collated_extra_data[key] = torch.tensor(
                list(itertools.chain(*data))).view(-1, 2)
        else:
            collated_extra_data[key] = default_collate(data)

    return inputs, labels, video_idx, collated_extra_data
예제 #7
0
 def collate_fn(batch):
     feats, n_feats, logits = zip(*batch)
     # the feats have all different lengths
     feats = torch.cat(feats, dim=0)
     n_feats = default_collate(n_feats)
     logits = default_collate(logits)
     return feats, n_feats, logits
예제 #8
0
def collate(batch, *, root=True):
    "Puts each data field into a tensor with outer dimension batch size"

    if len(batch) == 0:
        return batch

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if torch.is_tensor(batch[0]):
        return default_collate(batch)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        return default_collate(batch)
    elif isinstance(batch[0], int_classes):
        return batch
    elif isinstance(batch[0], float):
        return batch
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], CameraIntrinsics):
        return batch
    elif isinstance(batch[0], Mapping):
        if root:
            return {
                key: collate([d[key] for d in batch], root=False)
                for key in batch[0]
            }
        else:
            return batch
    elif isinstance(batch[0], Sequence):
        return [collate(e, root=False) for e in batch]

    raise TypeError((error_msg.format(type(batch[0]))))
예제 #9
0
def bbox_collate(batch):
    '''
    Padding the bounding box in the collate function
    '''
    if torch.is_tensor(batch[0][0]):
        return default_collate(batch)
    if isinstance(batch[0][0], dict) and 'xs' not in batch[0][0] \
            and 'masks' not in batch[0][0]:
        return default_collate(batch)
    if 'masks' not in batch[0][0] and batch[0][0]['xs'].ndim == 0:
        return default_collate(batch)

    samples = [item[0] for item in batch]
    if 'masks' not in batch[0][0]:  # do the bbox
        # pad the bboxes into xs, ys, ws, hs
        data = {
            'xs':
            pad_sequence([sample['xs'] for sample in samples],
                         batch_first=True,
                         padding_value=-1.),
            'ys':
            pad_sequence([sample['ys'] for sample in samples],
                         batch_first=True,
                         padding_value=-1.),
            'ws':
            pad_sequence([sample['ws'] for sample in samples],
                         batch_first=True,
                         padding_value=-1.),
            'hs':
            pad_sequence([sample['hs'] for sample in samples],
                         batch_first=True,
                         padding_value=-1.),
        }
    else:  # only do the masks instead
        data = {}
        masks = [s['masks'] for s in samples]
        if masks == [None] * len(masks):
            # Hack: since we can't pass None in pl, we pass a tensor that
            #   makes the __base__.py has no bbox
            data['masks'] = torch.zeros(len(masks), 1, 1, 1)
        else:
            masks = [
                torch.zeros(1, *samples[0]['imgs'].shape[1:])
                if m is None else m for m in masks
            ]
            data['masks'] = torch.stack(masks, dim=0)

    data['imgs'] = default_collate([item['imgs'] for item in samples])
    if 'imgs_cf' in samples[0]:
        imgs_cf = [s['imgs_cf'] for s in samples]
        if imgs_cf != [None] * len(imgs_cf):
            imgs_cf = [
                torch.zeros_like(samples[0]['imgs']) if m is None else m
                for m in imgs_cf
            ]
            data['imgs_cf'] = torch.stack(imgs_cf, dim=0)

    targets = default_collate([item[1] for item in batch])
    return [data, targets]
예제 #10
0
파일: dataset.py 프로젝트: d4l3k/ourgraph
def collate_docs(docs):
    dense, doc_ids, tag_ids, tag_offsets = zip(*docs)
    return (
        default_collate(dense),
        default_collate(doc_ids),
        torch.cat(tag_ids),
        cum_offsets(tag_offsets),
    )
예제 #11
0
    def batch_collator(batch):
        transposed_batch = list(zip(*batch))
        transposed_batch[0] = default_collate(transposed_batch[0])
        transposed_batch[1] = default_collate(transposed_batch[1])

        if len(transposed_batch) > 4:
            transposed_batch[2] = default_collate(transposed_batch[2])
            transposed_batch[3] = default_collate(transposed_batch[3])

        return transposed_batch
예제 #12
0
def collate(batch):
    if len(batch) == 0:
        return batch

    elem = batch[0]
    if elem is None:
        return None

    elif isinstance(elem, container_abcs.Sequence):
        if len(
                elem
        ) == 4:  # We assume those are the maps, map points, headings and patch_size
            scene_map, scene_pts, heading_angle, patch_size = zip(*batch)
            if heading_angle[0] is None:
                heading_angle = None
            else:
                heading_angle = torch.Tensor(heading_angle)
            map = scene_map[0].get_cropped_maps_from_scene_map_batch(
                scene_map,
                scene_pts=torch.Tensor(scene_pts),
                patch_size=patch_size[0],
                rotation=heading_angle)
            return map
        transposed = zip(*batch)
        return [collate(samples) for samples in transposed]

    elif isinstance(elem, container_abcs.Mapping):
        return {key: collate([d[key] for d in batch]) for key in elem}

    elif isinstance(elem, torch.Tensor):
        batch_shapes = set([el.shape for el in batch])
        if len(batch_shapes) != 1:
            new_batch = list(batch)
            # This is the case when we have different agent_count or largest_state_dim.
            max_agent_count = max(batch_shapes,
                                  key=lambda batch_shape: batch_shape[0])[0]
            max_largest_state_dim = max(
                batch_shapes, key=lambda batch_shape: batch_shape[-1])[-1]
            for idx in range(len(batch)):
                # num_nodes, time_length, largest_state_dim
                curr_shape = batch[idx].shape
                if len(curr_shape) == 3:
                    pad = (0, max_largest_state_dim - curr_shape[-1], 0, 0, 0,
                           max_agent_count - curr_shape[0])
                    const = np.nan
                elif len(curr_shape) == 2:
                    pad = (0, max_largest_state_dim - curr_shape[-1], 0, 0)
                    const = np.nan
                else:
                    pad = (0, max_largest_state_dim - curr_shape[-1])
                    const = -1
                new_batch[idx] = F.pad(batch[idx], pad, "constant", const)
            return default_collate(new_batch)

    return default_collate(batch)
예제 #13
0
def caption_collate(batch):
    if len(batch[0]) == 3:  # is_train=True
        return default_collate(batch)
    elif len(batch[0]) == 4:  # is_train=False
        collated = list()
        for item_i, item_list in enumerate(zip(*batch)):
            if item_i < 3:
                collated.append(default_collate(item_list))
            else:
                collated.append(item_list)
        return collated
    else:
        raise NotImplementedError
예제 #14
0
파일: loader.py 프로젝트: vhvkhoa/SlowFast
def feature_extract_bbox(batch):
    inputs, bboxes, segment_indices, indices = zip(*batch)
    bboxes = [
        np.concatenate([np.full((bboxes[i].shape[0], 1), float(i)), bboxes[i]],
                       axis=1) for i in range(len(bboxes))
        if len(bboxes[i]) > 0
    ]

    if len(bboxes) > 0:
        bboxes = torch.tensor(np.concatenate(bboxes, axis=0)).float()
    inputs = default_collate(inputs)
    segment_indices = default_collate(segment_indices)
    indices = default_collate(indices)

    return inputs, bboxes, segment_indices, indices
예제 #15
0
def multi_supervision_collate_fn(batch: Sequence[Dict]) -> Dict:
    """
    Custom collate_fn for K2SpeechRecognitionDataset.

    It merges the items provided by K2SpeechRecognitionDataset into the following structure:

    .. code-block::

        {
            'features': float tensor of shape (B, T, F)
            'supervisions': [
                {
                    'sequence_idx': Tensor[int] of shape (S,)
                    'text': List[str] of len S
                    'start_frame': Tensor[int] of shape (S,)
                    'num_frames': Tensor[int] of shape (S,)
                }
            ]
        }

    Dimension symbols legend:
    * ``B`` - batch size (number of Cuts),
    * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions),
    * ``T`` - number of frames of the longest Cut
    * ``F`` - number of features
    """
    from torch.utils.data._utils.collate import default_collate

    dataset_idx_to_batch_idx = {
        example['supervisions'][0]['sequence_idx']: batch_idx
        for batch_idx, example in enumerate(batch)
    }

    def update(d: Dict, **kwargs) -> Dict:
        for key, value in kwargs.items():
            d[key] = value
        return d

    supervisions = default_collate([
        update(sup, sequence_idx=dataset_idx_to_batch_idx[sup['sequence_idx']])
        for example in batch
        for sup in example['supervisions']
    ])
    feats = default_collate([example['features'] for example in batch])
    return {
        'features': feats,
        'supervisions': supervisions
    }
def collate_fn_dailydialog(batch, include_lens=False):
    result = default_collate(batch)
    if include_lens:
        return (result[0],
                result[1]), result[2], result[3], result[4], batch[-1][-1]
    else:
        return result[0], result[2], result[3], result[4], batch[-1][-1]
예제 #17
0
def mixup_collate(data, alpha=0.1):
    """Implements a batch collate function with MixUp strategy from
    `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/pdf/1710.09412.pdf>`_

    Args:
        data (list): list of elements
        alpha (float, optional): mixup factor

    Example::
        >>> import torch
        >>> from holocron import utils
        >>> loader = torch.utils.data.DataLoader(dataset, batch_size, collate_fn=utils.data.mixup_collate)
    """

    inputs, targets = default_collate(data)

    # Sample lambda
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    # Mix batch indices
    batch_size = inputs.size()[0]
    index = torch.randperm(batch_size)

    # Create the new input and targets
    inputs = lam * inputs + (1 - lam) * inputs[index, :]
    targets_a, targets_b = targets, targets[index]

    return inputs, targets_a, targets_b, lam
예제 #18
0
def run(dataset_type,
        output_dir,
        use_intensity,
        use_squared_falloff,
        dc_count,
        _config):  # The entire config dict for this experiment
    print("dataset_type: {}".format(dataset_type))
    dataset = load_data(dataset_type)
    all_spad_counts = []
    all_intensities = []
    for i in range(len(dataset)):
        print("Simulating SPAD for entry {}".format(i))
        data = default_collate([dataset[i]])
        intensity = rgb2gray(data["rgb_cropped"].numpy()/255.)
        spad_counts = simulate_spad(depth_truth=data["depth_cropped"].numpy(),
                                    intensity=intensity,
                                    mask=np.ones_like(intensity))
        all_spad_counts.append(spad_counts)
        all_intensities.append(intensity)

    output = {
        "config": _config,
        "spad": np.array(all_spad_counts),
        "intensity": np.concatenate(all_intensities),
    }

    print("saving {}_int_{}_fall_{}_dc_{}_spad.npy to {}".format(dataset_type,
                                                                 use_intensity, use_squared_falloff, dc_count,
                                                                 output_dir))
    np.save(os.path.join(output_dir, "{}_int_{}_fall_{}_dc_{}_spad.npy".format(dataset_type,
                                                                               use_intensity,
                                                                               use_squared_falloff,
                                                                               dc_count)), output)
예제 #19
0
파일: dataset.py 프로젝트: d4l3k/ourgraph
def graph_collate(batch):
    a, b, liked = zip(*batch)
    return (
        collate_docs(a),
        collate_docs(b),
        default_collate(liked),
    )
예제 #20
0
    def collate(batch):
        cmap_offset = (ColoredMNIST.cmap -
                       1) * ColoredMNIST.mean / ColoredMNIST.std

        cs = np.random.choice(len(ColoredMNIST.cmap), 2, replace=False)
        c1 = ColoredMNIST.cmap[cs[0]]
        c2 = ColoredMNIST.cmap[cs[1]]
        o1 = cmap_offset[cs[0]]
        o2 = cmap_offset[cs[1]]
        pair_list = []
        for batch_idx in range(len(batch)):
            data = batch[batch_idx][0]
            target = batch[batch_idx][1]
            data1 = data[0]
            data2 = data[1]
            target1 = target[0]
            target2 = target[1]
            data1 = torch.cat([
                data1 * c1[0] + o1[0], data1 * c1[1] + o1[1],
                data1 * c1[2] + o1[2]
            ],
                              dim=0)
            data2 = torch.cat([
                data2 * c2[0] + o2[0], data2 * c2[1] + o2[1],
                data2 * c2[2] + o2[2]
            ],
                              dim=0)
            target1 = (torch.tensor(cs[0]), target1)
            target2 = (torch.tensor(cs[1]), target2)
            pair = ((data1, data2), (target1, target2))
            pair_list.append(pair)
        batch = default_collate(pair_list)
        return batch
예제 #21
0
    def collate_fn(self, batch: List[Feature]):
        word_seq_len = [len(feature.orig_to_tok_index) for feature in batch]
        max_seq_len = max(word_seq_len)
        max_wordpiece_length = max(
            [len(feature.input_ids) for feature in batch])
        for i, feature in enumerate(batch):
            padding_length = max_wordpiece_length - len(feature.input_ids)
            input_ids = feature.input_ids + [self.tokenizer.pad_token_id
                                             ] * padding_length
            mask = feature.attention_mask + [0] * padding_length
            type_ids = feature.token_type_ids + [
                self.tokenizer.pad_token_type_id
            ] * padding_length
            padding_word_len = max_seq_len - len(feature.orig_to_tok_index)
            orig_to_tok_index = feature.orig_to_tok_index + [
                0
            ] * padding_word_len
            label_ids = feature.label_ids + [0] * padding_word_len

            batch[i] = Feature(input_ids=np.asarray(input_ids),
                               attention_mask=np.asarray(mask),
                               token_type_ids=np.asarray(type_ids),
                               orig_to_tok_index=np.asarray(orig_to_tok_index),
                               word_seq_len=feature.word_seq_len,
                               label_ids=np.asarray(label_ids))
        results = Feature(*(default_collate(samples)
                            for samples in zip(*batch)))
        return results
예제 #22
0
def collate_flatten(batch, x_dim_flatten=5, y_dim_flatten=2):
    x, y = default_collate(batch)
    if len(x.shape) > x_dim_flatten:
        x = x.flatten(start_dim=0, end_dim=len(x.shape) - x_dim_flatten)
    if len(y.shape) > y_dim_flatten:
        y = y.flatten(start_dim=0, end_dim=len(y.shape) - y_dim_flatten)
    return [x, y]
예제 #23
0
def my_collate(batch):
    modified_batch = []
    for item in batch:
        image, label = item
        if label==1 or label==2:
            modified_batch.append(item)
    return default_collate(modified_batch)
예제 #24
0
    def collate(cls, samples):
        """Collates a sequence of samples containing graphs into a batch

        The samples in the sequence can contain multiple types of inputs, such as:

        >>> [
        >>>   (input_graph, tensor, other_tensor, output_graph),
        >>>   (input_graph, tensor, other_tensor, output_graph),
        >>>   ...
        >>> ]

        """
        if isinstance(samples[0], Graph):
            return cls.from_graphs(samples)
        elif isinstance(samples[0], (str, bytes)):
            return samples
        elif isinstance(samples[0], collections.abc.Mapping):
            return {
                key: cls.collate([d[key] for d in samples])
                for key in samples[0]
            }
        elif isinstance(samples[0], collections.abc.Sequence):
            transposed = zip(*samples)
            return [cls.collate(samples) for samples in transposed]
        else:
            return default_collate(samples)
예제 #25
0
 def collate(self, samples: Any) -> Any:
     if not isinstance(samples, Tensor):
         elem = samples[0]
         if isinstance(elem, container_abcs.Sequence):
             return tuple(zip(*samples))
         return default_collate(samples)
     return samples.unsqueeze(dim=0)
예제 #26
0
def list_data_collate(batch: Sequence):
    """
    Enhancement for PyTorch DataLoader default collate.
    If dataset already returns a list of batch data that generated in transforms, need to merge all data to 1 list.
    Then it's same as the default collate behavior.

    Note:
        Need to use this collate if apply some transforms that can generate batch data.

    """
    elem = batch[0]
    data = [i for k in batch for i in k] if isinstance(elem, list) else batch
    try:
        return default_collate(data)
    except RuntimeError as re:
        re_str = str(re)
        if "equal size" in re_str:
            re_str += (
                "\n\nMONAI hint: if your transforms intentionally create images of different shapes, creating your "
                +
                "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its "
                + "documentation).")
        raise RuntimeError(re_str)
    except TypeError as re:
        re_str = str(re)
        if "numpy" in re_str and "Tensor" in re_str:
            re_str += (
                "\n\nMONAI hint: if your transforms intentionally create mixtures of torch Tensor and numpy ndarray, "
                +
                "creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem "
                + "(check its documentation).")
        raise TypeError(re_str)
예제 #27
0
    def collate_fn(self, batch: List[Feature]):
        word_seq_lens = [len(feature.words) for feature in batch]
        max_seq_len = max(word_seq_lens)
        max_char_seq_len = -1
        for feature in batch:
            curr_max_char_seq_len = max(feature.char_seq_lens)
            max_char_seq_len = max(curr_max_char_seq_len, max_char_seq_len)
        for i, feature in enumerate(batch):
            padding_length = max_seq_len - len(feature.words)
            words = feature.words + [0] * padding_length
            chars = []
            char_seq_lens = feature.char_seq_lens + [1] * padding_length
            for word_idx in range(feature.word_seq_len):
                pad_char_length = max_char_seq_len - feature.char_seq_lens[
                    word_idx]
                word_chars = feature.chars[word_idx] + [0] * pad_char_length
                chars.append(word_chars)
            for _ in range(max_seq_len - feature.word_seq_len):
                chars.append([0] * max_char_seq_len)
            labels = feature.labels + [
                0
            ] * padding_length if feature.labels is not None else None

            batch[i] = Feature(
                words=np.asarray(words),
                chars=np.asarray(chars),
                char_seq_lens=np.asarray(char_seq_lens),
                context_emb=feature.context_emb,
                word_seq_len=feature.word_seq_len,
                labels=np.asarray(labels) if labels is not None else None)
        results = Feature(*(default_collate(samples)
                            if not check_all_obj_is_None(samples) else None
                            for samples in zip(*batch)))
        return results
예제 #28
0
    def impute(self, x, ids, dropout_threshold=None):

        assert len(x) == len(ids)
        assert len(x) == self.n_channels

        union = sorted(set(reduce(add, ids)))
        if len(union) is not 0:
            x_hat = []
            for s in union:
                # print(s, union.index(s))
                # available_channels_for_s = [i if s in _ and i in self.enc_channels else None for i, _ in enumerate(ids)]
                available_channels_for_s = [
                    i if s in _ else None for i, _ in enumerate(ids)
                ]
                xs = [
                    x[ch][ids[ch].index(s)].unsqueeze(0)
                    if ch is not None else None
                    for ch in available_channels_for_s
                ]
                qs = [
                    self.vae[ch].encode(_) if _ is not None else None
                    for ch, _ in enumerate(xs)
                ]
                zs_ = [_.loc if _ is not None else None for _ in qs]
                if dropout_threshold is not None:
                    zs = [
                        self.apply_threshold(_, dropout_threshold)
                        if _ is not None else None for _ in zs_
                    ]
                else:
                    zs = zs_
                xs_hat = []
                for i in range(self.n_channels):
                    if i in self.dec_channels:
                        ps_hat_i = [
                            self.vae[i].decode(z) for z in zs if z is not None
                        ]
                        ev_hat_i = [
                            self.vae[i].p_to_expected_value(_)
                            for _ in ps_hat_i
                        ]
                        xs_hat_i = torch.stack(ev_hat_i, 0).mean(0)
                        if isinstance(ps_hat_i[0],
                                      torch.distributions.Categorical):
                            xs_hat_i = xs_hat_i.argmax(1, keepdim=True)
                        xs_hat.append(xs_hat_i)
                        del ps_hat_i, ev_hat_i, xs_hat_i
                    else:
                        xs_hat.append([])
                x_hat.append(xs_hat)
                del xs_hat

            ret = [
                _.squeeze() if not isinstance(_, list) else []
                for _ in default_collate(x_hat)
            ]
        else:
            ret = [[] for _ in x]

        return ret
 def collate_none_when_even(self, batch):
     if self.counter % 2 == 0:
         result = None
     else:
         result = default_collate(batch)
     self.counter += 1
     return result
예제 #30
0
파일: __init__.py 프로젝트: ChoiDM/BraTs
def test_collate(batch):
    imgs = torch.stack([torch.Tensor(item[0]) for item in batch], 0)
    masks_cropped = torch.stack([torch.Tensor(item[1]) for item in batch], 0)
    masks_org = [torch.Tensor(item[2]) for item in batch]
    meta = default_collate([item[3] for item in batch])

    return imgs, masks_cropped, masks_org, meta