Esempio n. 1
0
    def __call__(self, data: TensorDict):
        # Apply joint transforms
        if self.transform['joint'] is not None:
            data['train_images'], data['train_anno'] = self.transform['joint'](image=data['train_images'], bbox=data['train_anno'])
            data['test_images'], data['test_anno'] = self.transform['joint'](image=data['test_images'], bbox=data['test_anno'], new_roll=False)

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]

            # Crop image region centered at jittered_anno box
            crops, boxes, _ = prutils.jittered_center_crop(data[s + '_images'], jittered_anno, data[s + '_anno'],
                                                           self.search_area_factor, self.output_sz)

            # Apply transforms
            data[s + '_images'], data[s + '_anno'] = self.transform[s](image=crops, bbox=boxes, joint=False)

        # Generate proposals
        proposals, proposal_density, gt_density, proposal_iou = zip(
            *[self._generate_proposals(a) for a in data['test_anno']])

        data['test_proposals'] = proposals
        data['proposal_density'] = proposal_density
        data['gt_density'] = gt_density
        data['proposal_iou'] = proposal_iou
        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Esempio n. 2
0
    def __call__(self, data: TensorDict):
        # Apply joint transformations. i.e. All train/test frames in a sequence are applied the transformation with the
        # same parameters
        if self.transform['joint'] is not None:
            data['train_images'], data['train_anno'], data['train_masks'] = self.transform['joint'](
                image=data['train_images'], bbox=data['train_anno'], mask=data['train_masks'])
            data['test_images'], data['test_anno'], data['test_masks'] = self.transform['joint'](
                image=data['test_images'], bbox=data['test_anno'], mask=data['test_masks'], new_roll=self.new_roll)

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]
            orig_anno = data[s + '_anno']

            # Extract a crop containing the target
            crops, boxes, mask_crops = prutils.target_image_crop(data[s + '_images'], jittered_anno,
                                                                 data[s + '_anno'], self.search_area_factor,
                                                                 self.output_sz, mode=self.crop_type,
                                                                 max_scale_change=self.max_scale_change,
                                                                 masks=data[s + '_masks'])

            # Apply independent transformations to each image
            data[s + '_images'], data[s + '_anno'], data[s + '_masks'] = self.transform[s](image=crops, bbox=boxes, mask=mask_crops, joint=False)

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Esempio n. 3
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -

        returns:
            TensorDict - output data block with following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
                'test_proposals'-
                'proposal_iou'  -
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            # Crop image region centered at jittered_anno box
            crops, boxes, _ = prutils.jittered_center_crop(
                data[s + '_images'], jittered_anno, data[s + '_anno'],
                self.search_area_factor, self.output_sz)

            # Apply transforms
            data[s + '_images'] = [self.transform[s](x) for x in crops]
            data[s + '_anno'] = boxes

        # Generate proposals
        frame2_proposals, regs = zip(
            *[self._generate_proposals(a) for a in data['test_anno']])

        data['test_proposals'] = list(frame2_proposals)
        data['regs'] = list(regs)

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Esempio n. 4
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images', test_images', 'train_anno', 'test_anno'
        returns:
            TensorDict - output data block with following fields:
                'train_images', 'test_images', 'train_anno', 'test_anno', 'test_proposals', 'proposal_density', 'gt_density',
                'test_label' (optional), 'train_label' (optional), 'test_label_density' (optional), 'train_label_density' (optional)
        """

        if self.transform['joint'] is not None:
            data['train_images'], data['train_anno'] = self.transform['joint'](image=data['train_images'], bbox=data['train_anno'])
            data['test_images'], data['test_anno'] = self.transform['joint'](image=data['test_images'], bbox=data['test_anno'], new_roll=False)

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]

            crops, boxes = prutils.target_image_crop(data[s + '_images'], jittered_anno, data[s + '_anno'],
                                                     self.search_area_factor, self.output_sz, mode=self.crop_type,
                                                     max_scale_change=self.max_scale_change)

            data[s + '_images'], data[s + '_anno'] = self.transform[s](image=crops, bbox=boxes, joint=False)

        # Generate proposals
        proposals, proposal_density, gt_density = zip(*[self._generate_proposals(a) for a in data['test_anno']])

        data['test_proposals'] = proposals
        data['proposal_density'] = proposal_density
        data['gt_density'] = gt_density

        for s in ['train', 'test']:
            is_distractor = data.get('is_distractor_{}_frame'.format(s), None)
            if is_distractor is not None:
                for is_dist, box in zip(is_distractor, data[s+'_anno']):
                    if is_dist:
                        box[0] = 99999999.9
                        box[1] = 99999999.9

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        # Generate label functions
        if self.label_function_params is not None:
            data['train_label'] = self._generate_label_function(data['train_anno'])
            data['test_label'] = self._generate_label_function(data['test_anno'])
        if self.label_density_params is not None:
            data['train_label_density'] = self._generate_label_density(data['train_anno'])
            data['test_label_density'] = self._generate_label_density(data['test_anno'])

        return data
Esempio n. 5
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -

        returns:
            TensorDict - output data block with following fields:
                'train_images'  -processing
                'test_images'   -processing
                'train_anno'    -processing
                'test_anno'     -processing
                'test_proposals'-
                'proposal_iou'  -
        """

        for s in ['train']:
            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            # Crop image region centered at jittered_anno box
            crops, boxes = prutils.jittered_center_crop(
                data[s + '_images'], jittered_anno, data[s + '_anno'],
                self.search_area_factor, self.output_sz)
            crops1, boxes1 = prutils.jittered_center_crop(
                data[s + '_mask'], jittered_anno, data[s + '_anno'],
                self.search_area_factor, self.output_sz)
            # Apply transforms
            data[s + '_images'] = [self.transform[s](x) for x in crops]
            print("data[s + '_images'] len", len(crops))
            print("data['train_mask'] lem", len(crops1))
            #print("crops",crops1[1].shape)
            data['train_mask'] = [
                torch.from_numpy(x).view(x.shape[0], x.shape[1])
                for x in crops1
            ]
            data[s + '_anno'] = boxes
        print("data[s + '_images'] len", data['train_images'])
        print("data['train_mask'] lem", data['train_mask'])
        # Generate proposals
        #gt_mask = zip(*[self._generate_proposals(a) for a in data['test_anno']])

        # data['train_mask'] = torch.from_numpy(data['train_mask']).view(1,data['train_mask'].shape[0],data['train_mask'].shape[1])

        #data[s + '_mask'] = [torch.from_numpy(x) for x in crops1]
        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Esempio n. 6
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'search_images', template_images', 'search_anno', 'template_anno'
        returns:
            TensorDict - output data block with following fields:
                'search_images', 'template_images', 'search_anno', 'template_anno'
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            data['search_images'], data['search_anno'] = self.transform[
                'joint'](image=data['search_images'], bbox=data['search_anno'])
            data['template_images'], data['template_anno'] = self.transform[
                'joint'](image=data['template_images'],
                         bbox=data['template_anno'],
                         new_roll=False)

        for s in ['search', 'template']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num search/template frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            # Crop image region centered at jittered_anno box
            if s == 'search':
                crops, boxes, _ = prutils.jittered_center_crop(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.search_area_factor, self.search_sz)
            elif s == 'template':
                crops, boxes, _ = prutils.jittered_center_crop(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.template_area_factor, self.temp_sz)
            else:
                raise NotImplementedError

            # Apply transforms
            data[s + '_images'], data[s + '_anno'] = self.transform[s](
                image=crops, bbox=boxes, joint=False)

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
        data['template_images'] = data['template_images'].squeeze()
        data['search_images'] = data['search_images'].squeeze()
        data['template_anno'] = data['template_anno'].squeeze()
        data['search_anno'] = data['search_anno'].squeeze()
        return data
Esempio n. 7
0
    def __call__(self,data: TensorDict, *args, **kwargs):

        for s in ['train', 'test']:
            assert len(data[s + '_images'])==len(data[s+'_anno'])==1

            # data[s + '_images'] = [self.transform[s](x) for x in crops]

            ori_shape = data[s + '_images'][0].shape
            img, img_shape, pad_shape, scale_factor = self.transform[s](
                data[s + '_images'][0],
                scale=(1333,800),
                keep_ratio=True)
            img = to_tensor(img).unsqueeze(0)
            img_meta = [
                dict(
                    ori_shape=ori_shape,
                    img_shape=img_shape,
                    pad_shape=pad_shape,
                    scale_factor=scale_factor,
                    flip=False)
            ]

            anno=data[s+'_anno'][0]
            bbox=BoxList(anno, [ori_shape[1], ori_shape[0]], mode="xywh")

            image_sz = img_shape[1], img_shape[0]
            anno = bbox.resize(image_sz).convert("xyxy").bbox
            # confidence = torch.ones(len(anno)).to(device)
            # proposals = [torch.cat([anno.bbox, confidence[:, None]], dim=1)]

            data[s+'_images']=[img]
            data[s+'_anno']=[anno]
            data[s+'_img_meta']=[img_meta]
        data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
        return data
Esempio n. 8
0
    def __getitem__(self, index):
        """
        args:
            index (int): Index (Ignored since we sample randomly)

        returns:
            TensorDict - dict containing all the data blocks
        """

        # Select a dataset
        dataset = random.choices(self.datasets, self.p_datasets)[0]
        is_video_dataset = dataset.is_video_sequence()

        min_visible_frames = 2 * (self.num_test_frames + self.num_train_frames)
        enough_visible_frames = False

        # Sample a sequence with enough visible frames and get anno for the same
        while not enough_visible_frames:
            seq_id = random.randint(0, dataset.get_num_sequences() - 1)
            anno, visible = dataset.get_sequence_info(seq_id)
            num_visible = visible.type(torch.int64).sum().item()
            enough_visible_frames = not is_video_dataset or (num_visible > min_visible_frames and len(visible) >= 20)
        gap_increase=0
        if is_video_dataset:
            train_frame_ids = None
            if self.frame_sample_mode == 'default':
                train_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_train_frames)       
            elif self.frame_sample_mode == 'causal':
                # Sample frame numbers in a causal manner, i.e. test_frame_ids > train_frame_ids
               
                base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=self.num_train_frames - 1,
                                                             max_id=len(visible)-self.num_test_frames)
                prev_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_train_frames - 1,
                                                              min_id=base_frame_id[0] - self.max_gap - gap_increase,
                                                              max_id=base_frame_id[0])
                while prev_frame_ids is None:
                    gap_increase += 5
                    prev_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_train_frames - 1,
                                                              min_id=base_frame_id[0] - self.max_gap - gap_increase,
                                                              max_id=base_frame_id[0])
                    continue
                train_frame_ids = base_frame_id + prev_frame_ids
                    
            else:
                raise ValueError('Unknown frame_sample_mode.')
        else:
            train_frame_ids = [1]*self.num_train_frames
           
        # Get frames
        train_frames, train_anno,train_mask,_ = dataset.get_frames_mask(seq_id, train_frame_ids, anno)

        # Prepare data
        data = TensorDict({'train_images': train_frames,
                           'train_anno': train_anno,
                           'train_mask': train_mask,
                           'dataset': dataset.get_name()})

        # Send for processing
        return self.processing(data)
