def ltr_collate_stack1(batch): """Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch""" error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" elem_type = type(batch[0]) if isinstance(batch[0], torch.Tensor): out = None if _check_use_shared_memory(): # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = batch[0].storage()._new_shared(numel) out = batch[0].new(storage) return torch.stack(batch, 1, out=out) # if batch[0].dim() < 4: # return torch.stack(batch, 0, out=out) # return torch.cat(batch, 0, out=out) elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': elem = batch[0] if elem_type.__name__ == 'ndarray': # array of string classes and object if torch.utils.data.dataloader.re.search( '[SaUO]', elem.dtype.str) is not None: raise TypeError(error_msg.format(elem.dtype)) return torch.stack([torch.from_numpy(b) for b in batch], 1) if elem.shape == (): # scalars py_type = float if elem.dtype.name.startswith('float') else int return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name]( list(map(py_type, batch))) elif isinstance(batch[0], int_classes): return torch.LongTensor(batch) elif isinstance(batch[0], float): return torch.DoubleTensor(batch) elif isinstance(batch[0], string_classes): return batch elif isinstance(batch[0], TensorDict): return TensorDict({ key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0] }) elif isinstance(batch[0], collections.Mapping): return { key: ltr_collate_stack1([d[key] for d in batch]) for key in batch[0] } elif isinstance(batch[0], TensorList): transposed = zip(*batch) return TensorList( [ltr_collate_stack1(samples) for samples in transposed]) elif isinstance(batch[0], collections.Sequence): transposed = zip(*batch) return [ltr_collate_stack1(samples) for samples in transposed] elif batch[0] is None: return batch raise TypeError((error_msg.format(type(batch[0]))))
def __call__(self, data: TensorDict, rng=None): # Apply joint transforms if self.transform['joint'] is not None: num_train_images = len(data['train_images']) all_images = data['train_images'] + data['test_images'] all_images_trans = self.transform['joint'](*all_images) data['train_images'] = all_images_trans[:num_train_images] data['test_images'] = all_images_trans[num_train_images:] for s in ['train', 'test']: assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \ "In pair mode, num train/test frames must be 1" # Add a uniform noise to the center pos jittered_anno = [ self._get_jittered_box(a, s, rng) for a in data[s + '_anno'] ] # Crop image region centered at jittered_anno box try: crops, boxes = prutils.jittered_center_crop( data[s + '_images'], jittered_anno, data[s + '_anno'], self.search_area_factor[s], self.output_sz[s], scale_type=self.scale_type, border_type=self.border_type) except Exception as e: print('{}, anno: {}'.format(data['dataset'], data[s + '_anno'])) raise e # Apply transforms data[s + '_images'] = [self.transform[s](x) for x in crops] data[s + '_anno'] = boxes # Prepare output if self.mode == 'sequence': data = data.apply(prutils.stack_tensors) else: data = data.apply(lambda x: x[0] if isinstance(x, list) else x) return data
def __iter__(self): """ args: index (int): Index (Ignored since we sample randomly) returns: TensorDict - dict containing all the data blocks """ # Select a dataset # dataset = self.rng.choices(self.datasets, self.p_datasets)[0] dataset_idx = self.rng.choice(range(len(self.datasets)), p=self.p_datasets, replace=False) dataset = self.datasets[dataset_idx] is_video_dataset = dataset.is_video_sequence() min_visible_frames = 2 * (self.num_test_frames + self.num_train_frames) enough_visible_frames = False # Sample a sequence with enough visible frames and get anno for the same while not enough_visible_frames: seq_id = self.rng.randint(0, dataset.get_num_sequences() - 1) anno, visible = dataset.get_sequence_info(seq_id) num_visible = np.sum(visible.astype('int64')) enough_visible_frames = not is_video_dataset or ( num_visible > min_visible_frames and len(visible) >= 20) if is_video_dataset: train_frame_ids = None test_frame_ids = None gap_increase = 0 if self.frame_sample_mode == 'default': # Sample frame numbers while test_frame_ids is None: train_frame_ids = self._sample_visible_ids( visible, num_ids=self.num_train_frames) test_frame_ids = self._sample_visible_ids( visible, min_id=train_frame_ids[0] - self.max_gap - gap_increase, max_id=train_frame_ids[0] + self.max_gap + gap_increase, num_ids=self.num_test_frames) gap_increase += 5 # Increase gap until a frame is found elif self.frame_sample_mode == 'causal': # Sample frame numbers in a causal manner, i.e. test_frame_ids > train_frame_ids while test_frame_ids is None: base_frame_id = self._sample_visible_ids( visible, num_ids=1, min_id=self.num_train_frames - 1, max_id=len(visible) - self.num_test_frames) prev_frame_ids = self._sample_visible_ids( visible, num_ids=self.num_train_frames - 1, min_id=base_frame_id[0] - self.max_gap - gap_increase, max_id=base_frame_id[0]) if prev_frame_ids is None: gap_increase += 5 continue train_frame_ids = base_frame_id + prev_frame_ids test_frame_ids = self._sample_visible_ids( visible, min_id=train_frame_ids[0] + 1, max_id=train_frame_ids[0] + self.max_gap + gap_increase, num_ids=self.num_test_frames) gap_increase += 5 # Increase gap until a frame is found else: raise ValueError('Unknown frame_sample_mode.') else: train_frame_ids = [1] * self.num_train_frames test_frame_ids = [1] * self.num_test_frames # Get frames train_frames, train_anno, _ = dataset.get_frames( seq_id, train_frame_ids, anno) test_frames, test_anno, _ = dataset.get_frames(seq_id, test_frame_ids, anno) # Prepare data data = TensorDict({ 'train_images': train_frames, 'train_anno': train_anno, 'test_images': test_frames, 'test_anno': test_anno, 'dataset': dataset.get_name() }) # Send for processing yield self.processing(data, rng=self.rng)
def __iter__(self): """ args: index (int): Index (Ignored since we sample randomly) returns: TensorDict - dict containing all the data blocks """ neg = self.neg and self.neg > random.random() # Select a dataset if neg: dataset_idx = self.rng.choice(range(len(self.datasets)), p=self.p_datasets, replace=False) train_dataset = self.datasets[dataset_idx] dataset_idx = self.rng.choice(range(len(self.datasets)), p=self.p_datasets, replace=False) test_dataset = self.datasets[dataset_idx] train_seq_id, test_seq_id, train_frame_ids, test_frame_ids, train_anno, test_anno = \ self._get_random_pair(train_dataset, test_dataset) # Get frames train_frames, train_anno, _ = train_dataset.get_frames( train_seq_id, train_frame_ids, train_anno) train_masks = [ np.zeros([frame.shape[0], frame.shape[1], 1], dtype=np.float32) for frame in train_frames ] test_frames, test_anno, _ = test_dataset.get_frames( test_seq_id, test_frame_ids, test_anno) test_masks = [ np.zeros([frame.shape[0], frame.shape[1], 1], dtype=np.float32) for frame in test_frames ] else: dataset_idx = self.rng.choice(range(len(self.datasets)), p=self.p_datasets, replace=False) dataset = self.datasets[dataset_idx] seq_id, train_frame_ids, test_frame_ids, anno = self._get_positive_pair( dataset) # Get frames if self.has_mask(dataset): train_frames, train_anno, train_masks, _ = dataset.get_frames_mask( seq_id, train_frame_ids, anno) test_frames, test_anno, test_masks, _ = dataset.get_frames_mask( seq_id, test_frame_ids, anno) else: train_frames, train_anno, _ = dataset.get_frames( seq_id, train_frame_ids, anno) train_masks = [ np.zeros([frame.shape[0], frame.shape[1], 1], dtype=np.float32) for frame in train_frames ] test_frames, test_anno, _ = dataset.get_frames( seq_id, test_frame_ids, anno) test_masks = [ np.zeros([frame.shape[0], frame.shape[1], 1], dtype=np.float32) for frame in test_frames ] # Prepare data data = TensorDict({ 'train_images': train_frames, 'train_anno': train_anno, 'train_masks': train_masks, 'test_images': test_frames, 'test_anno': test_anno, 'test_masks': test_masks, 'neg': neg }) # Send for processing yield self.processing(data, rng=self.rng)
def __call__(self, data: TensorDict, rng=None): """ args: data - The input data, should contain the following fields: 'train_images' - 'test_images' - 'train_anno' - 'test_anno' - returns: TensorDict - output data block with following fields: 'train_images' - 'test_images' - 'train_anno' - 'test_anno' - 'test_proposals'- 'proposal_iou' - """ # Apply joint transforms if self.transform['joint'] is not None: num_train_images = len(data['train_images']) all_images = data['train_images'] + data['test_images'] all_images_trans = self.transform['joint'](*all_images) data['train_images'] = all_images_trans[:num_train_images] data['test_images'] = all_images_trans[num_train_images:] for s in ['train', 'test']: assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \ "In pair mode, num train/test frames must be 1" # Add a uniform noise to the center pos jittered_anno = [ self._get_jittered_box(a, s, rng) for a in data[s + '_anno'] ] # Crop image region centered at jittered_anno box try: crops, boxes = prutils.jittered_center_crop( data[s + '_images'], jittered_anno, data[s + '_anno'], self.search_area_factor, self.output_sz) except Exception as e: print('{}, anno: {}'.format(data['dataset'], data[s + '_anno'])) raise e # Apply transforms data[s + '_images'] = [self.transform[s](x) for x in crops] data[s + '_anno'] = boxes # Generate proposals frame2_proposals, gt_iou = zip( *[self._generate_proposals(a, rng) for a in data['test_anno']]) data['test_proposals'] = list(frame2_proposals) data['proposal_iou'] = list(gt_iou) # Prepare output if self.mode == 'sequence': data = data.apply(prutils.stack_tensors) else: data = data.apply(lambda x: x[0] if isinstance(x, list) else x) return data
def __getitem__(self, index): """ args: index (int): Index (Ignored since we sample randomly) returns: TensorDict - dict containing all the data blocks """ # Select a dataset dataset = random.choices(self.datasets, self.p_datasets)[0] is_video_dataset = dataset.is_video_sequence() # Sample a sequence with enough visible frames enough_visible_frames = False while not enough_visible_frames: # Sample a sequence seq_id = random.randint(0, dataset.get_num_sequences() - 1) # Sample frames seq_info_dict = dataset.get_sequence_info(seq_id) visible = seq_info_dict['visible'] enough_visible_frames = visible.type(torch.int64).sum().item( ) > 2 * (self.num_test_frames + self.num_train_frames) and len(visible) >= 20 enough_visible_frames = enough_visible_frames or not is_video_dataset if is_video_dataset: train_frame_ids = None test_frame_ids = None gap_increase = 0 if self.frame_sample_mode == 'interval': # Sample frame numbers within interval defined by the first frame while test_frame_ids is None: base_frame_id = self._sample_visible_ids(visible, num_ids=1) extra_train_frame_ids = self._sample_visible_ids( visible, num_ids=self.num_train_frames - 1, min_id=base_frame_id[0] - self.max_gap - gap_increase, max_id=base_frame_id[0] + self.max_gap + gap_increase) if extra_train_frame_ids is None: gap_increase += 5 continue train_frame_ids = base_frame_id + extra_train_frame_ids test_frame_ids = self._sample_visible_ids( visible, num_ids=self.num_test_frames, min_id=train_frame_ids[0] - self.max_gap - gap_increase, max_id=train_frame_ids[0] + self.max_gap + gap_increase) gap_increase += 5 # Increase gap until a frame is found elif self.frame_sample_mode == 'causal': # Sample test and train frames in a causal manner, i.e. test_frame_ids > train_frame_ids while test_frame_ids is None: base_frame_id = self._sample_visible_ids( visible, num_ids=1, min_id=self.num_train_frames - 1, max_id=len(visible) - self.num_test_frames) prev_frame_ids = self._sample_visible_ids( visible, num_ids=self.num_train_frames - 1, min_id=base_frame_id[0] - self.max_gap - gap_increase, max_id=base_frame_id[0]) if prev_frame_ids is None: gap_increase += 5 continue train_frame_ids = base_frame_id + prev_frame_ids test_frame_ids = self._sample_visible_ids( visible, min_id=train_frame_ids[0] + 1, max_id=train_frame_ids[0] + self.max_gap + gap_increase, num_ids=self.num_test_frames) # Increase gap until a frame is found gap_increase += 5 else: # In case of image dataset, just repeat the image to generate synthetic video train_frame_ids = [1] * self.num_train_frames test_frame_ids = [1] * self.num_test_frames train_frames, train_anno, meta_obj_train = dataset.get_frames( seq_id, train_frame_ids, seq_info_dict) test_frames, test_anno, meta_obj_test = dataset.get_frames( seq_id, test_frame_ids, seq_info_dict) data = TensorDict({ 'train_images': train_frames, 'train_anno': train_anno['bbox'], 'test_images': test_frames, 'test_anno': test_anno['bbox'], 'dataset': dataset.get_name(), 'test_class': meta_obj_test.get('object_class_name') }) return self.processing(data)
def __getitem__(self, index): """ args: index (int): Index (Ignored since we sample randomly) returns: TensorDict - dict containing all the data blocks """ # Select a dataset p_datasets = self.p_datasets dataset = random.choices(self.datasets, p_datasets)[0] is_video_dataset = dataset.is_video_sequence() num_train_frames = self.sequence_sample_info['num_train_frames'] num_test_frames = self.sequence_sample_info['num_test_frames'] max_train_gap = self.sequence_sample_info['max_train_gap'] allow_missing_target = self.sequence_sample_info[ 'allow_missing_target'] min_fraction_valid_frames = self.sequence_sample_info.get( 'min_fraction_valid_frames', 0.0) if allow_missing_target: min_visible_frames = 0 else: raise NotImplementedError valid_sequence = False # Sample a sequence with enough visible frames and get anno for the same while not valid_sequence: seq_id = random.randint(0, dataset.get_num_sequences() - 1) seq_info_dict = dataset.get_sequence_info(seq_id) visible = seq_info_dict['visible'] visible_ratio = seq_info_dict.get('visible_ratio', visible) num_visible = visible.type(torch.int64).sum().item() enough_visible_frames = not is_video_dataset or ( num_visible > min_visible_frames and len(visible) >= 20) valid_sequence = enough_visible_frames if self.sequence_sample_info['mode'] == 'Sequence': if is_video_dataset: train_frame_ids = None test_frame_ids = None gap_increase = 0 test_valid_image = torch.zeros(num_test_frames, dtype=torch.int8) # Sample frame numbers in a causal manner, i.e. test_frame_ids > train_frame_ids while test_frame_ids is None: occlusion_sampling = False if dataset.has_occlusion_info( ) and self.sample_occluded_sequences: target_not_fully_visible = visible_ratio < 0.9 if target_not_fully_visible.float().sum() > 0: occlusion_sampling = True if occlusion_sampling: first_occ_frame = target_not_fully_visible.nonzero()[0] occ_end_frame = self.find_occlusion_end_frame( first_occ_frame, target_not_fully_visible) # Make sure target visible in first frame base_frame_id = self._sample_ids( visible, num_ids=1, min_id=max(0, first_occ_frame - 20), max_id=first_occ_frame - 5) if base_frame_id is None: base_frame_id = 0 else: base_frame_id = base_frame_id[0] prev_frame_ids = self._sample_ids( visible, num_ids=num_train_frames, min_id=base_frame_id - max_train_gap - gap_increase - 1, max_id=base_frame_id - 1) if prev_frame_ids is None: if base_frame_id - max_train_gap - gap_increase - 1 < 0: prev_frame_ids = [base_frame_id ] * num_train_frames else: gap_increase += 5 continue train_frame_ids = prev_frame_ids end_frame = min(occ_end_frame + random.randint(5, 20), len(visible) - 1) if (end_frame - base_frame_id) < num_test_frames: rem_frames = num_test_frames - (end_frame - base_frame_id) end_frame = random.randint( end_frame, min(len(visible) - 1, end_frame + rem_frames)) base_frame_id = max( 0, end_frame - num_test_frames + 1) end_frame = min(end_frame, len(visible) - 1) step_len = float(end_frame - base_frame_id) / float( num_test_frames) test_frame_ids = [ base_frame_id + int(x * step_len) for x in range(0, num_test_frames) ] test_valid_image[:len(test_frame_ids)] = 1 test_frame_ids = test_frame_ids + [0] * ( num_test_frames - len(test_frame_ids)) else: # Make sure target visible in first frame base_frame_id = self._sample_ids( visible, num_ids=1, min_id=2 * num_train_frames, max_id=len(visible) - int(num_test_frames * min_fraction_valid_frames)) if base_frame_id is None: base_frame_id = 0 else: base_frame_id = base_frame_id[0] prev_frame_ids = self._sample_ids( visible, num_ids=num_train_frames, min_id=base_frame_id - max_train_gap - gap_increase - 1, max_id=base_frame_id - 1) if prev_frame_ids is None: if base_frame_id - max_train_gap - gap_increase - 1 < 0: prev_frame_ids = [base_frame_id ] * num_train_frames else: gap_increase += 5 continue train_frame_ids = prev_frame_ids test_frame_ids = list( range( base_frame_id, min(len(visible), base_frame_id + num_test_frames))) test_valid_image[:len(test_frame_ids)] = 1 test_frame_ids = test_frame_ids + [0] * ( num_test_frames - len(test_frame_ids)) else: raise NotImplementedError else: raise NotImplementedError # Get frames train_frames, train_anno_dict, _ = dataset.get_frames( seq_id, train_frame_ids, seq_info_dict) train_anno = train_anno_dict['bbox'] test_frames, test_anno_dict, _ = dataset.get_frames( seq_id, test_frame_ids, seq_info_dict) test_anno = test_anno_dict['bbox'] test_valid_anno = test_anno_dict['valid'] test_visible = test_anno_dict['visible'] test_visible_ratio = test_anno_dict.get('visible_ratio', torch.ones(len(test_visible))) # Prepare data data = TensorDict({ 'train_images': train_frames, 'train_anno': train_anno, 'test_images': test_frames, 'test_anno': test_anno, 'test_valid_anno': test_valid_anno, 'test_visible': test_visible, 'test_valid_image': test_valid_image, 'test_visible_ratio': test_visible_ratio, 'dataset': dataset.get_name() }) # Send for processing return self.processing(data)
def __getitem__(self, index): """ args: index (int): Index (dataset index) returns: TensorDict - dict containing all the data blocks """ # Select a dataset dataset = random.choices(self.datasets, self.p_datasets)[0] is_video_dataset = dataset.is_video_sequence() reverse_sequence = False if self.p_reverse is not None: reverse_sequence = random.random() < self.p_reverse # Sample a sequence with enough visible frames enough_visible_frames = False while not enough_visible_frames: # Sample a sequence seq_id = random.randint(0, dataset.get_num_sequences() - 1) # Sample frames seq_info_dict = dataset.get_sequence_info(seq_id) visible = seq_info_dict['visible'] enough_visible_frames = visible.type(torch.int64).sum().item( ) > 2 * (self.num_test_frames + self.num_train_frames) enough_visible_frames = enough_visible_frames or not is_video_dataset if is_video_dataset: train_frame_ids = None test_frame_ids = None gap_increase = 0 # Sample test and train frames in a causal manner, i.e. test_frame_ids > train_frame_ids while test_frame_ids is None: if gap_increase > 1000: raise Exception('Frame not found') if not reverse_sequence: base_frame_id = self._sample_visible_ids( visible, num_ids=1, min_id=self.num_train_frames - 1, max_id=len(visible) - self.num_test_frames) prev_frame_ids = self._sample_visible_ids( visible, num_ids=self.num_train_frames - 1, min_id=base_frame_id[0] - self.max_gap - gap_increase, max_id=base_frame_id[0]) if prev_frame_ids is None: gap_increase += 5 continue train_frame_ids = base_frame_id + prev_frame_ids test_frame_ids = self._sample_visible_ids( visible, min_id=train_frame_ids[0] + 1, max_id=train_frame_ids[0] + self.max_gap + gap_increase, num_ids=self.num_test_frames) # Increase gap until a frame is found gap_increase += 5 else: # Sample in reverse order, i.e. train frames come after the test frames base_frame_id = self._sample_visible_ids( visible, num_ids=1, min_id=self.num_test_frames + 1, max_id=len(visible) - self.num_train_frames - 1) prev_frame_ids = self._sample_visible_ids( visible, num_ids=self.num_train_frames - 1, min_id=base_frame_id[0], max_id=base_frame_id[0] + self.max_gap + gap_increase) if prev_frame_ids is None: gap_increase += 5 continue train_frame_ids = base_frame_id + prev_frame_ids test_frame_ids = self._sample_visible_ids( visible, min_id=0, max_id=train_frame_ids[0] - 1, num_ids=self.num_test_frames) # Increase gap until a frame is found gap_increase += 5 else: # In case of image dataset, just repeat the image to generate synthetic video train_frame_ids = [1] * self.num_train_frames test_frame_ids = [1] * self.num_test_frames # Sort frames train_frame_ids = sorted(train_frame_ids, reverse=reverse_sequence) test_frame_ids = sorted(test_frame_ids, reverse=reverse_sequence) all_frame_ids = train_frame_ids + test_frame_ids # Load frames all_frames, all_anno, meta_obj = dataset.get_frames( seq_id, all_frame_ids, seq_info_dict) train_frames = all_frames[:len(train_frame_ids)] test_frames = all_frames[len(train_frame_ids):] train_anno = {} test_anno = {} for key, value in all_anno.items(): train_anno[key] = value[:len(train_frame_ids)] test_anno[key] = value[len(train_frame_ids):] train_masks = train_anno['mask'] if 'mask' in train_anno else None test_masks = test_anno['mask'] if 'mask' in test_anno else None data = TensorDict({ 'train_images': train_frames, 'train_masks': train_masks, 'train_anno': train_anno['bbox'], 'test_images': test_frames, 'test_masks': test_masks, 'test_anno': test_anno['bbox'], 'dataset': dataset.get_name() }) return self.processing(data)
def __call__(self, data: TensorDict, rng=None): neg = data['neg'] # Apply joint transforms if self.transform['joint'] is not None: num_train_images = len(data['train_images']) all_images = data['train_images'] + data['test_images'] all_images_trans = self.transform['joint'](*all_images) data['train_images'] = all_images_trans[:num_train_images] data['test_images'] = all_images_trans[num_train_images:] for s in ['train', 'test']: assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \ "In pair mode, num train/test frames must be 1" # Add a uniform noise to the center pos jittered_anno = [ self._get_jittered_box(a, s, rng) for a in data[s + '_anno'] ] # Crop image region centered at jittered_anno box try: crops, boxes = prutils.jittered_center_crop( data[s + '_images'], jittered_anno, data[s + '_anno'], self.search_area_factor[s], self.output_sz[s], scale_type=self.scale_type, border_type=self.border_type) mask_crops, _ = prutils.jittered_center_crop( data[s + '_masks'], jittered_anno, data[s + '_anno'], self.search_area_factor[s], self.output_sz[s], scale_type=self.scale_type, border_type='zeropad') except Exception as e: print('{}, anno: {}'.format(data['dataset'], data[s + '_anno'])) raise e # Apply transforms data[s + '_images'] = [self.transform[s](x) for x in crops] data[s + '_anno'] = boxes data[s + '_masks'] = [ self.transform[s + '_mask'](x) for x in mask_crops ] # Prepare output if self.mode == 'sequence': data = data.apply(prutils.stack_tensors) else: data = data.apply(lambda x: x[0] if isinstance(x, list) else x) # Get labels if self.label_params is not None: assert data['test_anno'].shape[0] == 1 gt_box = data['test_anno'][0] gt_box[2:] += gt_box[:2] cls, delta, delta_weight, overlap = self._get_label(gt_box, neg) mask = data['test_masks'][0] if np.sum(mask) > 0: mask_weight = cls.max(axis=0, keepdims=True) else: mask_weight = np.zeros([1, cls.shape[1], cls.shape[2]], dtype=np.float32) mask = (mask > 0.5) * 2. - 1. data['label_cls'] = cls data['label_loc'] = delta data['label_loc_weight'] = delta_weight data['label_mask'] = mask data['label_mask_weight'] = mask_weight data.pop('train_anno') data.pop('test_anno') data.pop('train_masks') data.pop('test_masks') return data