Пример #1
0
    def __call__(self, data: TensorDict):
        # Apply joint transforms
        if self.transform['joint'] is not None:
            data['train_images'], data['train_anno'] = self.transform['joint'](image=data['train_images'], bbox=data['train_anno'])
            data['test_images'], data['test_anno'] = self.transform['joint'](image=data['test_images'], bbox=data['test_anno'], new_roll=False)

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]

            # Crop image region centered at jittered_anno box
            crops, boxes, _ = prutils.jittered_center_crop(data[s + '_images'], jittered_anno, data[s + '_anno'],
                                                           self.search_area_factor, self.output_sz)

            # Apply transforms
            data[s + '_images'], data[s + '_anno'] = self.transform[s](image=crops, bbox=boxes, joint=False)

        # Generate proposals
        proposals, proposal_density, gt_density, proposal_iou = zip(
            *[self._generate_proposals(a) for a in data['test_anno']])

        data['test_proposals'] = proposals
        data['proposal_density'] = proposal_density
        data['gt_density'] = gt_density
        data['proposal_iou'] = proposal_iou
        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Пример #2
0
    def __call__(self, data: TensorDict):
        # Apply joint transformations. i.e. All train/test frames in a sequence are applied the transformation with the
        # same parameters
        if self.transform['joint'] is not None:
            data['train_images'], data['train_anno'], data['train_masks'] = self.transform['joint'](
                image=data['train_images'], bbox=data['train_anno'], mask=data['train_masks'])
            data['test_images'], data['test_anno'], data['test_masks'] = self.transform['joint'](
                image=data['test_images'], bbox=data['test_anno'], mask=data['test_masks'], new_roll=self.new_roll)

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]
            orig_anno = data[s + '_anno']

            # Extract a crop containing the target
            crops, boxes, mask_crops = prutils.target_image_crop(data[s + '_images'], jittered_anno,
                                                                 data[s + '_anno'], self.search_area_factor,
                                                                 self.output_sz, mode=self.crop_type,
                                                                 max_scale_change=self.max_scale_change,
                                                                 masks=data[s + '_masks'])

            # Apply independent transformations to each image
            data[s + '_images'], data[s + '_anno'], data[s + '_masks'] = self.transform[s](image=crops, bbox=boxes, mask=mask_crops, joint=False)

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Пример #3
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -

        returns:
            TensorDict - output data block with following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
                'test_proposals'-
                'proposal_iou'  -
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            # Crop image region centered at jittered_anno box
            crops, boxes, _ = prutils.jittered_center_crop(
                data[s + '_images'], jittered_anno, data[s + '_anno'],
                self.search_area_factor, self.output_sz)

            # Apply transforms
            data[s + '_images'] = [self.transform[s](x) for x in crops]
            data[s + '_anno'] = boxes

        # Generate proposals
        frame2_proposals, regs = zip(
            *[self._generate_proposals(a) for a in data['test_anno']])

        data['test_proposals'] = list(frame2_proposals)
        data['regs'] = list(regs)

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Пример #4
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images', test_images', 'train_anno', 'test_anno'
        returns:
            TensorDict - output data block with following fields:
                'train_images', 'test_images', 'train_anno', 'test_anno', 'test_proposals', 'proposal_density', 'gt_density',
                'test_label' (optional), 'train_label' (optional), 'test_label_density' (optional), 'train_label_density' (optional)
        """

        if self.transform['joint'] is not None:
            data['train_images'], data['train_anno'] = self.transform['joint'](image=data['train_images'], bbox=data['train_anno'])
            data['test_images'], data['test_anno'] = self.transform['joint'](image=data['test_images'], bbox=data['test_anno'], new_roll=False)

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]

            crops, boxes = prutils.target_image_crop(data[s + '_images'], jittered_anno, data[s + '_anno'],
                                                     self.search_area_factor, self.output_sz, mode=self.crop_type,
                                                     max_scale_change=self.max_scale_change)

            data[s + '_images'], data[s + '_anno'] = self.transform[s](image=crops, bbox=boxes, joint=False)

        # Generate proposals
        proposals, proposal_density, gt_density = zip(*[self._generate_proposals(a) for a in data['test_anno']])

        data['test_proposals'] = proposals
        data['proposal_density'] = proposal_density
        data['gt_density'] = gt_density

        for s in ['train', 'test']:
            is_distractor = data.get('is_distractor_{}_frame'.format(s), None)
            if is_distractor is not None:
                for is_dist, box in zip(is_distractor, data[s+'_anno']):
                    if is_dist:
                        box[0] = 99999999.9
                        box[1] = 99999999.9

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        # Generate label functions
        if self.label_function_params is not None:
            data['train_label'] = self._generate_label_function(data['train_anno'])
            data['test_label'] = self._generate_label_function(data['test_anno'])
        if self.label_density_params is not None:
            data['train_label_density'] = self._generate_label_density(data['train_anno'])
            data['test_label_density'] = self._generate_label_density(data['test_anno'])

        return data
Пример #5
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -

        returns:
            TensorDict - output data block with following fields:
                'train_images'  -processing
                'test_images'   -processing
                'train_anno'    -processing
                'test_anno'     -processing
                'test_proposals'-
                'proposal_iou'  -
        """

        for s in ['train']:
            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            # Crop image region centered at jittered_anno box
            crops, boxes = prutils.jittered_center_crop(
                data[s + '_images'], jittered_anno, data[s + '_anno'],
                self.search_area_factor, self.output_sz)
            crops1, boxes1 = prutils.jittered_center_crop(
                data[s + '_mask'], jittered_anno, data[s + '_anno'],
                self.search_area_factor, self.output_sz)
            # Apply transforms
            data[s + '_images'] = [self.transform[s](x) for x in crops]
            print("data[s + '_images'] len", len(crops))
            print("data['train_mask'] lem", len(crops1))
            #print("crops",crops1[1].shape)
            data['train_mask'] = [
                torch.from_numpy(x).view(x.shape[0], x.shape[1])
                for x in crops1
            ]
            data[s + '_anno'] = boxes
        print("data[s + '_images'] len", data['train_images'])
        print("data['train_mask'] lem", data['train_mask'])
        # Generate proposals
        #gt_mask = zip(*[self._generate_proposals(a) for a in data['test_anno']])

        # data['train_mask'] = torch.from_numpy(data['train_mask']).view(1,data['train_mask'].shape[0],data['train_mask'].shape[1])

        #data[s + '_mask'] = [torch.from_numpy(x) for x in crops1]
        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Пример #6
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'search_images', template_images', 'search_anno', 'template_anno'
        returns:
            TensorDict - output data block with following fields:
                'search_images', 'template_images', 'search_anno', 'template_anno'
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            data['search_images'], data['search_anno'] = self.transform[
                'joint'](image=data['search_images'], bbox=data['search_anno'])
            data['template_images'], data['template_anno'] = self.transform[
                'joint'](image=data['template_images'],
                         bbox=data['template_anno'],
                         new_roll=False)

        for s in ['search', 'template']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num search/template frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            # Crop image region centered at jittered_anno box
            if s == 'search':
                crops, boxes, _ = prutils.jittered_center_crop(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.search_area_factor, self.search_sz)
            elif s == 'template':
                crops, boxes, _ = prutils.jittered_center_crop(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.template_area_factor, self.temp_sz)
            else:
                raise NotImplementedError

            # Apply transforms
            data[s + '_images'], data[s + '_anno'] = self.transform[s](
                image=crops, bbox=boxes, joint=False)

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
        data['template_images'] = data['template_images'].squeeze()
        data['search_images'] = data['search_images'].squeeze()
        data['template_anno'] = data['template_anno'].squeeze()
        data['search_anno'] = data['search_anno'].squeeze()
        return data