Esempio n. 9
0
def ltr_collate_stack1(batch):
    """Puts each data field into a tensor. The tensors are stacked at dim=1 to form the batch"""

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        out = None
        if _check_use_shared_memory():
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = batch[0].storage()._new_shared(numel)
            out = batch[0].new(storage)
        return torch.stack(batch, 1, out=out)
        # if batch[0].dim() < 4:
        #     return torch.stack(batch, 0, out=out)
        # return torch.cat(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            # if torch.utils.data.dataloader.re.search('[SaUO]', elem.dtype.str) is not None:
            # Modified by Song
            if re.search('[SaUO]', elem.dtype.str) is not None:
                raise TypeError(error_msg.format(elem.dtype))

            return torch.stack([torch.from_numpy(b) for b in batch], 1)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return torch.utils.data.dataloader.numpy_type_map[elem.dtype.name](
                list(map(py_type, batch)))
    elif isinstance(batch[0], int_classes):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], TensorDict):
        return TensorDict({
            key: ltr_collate_stack1([d[key] for d in batch])
            for key in batch[0]
        })
    elif isinstance(batch[0], collections.Mapping):
        return {
            key: ltr_collate_stack1([d[key] for d in batch])
            for key in batch[0]
        }
    elif isinstance(batch[0], TensorList):
        transposed = zip(*batch)
        return TensorList(
            [ltr_collate_stack1(samples) for samples in transposed])
    elif isinstance(batch[0], collections.Sequence):
        transposed = zip(*batch)
        return [ltr_collate_stack1(samples) for samples in transposed]
    elif batch[0] is None:
        return batch

    raise TypeError((error_msg.format(type(batch[0]))))
Esempio n. 10
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

        for s in ['train', 'test']:
            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            # Crop image region centered at jittered_anno box
            crops, boxes = prutils.jittered_center_crop(
                data[s + '_images'], jittered_anno, data[s + '_anno'],
                self.search_area_factor, self.output_sz)

            # Apply transforms
            data[s + '_images'] = [self.transform[s](x) for x in crops]
            data[s + '_anno'] = boxes

            ## random flip
            FLIP = random.random() < 0.5
            if FLIP:
                data[s + '_images'][0] = data[s + '_images'][0].flip(2)
                WIDTH = data[s + '_images'][0].shape[1]
                data[s + '_anno'][0][0] = WIDTH - data[
                    s + '_anno'][0][0] - data[s + '_anno'][0][2]

        # torch.set_printoptions(threshold=20000)
        # Generate train and test proposals for scaler
        train_scaler_proposals, train_scaler_labels = zip(
            *[self._generate_scaler_proposals(a) for a in data['train_anno']])
        test_scaler_proposals, test_scaler_labels = zip(
            *[self._generate_scaler_proposals(a) for a in data['test_anno']])

        data['train_scaler_proposals'], data['train_scaler_labels'] = list(
            train_scaler_proposals), list(train_scaler_labels)
        data['test_scaler_proposals'], data['test_scaler_labels'] = list(
            test_scaler_proposals), list(test_scaler_labels)

        # Generate train and test proposals for locator
        data['test_anno_jittered'] = [
            self._get_jittered_box2(a) for a in data['test_anno']
        ]
        train_locator_proposals, train_locator_labels = zip(
            *[self._generate_locator_proposals(a) for a in data['train_anno']])
        test_locator_proposals, test_locator_labels = zip(*[
            self._generate_locator_proposals(a)
            for a in data['test_anno_jittered']
        ])

        data['train_locator_proposals'], data['train_locator_labels'] = list(
            train_locator_proposals), list(train_locator_labels)
        data['test_locator_proposals'], data['test_locator_labels'] = list(
            test_locator_proposals), list(test_locator_labels)

        data['train_locator_proposals'][0] = torch.cat(
            (data['train_locator_proposals'][0], data['train_anno'][0].reshape(
                1, -1)),
            dim=0)
        data['train_locator_labels'][0] = torch.cat(
            (data['train_locator_labels'][0], torch.Tensor([1.0])), dim=0)
        data['test_locator_proposals'][0] = torch.cat(
            (data['test_locator_proposals'][0], data['test_anno'][0].reshape(
                1, -1)),
            dim=0)
        data['test_locator_labels'][0] = torch.cat(
            (data['test_locator_labels'][0], torch.Tensor([1.0])), dim=0)

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Esempio n. 11
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'
                'test_images'
                'train_anno'
                'test_anno'

        returns:
            TensorDict - output data block with following fields:
                'train_images'
                'test_images'
                'train_anno'
                'test_anno'
                'test_proposals'
                'proposal_iou'
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            # Crop image region centered at jittered_anno box
            crops, boxes = prutils_SE.jittered_center_crop_SE(
                data[s + '_images'],
                jittered_anno,
                data[s + '_anno'],
                self.search_area_factor,
                self.output_sz,
                mode=cv.BORDER_CONSTANT)
            # Apply transforms
            data[s + '_images'] = [self.transform[s](x)
                                   for x in crops]  # x : `numpy.ndarray`
            data[s + '_anno'] = boxes

            mask_crops = prutils_SE.jittered_center_crop_SE(
                data[s + '_masks'],
                jittered_anno,
                data[s + '_anno'],
                self.search_area_factor,
                self.output_sz,
                get_bbox_coord=False,
                mode=cv.BORDER_CONSTANT)
            data[s + '_masks'] = [self.mask_np2torch(x) for x in mask_crops]

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Esempio n. 12
0
    def __call__(self, data: TensorDict):
        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            crops, boxes = prutils.jittered_center_crop_comb(
                data[s + '_images'], jittered_anno, data[s + '_anno'],
                self.search_area_factor, self.output_sz)

            data[s + '_images'] = [self.transform[s](x) for x in crops]
            data[s + '_anno'] = boxes

        if self.proposal_params:
            frame2_proposals, gt_iou = zip(*[
                self._generate_proposals(a.numpy()) for a in data['test_anno']
            ])

            data['test_proposals'] = [
                torch.tensor(p, dtype=torch.float32) for p in frame2_proposals
            ]
            data['proposal_iou'] = [
                torch.tensor(gi, dtype=torch.float32) for gi in gt_iou
            ]

        if 'is_distractor_test_frame' in data:
            data['is_distractor_test_frame'] = torch.tensor(
                data['is_distractor_test_frame'], dtype=torch.uint8)
        else:
            data['is_distractor_test_frame'] = torch.zeros(len(
                data['test_images']),
                                                           dtype=torch.uint8)

        if 'is_distractor_train_frame' in data:
            data['is_distractor_train_frame'] = torch.tensor(
                data['is_distractor_train_frame'], dtype=torch.uint8)
        else:
            data['is_distractor_train_frame'] = torch.zeros(len(
                data['train_images']),
                                                            dtype=torch.uint8)

        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        if self.label_function_params is not None:
            data['train_label'] = self._generate_label_function(
                data['train_anno'], data['is_distractor_train_frame'])

            data['test_label'] = self._generate_label_function(
                data['test_anno'], data['is_distractor_test_frame'])

        return data
Esempio n. 13
0
    def forward(self, *args, **kwargs):
        data = TensorDict(kwargs)
        data = data.apply(lambda x: x[0] if isinstance(x, torch.Tensor) else x)

        img_meta = data['train_img_meta'][0].copy()
        for key in img_meta:
            values = img_meta[key]
            for i in range(len(values)):
                if len(values) > 1:
                    img_meta[key][i] = img_meta[key][i].item()
                else:
                    img_meta[key] = img_meta[key][i].item()

        with torch.no_grad():
            self.detector.extract_feat_pre(data['train_images'])
            result = self.detector.simple_test_post(rescale=True,
                                                    img_meta=[img_meta])
            prediction = self.toBoxlist(result, 0.3)
            if prediction is None:
                box_gt = data['train_anno'].to("cpu")
                roi_box = torch.cat([box_gt])
            else:
                num_pred = len(prediction)
                top_k = min(self.top_k, num_pred)
                scores = prediction.get_field("scores")
                _, ind = torch.topk(scores, top_k)
                img_shape = img_meta['img_shape']
                boxes = prediction[ind].resize([img_shape[1],
                                                img_shape[0]]).bbox
                box_gt = data['train_anno'].to("cpu")
                roi_box = torch.cat([box_gt, boxes])

            iou = self.compute_iou(roi_box, box_gt)
            roi_ind = torch.zeros([len(roi_box), 1])

            roi1 = torch.cat([roi_ind, roi_box], dim=1).to("cuda")
            labels1 = (iou > 0.7).squeeze().float().to("cuda")
            roi_fea1 = self.detector.extract_roi_featrue(roi1)

        img_meta = data['test_img_meta'][0].copy()
        for key in img_meta:
            values = img_meta[key]
            for i in range(len(values)):
                if len(values) > 1:
                    img_meta[key][i] = img_meta[key][i].item()
                else:
                    img_meta[key] = img_meta[key][i].item()

        with torch.no_grad():
            self.detector.extract_feat_pre(data['test_images'])
            result = self.detector.simple_test_post(rescale=True,
                                                    img_meta=[img_meta])
            prediction = self.toBoxlist(result, 0.3)

            if prediction is None:
                box_gt = data['test_anno'].to("cpu")
                roi_box = torch.cat([box_gt])
            else:
                num_pred = len(prediction)
                top_k = min(self.top_k, num_pred)
                scores = prediction.get_field("scores")
                _, ind = torch.topk(scores, top_k)
                img_shape = img_meta['img_shape']
                boxes = prediction[ind].resize([img_shape[1],
                                                img_shape[0]]).bbox
                box_gt = data['test_anno'].to("cpu")
                roi_box = torch.cat([box_gt, boxes])

            iou = self.compute_iou(roi_box, box_gt)
            roi_ind = torch.zeros([len(roi_box), 1])

            roi2 = torch.cat([roi_ind, roi_box], dim=1).to("cuda")
            labels2 = (iou > 0.7).squeeze().float().to("cuda")
            roi_fea2 = self.detector.extract_roi_featrue(roi2)

        predict_scores1 = self.selector(roi_fea2[0][None, ], roi_fea1)
        predict_scores2 = self.selector(roi_fea1[0][None, ], roi_fea2)

        return predict_scores1, labels1, predict_scores2, labels2
