def multiple_samples_collate(batch, fold=False): """ Collate function for repeated augmentation. Each instance in the batch has more than one sample. Args: batch (tuple or list): data batch to collate. Returns: (tuple): collated data batch. """ inputs, labels, video_idx, time, extra_data = zip(*batch) inputs = [item for sublist in inputs for item in sublist] labels = [item for sublist in labels for item in sublist] video_idx = [item for sublist in video_idx for item in sublist] time = [item for sublist in time for item in sublist] inputs, labels, video_idx, time, extra_data = ( default_collate(inputs), default_collate(labels), default_collate(video_idx), default_collate(time), default_collate(extra_data), ) if fold: return [inputs], labels, video_idx, time, extra_data else: return inputs, labels, video_idx, time, extra_data
def collate(batch): return { IMAGE: default_collate([d[IMAGE] for d in batch]).float(), KEYPOINT_MAP: default_collate([d[KEYPOINT_MAP] for d in batch]).float(), MASK: default_collate([d[MASK] for d in batch]).float().unsqueeze(0) }
def collate_fn(batch): batch_clips, batch_targets, batch_keys = zip(*batch) # batch_keys = [key for key in batch_keys] return default_collate(batch_clips), default_collate( batch_targets), batch_keys
def maskv3_collate_fn(data): imgs, masks, targets = zip(*data) imgs = default_collate(imgs) targets = default_collate(targets) return imgs, masks, targets
def collate(batch): if len(batch) == 0: return batch elem = batch[0] if elem is None: return None elif isinstance(elem, container_abcs.Sequence): if len( elem ) == 4: # We assume those are the maps, map points, headings and patch_size scene_map, scene_pts, heading_angle, patch_size = zip(*batch) if heading_angle[0] is None: heading_angle = None else: heading_angle = torch.Tensor(heading_angle) map = scene_map[0].get_cropped_maps_from_scene_map_batch( scene_map, scene_pts=torch.Tensor(scene_pts), patch_size=patch_size[0], rotation=heading_angle) return map if isinstance(elem, (str, int)): return default_collate(batch) transposed = zip(*batch) return [collate(samples) for samples in transposed] elif isinstance(elem, container_abcs.Mapping): # We have to dill the neighbors structures. Otherwise each tensor is put into # shared memory separately -> slow, file pointer overhead # we only do this in multiprocessing neighbor_dict = {key: [d[key] for d in batch] for key in elem} return dill.dumps(neighbor_dict) if torch.utils.data.get_worker_info( ) else neighbor_dict return default_collate(batch)
def detection_collate(batch): """ Collate function for detection task. Concatanate bboxes, labels and metadata from different samples in the first dimension instead of stacking them to have a batch-size dimension. Args: batch (tuple or list): data batch to collate. Returns: (tuple): collated detection data batch. """ inputs, labels, video_idx, extra_data = zip(*batch) inputs, video_idx = default_collate(inputs), default_collate(video_idx) labels = torch.tensor(np.concatenate(labels, axis=0)).float() collated_extra_data = {} for key in extra_data[0].keys(): data = [d[key] for d in extra_data] if key == "boxes" or key == "ori_boxes": # Append idx info to the bboxes before concatenating them. bboxes = [ np.concatenate( [np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1) for i in range(len(data)) ] bboxes = np.concatenate(bboxes, axis=0) collated_extra_data[key] = torch.tensor(bboxes).float() elif key == "metadata": collated_extra_data[key] = torch.tensor( list(itertools.chain(*data))).view(-1, 2) else: collated_extra_data[key] = default_collate(data) return inputs, labels, video_idx, collated_extra_data
def collate_fn(batch): feats, n_feats, logits = zip(*batch) # the feats have all different lengths feats = torch.cat(feats, dim=0) n_feats = default_collate(n_feats) logits = default_collate(logits) return feats, n_feats, logits
def collate(batch, *, root=True): "Puts each data field into a tensor with outer dimension batch size" if len(batch) == 0: return batch error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" elem_type = type(batch[0]) if torch.is_tensor(batch[0]): return default_collate(batch) elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': return default_collate(batch) elif isinstance(batch[0], int_classes): return batch elif isinstance(batch[0], float): return batch elif isinstance(batch[0], string_classes): return batch elif isinstance(batch[0], CameraIntrinsics): return batch elif isinstance(batch[0], Mapping): if root: return { key: collate([d[key] for d in batch], root=False) for key in batch[0] } else: return batch elif isinstance(batch[0], Sequence): return [collate(e, root=False) for e in batch] raise TypeError((error_msg.format(type(batch[0]))))
def bbox_collate(batch): ''' Padding the bounding box in the collate function ''' if torch.is_tensor(batch[0][0]): return default_collate(batch) if isinstance(batch[0][0], dict) and 'xs' not in batch[0][0] \ and 'masks' not in batch[0][0]: return default_collate(batch) if 'masks' not in batch[0][0] and batch[0][0]['xs'].ndim == 0: return default_collate(batch) samples = [item[0] for item in batch] if 'masks' not in batch[0][0]: # do the bbox # pad the bboxes into xs, ys, ws, hs data = { 'xs': pad_sequence([sample['xs'] for sample in samples], batch_first=True, padding_value=-1.), 'ys': pad_sequence([sample['ys'] for sample in samples], batch_first=True, padding_value=-1.), 'ws': pad_sequence([sample['ws'] for sample in samples], batch_first=True, padding_value=-1.), 'hs': pad_sequence([sample['hs'] for sample in samples], batch_first=True, padding_value=-1.), } else: # only do the masks instead data = {} masks = [s['masks'] for s in samples] if masks == [None] * len(masks): # Hack: since we can't pass None in pl, we pass a tensor that # makes the __base__.py has no bbox data['masks'] = torch.zeros(len(masks), 1, 1, 1) else: masks = [ torch.zeros(1, *samples[0]['imgs'].shape[1:]) if m is None else m for m in masks ] data['masks'] = torch.stack(masks, dim=0) data['imgs'] = default_collate([item['imgs'] for item in samples]) if 'imgs_cf' in samples[0]: imgs_cf = [s['imgs_cf'] for s in samples] if imgs_cf != [None] * len(imgs_cf): imgs_cf = [ torch.zeros_like(samples[0]['imgs']) if m is None else m for m in imgs_cf ] data['imgs_cf'] = torch.stack(imgs_cf, dim=0) targets = default_collate([item[1] for item in batch]) return [data, targets]
def collate_docs(docs): dense, doc_ids, tag_ids, tag_offsets = zip(*docs) return ( default_collate(dense), default_collate(doc_ids), torch.cat(tag_ids), cum_offsets(tag_offsets), )
def batch_collator(batch): transposed_batch = list(zip(*batch)) transposed_batch[0] = default_collate(transposed_batch[0]) transposed_batch[1] = default_collate(transposed_batch[1]) if len(transposed_batch) > 4: transposed_batch[2] = default_collate(transposed_batch[2]) transposed_batch[3] = default_collate(transposed_batch[3]) return transposed_batch
def collate(batch): if len(batch) == 0: return batch elem = batch[0] if elem is None: return None elif isinstance(elem, container_abcs.Sequence): if len( elem ) == 4: # We assume those are the maps, map points, headings and patch_size scene_map, scene_pts, heading_angle, patch_size = zip(*batch) if heading_angle[0] is None: heading_angle = None else: heading_angle = torch.Tensor(heading_angle) map = scene_map[0].get_cropped_maps_from_scene_map_batch( scene_map, scene_pts=torch.Tensor(scene_pts), patch_size=patch_size[0], rotation=heading_angle) return map transposed = zip(*batch) return [collate(samples) for samples in transposed] elif isinstance(elem, container_abcs.Mapping): return {key: collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, torch.Tensor): batch_shapes = set([el.shape for el in batch]) if len(batch_shapes) != 1: new_batch = list(batch) # This is the case when we have different agent_count or largest_state_dim. max_agent_count = max(batch_shapes, key=lambda batch_shape: batch_shape[0])[0] max_largest_state_dim = max( batch_shapes, key=lambda batch_shape: batch_shape[-1])[-1] for idx in range(len(batch)): # num_nodes, time_length, largest_state_dim curr_shape = batch[idx].shape if len(curr_shape) == 3: pad = (0, max_largest_state_dim - curr_shape[-1], 0, 0, 0, max_agent_count - curr_shape[0]) const = np.nan elif len(curr_shape) == 2: pad = (0, max_largest_state_dim - curr_shape[-1], 0, 0) const = np.nan else: pad = (0, max_largest_state_dim - curr_shape[-1]) const = -1 new_batch[idx] = F.pad(batch[idx], pad, "constant", const) return default_collate(new_batch) return default_collate(batch)
def caption_collate(batch): if len(batch[0]) == 3: # is_train=True return default_collate(batch) elif len(batch[0]) == 4: # is_train=False collated = list() for item_i, item_list in enumerate(zip(*batch)): if item_i < 3: collated.append(default_collate(item_list)) else: collated.append(item_list) return collated else: raise NotImplementedError
def feature_extract_bbox(batch): inputs, bboxes, segment_indices, indices = zip(*batch) bboxes = [ np.concatenate([np.full((bboxes[i].shape[0], 1), float(i)), bboxes[i]], axis=1) for i in range(len(bboxes)) if len(bboxes[i]) > 0 ] if len(bboxes) > 0: bboxes = torch.tensor(np.concatenate(bboxes, axis=0)).float() inputs = default_collate(inputs) segment_indices = default_collate(segment_indices) indices = default_collate(indices) return inputs, bboxes, segment_indices, indices
def multi_supervision_collate_fn(batch: Sequence[Dict]) -> Dict: """ Custom collate_fn for K2SpeechRecognitionDataset. It merges the items provided by K2SpeechRecognitionDataset into the following structure: .. code-block:: { 'features': float tensor of shape (B, T, F) 'supervisions': [ { 'sequence_idx': Tensor[int] of shape (S,) 'text': List[str] of len S 'start_frame': Tensor[int] of shape (S,) 'num_frames': Tensor[int] of shape (S,) } ] } Dimension symbols legend: * ``B`` - batch size (number of Cuts), * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions), * ``T`` - number of frames of the longest Cut * ``F`` - number of features """ from torch.utils.data._utils.collate import default_collate dataset_idx_to_batch_idx = { example['supervisions'][0]['sequence_idx']: batch_idx for batch_idx, example in enumerate(batch) } def update(d: Dict, **kwargs) -> Dict: for key, value in kwargs.items(): d[key] = value return d supervisions = default_collate([ update(sup, sequence_idx=dataset_idx_to_batch_idx[sup['sequence_idx']]) for example in batch for sup in example['supervisions'] ]) feats = default_collate([example['features'] for example in batch]) return { 'features': feats, 'supervisions': supervisions }
def collate_fn_dailydialog(batch, include_lens=False): result = default_collate(batch) if include_lens: return (result[0], result[1]), result[2], result[3], result[4], batch[-1][-1] else: return result[0], result[2], result[3], result[4], batch[-1][-1]
def mixup_collate(data, alpha=0.1): """Implements a batch collate function with MixUp strategy from `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/pdf/1710.09412.pdf>`_ Args: data (list): list of elements alpha (float, optional): mixup factor Example:: >>> import torch >>> from holocron import utils >>> loader = torch.utils.data.DataLoader(dataset, batch_size, collate_fn=utils.data.mixup_collate) """ inputs, targets = default_collate(data) # Sample lambda if alpha > 0: lam = np.random.beta(alpha, alpha) else: lam = 1 # Mix batch indices batch_size = inputs.size()[0] index = torch.randperm(batch_size) # Create the new input and targets inputs = lam * inputs + (1 - lam) * inputs[index, :] targets_a, targets_b = targets, targets[index] return inputs, targets_a, targets_b, lam
def run(dataset_type, output_dir, use_intensity, use_squared_falloff, dc_count, _config): # The entire config dict for this experiment print("dataset_type: {}".format(dataset_type)) dataset = load_data(dataset_type) all_spad_counts = [] all_intensities = [] for i in range(len(dataset)): print("Simulating SPAD for entry {}".format(i)) data = default_collate([dataset[i]]) intensity = rgb2gray(data["rgb_cropped"].numpy()/255.) spad_counts = simulate_spad(depth_truth=data["depth_cropped"].numpy(), intensity=intensity, mask=np.ones_like(intensity)) all_spad_counts.append(spad_counts) all_intensities.append(intensity) output = { "config": _config, "spad": np.array(all_spad_counts), "intensity": np.concatenate(all_intensities), } print("saving {}_int_{}_fall_{}_dc_{}_spad.npy to {}".format(dataset_type, use_intensity, use_squared_falloff, dc_count, output_dir)) np.save(os.path.join(output_dir, "{}_int_{}_fall_{}_dc_{}_spad.npy".format(dataset_type, use_intensity, use_squared_falloff, dc_count)), output)
def graph_collate(batch): a, b, liked = zip(*batch) return ( collate_docs(a), collate_docs(b), default_collate(liked), )
def collate(batch): cmap_offset = (ColoredMNIST.cmap - 1) * ColoredMNIST.mean / ColoredMNIST.std cs = np.random.choice(len(ColoredMNIST.cmap), 2, replace=False) c1 = ColoredMNIST.cmap[cs[0]] c2 = ColoredMNIST.cmap[cs[1]] o1 = cmap_offset[cs[0]] o2 = cmap_offset[cs[1]] pair_list = [] for batch_idx in range(len(batch)): data = batch[batch_idx][0] target = batch[batch_idx][1] data1 = data[0] data2 = data[1] target1 = target[0] target2 = target[1] data1 = torch.cat([ data1 * c1[0] + o1[0], data1 * c1[1] + o1[1], data1 * c1[2] + o1[2] ], dim=0) data2 = torch.cat([ data2 * c2[0] + o2[0], data2 * c2[1] + o2[1], data2 * c2[2] + o2[2] ], dim=0) target1 = (torch.tensor(cs[0]), target1) target2 = (torch.tensor(cs[1]), target2) pair = ((data1, data2), (target1, target2)) pair_list.append(pair) batch = default_collate(pair_list) return batch
def collate_fn(self, batch: List[Feature]): word_seq_len = [len(feature.orig_to_tok_index) for feature in batch] max_seq_len = max(word_seq_len) max_wordpiece_length = max( [len(feature.input_ids) for feature in batch]) for i, feature in enumerate(batch): padding_length = max_wordpiece_length - len(feature.input_ids) input_ids = feature.input_ids + [self.tokenizer.pad_token_id ] * padding_length mask = feature.attention_mask + [0] * padding_length type_ids = feature.token_type_ids + [ self.tokenizer.pad_token_type_id ] * padding_length padding_word_len = max_seq_len - len(feature.orig_to_tok_index) orig_to_tok_index = feature.orig_to_tok_index + [ 0 ] * padding_word_len label_ids = feature.label_ids + [0] * padding_word_len batch[i] = Feature(input_ids=np.asarray(input_ids), attention_mask=np.asarray(mask), token_type_ids=np.asarray(type_ids), orig_to_tok_index=np.asarray(orig_to_tok_index), word_seq_len=feature.word_seq_len, label_ids=np.asarray(label_ids)) results = Feature(*(default_collate(samples) for samples in zip(*batch))) return results
def collate_flatten(batch, x_dim_flatten=5, y_dim_flatten=2): x, y = default_collate(batch) if len(x.shape) > x_dim_flatten: x = x.flatten(start_dim=0, end_dim=len(x.shape) - x_dim_flatten) if len(y.shape) > y_dim_flatten: y = y.flatten(start_dim=0, end_dim=len(y.shape) - y_dim_flatten) return [x, y]
def my_collate(batch): modified_batch = [] for item in batch: image, label = item if label==1 or label==2: modified_batch.append(item) return default_collate(modified_batch)
def collate(cls, samples): """Collates a sequence of samples containing graphs into a batch The samples in the sequence can contain multiple types of inputs, such as: >>> [ >>> (input_graph, tensor, other_tensor, output_graph), >>> (input_graph, tensor, other_tensor, output_graph), >>> ... >>> ] """ if isinstance(samples[0], Graph): return cls.from_graphs(samples) elif isinstance(samples[0], (str, bytes)): return samples elif isinstance(samples[0], collections.abc.Mapping): return { key: cls.collate([d[key] for d in samples]) for key in samples[0] } elif isinstance(samples[0], collections.abc.Sequence): transposed = zip(*samples) return [cls.collate(samples) for samples in transposed] else: return default_collate(samples)
def collate(self, samples: Any) -> Any: if not isinstance(samples, Tensor): elem = samples[0] if isinstance(elem, container_abcs.Sequence): return tuple(zip(*samples)) return default_collate(samples) return samples.unsqueeze(dim=0)
def list_data_collate(batch: Sequence): """ Enhancement for PyTorch DataLoader default collate. If dataset already returns a list of batch data that generated in transforms, need to merge all data to 1 list. Then it's same as the default collate behavior. Note: Need to use this collate if apply some transforms that can generate batch data. """ elem = batch[0] data = [i for k in batch for i in k] if isinstance(elem, list) else batch try: return default_collate(data) except RuntimeError as re: re_str = str(re) if "equal size" in re_str: re_str += ( "\n\nMONAI hint: if your transforms intentionally create images of different shapes, creating your " + "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its " + "documentation).") raise RuntimeError(re_str) except TypeError as re: re_str = str(re) if "numpy" in re_str and "Tensor" in re_str: re_str += ( "\n\nMONAI hint: if your transforms intentionally create mixtures of torch Tensor and numpy ndarray, " + "creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem " + "(check its documentation).") raise TypeError(re_str)
def collate_fn(self, batch: List[Feature]): word_seq_lens = [len(feature.words) for feature in batch] max_seq_len = max(word_seq_lens) max_char_seq_len = -1 for feature in batch: curr_max_char_seq_len = max(feature.char_seq_lens) max_char_seq_len = max(curr_max_char_seq_len, max_char_seq_len) for i, feature in enumerate(batch): padding_length = max_seq_len - len(feature.words) words = feature.words + [0] * padding_length chars = [] char_seq_lens = feature.char_seq_lens + [1] * padding_length for word_idx in range(feature.word_seq_len): pad_char_length = max_char_seq_len - feature.char_seq_lens[ word_idx] word_chars = feature.chars[word_idx] + [0] * pad_char_length chars.append(word_chars) for _ in range(max_seq_len - feature.word_seq_len): chars.append([0] * max_char_seq_len) labels = feature.labels + [ 0 ] * padding_length if feature.labels is not None else None batch[i] = Feature( words=np.asarray(words), chars=np.asarray(chars), char_seq_lens=np.asarray(char_seq_lens), context_emb=feature.context_emb, word_seq_len=feature.word_seq_len, labels=np.asarray(labels) if labels is not None else None) results = Feature(*(default_collate(samples) if not check_all_obj_is_None(samples) else None for samples in zip(*batch))) return results
def impute(self, x, ids, dropout_threshold=None): assert len(x) == len(ids) assert len(x) == self.n_channels union = sorted(set(reduce(add, ids))) if len(union) is not 0: x_hat = [] for s in union: # print(s, union.index(s)) # available_channels_for_s = [i if s in _ and i in self.enc_channels else None for i, _ in enumerate(ids)] available_channels_for_s = [ i if s in _ else None for i, _ in enumerate(ids) ] xs = [ x[ch][ids[ch].index(s)].unsqueeze(0) if ch is not None else None for ch in available_channels_for_s ] qs = [ self.vae[ch].encode(_) if _ is not None else None for ch, _ in enumerate(xs) ] zs_ = [_.loc if _ is not None else None for _ in qs] if dropout_threshold is not None: zs = [ self.apply_threshold(_, dropout_threshold) if _ is not None else None for _ in zs_ ] else: zs = zs_ xs_hat = [] for i in range(self.n_channels): if i in self.dec_channels: ps_hat_i = [ self.vae[i].decode(z) for z in zs if z is not None ] ev_hat_i = [ self.vae[i].p_to_expected_value(_) for _ in ps_hat_i ] xs_hat_i = torch.stack(ev_hat_i, 0).mean(0) if isinstance(ps_hat_i[0], torch.distributions.Categorical): xs_hat_i = xs_hat_i.argmax(1, keepdim=True) xs_hat.append(xs_hat_i) del ps_hat_i, ev_hat_i, xs_hat_i else: xs_hat.append([]) x_hat.append(xs_hat) del xs_hat ret = [ _.squeeze() if not isinstance(_, list) else [] for _ in default_collate(x_hat) ] else: ret = [[] for _ in x] return ret
def collate_none_when_even(self, batch): if self.counter % 2 == 0: result = None else: result = default_collate(batch) self.counter += 1 return result
def test_collate(batch): imgs = torch.stack([torch.Tensor(item[0]) for item in batch], 0) masks_cropped = torch.stack([torch.Tensor(item[1]) for item in batch], 0) masks_org = [torch.Tensor(item[2]) for item in batch] meta = default_collate([item[3] for item in batch]) return imgs, masks_cropped, masks_org, meta