Пример #7
0
    def __call__(self,data: TensorDict, *args, **kwargs):

        for s in ['train', 'test']:
            assert len(data[s + '_images'])==len(data[s+'_anno'])==1

            # data[s + '_images'] = [self.transform[s](x) for x in crops]

            ori_shape = data[s + '_images'][0].shape
            img, img_shape, pad_shape, scale_factor = self.transform[s](
                data[s + '_images'][0],
                scale=(1333,800),
                keep_ratio=True)
            img = to_tensor(img).unsqueeze(0)
            img_meta = [
                dict(
                    ori_shape=ori_shape,
                    img_shape=img_shape,
                    pad_shape=pad_shape,
                    scale_factor=scale_factor,
                    flip=False)
            ]

            anno=data[s+'_anno'][0]
            bbox=BoxList(anno, [ori_shape[1], ori_shape[0]], mode="xywh")

            image_sz = img_shape[1], img_shape[0]
            anno = bbox.resize(image_sz).convert("xyxy").bbox
            # confidence = torch.ones(len(anno)).to(device)
            # proposals = [torch.cat([anno.bbox, confidence[:, None]], dim=1)]

            data[s+'_images']=[img]
            data[s+'_anno']=[anno]
            data[s+'_img_meta']=[img_meta]
        data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
        return data