Esempio n. 14
0
    def __getitem__(self, index):
        """
        args:
            index (int): Index (Ignored since we sample randomly)

        returns:
            TensorDict - dict containing all the data blocks
        """

        # Select a dataset
        p_datasets = self.p_datasets

        dataset = random.choices(self.datasets, p_datasets)[0]
        is_video_dataset = dataset.is_video_sequence()

        num_train_frames = self.sequence_sample_info['num_train_frames']
        num_test_frames = self.sequence_sample_info['num_test_frames']
        max_train_gap = self.sequence_sample_info['max_train_gap']
        allow_missing_target = self.sequence_sample_info['allow_missing_target']
        min_fraction_valid_frames = self.sequence_sample_info.get('min_fraction_valid_frames', 0.0)

        if allow_missing_target:
            min_visible_frames = 0
        else:
            raise NotImplementedError

        valid_sequence = False

        # Sample a sequence with enough visible frames and get anno for the same
        while not valid_sequence:
            seq_id = random.randint(0, dataset.get_num_sequences() - 1)

            seq_info_dict = dataset.get_sequence_info(seq_id)
            visible = seq_info_dict['visible']
            visible_ratio = seq_info_dict.get('visible_ratio', visible)

            num_visible = visible.type(torch.int64).sum().item()

            enough_visible_frames = not is_video_dataset or (num_visible > min_visible_frames and len(visible) >= 20)

            valid_sequence = enough_visible_frames

        if self.sequence_sample_info['mode'] == 'Sequence':
            if is_video_dataset:
                train_frame_ids = None
                test_frame_ids = None
                gap_increase = 0

                test_valid_image = torch.zeros(num_test_frames, dtype=torch.int8)
                # Sample frame numbers in a causal manner, i.e. test_frame_ids > train_frame_ids
                while test_frame_ids is None:
                    occlusion_sampling = False
                    if dataset.has_occlusion_info() and self.sample_occluded_sequences:
                        target_not_fully_visible = visible_ratio < 0.9
                        if target_not_fully_visible.float().sum() > 0:
                            occlusion_sampling = True

                    if occlusion_sampling:
                        first_occ_frame = target_not_fully_visible.nonzero()[0]

                        occ_end_frame = self.find_occlusion_end_frame(first_occ_frame, target_not_fully_visible)

                        # Make sure target visible in first frame
                        base_frame_id = self._sample_ids(visible, num_ids=1, min_id=max(0, first_occ_frame - 20),
                                                         max_id=first_occ_frame - 5)

                        if base_frame_id is None:
                            base_frame_id = 0
                        else:
                            base_frame_id = base_frame_id[0]

                        prev_frame_ids = self._sample_ids(visible, num_ids=num_train_frames,
                                                          min_id=base_frame_id - max_train_gap - gap_increase - 1,
                                                          max_id=base_frame_id - 1)

                        if prev_frame_ids is None:
                            if base_frame_id - max_train_gap - gap_increase - 1 < 0:
                                prev_frame_ids = [base_frame_id] * num_train_frames
                            else:
                                gap_increase += 5
                                continue

                        train_frame_ids = prev_frame_ids

                        end_frame = min(occ_end_frame + random.randint(5, 20), len(visible) - 1)

                        if (end_frame - base_frame_id) < num_test_frames:
                            rem_frames = num_test_frames - (end_frame - base_frame_id)
                            end_frame = random.randint(end_frame, min(len(visible) - 1, end_frame + rem_frames))
                            base_frame_id = max(0, end_frame - num_test_frames + 1)

                            end_frame = min(end_frame, len(visible) - 1)

                        step_len = float(end_frame - base_frame_id) / float(num_test_frames)

                        test_frame_ids = [base_frame_id + int(x * step_len) for x in range(0, num_test_frames)]
                        test_valid_image[:len(test_frame_ids)] = 1

                        test_frame_ids = test_frame_ids + [0] * (num_test_frames - len(test_frame_ids))
                    else:
                        # Make sure target visible in first frame
                        base_frame_id = self._sample_ids(visible, num_ids=1, min_id=2*num_train_frames,
                                                         max_id=len(visible) - int(num_test_frames * min_fraction_valid_frames))
                        if base_frame_id is None:
                            base_frame_id = 0
                        else:
                            base_frame_id = base_frame_id[0]

                        prev_frame_ids = self._sample_ids(visible, num_ids=num_train_frames,
                                                          min_id=base_frame_id - max_train_gap - gap_increase - 1,
                                                          max_id=base_frame_id - 1)
                        if prev_frame_ids is None:
                            if base_frame_id - max_train_gap - gap_increase - 1 < 0:
                                prev_frame_ids = [base_frame_id] * num_train_frames
                            else:
                                gap_increase += 5
                                continue

                        train_frame_ids = prev_frame_ids

                        test_frame_ids = list(range(base_frame_id, min(len(visible), base_frame_id + num_test_frames)))
                        test_valid_image[:len(test_frame_ids)] = 1

                        test_frame_ids = test_frame_ids + [0]*(num_test_frames - len(test_frame_ids))
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError

        # Get frames
        train_frames, train_anno_dict, _ = dataset.get_frames(seq_id, train_frame_ids, seq_info_dict)
        train_anno = train_anno_dict['bbox']

        test_frames, test_anno_dict, _ = dataset.get_frames(seq_id, test_frame_ids, seq_info_dict)
        test_anno = test_anno_dict['bbox']
        test_valid_anno = test_anno_dict['valid']
        test_visible = test_anno_dict['visible']
        test_visible_ratio = test_anno_dict.get('visible_ratio', torch.ones(len(test_visible)))

        # Prepare data
        data = TensorDict({'train_images': train_frames,
                           'train_anno': train_anno,
                           'test_images': test_frames,
                           'test_anno': test_anno,
                           'test_valid_anno': test_valid_anno,
                           'test_visible': test_visible,
                           'test_valid_image': test_valid_image,
                           'test_visible_ratio': test_visible_ratio,
                           'dataset': dataset.get_name()})

        # Send for processing
        return self.processing(data)
Esempio n. 15
0
    def __getitem__(self,
                    index,
                    vis_thresh=20,
                    challenge_thresh=0.5):  # Dropped vis_thresh from 20 -> 10
        """
        args:
            index (int): Index (Ignored since we sample randomly)

        returns:
            TensorDict - dict containing all the data blocks
        """

        # Order dataset by difficulty
        # ids = []
        # for dataset in self.datasets:
        #     if os.path.exists("data_stats/{}.npy".format(dataset.name)):
        #         ids.append(np.load("data_stats/{}.npy".format(dataset.name)))
        #     else:
        #         ids.append(self.get_data_difficulty(dataset))

        # Select a dataset
        idx = random.choices(range(len(self.p_datasets)), self.p_datasets)[0]
        dataset = self.datasets[idx]
        # cid = ids[idx][:int(challenge_thresh * len(ids[idx]))]
        # dataset = random.choices(self.datasets, self.p_datasets)[0]
        is_video_dataset = dataset.is_video_sequence()

        # Sample a sequence with enough visible frames
        seq_id, seq_info_dict, visible, enough_visible_frames = self.sample_seq(
            dataset, vis_thresh, is_video_dataset)
        count = 0
        if is_video_dataset:
            template_frame_ids = None
            search_frame_ids = None
            gap_increase = 0

            if self.frame_sample_mode == 'interval':
                # Sample frame numbers within interval defined by the first frame
                while search_frame_ids is None:
                    base_frame_id = self._sample_visible_ids(visible,
                                                             num_ids=1)
                    extra_template_frame_ids = self._sample_visible_ids(
                        visible,
                        num_ids=self.num_template_frames - 1,
                        min_id=base_frame_id[0] - self.max_gap - gap_increase,
                        max_id=base_frame_id[0] + self.max_gap + gap_increase)
                    if extra_template_frame_ids is None:
                        gap_increase += 5
                        continue
                    template_frame_ids = base_frame_id + extra_template_frame_ids
                    search_frame_ids = self._sample_visible_ids(
                        visible,
                        num_ids=self.num_search_frames,
                        min_id=template_frame_ids[0] - self.max_gap -
                        gap_increase,
                        max_id=template_frame_ids[0] + self.max_gap +
                        gap_increase)
                    gap_increase += 5  # Increase gap until a frame is found

            elif self.frame_sample_mode == 'interval_sorted':
                # Sample frame numbers within interval defined by the first frame
                while search_frame_ids is None:
                    base_frame_id = self._sample_visible_ids(visible,
                                                             num_ids=1)
                    extra_template_frame_ids = self._sample_visible_ids(
                        visible,
                        num_ids=self.num_template_frames - 1,
                        min_id=base_frame_id[0] - self.max_gap - gap_increase,
                        max_id=base_frame_id[0] + self.max_gap + gap_increase)
                    if extra_template_frame_ids is None:
                        gap_increase += 5
                        continue
                    template_frame_ids = base_frame_id + extra_template_frame_ids
                    search_frame_ids = self._sample_visible_ids(
                        visible,
                        num_ids=self.num_search_frames,
                        min_id=template_frame_ids[0] - self.max_gap -
                        gap_increase,
                        max_id=template_frame_ids[0] + self.max_gap +
                        gap_increase)
                    if template_frame_ids[0] > max(search_frame_ids):
                        search_frame_ids = sorted(search_frame_ids)
                    else:
                        search_frame_ids = sorted(
                            search_frame_ids)[::-1]  # Sort for the RNN
                    gap_increase += 5  # Increase gap until a frame is found

            elif self.frame_sample_mode == 'causal':
                # Sample search and template frames in a causal manner, i.e. search_frame_ids > template_frame_ids
                while search_frame_ids is None:
                    base_frame_id = self._sample_visible_ids(
                        visible,
                        num_ids=1,
                        min_id=self.num_template_frames - 1,
                        max_id=len(visible) - self.num_search_frames)
                    prev_frame_ids = self._sample_visible_ids(
                        visible,
                        num_ids=self.num_template_frames - 1,
                        min_id=base_frame_id[0] - self.max_gap - gap_increase,
                        max_id=base_frame_id[0])
                    if prev_frame_ids is None:
                        gap_increase += 5
                        continue
                    template_frame_ids = base_frame_id + prev_frame_ids
                    search_frame_ids = self._sample_visible_ids(
                        visible,
                        min_id=template_frame_ids[0] + 1,
                        max_id=template_frame_ids[0] + self.max_gap +
                        gap_increase,
                        num_ids=self.num_search_frames)
                    # Increase gap until a frame is found
                    gap_increase += 5
            elif self.frame_sample_mode == 'rnn_causal':
                # Sample search and template frames in a causal manner, i.e. search_frame_ids > template_frame_ids
                # visible = torch.ones_like(visible)  # Force everything visible
                while search_frame_ids is None:
                    base_frame_id = self._sample_visible_ids(
                        visible,
                        num_ids=1,
                        min_id=self.num_template_frames - 1,
                        max_id=len(visible) - self.num_search_frames)
                    prev_frame_ids = self._sample_visible_ids(
                        visible,
                        num_ids=self.num_template_frames - 1,
                        min_id=base_frame_id[0] - self.max_gap - gap_increase,
                        max_id=base_frame_id[0])
                    if prev_frame_ids is None:
                        gap_increase += 5
                        continue
                    template_frame_ids = base_frame_id + prev_frame_ids
                    # search_frame_ids = np.arange(template_frame_ids[0] + 1, template_frame_ids[0] + 1 + self.num_search_frames)  # Rather than sample just take the rest of the sequence in order.
                    # Sample from template to the next self.num_search_frames
                    search_frame_ids = self._sample_seq_ids(
                        visible,
                        min_id=template_frame_ids[0] +
                        1,  # template_frame_ids[0] + 1,
                        max_id=template_frame_ids[0] + self.max_gap +
                        gap_increase,
                        num_ids=self.num_search_frames)
                    # Increase gap until a frame is found
                    gap_increase += 5

            elif self.frame_sample_mode == 'rnn_interval':
                # Sample search and template frames in a causal manner, i.e. search_frame_ids > template_frame_ids
                # visible = torch.ones_like(visible)  # Force everything visible
                gap_increase = 0
                count = 0
                num_search_frames = self.num_search_frames
                while search_frame_ids is None:
                    base_frame_id = self._sample_visible_ids_ar(
                        visible, num_ids=1, num_search=num_search_frames)
                    extra_template_frame_ids = self._sample_visible_ids_ar(
                        visible,
                        num_ids=self.num_template_frames - 1,
                        num_search=num_search_frames,
                        min_id=base_frame_id[0] - self.max_gap - 1 -
                        gap_increase,
                        max_id=base_frame_id[0] + self.max_gap + 1 +
                        gap_increase)
                    # if extra_template_frame_ids is None:
                    #     gap_increase += 5
                    #     continue
                    template_frame_ids = base_frame_id + extra_template_frame_ids
                    # 4 cases: min -> mid, mid -> max, max -> mid, mid -> min
                    min_to_mid = 1 if (template_frame_ids[0] - self.max_gap -
                                       gap_increase) > 0 else 0
                    mid_to_max = 1 if (template_frame_ids[0] + self.max_gap +
                                       gap_increase) < len(visible) else 0
                    max_to_mid = 1 if (template_frame_ids[0] + self.max_gap +
                                       gap_increase) < len(visible) else 0
                    mid_to_min = 1 if (template_frame_ids[0] - self.max_gap -
                                       gap_increase) > 0 else 0
                    # case = random.choices(np.arange(5), [min_to_mid, mid_to_max, max_to_mid, mid_to_min, 1])[0]
                    if sum([0, mid_to_max, 0, mid_to_min]) == 0:
                        # gap_increase += 5  # Increase gap until a frame is found
                        # count += 1
                        # if count > 4:
                        # num_search_frames = num_search_frames // 2
                        # Let's just sample a new dataset.
                        seq_id, seq_info_dict, visible, enough_visible_frames = self.sample_seq(
                            dataset, vis_thresh, is_video_dataset)
                        continue

                    case = random.choices(np.arange(4),
                                          [0, mid_to_max, 0, mid_to_min])[0]
                    if case == 0:
                        search_frame_ids = self._sample_seq_ids(
                            visible,
                            min_id=template_frame_ids[0] - self.max_gap -
                            gap_increase,  # template_frame_ids[0] + 1,
                            max_id=template_frame_ids[0],
                            num_ids=num_search_frames)
                    elif case == 1:
                        search_frame_ids = self._sample_seq_ids(
                            visible,
                            min_id=template_frame_ids[
                                0],  # template_frame_ids[0] + 1,
                            max_id=template_frame_ids[0] + self.max_gap +
                            gap_increase,
                            num_ids=num_search_frames)
                    elif case == 2:
                        search_frame_ids = self._sample_seq_ids(
                            visible,
                            min_id=template_frame_ids[0] + self.max_gap +
                            gap_increase,  # template_frame_ids[0] + 1,
                            max_id=template_frame_ids[0],
                            num_ids=num_search_frames)
                    elif case == 3:
                        search_frame_ids = self._sample_seq_ids(
                            visible,
                            min_id=template_frame_ids[
                                0],  # template_frame_ids[0] + 1,
                            max_id=template_frame_ids[0] - self.max_gap -
                            gap_increase,
                            num_ids=num_search_frames)

                # if search_frame_ids is None or np.any(search_frame_ids < 0):
                #     import pdb;pdb.set_trace()

        else:
            # In case of image dataset, just repeat the image to generate synthetic video
            template_frame_ids = [1] * self.num_template_frames
            search_frame_ids = [1] * self.num_search_frames

        template_frames, template_anno, meta_obj_template = dataset.get_frames(
            seq_id, template_frame_ids, seq_info_dict)
        search_frames, search_anno, meta_obj_search = dataset.get_frames(
            seq_id, search_frame_ids, seq_info_dict)
        data = TensorDict({
            'template_images': template_frames,
            'template_anno': template_anno['bbox'],
            'search_images': search_frames,
            'search_visible': visible[search_frame_ids],
            'search_anno': search_anno['bbox']
        })
        return self.processing(data)
Esempio n. 16
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images', test_images', 'train_anno', 'test_anno'
        returns:
            TensorDict - output data block with following fields:
                'train_images', 'test_images', 'train_anno', 'test_anno', 'test_proposals', 'proposal_iou'
        """

        # Apply joint transforms
        if self.transform['joint'] is not None:
            data['train_images'], data['train_anno'] = self.transform['joint'](image=data['train_images'], bbox=data['train_anno'])
            data['test_images'], data['test_anno'] = self.transform['joint'](image=data['test_images'], bbox=data['test_anno'], new_roll=False)


        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]

            # Crop image region centered at jittered_anno box
            #crops, boxes, depth_crops = prutils.jittered_center_crop(data[s + '_images'], jittered_anno, data[s + '_anno'],
                                                           #self.search_area_factor, self.output_sz, masks=data[s + '_depths'])
            crops, boxes = prutils.jittered_center_crop(data[s + '_images'], jittered_anno, data[s + '_anno'],
                                                           self.search_area_factor, self.output_sz)
            crops_depth, boxes_depth = prutils.jittered_center_crop(data[s + '_depths'], jittered_anno, data[s + '_anno'],
                                                           self.search_area_factor, self.output_sz)
            # data[s + '_depths'] = crops_depth
            # Apply transforms
            data[s + '_images'], data[s + '_anno'] = self.transform[s](image=crops, bbox=boxes, joint=False)

            # Depth crops no need to bright nromalizetion
            # data[s + '_depths'], _ = self.transform[s](image=crops_depth, bbox=boxes_depth, joint=False)
            # Song : add depth, just need ToTensor,
            if isinstance(crops_depth, (list, tuple)):
                data[s + '_depths'] = [torch.from_numpy(np.asarray(x).transpose((2, 0, 1))) for x in  crops_depth]
            else:
                crops_depth = np.asarray(crops_depth)
                if len(crops_depth.shape) == 3:
                    data[s + '_depths'] = [torch.from_numpy(np.asarray(crops_depth).transpose((2, 0, 1)))]
                elif len(crops_depth.shape) == 4:
                    data[s + '_depths'] = [torch.from_numpy(np.asarray(crops_depth).transpose((0, 3, 1, 2)))]
                else:
                    print('crops_depth dimensions error, num_dim=', np.ndim(crops_depth))
                    data[s + '_depths'] = torch.from_numpy(np.asarray(crops_depth))
                    
        # Generate proposals
        frame2_proposals, gt_iou = zip(*[self._generate_proposals(a) for a in data['test_anno']])

        data['test_proposals'] = list(frame2_proposals)
        data['proposal_iou'] = list(gt_iou)

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Esempio n. 17
0
    def __getitem__(self, index):
        """
        args:
            index (int): Index (Ignored since we sample randomly)

        returns:
            TensorDict - dict containing all the data blocks
        """
        dataset = random.choices(self.datasets, self.p_datasets)[0]
        is_video_dataset = dataset.is_video_sequence()

        min_visible_frames = 2 * (self.num_test_frames + self.num_train_frames)
        enough_visible_frames = False

        # Sample a sequence with enough visible frames and get anno for the same
        while not enough_visible_frames:
            seq_id = random.randint(0, dataset.get_num_sequences() - 1)
            seq_info_dict = dataset.get_sequence_info(seq_id)
            visible = seq_info_dict['visible']

            num_visible = visible.type(torch.int64).sum().item()
            # visible frames > 20 ==> ,visible frames > 20
            enough_visible_frames = ((not is_video_dataset) and num_visible > 0) or\
                                    (num_visible > min_visible_frames and (not dataset.has_mask()) and len(visible) >= 20) or \
                                    (num_visible > min_visible_frames and (dataset.has_mask()) and len(visible) >= 2)

        if is_video_dataset:
            train_frame_ids = None
            test_frame_ids = None
            gap_increase = 0
            if self.frame_sample_mode == 'default':
                # Sample frame numbers
                while test_frame_ids is None:
                    train_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_train_frames)
                    test_frame_ids = self._sample_visible_ids(visible, min_id=train_frame_ids[0] - self.max_gap - gap_increase,
                                                              max_id=train_frame_ids[0] + self.max_gap + gap_increase,
                                                              num_ids=self.num_test_frames)
                    gap_increase += 5   # Increase gap until a frame is found
            elif self.frame_sample_mode == 'causal':
                # Sample frame numbers in a causal manner, i.e. test_frame_ids > train_frame_ids
                while test_frame_ids is None:
                    base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=self.num_train_frames - 1,
                                                             max_id=len(visible)-self.num_test_frames)
                    prev_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_train_frames - 1,
                                                              min_id=base_frame_id[0] - self.max_gap - gap_increase,
                                                              max_id=base_frame_id[0])
                    if prev_frame_ids is None:
                        gap_increase += 5
                        continue
                    train_frame_ids = base_frame_id + prev_frame_ids
                    test_frame_ids = self._sample_visible_ids(visible, min_id=train_frame_ids[0]+1,
                                                              max_id=train_frame_ids[0] + self.max_gap + gap_increase,
                                                              num_ids=self.num_test_frames)
                    gap_increase += 5   # Increase gap until a frame is found
            else:
                raise ValueError('Unknown frame_sample_mode.')
        else:
            train_frame_ids = [1]*self.num_train_frames
            test_frame_ids = [1]*self.num_test_frames

        # Get frames
        if not dataset.has_mask():
            # standard procedure as ATOM
            train_frames, train_anno_dict, _ = dataset.get_frames(seq_id, train_frame_ids, seq_info_dict)
            train_anno = train_anno_dict['bbox']

            test_frames, test_anno_dict, _ = dataset.get_frames(seq_id, test_frame_ids, seq_info_dict)
            test_anno = test_anno_dict['bbox']

            # Prepare data
            H,W,_ = train_frames[0].shape
            data = TensorDict({'train_images': train_frames,
                               'train_masks': [np.zeros((H,W,1))],
                               'train_anno': train_anno, #list [(4,) torch tensor]
                               'test_images': test_frames,
                               'test_masks': [np.zeros((H,W,1))],
                               'test_anno': test_anno, # list [(4,) torch tensor]
                               'dataset': dataset.get_name(),
                               'mask':False})

        else:
            # if mask exists in data, process it here
            train_frames, train_masks, train_anno_dict, _ = dataset.get_frames(seq_id, train_frame_ids, seq_info_dict)
            train_anno = train_anno_dict['bbox']
            test_frames, test_masks, test_anno_dict, _ = dataset.get_frames(seq_id, test_frame_ids, seq_info_dict)
            test_anno = test_anno_dict['bbox']
            # Prepare data
            data = TensorDict({'train_images': train_frames,
                               'train_masks': train_masks, # [ndarray (H,W,1)]
                               'train_anno': train_anno,  # list [(4,) torch tensor]
                               'test_images': test_frames,
                               'test_masks': test_masks, # [ndarray (H,W,1)]
                               'test_anno': test_anno,  # list [(4,) torch tensor]
                               'dataset': dataset.get_name(),
                               'mask':True})
        # Send for processing
        data = self.processing(data)
        return data
Esempio n. 18
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
        """

        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            if self.crop_type == 'replicate':
                crops, boxes = prutils.jittered_center_crop(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.search_area_factor, self.output_sz)
            elif self.crop_type == 'nopad':
                crops, boxes = prutils.jittered_center_crop_nopad(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.search_area_factor, self.output_sz)
            else:
                raise ValueError('Unknown crop type {}'.format(self.crop_type))

            data[s + '_images'] = [self.transform[s](x) for x in crops]
            boxes = torch.stack(boxes)
            boxes_init = boxes
            boxes_init[:, 2:4] = boxes[:, 0:2] + boxes[:, 2:4]
            boxes = boxes_init.clamp(0.0, 287.0)

            boxes[:, 2:4] = boxes[:, 2:4] - boxes[:, 0:2]
            data[s + '_anno'] = boxes

        if self.proposal_params:
            frame2_proposals, gt_iou = zip(
                *[self._generate_proposals(a) for a in data['test_anno']])

            data['test_proposals'] = list(frame2_proposals)
            data['proposal_iou'] = list(gt_iou)

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        test_anno = data['test_anno'].clone()
        test_anno[:, 2:4] = test_anno[:, 0:2] + test_anno[:, 2:4]
        center_288 = (test_anno[:, 0:2] + test_anno[:, 2:4]) * 0.5
        w_288, h_288 = test_anno[:,
                                 2] - test_anno[:,
                                                0], test_anno[:,
                                                              3] - test_anno[:,
                                                                             1]
        wl_288, wr_288 = center_288[:,
                                    0] - test_anno[:,
                                                   0], test_anno[:,
                                                                 2] - center_288[:,
                                                                                 0]
        ht_288, hb_288 = center_288[:,
                                    1] - test_anno[:,
                                                   1], test_anno[:,
                                                                 3] - center_288[:,
                                                                                 1]
        w2h2_288 = torch.stack((wl_288, wr_288, ht_288, hb_288),
                               dim=1)  # [num_images, 4]

        boxes_72 = (data['test_anno'] * self.output_spatial_scale).float()
        # boxes is in format xywh, convert it to x0y0x1y1 format
        boxes_72[:, 2:4] = boxes_72[:, 0:2] + boxes_72[:, 2:4]

        center_float = torch.stack(((boxes_72[:, 0] + boxes_72[:, 2]) / 2.,
                                    (boxes_72[:, 1] + boxes_72[:, 3]) / 2.),
                                   dim=1)
        center_int = center_float.int().float()
        ind_72 = center_int[:,
                            1] * self.output_w + center_int[:,
                                                            0]  # [num_images, 1]

        data['ind_72'] = ind_72.long()
        data['w2h2_288'] = w2h2_288
        data['w2h2_72'] = w2h2_288 * 0.25

        ### Generate label functions
        if self.label_function_params is not None:
            data['train_label'] = self._generate_label_function(
                data['train_anno'])
            data['test_label'] = self._generate_label_function(
                data['test_anno'])

            # data['train_label_36'] = self._generate_label_36_function(data['train_anno'])
            # data['test_label_36'] = self._generate_label_36_function(data['test_anno'])

            data['train_label_72'] = self._generate_label_72_function(
                data['train_anno'])
            data['test_label_72'] = self._generate_label_72_function(
                data['test_anno'])

        return data