Пример #8
0
    def forward(self, *args, **kwargs):
        data = TensorDict(kwargs)
        data = data.apply(lambda x: x[0] if isinstance(x, torch.Tensor) else x)

        img_meta = data['train_img_meta'][0].copy()
        for key in img_meta:
            values = img_meta[key]
            for i in range(len(values)):
                if len(values) > 1:
                    img_meta[key][i] = img_meta[key][i].item()
                else:
                    img_meta[key] = img_meta[key][i].item()

        with torch.no_grad():
            self.detector.extract_feat_pre(data['train_images'])
            result = self.detector.simple_test_post(rescale=True,
                                                    img_meta=[img_meta])
            prediction = self.toBoxlist(result, 0.3)
            if prediction is None:
                box_gt = data['train_anno'].to("cpu")
                roi_box = torch.cat([box_gt])
            else:
                num_pred = len(prediction)
                top_k = min(self.top_k, num_pred)
                scores = prediction.get_field("scores")
                _, ind = torch.topk(scores, top_k)
                img_shape = img_meta['img_shape']
                boxes = prediction[ind].resize([img_shape[1],
                                                img_shape[0]]).bbox
                box_gt = data['train_anno'].to("cpu")
                roi_box = torch.cat([box_gt, boxes])

            iou = self.compute_iou(roi_box, box_gt)
            roi_ind = torch.zeros([len(roi_box), 1])

            roi1 = torch.cat([roi_ind, roi_box], dim=1).to("cuda")
            labels1 = (iou > 0.7).squeeze().float().to("cuda")
            roi_fea1 = self.detector.extract_roi_featrue(roi1)

        img_meta = data['test_img_meta'][0].copy()
        for key in img_meta:
            values = img_meta[key]
            for i in range(len(values)):
                if len(values) > 1:
                    img_meta[key][i] = img_meta[key][i].item()
                else:
                    img_meta[key] = img_meta[key][i].item()

        with torch.no_grad():
            self.detector.extract_feat_pre(data['test_images'])
            result = self.detector.simple_test_post(rescale=True,
                                                    img_meta=[img_meta])
            prediction = self.toBoxlist(result, 0.3)

            if prediction is None:
                box_gt = data['test_anno'].to("cpu")
                roi_box = torch.cat([box_gt])
            else:
                num_pred = len(prediction)
                top_k = min(self.top_k, num_pred)
                scores = prediction.get_field("scores")
                _, ind = torch.topk(scores, top_k)
                img_shape = img_meta['img_shape']
                boxes = prediction[ind].resize([img_shape[1],
                                                img_shape[0]]).bbox
                box_gt = data['test_anno'].to("cpu")
                roi_box = torch.cat([box_gt, boxes])

            iou = self.compute_iou(roi_box, box_gt)
            roi_ind = torch.zeros([len(roi_box), 1])

            roi2 = torch.cat([roi_ind, roi_box], dim=1).to("cuda")
            labels2 = (iou > 0.7).squeeze().float().to("cuda")
            roi_fea2 = self.detector.extract_roi_featrue(roi2)

        predict_scores1 = self.selector(roi_fea2[0][None, ], roi_fea1)
        predict_scores2 = self.selector(roi_fea1[0][None, ], roi_fea2)

        return predict_scores1, labels1, predict_scores2, labels2