Esempio n. 19
0
    def __getitem__(self, index):
        """
        args:
            index (int): Index (Ignored since we sample randomly)

        returns:
            TensorDict - dict containing all the data blocks
        """
        # Select a dataset
        dataset = random.choices(self.datasets, self.p_datasets)[0]
        is_video_dataset = dataset.is_video_sequence()

        # Sample a sequence with enough visible frames
        enough_visible_frames = False
        while not enough_visible_frames:
            # Sample a sequence
            seq_id = random.randint(0, dataset.get_num_sequences() - 1)

            # Sample frames
            seq_info_dict = dataset.get_sequence_info(seq_id)
            visible = seq_info_dict['valid']
            visible_sum = visible.sum()
            enough_visible_frames = visible_sum > 2 * (
                self.num_test_frames +
                self.num_train_frames) and visible_sum >= 20

            enough_visible_frames = enough_visible_frames or not is_video_dataset

        if is_video_dataset:
            train_frame_ids = None
            test_frame_ids = None
            gap_increase = 0

            if self.frame_sample_mode == 'interval':
                # Sample frame numbers within interval defined by the first frame
                while test_frame_ids is None:
                    base_frame_id = self._sample_visible_ids(visible,
                                                             num_ids=1)
                    extra_train_frame_ids = self._sample_visible_ids(
                        visible,
                        num_ids=self.num_train_frames - 1,
                        min_id=base_frame_id[0] - self.max_gap - gap_increase,
                        max_id=base_frame_id[0] + self.max_gap + gap_increase)
                    if extra_train_frame_ids is None:
                        gap_increase += 5
                        continue
                    train_frame_ids = base_frame_id + extra_train_frame_ids
                    test_frame_ids = self._sample_visible_ids(
                        visible,
                        num_ids=self.num_test_frames,
                        min_id=train_frame_ids[0] - self.max_gap -
                        gap_increase,
                        max_id=train_frame_ids[0] + self.max_gap +
                        gap_increase)
                    gap_increase += 5  # Increase gap until a frame is found

            elif self.frame_sample_mode == 'causal':
                # Sample test and train frames in a causal manner, i.e. test_frame_ids > train_frame_ids
                while test_frame_ids is None:
                    base_frame_id = self._sample_visible_ids(
                        visible,
                        num_ids=1,
                        min_id=self.num_train_frames - 1,
                        max_id=len(visible) - self.num_test_frames)
                    prev_frame_ids = self._sample_visible_ids(
                        visible,
                        num_ids=self.num_train_frames - 1,
                        min_id=base_frame_id[0] - self.max_gap - gap_increase,
                        max_id=base_frame_id[0])
                    if prev_frame_ids is None:
                        gap_increase += 5
                        continue
                    train_frame_ids = base_frame_id + prev_frame_ids
                    test_frame_ids = self._sample_visible_ids(
                        visible,
                        min_id=train_frame_ids[0] + 1,
                        max_id=train_frame_ids[0] + self.max_gap +
                        gap_increase,
                        num_ids=self.num_test_frames)
                    # Increase gap until a frame is found
                    gap_increase += 5
        else:
            # In case of image dataset, just repeat the image to generate synthetic video
            train_frame_ids = [1] * self.num_train_frames
            test_frame_ids = [1] * self.num_test_frames

        train_frames, train_depth, train_anno = dataset.get_frames(
            seq_id, train_frame_ids, seq_info_dict)
        test_frames, test_depth, test_anno = dataset.get_frames(
            seq_id, test_frame_ids, seq_info_dict)

        data = TensorDict({
            'train_images': train_frames,
            'train_depths': train_depth,
            'train_anno': train_anno['bbox'],
            'test_images': test_frames,
            'test_depths': test_depth,
            'test_anno': test_anno['bbox'],
            'dataset': dataset.get_name()
        })
        return self.processing(data)
Esempio n. 20
0
    def __getitem__(self, index):
        """
        args:
            index (int): Index (dataset index)

        returns:
            TensorDict - dict containing all the data blocks
        """

        # Select a dataset
        dataset = random.choices(self.datasets, self.p_datasets)[0]

        is_video_dataset = dataset.is_video_sequence()

        reverse_sequence = False
        if self.p_reverse is not None:
            reverse_sequence = random.random() < self.p_reverse

        # Sample a sequence with enough visible frames
        enough_visible_frames = False
        while not enough_visible_frames:
            # Sample a sequence
            seq_id = random.randint(0, dataset.get_num_sequences() - 1)

            # Sample frames
            seq_info_dict = dataset.get_sequence_info(seq_id)
            visible = seq_info_dict['visible']

            enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (self.num_test_frames + self.num_train_frames)

            enough_visible_frames = enough_visible_frames or not is_video_dataset

        if is_video_dataset:
            train_frame_ids = None
            test_frame_ids = None
            gap_increase = 0

            # Sample test and train frames in a causal manner, i.e. test_frame_ids > train_frame_ids
            while test_frame_ids is None:
                if gap_increase > 1000:
                    raise Exception('Frame not found')

                if not reverse_sequence:
                    base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=self.num_train_frames - 1,
                                                             max_id=len(visible)-self.num_test_frames)
                    prev_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_train_frames - 1,
                                                              min_id=base_frame_id[0] - self.max_gap - gap_increase,
                                                              max_id=base_frame_id[0])
                    if prev_frame_ids is None:
                        gap_increase += 5
                        continue
                    train_frame_ids = base_frame_id + prev_frame_ids
                    test_frame_ids = self._sample_visible_ids(visible, min_id=train_frame_ids[0]+1,
                                                              max_id=train_frame_ids[0] + self.max_gap + gap_increase,
                                                              num_ids=self.num_test_frames)

                    # Increase gap until a frame is found
                    gap_increase += 5
                else:
                    # Sample in reverse order, i.e. train frames come after the test frames
                    base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=self.num_test_frames + 1,
                                                             max_id=len(visible) - self.num_train_frames - 1)
                    prev_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_train_frames - 1,
                                                              min_id=base_frame_id[0],
                                                              max_id=base_frame_id[0] + self.max_gap + gap_increase)
                    if prev_frame_ids is None:
                        gap_increase += 5
                        continue
                    train_frame_ids = base_frame_id + prev_frame_ids
                    test_frame_ids = self._sample_visible_ids(visible, min_id=0,
                                                              max_id=train_frame_ids[0] - 1,
                                                              num_ids=self.num_test_frames)

                    # Increase gap until a frame is found
                    gap_increase += 5
        else:
            # In case of image dataset, just repeat the image to generate synthetic video
            train_frame_ids = [1]*self.num_train_frames
            test_frame_ids = [1]*self.num_test_frames

        # Sort frames
        train_frame_ids = sorted(train_frame_ids, reverse=reverse_sequence)
        test_frame_ids = sorted(test_frame_ids, reverse=reverse_sequence)

        all_frame_ids = train_frame_ids + test_frame_ids

        # Load frames
        all_frames, all_anno, meta_obj = dataset.get_frames(seq_id, all_frame_ids, seq_info_dict)

        train_frames = all_frames[:len(train_frame_ids)]
        test_frames = all_frames[len(train_frame_ids):]

        train_anno = {}
        test_anno = {}
        for key, value in all_anno.items():
            train_anno[key] = value[:len(train_frame_ids)]
            test_anno[key] = value[len(train_frame_ids):]

        train_masks = train_anno['mask'] if 'mask' in train_anno else None
        test_masks = test_anno['mask'] if 'mask' in test_anno else None

        data = TensorDict({'train_images': train_frames,
                           'train_masks': train_masks,
                           'train_anno': train_anno['bbox'],
                           'test_images': test_frames,
                           'test_masks': test_masks,
                           'test_anno': test_anno['bbox'],
                           'dataset': dataset.get_name()})

        return self.processing(data)
Esempio n. 21
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -

        returns:
            TensorDict - output data block with following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
                'test_proposals' (optional) -
                'proposal_iou'  (optional)  -
                'test_label' (optional)     -
                'train_label' (optional)    -
        """

        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

            #print([all_images[0].shape])[(480, 640, 4)]

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            if self.crop_type == 'replicate':
                crops, boxes = prutils.jittered_center_crop(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.search_area_factor, self.output_sz)
            elif self.crop_type == 'nopad':
                crops, boxes = prutils.jittered_center_crop_nopad(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.search_area_factor, self.output_sz)
            else:
                raise ValueError('Unknown crop type {}'.format(self.crop_type))

            #print([crops[0].shape])[(288, 288, 4)]
            # for x in crops:
            #     # if x.shape[2]<4:
            #     #     print([len(crops), s, crops[0].shape, crops[1].shape, crops[2].shape])
            #     #     print(data[s+'_images'][0].shape)
            #     print(x.dtype)

            data[s + '_images'] = [
                self.transform[s](x.astype(np.float32)) for x in crops
            ]
            data[s + '_anno'] = boxes

        # Generate proposals
        if self.proposal_params:
            frame2_proposals, gt_iou = zip(
                *[self._generate_proposals(a) for a in data['test_anno']])

            data['test_proposals'] = list(frame2_proposals)
            data['proposal_iou'] = list(gt_iou)

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        # Generate label functions
        if self.label_function_params is not None:
            data['train_label'] = self._generate_label_function(
                data['train_anno'])
            data['test_label'] = self._generate_label_function(
                data['test_anno'])

        return data
Esempio n. 22
0
    def __getitem__(self, index):
        """
        args:
            index (int): Index (Ignored since we sample randomly)

        returns:
            TensorDict - dict containing all the data blocks
        """

        # Select a dataset
        dataset = random.choices(self.datasets, self.p_datasets)[0]
        is_video_dataset = dataset.is_video_sequence()

        # Sample a sequence with enough visible frames
        enough_visible_frames = False
        while not enough_visible_frames:
            # Sample a sequence
            seq_id = random.randint(0, dataset.get_num_sequences() - 1)

            # Sample frames
            seq_info_dict = dataset.get_sequence_info(seq_id)
            visible = seq_info_dict['visible']

            enough_visible_frames = visible.type(torch.int64).sum().item(
            ) > 2 * (self.num_test_frames +
                     self.num_train_frames) and len(visible) >= 20

            enough_visible_frames = enough_visible_frames or not is_video_dataset

        if is_video_dataset:
            train_frame_ids = None
            test_frame_ids = None
            gap_increase = 0

            # Sample test and train frames in a causal manner, i.e. test_frame_ids > train_frame_ids
            while test_frame_ids is None:
                base_frame_id = self._sample_visible_ids(
                    visible,
                    num_ids=1,
                    min_id=self.num_train_frames - 1,
                    max_id=len(visible) - self.num_test_frames)
                prev_frame_ids = self._sample_visible_ids(
                    visible,
                    num_ids=self.num_train_frames - 1,
                    min_id=base_frame_id[0] - self.max_gap - gap_increase,
                    max_id=base_frame_id[0])
                if prev_frame_ids is None:
                    gap_increase += 5
                    continue
                train_frame_ids = base_frame_id + prev_frame_ids
                test_frame_ids = self._sample_visible_ids(
                    visible,
                    min_id=train_frame_ids[0] + 1,
                    max_id=train_frame_ids[0] + self.max_gap + gap_increase,
                    num_ids=self.num_test_frames)
                # Increase gap until a frame is found
                gap_increase += 5
        else:
            # In case of image dataset, just repeat the image to generate synthetic video
            train_frame_ids = [1] * self.num_train_frames
            test_frame_ids = [1] * self.num_test_frames

        train_frames, train_anno, meta_obj_train = dataset.get_frames(
            seq_id, train_frame_ids, seq_info_dict)
        test_frames, test_anno, meta_obj_test = dataset.get_frames(
            seq_id, test_frame_ids, seq_info_dict)

        #print([train_frames[0].shape, test_frames[0].shape])#[(480, 640, 3)]

        data = TensorDict({
            'train_images': train_frames,
            'train_anno': train_anno['bbox'],
            'test_images': test_frames,
            'test_anno': test_anno['bbox'],
            'dataset': dataset.get_name()
        })

        return self.processing(data)
Esempio n. 23
0
    def __getitem__(self, index):
        """
        Args:
            index (int): Index (Ignored since we sample randomly)

        Returns:

        """

        dataset = random.choices(self.datasets, self.p_datasets)[0]
        is_video_dataset = dataset.is_video_sequence()

        enough_visible_frames = False
        # TODO clean this part
        while not enough_visible_frames:
            # Select a class
            if self.sample_mode == 'sequence':
                while not enough_visible_frames:
                    # Sample a sequence
                    seq_id = random.randint(0, dataset.get_num_sequences() - 1)
                    # Sample frames
                    seq_info_dict = dataset.get_sequence_info(seq_id)
                    visible = seq_info_dict['visible']

                    enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (self.num_seq_test_frames + self.num_seq_train_frames) and \
                        len(visible) >= 20

                    enough_visible_frames = enough_visible_frames or not is_video_dataset
                if self.use_class_info:
                    class_name = dataset.get_class_name(seq_id)
                    class_sequences = dataset.get_sequences_in_class(
                        class_name)
            elif self.sample_mode == 'class':
                class_name = random.choices(dataset.get_class_list())[0]
                class_sequences = dataset.get_sequences_in_class(class_name)

                # Sample test frames from the sequence
                try_ct = 0
                while not enough_visible_frames and try_ct < 5:
                    # Sample a sequence
                    seq_id = random.choices(class_sequences)[0]
                    # Sample frames
                    seq_info_dict = dataset.get_sequence_info(seq_id)
                    visible = seq_info_dict['visible']

                    # TODO probably filter sequences where we don't have enough visible frames in a pre-processing step
                    #  so that we are not stuck in a while loop
                    enough_visible_frames = visible.type(torch.int64).sum().item() > self.num_seq_test_frames + \
                                            self.num_seq_train_frames
                    enough_visible_frames = enough_visible_frames or not is_video_dataset
                    try_ct += 1
            else:
                raise ValueError

        if is_video_dataset:
            train_frame_ids = None
            test_frame_ids = None
            gap_increase = 0
            if self.frame_sample_mode == 'default':
                while test_frame_ids is None:
                    train_frame_ids = self._sample_visible_ids(
                        visible, num_ids=self.num_seq_train_frames)
                    test_frame_ids = self._sample_visible_ids(
                        visible,
                        min_id=train_frame_ids[0] - self.max_gap -
                        gap_increase,
                        max_id=train_frame_ids[0] + self.max_gap +
                        gap_increase,
                        num_ids=self.num_seq_test_frames)
                    gap_increase += 5  # Increase gap until a frame is found
            elif self.frame_sample_mode == 'causal':
                while test_frame_ids is None:
                    base_frame_id = self._sample_visible_ids(
                        visible,
                        num_ids=1,
                        min_id=self.num_seq_train_frames - 1,
                        max_id=len(visible) - self.num_seq_test_frames)
                    prev_frame_ids = self._sample_visible_ids(
                        visible,
                        num_ids=self.num_seq_train_frames - 1,
                        min_id=base_frame_id[0] - self.max_gap - gap_increase,
                        max_id=base_frame_id[0])
                    if prev_frame_ids is None:
                        gap_increase += 5
                        continue
                    train_frame_ids = base_frame_id + prev_frame_ids
                    test_frame_ids = self._sample_visible_ids(
                        visible,
                        min_id=train_frame_ids[0] + 1,
                        max_id=train_frame_ids[0] + self.max_gap +
                        gap_increase,
                        num_ids=self.num_seq_test_frames)
                    gap_increase += 5  # Increase gap until a frame is found
            else:
                raise ValueError('Unknown frame_sample_mode.')
        else:
            train_frame_ids = [1] * self.num_seq_train_frames
            test_frame_ids = [1] * self.num_seq_test_frames

        seq_train_frames, seq_train_anno, meta_obj_train = dataset.get_frames(
            seq_id, train_frame_ids, seq_info_dict)

        seq_test_frames, seq_test_anno, meta_obj_test = dataset.get_frames(
            seq_id, test_frame_ids, seq_info_dict)

        assert meta_obj_train['object_class'] == meta_obj_test[
            'object_class'], "Train and test classes don't match!!"

        # Sample from sequences with the same class
        # TODO fix sequences which do not have a single visible frame
        if self.use_class_info and len(class_sequences) > 5:
            cls_dist_train_frames, cls_dist_train_anno, cls_dist_test_frames, cls_dist_test_anno = \
                self._sample_class_distractors(dataset, seq_id, class_sequences, self.num_class_distractor_frames,
                self.num_class_distractor_train_frames)
            num_rnd_distractors = self.num_random_distractor_frames
            num_rnd_train_distractors = self.num_random_distractor_train_frames
        else:
            cls_dist_train_frames = []
            cls_dist_train_anno = []
            cls_dist_test_frames = []
            cls_dist_test_anno = []
            num_rnd_distractors = self.num_random_distractor_frames + self.num_class_distractor_frames
            num_rnd_train_distractors = self.num_random_distractor_train_frames + self.num_class_distractor_train_frames

        # Sample sequences from any class
        rnd_dist_train_frames, rnd_dist_train_anno, rnd_dist_test_frames, rnd_dist_test_anno = \
            self._sample_random_distractors(num_rnd_distractors, num_rnd_train_distractors)

        train_frames = seq_train_frames + cls_dist_train_frames + rnd_dist_train_frames
        test_frames = seq_test_frames + cls_dist_test_frames + rnd_dist_test_frames

        train_anno = self._dict_cat(seq_train_anno, cls_dist_train_anno,
                                    rnd_dist_train_anno)
        test_anno = self._dict_cat(seq_test_anno, cls_dist_test_anno,
                                   rnd_dist_test_anno)

        is_distractor_train_frame = [False]*self.num_seq_train_frames + \
                                    [True]*(self.num_class_distractor_train_frames + self.num_random_distractor_train_frames)
        is_distractor_test_frame = [False] * self.num_seq_test_frames + [
            True
        ] * (self.num_class_distractor_frames +
             self.num_random_distractor_frames)
        if self.parent_class_list:
            if meta_obj_train['object_class']:
                parent_class = self.map_parent[meta_obj_train['object_class']]
                parent_class_id = self.parent_class_list.index(parent_class)
            else:
                parent_class = None
                parent_class_id = -1
        else:
            parent_class = None
            parent_class_id = -1

        # TODO send in class name for each frame
        data = TensorDict({
            'train_images': train_frames,
            'train_anno': train_anno['bbox'],
            'test_images': test_frames,
            'test_anno': test_anno['bbox'],
            'object_class_id': parent_class_id,
            'motion_class': meta_obj_train['motion_class'],
            'major_class': meta_obj_train['major_class'],
            'root_class': meta_obj_train['root_class'],
            'motion_adverb': meta_obj_train['motion_adverb'],
            'object_class_name': meta_obj_train['object_class'],
            # 'object_class_name': parent_class,
            'dataset': dataset.get_name(),
            'is_distractor_train_frame': is_distractor_train_frame,
            'is_distractor_test_frame': is_distractor_test_frame
        })

        return self.processing(data)
Esempio n. 24
0
    def __getitem__(self, index):
        """
        args:
            index (int): Index (Ignored since we sample randomly)

        returns:
            TensorDict - dict containing all the data blocks
        """

        # Select a dataset
        dataset = random.choices(self.datasets, self.p_datasets)[0]
        is_video_dataset = dataset.is_video_sequence()

        # Sample a sequence with enough visible frames
        enough_visible_frames = False
        while not enough_visible_frames:
            # Sample a sequence
            seq_id = random.randint(0, dataset.get_num_sequences() - 1)

            # Sample frames
            seq_info_dict = dataset.get_sequence_info(seq_id)
            visible = seq_info_dict['visible']

            enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (
                    self.num_test_frames + self.num_train_frames) and len(visible) >= 20

            enough_visible_frames = enough_visible_frames or not is_video_dataset

        if is_video_dataset:
            train_frame_ids = None
            test_frame_ids = None
            gap_increase = 0

            if self.frame_sample_mode == 'interval':
                # Sample frame numbers within interval defined by the first frame
                while test_frame_ids is None:
                    base_frame_id = self._sample_visible_ids(visible, num_ids=1)
                    extra_train_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_train_frames - 1,
                                                                     min_id=base_frame_id[
                                                                                0] - self.max_gap - gap_increase,
                                                                     max_id=base_frame_id[
                                                                                0] + self.max_gap + gap_increase)
                    if extra_train_frame_ids is None:
                        gap_increase += 5
                        continue
                    train_frame_ids = base_frame_id + extra_train_frame_ids
                    test_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_test_frames,
                                                              min_id=train_frame_ids[0] - self.max_gap - gap_increase,
                                                              max_id=train_frame_ids[0] + self.max_gap + gap_increase)
                    gap_increase += 5  # Increase gap until a frame is found

            elif self.frame_sample_mode == 'causal':
                # Sample test and train frames in a causal manner, i.e. test_frame_ids > train_frame_ids
                while test_frame_ids is None:
                    base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=self.num_train_frames - 1,
                                                             max_id=len(visible) - self.num_test_frames)
                    prev_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_train_frames - 1,
                                                              min_id=base_frame_id[0] - self.max_gap - gap_increase,
                                                              max_id=base_frame_id[0])
                    if prev_frame_ids is None:
                        gap_increase += 5
                        continue
                    train_frame_ids = base_frame_id + prev_frame_ids
                    test_frame_ids = self._sample_visible_ids(visible, min_id=train_frame_ids[0] + 1,
                                                              max_id=train_frame_ids[0] + self.max_gap + gap_increase,
                                                              num_ids=self.num_test_frames)
                    # Increase gap until a frame is found
                    gap_increase += 5
                    
            elif self.frame_sample_mode == '3d':
                # Sample test and train frames in a causal manner, i.e. test_frame_ids > train_frame_ids
                while test_frame_ids is None:
                    base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=4,
                                                             max_id=len(visible) - self.num_test_frames)
                    prev_frame_ids = self._sample_visible_ids(visible, num_ids=self.num_train_frames - 1,
                                                              min_id=max(4,base_frame_id[0] - self.max_gap - gap_increase),
                                                              max_id=base_frame_id[0])
                    if prev_frame_ids is None:
                        gap_increase += 5
                        continue
                    train_frame_ids = base_frame_id + prev_frame_ids
                    test_frame_ids = self._sample_visible_ids(visible, min_id=max(4, train_frame_ids[0] + 1),
                                                              max_id=train_frame_ids[0] + self.max_gap + gap_increase,
                                                              num_ids=self.num_test_frames)
                    # Increase gap until a frame is found
                    gap_increase += 5
                    if test_frame_ids is None:
                        continue
                    
                    K = []
                    for l,value in enumerate(train_frame_ids): 
                         Temp_frame_ids = [value - 3, value - 2, value - 1, value]
                         K.extend(Temp_frame_ids)
                    train_frame_ids = K
                    K = []
                    for l,value in enumerate(test_frame_ids): 
                         Temp_frame_ids = [value - 3, value - 2, value - 1, value]
                         K.extend(Temp_frame_ids)
                    test_frame_ids = K
        else:
            # In case of image dataset, just repeat the image to generate synthetic video
            train_frame_ids = [1] * self.num_train_frames
            test_frame_ids = [1] * self.num_test_frames

        train_frames, train_anno, meta_obj_train = dataset.get_frames(seq_id, train_frame_ids, seq_info_dict)
        test_frames, test_anno, meta_obj_test = dataset.get_frames(seq_id, test_frame_ids, seq_info_dict)

        data = TensorDict({'train_images': train_frames,
                            'train_anno': train_anno['bbox'][self.clip_len-1::self.clip_len],
                            'test_images': test_frames,
                            'test_anno': test_anno['bbox'][self.clip_len-1::self.clip_len],
                            'dataset': dataset.get_name(),
                            'test_class': meta_obj_test.get('object_class_name')})
        
        # data = TensorDict({'train_images': train_frames,
        #                    'train_anno': train_anno['bbox'],
        #                    'test_images': test_frames,
        #                    'test_anno': test_anno['bbox'],
        #                    'dataset': dataset.get_name(),
        #                    'test_class': meta_obj_test.get('object_class_name')})
        return self.processing(data)
Esempio n. 25
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'search_images', template_images', 'search_anno', 'template_anno'
        returns:
            TensorDict - output data block with following fields:
                'search_images', 'template_images', 'search_anno', 'template_anno'
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            data['search_images'], data['search_anno'] = self.transform[
                'joint'](image=data['search_images'], bbox=data['search_anno'])
            data['template_images'], data['template_anno'] = self.transform[
                'joint'](image=data['template_images'],
                         bbox=data['template_anno'],
                         new_roll=False)

        # self.label_function_params = {"kernel_sz": 4, "feature_sz": 256, "output_sz": self.search_sz, "end_pad_if_even": False, "sigma_factor": 0.05}
        for s in ['search', 'template']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num search/template frames must be 1"

            # Add a uniform noise to the center pos
            if self.rand:
                rand_size_a = torch.randn(2)
                rand_center_a = torch.rand(2)
                rand_size_b = torch.randn(2)
                rand_center_b = torch.rand(2)

                # Linearly interpolate from 0 to rand_size/center
                size_step = torch.tensor(
                    np.linspace(rand_size_a, rand_size_b,
                                len(data[s + '_anno'])))
                center_step = torch.tensor(
                    np.linspace(rand_center_a, rand_center_b,
                                len(data[s + '_anno'])))
                jittered_anno = [
                    self._get_jittered_box(a, s, rand_size=rs, rand_center=rc)
                    for a, rs, rc in zip(data[s +
                                              '_anno'], size_step, center_step)
                ]
            else:
                jittered_anno = [
                    self._get_jittered_box(a, s) for a in data[s + '_anno']
                ]

            # Crop image region centered at jittered_anno box

            if s == 'search':
                if torch.any(data['search_visible'] == 0):
                    # For empty annos, use the most recent crop box coords.
                    filler_anno = jittered_anno[0]
                    # assert filler_anno.sum(), "First frame was empty."  # Only last frame matters
                    filler_jitter = data[s + '_anno'][0]
                    for mi in range(len(data['search_visible'])):
                        if data['search_visible'][mi] == 0:
                            jittered_anno[mi] = filler_anno
                            data[s + '_anno'][mi] = filler_jitter
                        else:
                            filler_anno = jittered_anno[mi]
                            filler_jitter = data[s + '_anno'][mi]
                crops, boxes, _ = prutils.jittered_center_crop(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.search_area_factor, self.search_sz)
                # except:
                #     print("Jitter")
                #     print(jittered_anno)
                #     print("Regular")
                #     print(data[s + '_anno'])
                #     print("data['search_visible']")
                #     print(data['search_visible'])
            elif s == 'template':
                crops, boxes, _ = prutils.jittered_center_crop(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.template_area_factor, self.temp_sz)
            else:
                raise NotImplementedError
            # Boxes is columns,rows,column-offset,row-offset

            # Apply transforms
            if s == "search" and self.occlusion:
                maybe_occlusion = np.random.rand() > 0.5
                crops = list(crops)
                min_size = 1  # 10
                min_frames = 7  # 10  # When should the occlusion start
                if maybe_occlusion:
                    # import pdb;pdb.set_trace()
                    # rand_frames_len = np.random.randint(low=0, high=len(crops) - min_frames)  # len(data[s + '_images']) - min_frames)
                    # rand_frames_start = np.random.randint(low=min_frames, high=len(crops) - rand_frames_len)  # data[s + '_images']) - rand_frames_len)
                    crop_len = len(crops)
                    rand_frames_start = np.random.randint(low=min_frames,
                                                          high=crop_len)
                    rand_frames_len = crop_len - rand_frames_start
                    top_side = rand_frames_start % 2

                    # Find the box in the first from, and use this to construct occluder
                    start_box = boxes[rand_frames_start].int()
                    crop_shape = crops[0].shape  # data[s + '_images'][0].shape
                    apply_occlusion = False
                    pass_check = start_box[2] // 2 > min_size and start_box[
                        3] // 2 > min_size and crops[0].shape[
                            0] > min_size and crops[0].shape[1] > min_size
                    if top_side and pass_check:
                        # These are row inds
                        rand_start = np.random.randint(low=0,
                                                       high=start_box[3] -
                                                       min_size - 1)
                        if rand_start > start_box[3] // 2:
                            margin = np.copy(rand_start)
                            rand_start = np.random.randint(low=0,
                                                           high=margin -
                                                           min_size)
                            rand_extent = margin - rand_start
                        else:
                            remainder = np.maximum(start_box[3] - rand_start,
                                                   min_size)
                            mc, xc = np.minimum(rand_start,
                                                remainder), np.maximum(
                                                    rand_start, remainder)
                            if mc == xc:
                                xc += 1
                                rand_extent = mc + 1
                            else:
                                rand_extent = np.random.randint(low=mc,
                                                                high=xc)

                        # rand_mask = (np.random.rand(rand_extent, crop_shape[1], crop_shape[2]) * 128) + 128
                        rand_start += start_box[1]
                        if rand_start + rand_extent < crops[0].shape[
                                0] and rand_start > 0:
                            apply_occlusion = True
                    elif not top_side and pass_check:
                        # These are width inds
                        rand_start = np.random.randint(low=0,
                                                       high=start_box[2] -
                                                       min_size - 1)
                        if rand_start > start_box[2] // 2:
                            margin = np.copy(rand_start)
                            rand_start = np.random.randint(low=0,
                                                           high=margin -
                                                           min_size)
                            rand_extent = margin - rand_start
                        else:
                            # remainder = np.maximum((start_box[2] - margin - rand_start), min_size + 1)
                            remainder = np.maximum(start_box[3] - rand_start,
                                                   min_size)
                            mc, xc = np.minimum(rand_start,
                                                remainder), np.maximum(
                                                    rand_start, remainder)
                            if mc == xc:
                                xc += 1
                                rand_extent = mc + 1
                            else:
                                rand_extent = np.random.randint(low=mc,
                                                                high=xc)

                        # rand_mask = (np.random.rand(crop_shape[0], rand_extent, crop_shape[2]) * 128) + 128
                        rand_start += start_box[0]
                        if rand_start + rand_extent < crops[0].shape[
                                1] and rand_start > 0:
                            apply_occlusion = True
                    if apply_occlusion:
                        # print("applying occlusion")
                        # for bidx in range(rand_frames_start, rand_frames_start + rand_frames_len):
                        for bidx in range(rand_frames_start, crop_len):
                            # Apply an occluder to a random location in a random chunk of the video
                            # data[s + '_images'][bidx] = data[s + '_images'][bidx] + mask
                            if top_side:
                                shuffle_box = crops[bidx][
                                    rand_start:rand_start + rand_extent]
                                shuffle_shape = shuffle_box.shape
                                shuffle_box = shuffle_box.reshape(
                                    -1, shuffle_shape[-1])  # channels last
                                shuffle_box = shuffle_box[
                                    np.random.permutation(shuffle_shape[0] *
                                                          shuffle_shape[1])]
                                crops[bidx][rand_start:rand_start +
                                            rand_extent] = shuffle_box.reshape(
                                                shuffle_shape)  #  rand_mask
                            else:
                                shuffle_box = crops[
                                    bidx][:,
                                          rand_start:rand_start + rand_extent]
                                shuffle_shape = shuffle_box.shape
                                shuffle_box = shuffle_box.reshape(
                                    -1, shuffle_shape[-1])  # channels last
                                shuffle_box = shuffle_box[
                                    np.random.permutation(shuffle_shape[0] *
                                                          shuffle_shape[1])]
                                crops[bidx][:, rand_start:rand_start +
                                            rand_extent] = shuffle_box.reshape(
                                                shuffle_shape)  #  rand_mask
                            # from matplotlib import pyplot as plt
                            # plt.imshow(crops[bidx])
                            # plt.title("frame: {} topside: {} start: {} extent: {}".format(bidx, top_side, rand_start, rand_extent))
                            # plt.show()
            data[s + '_images'], data[s + '_anno'] = self.transform[s](
                image=crops, bbox=boxes, joint=self.joint)
            if s == "search":
                im_shape = [len(data[s + '_images']), 1
                            ] + [x for x in data[s + '_images'][0].shape[1:]]
                bumps = torch.zeros(im_shape,
                                    device=data[s +
                                                '_images'][0].device).float()
                for bidx in range(bumps.shape[0]):
                    box = boxes[bidx].int()
                    bumps[bidx, :, box[1]:box[1] + box[3],
                          box[0]:box[0] + box[2]] = 1
                data[
                    "bump"] = bumps  # self._generate_label_function(torch.cat(boxes, 0) / self.search_sz)
        self.prev_annos = jittered_anno

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
        data['template_images'] = data['template_images'].squeeze()
        data['search_images'] = data['search_images'].squeeze()
        data['template_anno'] = data['template_anno'].squeeze()
        data['search_anno'] = data['search_anno'].squeeze()
        return data
Esempio n. 26
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
                'train_masks'   -
                'test_masks'    -

        returns:
            TensorDict - output data block with following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
                'train_masks'   -
                'test_masks'    -
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

        # extract patches from images
        for s in ['test', 'train']:  #['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            # Crop image region centered at jittered_anno box
            crops_img, boxes = prutils.jittered_center_crop(
                data[s + '_images'], jittered_anno, data[s + '_anno'],
                self.search_area_factor, self.output_sz)

            # Crop mask region centered at jittered_anno box
            crops_mask, _ = prutils.jittered_center_crop(
                data[s + '_masks'],
                jittered_anno,
                data[s + '_anno'],
                self.search_area_factor,
                self.output_sz,
                pad_val=float(0))

            if s == 'train' and self.use_distance:
                # use target center only to create distance map
                cx_ = (boxes[0][0] + boxes[0][2] / 2).item()
                cy_ = (boxes[0][1] + boxes[0][3] / 2).item()
                x_ = np.linspace(1, crops_img[0].shape[1],
                                 crops_img[0].shape[1]) - 1 - cx_
                y_ = np.linspace(1, crops_img[0].shape[0],
                                 crops_img[0].shape[0]) - 1 - cy_
                X, Y = np.meshgrid(x_, y_)
                D = np.sqrt(np.square(X) + np.square(Y)).astype(np.float32)

                data['test_dist'] = [
                    torch.from_numpy(np.expand_dims(D, axis=0))
                ]

            # Apply transforms
            data[s + '_images'] = [self.transform[s](x) for x in crops_img]
            data[s + '_anno'] = boxes
            data[s + '_masks'] = [
                torch.from_numpy(np.expand_dims(x, axis=0)) for x in crops_mask
            ]

            if s == 'train':
                data[s + '_init_masks'] = [
                    torch.from_numpy(
                        np.expand_dims(self._make_aabb_mask(x_.shape, bb_),
                                       axis=0))
                    for x_, bb_ in zip(crops_mask, boxes)
                ]

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Esempio n. 27
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
                'train_masks'   -
                'test_masks'    -

        returns:
            TensorDict - output data block with following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
                'train_masks'   -
                'test_masks'    -
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

        # extract patches from images
        for s in ['test', 'train']:#['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]

            # Crop image region centered at jittered_anno box
            crops_img, boxes = prutils.jittered_center_crop(data[s + '_images'], jittered_anno, data[s + '_anno'],
                                                self.search_area_factor, self.output_sz)

            # Crop mask region centered at jittered_anno box
            crops_mask, _ = prutils.jittered_center_crop(data[s + '_masks'], jittered_anno, data[s + '_anno'],
                                                            self.search_area_factor, self.output_sz, pad_val=float(0))

            if s == 'test' and self.use_distance:
                # use target center only to create distance map
                cx_ = (boxes[0][0] + boxes[0][2] / 2).item() + ((0.25 * boxes[0][2].item()) * (random.random() - 0.5))
                cy_ = (boxes[0][1] + boxes[0][3] / 2).item() + ((0.25 * boxes[0][3].item()) * (random.random() - 0.5))
                x_ = np.linspace(1, crops_img[0].shape[1], crops_img[0].shape[1]) - 1 - cx_
                y_ = np.linspace(1, crops_img[0].shape[0], crops_img[0].shape[0]) - 1 - cy_
                X, Y = np.meshgrid(x_, y_)
                D = np.sqrt(np.square(X) + np.square(Y)).astype(np.float32)

                data['test_dist'] = [torch.from_numpy(np.expand_dims(D, axis=0))]

            # Apply transforms
            data[s + '_images'] = [self.transform[s](x) for x in crops_img]
            data[s + '_anno'] = boxes

            if s == 'train':
                data[s + '_masks'] = [torch.from_numpy(np.expand_dims(x, axis=0)) for x in crops_mask]
########## Coding by Yang 2020.10 ######## Generate contours of masks #####################################
            if s == 'test':
                for x in crops_mask:
                    contours, _ = cv2.findContours(x.astype('uint8'), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
                    mask_contour = cv2.drawContours(np.zeros((x.shape[1],x.shape[0])).astype('float32'), contours, -1, 1, thickness=1)
                    mask_ = cv2.drawContours(x, contours, -1, 1, thickness=1)
                    data['test_masks'] = [torch.from_numpy(np.expand_dims(mask_, axis=0))]
                    data['test_contour'] = [torch.from_numpy(np.expand_dims(mask_contour, axis=0))]
###########################################################################################################
            if s == 'train' and random.random() < 0.001:
                # on random use binary mask generated from axis-aligned bbox
                data['test_images'] = copy.deepcopy(data['train_images'])
                data['test_masks'] = copy.deepcopy(data['train_masks'])
                data['test_anno'] = copy.deepcopy(data['train_anno'])
                data[s + '_masks'] = [torch.from_numpy(np.expand_dims(self._make_aabb_mask(x_.shape, bb_), axis=0)) for x_, bb_ in zip(crops_mask, boxes)]

                if self.use_distance:
                    # there is no need to randomly perturb center since we are working with ground-truth here
                    cx_ = (boxes[0][0] + boxes[0][2] / 2).item()
                    cy_ = (boxes[0][1] + boxes[0][3] / 2).item()
                    x_ = np.linspace(1, crops_img[0].shape[1], crops_img[0].shape[1]) - 1 - cx_
                    y_ = np.linspace(1, crops_img[0].shape[0], crops_img[0].shape[0]) - 1 - cy_
                    X, Y = np.meshgrid(x_, y_)
                    D = np.sqrt(np.square(X) + np.square(Y)).astype(np.float32)
                    data['test_dist'] = [torch.from_numpy(np.expand_dims(D, axis=0))]

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data