Пример #9
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
        """

        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            if self.crop_type == 'replicate':
                crops, boxes = prutils.jittered_center_crop(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.search_area_factor, self.output_sz)
            elif self.crop_type == 'nopad':
                crops, boxes = prutils.jittered_center_crop_nopad(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.search_area_factor, self.output_sz)
            else:
                raise ValueError('Unknown crop type {}'.format(self.crop_type))

            data[s + '_images'] = [self.transform[s](x) for x in crops]
            boxes = torch.stack(boxes)
            boxes_init = boxes
            boxes_init[:, 2:4] = boxes[:, 0:2] + boxes[:, 2:4]
            boxes = boxes_init.clamp(0.0, 287.0)

            boxes[:, 2:4] = boxes[:, 2:4] - boxes[:, 0:2]
            data[s + '_anno'] = boxes

        if self.proposal_params:
            frame2_proposals, gt_iou = zip(
                *[self._generate_proposals(a) for a in data['test_anno']])

            data['test_proposals'] = list(frame2_proposals)
            data['proposal_iou'] = list(gt_iou)

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        test_anno = data['test_anno'].clone()
        test_anno[:, 2:4] = test_anno[:, 0:2] + test_anno[:, 2:4]
        center_288 = (test_anno[:, 0:2] + test_anno[:, 2:4]) * 0.5
        w_288, h_288 = test_anno[:,
                                 2] - test_anno[:,
                                                0], test_anno[:,
                                                              3] - test_anno[:,
                                                                             1]
        wl_288, wr_288 = center_288[:,
                                    0] - test_anno[:,
                                                   0], test_anno[:,
                                                                 2] - center_288[:,
                                                                                 0]
        ht_288, hb_288 = center_288[:,
                                    1] - test_anno[:,
                                                   1], test_anno[:,
                                                                 3] - center_288[:,
                                                                                 1]
        w2h2_288 = torch.stack((wl_288, wr_288, ht_288, hb_288),
                               dim=1)  # [num_images, 4]

        boxes_72 = (data['test_anno'] * self.output_spatial_scale).float()
        # boxes is in format xywh, convert it to x0y0x1y1 format
        boxes_72[:, 2:4] = boxes_72[:, 0:2] + boxes_72[:, 2:4]

        center_float = torch.stack(((boxes_72[:, 0] + boxes_72[:, 2]) / 2.,
                                    (boxes_72[:, 1] + boxes_72[:, 3]) / 2.),
                                   dim=1)
        center_int = center_float.int().float()
        ind_72 = center_int[:,
                            1] * self.output_w + center_int[:,
                                                            0]  # [num_images, 1]

        data['ind_72'] = ind_72.long()
        data['w2h2_288'] = w2h2_288
        data['w2h2_72'] = w2h2_288 * 0.25

        ### Generate label functions
        if self.label_function_params is not None:
            data['train_label'] = self._generate_label_function(
                data['train_anno'])
            data['test_label'] = self._generate_label_function(
                data['test_anno'])

            # data['train_label_36'] = self._generate_label_36_function(data['train_anno'])
            # data['test_label_36'] = self._generate_label_36_function(data['test_anno'])

            data['train_label_72'] = self._generate_label_72_function(
                data['train_anno'])
            data['test_label_72'] = self._generate_label_72_function(
                data['test_anno'])

        return data
Пример #10
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

        for s in ['train', 'test']:
            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            # Crop image region centered at jittered_anno box
            crops, boxes = prutils.jittered_center_crop(
                data[s + '_images'], jittered_anno, data[s + '_anno'],
                self.search_area_factor, self.output_sz)

            # Apply transforms
            data[s + '_images'] = [self.transform[s](x) for x in crops]
            data[s + '_anno'] = boxes

            ## random flip
            FLIP = random.random() < 0.5
            if FLIP:
                data[s + '_images'][0] = data[s + '_images'][0].flip(2)
                WIDTH = data[s + '_images'][0].shape[1]
                data[s + '_anno'][0][0] = WIDTH - data[
                    s + '_anno'][0][0] - data[s + '_anno'][0][2]

        # torch.set_printoptions(threshold=20000)
        # Generate train and test proposals for scaler
        train_scaler_proposals, train_scaler_labels = zip(
            *[self._generate_scaler_proposals(a) for a in data['train_anno']])
        test_scaler_proposals, test_scaler_labels = zip(
            *[self._generate_scaler_proposals(a) for a in data['test_anno']])

        data['train_scaler_proposals'], data['train_scaler_labels'] = list(
            train_scaler_proposals), list(train_scaler_labels)
        data['test_scaler_proposals'], data['test_scaler_labels'] = list(
            test_scaler_proposals), list(test_scaler_labels)

        # Generate train and test proposals for locator
        data['test_anno_jittered'] = [
            self._get_jittered_box2(a) for a in data['test_anno']
        ]
        train_locator_proposals, train_locator_labels = zip(
            *[self._generate_locator_proposals(a) for a in data['train_anno']])
        test_locator_proposals, test_locator_labels = zip(*[
            self._generate_locator_proposals(a)
            for a in data['test_anno_jittered']
        ])

        data['train_locator_proposals'], data['train_locator_labels'] = list(
            train_locator_proposals), list(train_locator_labels)
        data['test_locator_proposals'], data['test_locator_labels'] = list(
            test_locator_proposals), list(test_locator_labels)

        data['train_locator_proposals'][0] = torch.cat(
            (data['train_locator_proposals'][0], data['train_anno'][0].reshape(
                1, -1)),
            dim=0)
        data['train_locator_labels'][0] = torch.cat(
            (data['train_locator_labels'][0], torch.Tensor([1.0])), dim=0)
        data['test_locator_proposals'][0] = torch.cat(
            (data['test_locator_proposals'][0], data['test_anno'][0].reshape(
                1, -1)),
            dim=0)
        data['test_locator_labels'][0] = torch.cat(
            (data['test_locator_labels'][0], torch.Tensor([1.0])), dim=0)

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Пример #11
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'
                'test_images'
                'train_anno'
                'test_anno'

        returns:
            TensorDict - output data block with following fields:
                'train_images'
                'test_images'
                'train_anno'
                'test_anno'
                'test_proposals'
                'proposal_iou'
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            # Crop image region centered at jittered_anno box
            crops, boxes = prutils_SE.jittered_center_crop_SE(
                data[s + '_images'],
                jittered_anno,
                data[s + '_anno'],
                self.search_area_factor,
                self.output_sz,
                mode=cv.BORDER_CONSTANT)
            # Apply transforms
            data[s + '_images'] = [self.transform[s](x)
                                   for x in crops]  # x : `numpy.ndarray`
            data[s + '_anno'] = boxes

            mask_crops = prutils_SE.jittered_center_crop_SE(
                data[s + '_masks'],
                jittered_anno,
                data[s + '_anno'],
                self.search_area_factor,
                self.output_sz,
                get_bbox_coord=False,
                mode=cv.BORDER_CONSTANT)
            data[s + '_masks'] = [self.mask_np2torch(x) for x in mask_crops]

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Пример #12
0
    def __call__(self, data: TensorDict):
        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            crops, boxes = prutils.jittered_center_crop_comb(
                data[s + '_images'], jittered_anno, data[s + '_anno'],
                self.search_area_factor, self.output_sz)

            data[s + '_images'] = [self.transform[s](x) for x in crops]
            data[s + '_anno'] = boxes

        if self.proposal_params:
            frame2_proposals, gt_iou = zip(*[
                self._generate_proposals(a.numpy()) for a in data['test_anno']
            ])

            data['test_proposals'] = [
                torch.tensor(p, dtype=torch.float32) for p in frame2_proposals
            ]
            data['proposal_iou'] = [
                torch.tensor(gi, dtype=torch.float32) for gi in gt_iou
            ]

        if 'is_distractor_test_frame' in data:
            data['is_distractor_test_frame'] = torch.tensor(
                data['is_distractor_test_frame'], dtype=torch.uint8)
        else:
            data['is_distractor_test_frame'] = torch.zeros(len(
                data['test_images']),
                                                           dtype=torch.uint8)

        if 'is_distractor_train_frame' in data:
            data['is_distractor_train_frame'] = torch.tensor(
                data['is_distractor_train_frame'], dtype=torch.uint8)
        else:
            data['is_distractor_train_frame'] = torch.zeros(len(
                data['train_images']),
                                                            dtype=torch.uint8)

        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        if self.label_function_params is not None:
            data['train_label'] = self._generate_label_function(
                data['train_anno'], data['is_distractor_train_frame'])

            data['test_label'] = self._generate_label_function(
                data['test_anno'], data['is_distractor_test_frame'])

        return data
Пример #13
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -

        returns:
            TensorDict - output data block with following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
                'test_proposals' (optional) -
                'proposal_iou'  (optional)  -
                'test_label' (optional)     -
                'train_label' (optional)    -
        """

        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

            #print([all_images[0].shape])[(480, 640, 4)]

        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            if self.crop_type == 'replicate':
                crops, boxes = prutils.jittered_center_crop(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.search_area_factor, self.output_sz)
            elif self.crop_type == 'nopad':
                crops, boxes = prutils.jittered_center_crop_nopad(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.search_area_factor, self.output_sz)
            else:
                raise ValueError('Unknown crop type {}'.format(self.crop_type))

            #print([crops[0].shape])[(288, 288, 4)]
            # for x in crops:
            #     # if x.shape[2]<4:
            #     #     print([len(crops), s, crops[0].shape, crops[1].shape, crops[2].shape])
            #     #     print(data[s+'_images'][0].shape)
            #     print(x.dtype)

            data[s + '_images'] = [
                self.transform[s](x.astype(np.float32)) for x in crops
            ]
            data[s + '_anno'] = boxes

        # Generate proposals
        if self.proposal_params:
            frame2_proposals, gt_iou = zip(
                *[self._generate_proposals(a) for a in data['test_anno']])

            data['test_proposals'] = list(frame2_proposals)
            data['proposal_iou'] = list(gt_iou)

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        # Generate label functions
        if self.label_function_params is not None:
            data['train_label'] = self._generate_label_function(
                data['train_anno'])
            data['test_label'] = self._generate_label_function(
                data['test_anno'])

        return data
Пример #14
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images', test_images', 'train_anno', 'test_anno'
        returns:
            TensorDict - output data block with following fields:
                'train_images', 'test_images', 'train_anno', 'test_anno', 'test_proposals', 'proposal_iou'
        """

        # Apply joint transforms
        if self.transform['joint'] is not None:
            data['train_images'], data['train_anno'] = self.transform['joint'](image=data['train_images'], bbox=data['train_anno'])
            data['test_images'], data['test_anno'] = self.transform['joint'](image=data['test_images'], bbox=data['test_anno'], new_roll=False)


        for s in ['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]

            # Crop image region centered at jittered_anno box
            #crops, boxes, depth_crops = prutils.jittered_center_crop(data[s + '_images'], jittered_anno, data[s + '_anno'],
                                                           #self.search_area_factor, self.output_sz, masks=data[s + '_depths'])
            crops, boxes = prutils.jittered_center_crop(data[s + '_images'], jittered_anno, data[s + '_anno'],
                                                           self.search_area_factor, self.output_sz)
            crops_depth, boxes_depth = prutils.jittered_center_crop(data[s + '_depths'], jittered_anno, data[s + '_anno'],
                                                           self.search_area_factor, self.output_sz)
            # data[s + '_depths'] = crops_depth
            # Apply transforms
            data[s + '_images'], data[s + '_anno'] = self.transform[s](image=crops, bbox=boxes, joint=False)

            # Depth crops no need to bright nromalizetion
            # data[s + '_depths'], _ = self.transform[s](image=crops_depth, bbox=boxes_depth, joint=False)
            # Song : add depth, just need ToTensor,
            if isinstance(crops_depth, (list, tuple)):
                data[s + '_depths'] = [torch.from_numpy(np.asarray(x).transpose((2, 0, 1))) for x in  crops_depth]
            else:
                crops_depth = np.asarray(crops_depth)
                if len(crops_depth.shape) == 3:
                    data[s + '_depths'] = [torch.from_numpy(np.asarray(crops_depth).transpose((2, 0, 1)))]
                elif len(crops_depth.shape) == 4:
                    data[s + '_depths'] = [torch.from_numpy(np.asarray(crops_depth).transpose((0, 3, 1, 2)))]
                else:
                    print('crops_depth dimensions error, num_dim=', np.ndim(crops_depth))
                    data[s + '_depths'] = torch.from_numpy(np.asarray(crops_depth))
                    
        # Generate proposals
        frame2_proposals, gt_iou = zip(*[self._generate_proposals(a) for a in data['test_anno']])

        data['test_proposals'] = list(frame2_proposals)
        data['proposal_iou'] = list(gt_iou)

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Пример #15
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'search_images', template_images', 'search_anno', 'template_anno'
        returns:
            TensorDict - output data block with following fields:
                'search_images', 'template_images', 'search_anno', 'template_anno'
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            data['search_images'], data['search_anno'] = self.transform[
                'joint'](image=data['search_images'], bbox=data['search_anno'])
            data['template_images'], data['template_anno'] = self.transform[
                'joint'](image=data['template_images'],
                         bbox=data['template_anno'],
                         new_roll=False)

        # self.label_function_params = {"kernel_sz": 4, "feature_sz": 256, "output_sz": self.search_sz, "end_pad_if_even": False, "sigma_factor": 0.05}
        for s in ['search', 'template']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num search/template frames must be 1"

            # Add a uniform noise to the center pos
            if self.rand:
                rand_size_a = torch.randn(2)
                rand_center_a = torch.rand(2)
                rand_size_b = torch.randn(2)
                rand_center_b = torch.rand(2)

                # Linearly interpolate from 0 to rand_size/center
                size_step = torch.tensor(
                    np.linspace(rand_size_a, rand_size_b,
                                len(data[s + '_anno'])))
                center_step = torch.tensor(
                    np.linspace(rand_center_a, rand_center_b,
                                len(data[s + '_anno'])))
                jittered_anno = [
                    self._get_jittered_box(a, s, rand_size=rs, rand_center=rc)
                    for a, rs, rc in zip(data[s +
                                              '_anno'], size_step, center_step)
                ]
            else:
                jittered_anno = [
                    self._get_jittered_box(a, s) for a in data[s + '_anno']
                ]

            # Crop image region centered at jittered_anno box

            if s == 'search':
                if torch.any(data['search_visible'] == 0):
                    # For empty annos, use the most recent crop box coords.
                    filler_anno = jittered_anno[0]
                    # assert filler_anno.sum(), "First frame was empty."  # Only last frame matters
                    filler_jitter = data[s + '_anno'][0]
                    for mi in range(len(data['search_visible'])):
                        if data['search_visible'][mi] == 0:
                            jittered_anno[mi] = filler_anno
                            data[s + '_anno'][mi] = filler_jitter
                        else:
                            filler_anno = jittered_anno[mi]
                            filler_jitter = data[s + '_anno'][mi]
                crops, boxes, _ = prutils.jittered_center_crop(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.search_area_factor, self.search_sz)
                # except:
                #     print("Jitter")
                #     print(jittered_anno)
                #     print("Regular")
                #     print(data[s + '_anno'])
                #     print("data['search_visible']")
                #     print(data['search_visible'])
            elif s == 'template':
                crops, boxes, _ = prutils.jittered_center_crop(
                    data[s + '_images'], jittered_anno, data[s + '_anno'],
                    self.template_area_factor, self.temp_sz)
            else:
                raise NotImplementedError
            # Boxes is columns,rows,column-offset,row-offset

            # Apply transforms
            if s == "search" and self.occlusion:
                maybe_occlusion = np.random.rand() > 0.5
                crops = list(crops)
                min_size = 1  # 10
                min_frames = 7  # 10  # When should the occlusion start
                if maybe_occlusion:
                    # import pdb;pdb.set_trace()
                    # rand_frames_len = np.random.randint(low=0, high=len(crops) - min_frames)  # len(data[s + '_images']) - min_frames)
                    # rand_frames_start = np.random.randint(low=min_frames, high=len(crops) - rand_frames_len)  # data[s + '_images']) - rand_frames_len)
                    crop_len = len(crops)
                    rand_frames_start = np.random.randint(low=min_frames,
                                                          high=crop_len)
                    rand_frames_len = crop_len - rand_frames_start
                    top_side = rand_frames_start % 2

                    # Find the box in the first from, and use this to construct occluder
                    start_box = boxes[rand_frames_start].int()
                    crop_shape = crops[0].shape  # data[s + '_images'][0].shape
                    apply_occlusion = False
                    pass_check = start_box[2] // 2 > min_size and start_box[
                        3] // 2 > min_size and crops[0].shape[
                            0] > min_size and crops[0].shape[1] > min_size
                    if top_side and pass_check:
                        # These are row inds
                        rand_start = np.random.randint(low=0,
                                                       high=start_box[3] -
                                                       min_size - 1)
                        if rand_start > start_box[3] // 2:
                            margin = np.copy(rand_start)
                            rand_start = np.random.randint(low=0,
                                                           high=margin -
                                                           min_size)
                            rand_extent = margin - rand_start
                        else:
                            remainder = np.maximum(start_box[3] - rand_start,
                                                   min_size)
                            mc, xc = np.minimum(rand_start,
                                                remainder), np.maximum(
                                                    rand_start, remainder)
                            if mc == xc:
                                xc += 1
                                rand_extent = mc + 1
                            else:
                                rand_extent = np.random.randint(low=mc,
                                                                high=xc)

                        # rand_mask = (np.random.rand(rand_extent, crop_shape[1], crop_shape[2]) * 128) + 128
                        rand_start += start_box[1]
                        if rand_start + rand_extent < crops[0].shape[
                                0] and rand_start > 0:
                            apply_occlusion = True
                    elif not top_side and pass_check:
                        # These are width inds
                        rand_start = np.random.randint(low=0,
                                                       high=start_box[2] -
                                                       min_size - 1)
                        if rand_start > start_box[2] // 2:
                            margin = np.copy(rand_start)
                            rand_start = np.random.randint(low=0,
                                                           high=margin -
                                                           min_size)
                            rand_extent = margin - rand_start
                        else:
                            # remainder = np.maximum((start_box[2] - margin - rand_start), min_size + 1)
                            remainder = np.maximum(start_box[3] - rand_start,
                                                   min_size)
                            mc, xc = np.minimum(rand_start,
                                                remainder), np.maximum(
                                                    rand_start, remainder)
                            if mc == xc:
                                xc += 1
                                rand_extent = mc + 1
                            else:
                                rand_extent = np.random.randint(low=mc,
                                                                high=xc)

                        # rand_mask = (np.random.rand(crop_shape[0], rand_extent, crop_shape[2]) * 128) + 128
                        rand_start += start_box[0]
                        if rand_start + rand_extent < crops[0].shape[
                                1] and rand_start > 0:
                            apply_occlusion = True
                    if apply_occlusion:
                        # print("applying occlusion")
                        # for bidx in range(rand_frames_start, rand_frames_start + rand_frames_len):
                        for bidx in range(rand_frames_start, crop_len):
                            # Apply an occluder to a random location in a random chunk of the video
                            # data[s + '_images'][bidx] = data[s + '_images'][bidx] + mask
                            if top_side:
                                shuffle_box = crops[bidx][
                                    rand_start:rand_start + rand_extent]
                                shuffle_shape = shuffle_box.shape
                                shuffle_box = shuffle_box.reshape(
                                    -1, shuffle_shape[-1])  # channels last
                                shuffle_box = shuffle_box[
                                    np.random.permutation(shuffle_shape[0] *
                                                          shuffle_shape[1])]
                                crops[bidx][rand_start:rand_start +
                                            rand_extent] = shuffle_box.reshape(
                                                shuffle_shape)  #  rand_mask
                            else:
                                shuffle_box = crops[
                                    bidx][:,
                                          rand_start:rand_start + rand_extent]
                                shuffle_shape = shuffle_box.shape
                                shuffle_box = shuffle_box.reshape(
                                    -1, shuffle_shape[-1])  # channels last
                                shuffle_box = shuffle_box[
                                    np.random.permutation(shuffle_shape[0] *
                                                          shuffle_shape[1])]
                                crops[bidx][:, rand_start:rand_start +
                                            rand_extent] = shuffle_box.reshape(
                                                shuffle_shape)  #  rand_mask
                            # from matplotlib import pyplot as plt
                            # plt.imshow(crops[bidx])
                            # plt.title("frame: {} topside: {} start: {} extent: {}".format(bidx, top_side, rand_start, rand_extent))
                            # plt.show()
            data[s + '_images'], data[s + '_anno'] = self.transform[s](
                image=crops, bbox=boxes, joint=self.joint)
            if s == "search":
                im_shape = [len(data[s + '_images']), 1
                            ] + [x for x in data[s + '_images'][0].shape[1:]]
                bumps = torch.zeros(im_shape,
                                    device=data[s +
                                                '_images'][0].device).float()
                for bidx in range(bumps.shape[0]):
                    box = boxes[bidx].int()
                    bumps[bidx, :, box[1]:box[1] + box[3],
                          box[0]:box[0] + box[2]] = 1
                data[
                    "bump"] = bumps  # self._generate_label_function(torch.cat(boxes, 0) / self.search_sz)
        self.prev_annos = jittered_anno

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
        data['template_images'] = data['template_images'].squeeze()
        data['search_images'] = data['search_images'].squeeze()
        data['template_anno'] = data['template_anno'].squeeze()
        data['search_anno'] = data['search_anno'].squeeze()
        return data
Пример #16
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
                'train_masks'   -
                'test_masks'    -

        returns:
            TensorDict - output data block with following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
                'train_masks'   -
                'test_masks'    -
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

        # extract patches from images
        for s in ['test', 'train']:  #['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [
                self._get_jittered_box(a, s) for a in data[s + '_anno']
            ]

            # Crop image region centered at jittered_anno box
            crops_img, boxes = prutils.jittered_center_crop(
                data[s + '_images'], jittered_anno, data[s + '_anno'],
                self.search_area_factor, self.output_sz)

            # Crop mask region centered at jittered_anno box
            crops_mask, _ = prutils.jittered_center_crop(
                data[s + '_masks'],
                jittered_anno,
                data[s + '_anno'],
                self.search_area_factor,
                self.output_sz,
                pad_val=float(0))

            if s == 'train' and self.use_distance:
                # use target center only to create distance map
                cx_ = (boxes[0][0] + boxes[0][2] / 2).item()
                cy_ = (boxes[0][1] + boxes[0][3] / 2).item()
                x_ = np.linspace(1, crops_img[0].shape[1],
                                 crops_img[0].shape[1]) - 1 - cx_
                y_ = np.linspace(1, crops_img[0].shape[0],
                                 crops_img[0].shape[0]) - 1 - cy_
                X, Y = np.meshgrid(x_, y_)
                D = np.sqrt(np.square(X) + np.square(Y)).astype(np.float32)

                data['test_dist'] = [
                    torch.from_numpy(np.expand_dims(D, axis=0))
                ]

            # Apply transforms
            data[s + '_images'] = [self.transform[s](x) for x in crops_img]
            data[s + '_anno'] = boxes
            data[s + '_masks'] = [
                torch.from_numpy(np.expand_dims(x, axis=0)) for x in crops_mask
            ]

            if s == 'train':
                data[s + '_init_masks'] = [
                    torch.from_numpy(
                        np.expand_dims(self._make_aabb_mask(x_.shape, bb_),
                                       axis=0))
                    for x_, bb_ in zip(crops_mask, boxes)
                ]

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data
Пример #17
0
    def __call__(self, data: TensorDict):
        """
        args:
            data - The input data, should contain the following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
                'train_masks'   -
                'test_masks'    -

        returns:
            TensorDict - output data block with following fields:
                'train_images'  -
                'test_images'   -
                'train_anno'    -
                'test_anno'     -
                'train_masks'   -
                'test_masks'    -
        """
        # Apply joint transforms
        if self.transform['joint'] is not None:
            num_train_images = len(data['train_images'])
            all_images = data['train_images'] + data['test_images']
            all_images_trans = self.transform['joint'](*all_images)

            data['train_images'] = all_images_trans[:num_train_images]
            data['test_images'] = all_images_trans[num_train_images:]

        # extract patches from images
        for s in ['test', 'train']:#['train', 'test']:
            assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
                "In pair mode, num train/test frames must be 1"

            # Add a uniform noise to the center pos
            jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]

            # Crop image region centered at jittered_anno box
            crops_img, boxes = prutils.jittered_center_crop(data[s + '_images'], jittered_anno, data[s + '_anno'],
                                                self.search_area_factor, self.output_sz)

            # Crop mask region centered at jittered_anno box
            crops_mask, _ = prutils.jittered_center_crop(data[s + '_masks'], jittered_anno, data[s + '_anno'],
                                                            self.search_area_factor, self.output_sz, pad_val=float(0))

            if s == 'test' and self.use_distance:
                # use target center only to create distance map
                cx_ = (boxes[0][0] + boxes[0][2] / 2).item() + ((0.25 * boxes[0][2].item()) * (random.random() - 0.5))
                cy_ = (boxes[0][1] + boxes[0][3] / 2).item() + ((0.25 * boxes[0][3].item()) * (random.random() - 0.5))
                x_ = np.linspace(1, crops_img[0].shape[1], crops_img[0].shape[1]) - 1 - cx_
                y_ = np.linspace(1, crops_img[0].shape[0], crops_img[0].shape[0]) - 1 - cy_
                X, Y = np.meshgrid(x_, y_)
                D = np.sqrt(np.square(X) + np.square(Y)).astype(np.float32)

                data['test_dist'] = [torch.from_numpy(np.expand_dims(D, axis=0))]

            # Apply transforms
            data[s + '_images'] = [self.transform[s](x) for x in crops_img]
            data[s + '_anno'] = boxes

            if s == 'train':
                data[s + '_masks'] = [torch.from_numpy(np.expand_dims(x, axis=0)) for x in crops_mask]
########## Coding by Yang 2020.10 ######## Generate contours of masks #####################################
            if s == 'test':
                for x in crops_mask:
                    contours, _ = cv2.findContours(x.astype('uint8'), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
                    mask_contour = cv2.drawContours(np.zeros((x.shape[1],x.shape[0])).astype('float32'), contours, -1, 1, thickness=1)
                    mask_ = cv2.drawContours(x, contours, -1, 1, thickness=1)
                    data['test_masks'] = [torch.from_numpy(np.expand_dims(mask_, axis=0))]
                    data['test_contour'] = [torch.from_numpy(np.expand_dims(mask_contour, axis=0))]
###########################################################################################################
            if s == 'train' and random.random() < 0.001:
                # on random use binary mask generated from axis-aligned bbox
                data['test_images'] = copy.deepcopy(data['train_images'])
                data['test_masks'] = copy.deepcopy(data['train_masks'])
                data['test_anno'] = copy.deepcopy(data['train_anno'])
                data[s + '_masks'] = [torch.from_numpy(np.expand_dims(self._make_aabb_mask(x_.shape, bb_), axis=0)) for x_, bb_ in zip(crops_mask, boxes)]

                if self.use_distance:
                    # there is no need to randomly perturb center since we are working with ground-truth here
                    cx_ = (boxes[0][0] + boxes[0][2] / 2).item()
                    cy_ = (boxes[0][1] + boxes[0][3] / 2).item()
                    x_ = np.linspace(1, crops_img[0].shape[1], crops_img[0].shape[1]) - 1 - cx_
                    y_ = np.linspace(1, crops_img[0].shape[0], crops_img[0].shape[0]) - 1 - cy_
                    X, Y = np.meshgrid(x_, y_)
                    D = np.sqrt(np.square(X) + np.square(Y)).astype(np.float32)
                    data['test_dist'] = [torch.from_numpy(np.expand_dims(D, axis=0))]

        # Prepare output
        if self.mode == 'sequence':
            data = data.apply(prutils.stack_tensors)
        else:
            data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

        return data