예제 #1
0
    def __init__(self, hparams, dbg=False):
        '''
        Pytorch lightning module for training goturn tracker.

        @hparams: all the argparse arguments for training
        @dbg: boolean for switching on visualizer
        '''
        logger.info('=' * 15)
        logger.info('GOTURN TRACKER')
        logger.info('=' * 15)

        super(GoturnTrain, self).__init__()

        self.__set_seed(hparams.seed)
        self.hparams = hparams
        logger.info('Setting up the network...')

        # network with pretrained model
        if not self.hparams.finetune:
            self._model = GoturnNetwork(self.hparams.pretrained_model)
        else:
            self._model = GoturnNetwork()
            checkpoint = torch.load(
                self.hparams.pretrained_model)['state_dict']
            checkpoint = self.load_model_param(checkpoint)
            self._model.load_state_dict(checkpoint, strict=False)
        self._dbg = dbg
        if dbg:
            self._viz = Visualizer(port=8097)
예제 #2
0
    def __init__(self,
                 imgs_dir,
                 ann_dir,
                 isTrain=True,
                 val_ratio=0.2,
                 dbg=False):
        '''
        loading video frames and annotation from alov
        @imgs_dir: alov video frames directory
        @ann_dir: annotations path
        @isTrain: True: Training, False: validation
        @val_ratio: validation data ratio
        @dbg: For visualization
        '''

        if not Path(imgs_dir).is_dir():
            logger.error('{} is not a valid directory'.format(imgs_dir))

        self._imgs_dir = Path(imgs_dir)
        self._ann_dir = Path(ann_dir)

        self._cats = {}
        self._isTrain = isTrain
        self._val_ratio = val_ratio

        self.__loaderAlov()
        self._alov_imgpairs = []
        self._alov_vids = self.__get_videos(self._isTrain, self._val_ratio)
        self.__parse_all()  # get all the image pairs in a list
        self._dbg = dbg
        if dbg:
            self._env = 'Alov'
            self._viz = Visualizer(env=self._env)
예제 #3
0
    def __init__(self, args, dbg=False):
        """load model """
        loader = loadfromfolder(args.input)
        self._vid_frames = loader.get_video_frames()

        model_dir = Path(args.model_dir)
        # Checkpoint path
        ckpt_dir = model_dir.joinpath('checkpoints')
        ckpt_path = next(ckpt_dir.glob('*.ckpt'))

        model = GoturnTrain.load_from_checkpoint(ckpt_path)
        model.eval()
        model.freeze()

        self._model = model
        if dbg:
            self._viz = Visualizer()

        self._dbg = dbg
예제 #4
0
    def __init__(self,
                 lamda_shift,
                 lamda_scale,
                 min_scale,
                 max_scale,
                 dbg=False,
                 env='sample_generator'):
        """set parameters """

        self._lamda_shift = lamda_shift
        self._lamda_scale = lamda_scale
        self._min_scale = min_scale
        self._max_scale = max_scale
        self._kSamplesPerImage = 10  # number of synthetic samples per image

        self._viz = None
        if dbg:
            self._env = env
            self._viz = Visualizer(env=self._env)

        self._dbg = dbg
예제 #5
0
    def __init__(self, hparams, dbg=False):
        '''
        Pytorch lightning module for training goturn tracker.

        @hparams: all the argparse arguments for training
        @dbg: boolean for switching on visualizer
        '''
        logger.info('=' * 15)
        logger.info('GOTURN TRACKER')
        logger.info('=' * 15)

        super(GoturnTrain, self).__init__()

        self.__set_seed(hparams.seed)
        self.hparams = hparams
        logger.info('Setting up the network...')

        # network with pretrained model
        self._model = GoturnNetwork(self.hparams.pretrained_model)
        self._dbg = dbg
        if dbg:
            self._viz = Visualizer(port=8097)
예제 #6
0
    def __init__(self, imgs_dir, ann_dir, isTrain=True, val_ratio=0.2, dbg=False):
        '''
        loading images and annotation from imagenet
        @imgs_dir: images path
        @ann_dir: annotations path
        @isTrain: True: Training, False: validation
        @val_ratio: validation data ratio
        @dbg: For visualization
        '''

        if not Path(imgs_dir).is_dir():
            logger.error('{} is not a valid directory'.format(imgs_dir))

        self._imgs_dir = Path(imgs_dir)
        self._ann_dir = Path(ann_dir)
        self._kMaxRatio = 0.66
        self._list_of_annotations = self.__loadImageNetDet(isTrain=isTrain, val_ratio=val_ratio)
        self._data_fetched = []  # for debug purposes
        assert len(self._list_of_annotations) > 0, 'Number of valid annotations is {}'.format(len(self._list_of_annotations))

        self._dbg = dbg
        if dbg:
            self._env = 'ImageNet'
            self._viz = Visualizer(env=self._env)
예제 #7
0
class GoturnTracker:
    """Docstring for . """
    def __init__(self, args, dbg=False):
        """load model """
        loader = loadfromfolder(args.input)
        self._vid_frames = loader.get_video_frames()

        model_dir = Path(args.model_dir)
        # Checkpoint path
        ckpt_dir = model_dir.joinpath('checkpoints')
        ckpt_path = next(ckpt_dir.glob('*.ckpt'))

        model = GoturnTrain.load_from_checkpoint(ckpt_path)
        model.eval()
        model.freeze()

        self._model = model
        if dbg:
            self._viz = Visualizer()

        self._dbg = dbg

    def vis_images(self, prev, curr, gt_bb, pred_bb, prefix='train'):
        def unnormalize(image, mean, std):
            image = np.transpose(image, (1, 2, 0)) * std + mean
            image = image.astype(np.float32)

            return image

        for i in range(0, prev.shape[0]):
            _mean = np.array([104, 117, 123])
            _std = np.ones_like(_mean)

            prev_img = prev[i].cpu().detach().numpy()
            curr_img = curr[i].cpu().detach().numpy()

            prev_img = unnormalize(prev_img, _mean, _std)
            curr_img = unnormalize(curr_img, _mean, _std)

            gt_bb_i = BoundingBox(*gt_bb[i].cpu().detach().numpy().tolist())
            gt_bb_i.unscale(curr_img)
            curr_img = draw.bbox(curr_img, gt_bb_i, color=(255, 255, 255))

            pred_bb_i = BoundingBox(
                *pred_bb[i].cpu().detach().numpy().tolist())
            pred_bb_i.unscale(curr_img)
            curr_img = draw.bbox(curr_img, pred_bb_i)

            out = np.concatenate(
                (prev_img[np.newaxis, ...], curr_img[np.newaxis, ...]), axis=0)
            out = np.transpose(out, [0, 3, 1, 2])

            self._viz.plot_images_np(out,
                                     title='sample_{}'.format(i),
                                     env='goturn_{}'.format(prefix))

    def _track(self, curr_frame, prev_frame, rect):
        """track current frame
        @curr_frame: current frame
        @prev_frame: prev frame
        @rect: bounding box of previous frame
        """
        prev_bbox = rect

        target_pad, _, _, _ = cropPadImage(prev_bbox, prev_frame)
        cur_search_region, search_location, edge_spacing_x, edge_spacing_y = cropPadImage(
            prev_bbox, curr_frame)

        if self._dbg:
            self._viz.plot_image_opencv(target_pad, 'target')
            self._viz.plot_image_opencv(cur_search_region, 'current')

        target_pad_in = self.preprocess(target_pad, mean=None).unsqueeze(0)
        cur_search_region_in = self.preprocess(cur_search_region,
                                               mean=None).unsqueeze(0)
        pred_bb = self._model.forward(target_pad_in, cur_search_region_in)
        if self._dbg:
            prev_bbox.scale(prev_frame)
            x1, y1, x2, y2 = prev_bbox.x1, prev_bbox.y1, prev_bbox.x2, prev_bbox.y2
            prev_bbox = torch.tensor([x1, y1, x2, y2]).unsqueeze(0)
            target_dbg = target_pad_in.clone()
            cur_search_region_dbg = cur_search_region_in.clone()
            self.vis_images(target_dbg, cur_search_region_dbg, prev_bbox,
                            pred_bb)

        pred_bb = BoundingBox(*pred_bb[0].cpu().detach().numpy().tolist())
        pred_bb.unscale(cur_search_region)
        pred_bb.uncenter(curr_frame, search_location, edge_spacing_x,
                         edge_spacing_y)
        x1, y1, x2, y2 = int(pred_bb.x1), int(pred_bb.y1), int(
            pred_bb.x2), int(pred_bb.y2)
        pred_bb = BoundingBox(x1, y1, x2, y2)
        return pred_bb

    def preprocess(self, im, mean=None):
        """preprocess image before forward pass, this is the same
        preprocessing used during training, please refer to collate function
        in train.py for reference
        @image: input image
        """
        # preprocessing for all pretrained pytorch models
        if mean:
            im = resize(im, (227, 227)) - mean
        else:
            mean = np.array([104, 117, 123])
            im = resize(im, (227, 227)) - mean
        im = image_io.image_to_tensor(im)
        return im

    def track(self):
        """Track"""
        vid_frames = self._vid_frames[0]
        num_frames = len(vid_frames)
        f_path = vid_frames[0]
        frame_0 = image_io.load(f_path)

        prev = np.asarray(frame_0)
        global image
        image = prev
        while True:
            # prev_out = cv2.cvtColor(prev, cv2.COLOR_RGB2BGR)
            prev_out = np.copy(prev)
            cv2.imshow('image', prev_out)
            key = cv2.waitKey(1) & 0xFF
            if key == ord('s'):
                (x1, y1), (x2, y2) = refPt[0], refPt[1]
                bbox_0 = BoundingBox(x1, y1, x2, y2)
                break
            elif key == ord('r'):
                (x1, y1), (x2, y2) = refPt[0], refPt[1]
                bbox_0 = BoundingBox(x1, y1, x2, y2)
                break

        for i in range(1, num_frames):
            f_path = vid_frames[i]
            frame_1 = image_io.load(f_path)
            curr = np.asarray(frame_1)
            bbox_0 = self._track(curr, prev, bbox_0)
            bbox = bbox_0
            prev = curr

            if cv2.waitKey(1) & 0xFF == ord('p'):
                while True:
                    image = curr
                    cv2.imshow("image", curr)
                    key = cv2.waitKey(0) & 0xFF
                    if key == ord("s"):
                        (x1, y1), (x2, y2) = refPt[0], refPt[1]
                        bbox_0 = BoundingBox(x1, y1, x2, y2)
                        break

            curr_dbg = np.copy(curr)
            curr_dbg = cv2.rectangle(curr_dbg, (int(bbox.x1), int(bbox.y1)),
                                     (int(bbox.x2), int(bbox.y2)),
                                     (255, 255, 0), 2)

            # curr_dbg = cv2.cvtColor(curr_dbg, cv2.COLOR_RGB2BGR)
            cv2.imshow('image', curr_dbg)
            # cv2.imwrite('./output/{:04d}.png'.format(i), curr_dbg)
            cv2.waitKey(20)
예제 #8
0
class GoturnTrain(LightningModule):
    """Docstring for GoturnTrain. """
    def __init__(self, hparams, dbg=False):
        '''
        Pytorch lightning module for training goturn tracker.

        @hparams: all the argparse arguments for training
        @dbg: boolean for switching on visualizer
        '''
        logger.info('=' * 15)
        logger.info('GOTURN TRACKER')
        logger.info('=' * 15)

        super(GoturnTrain, self).__init__()

        self.__set_seed(hparams.seed)
        self.hparams = hparams
        logger.info('Setting up the network...')

        # network with pretrained model
        self._model = GoturnNetwork(self.hparams.pretrained_model)
        self._dbg = dbg
        if dbg:
            self._viz = Visualizer(port=8097)

    def __freeze(self):
        """Freeze the model features layer
        """
        features_layer = self._model._net
        for param in features_layer.parameters():
            param.requires_grad = False

    def _set_conv_layer(self, conv_layers, param_dict):
        for layer in conv_layers.modules():
            if type(layer) == torch.nn.modules.conv.Conv2d:
                param_dict.append({
                    'params': layer.weight,
                    'lr': 0,
                    'weight_decay': self.hparams.wd
                })
                param_dict.append({
                    'params': layer.bias,
                    'lr': 0,
                    'weight_decay': 0
                })
        return param_dict

    def __set_lr(self):
        '''set learning rate for classifier layer'''
        param_dict = []
        if 1:
            conv_layers = self._model._net_1
            param_dict = self._set_conv_layer(conv_layers, param_dict)
            conv_layers = self._model._net_2
            param_dict = self._set_conv_layer(conv_layers, param_dict)

            regression_layer = self._model._classifier
            for layer in regression_layer.modules():
                if type(layer) == torch.nn.modules.linear.Linear:
                    param_dict.append({
                        'params': layer.weight,
                        'lr': 10 * self.hparams.lr,
                        'weight_decay': self.hparams.wd
                    })
                    param_dict.append({
                        'params': layer.bias,
                        'lr': 20 * self.hparams.lr,
                        'weight_decay': 0
                    })
        return param_dict

    def find_lr(self):
        """finding suitable learning rate """
        model = self._model
        params = self.__set_lr()

        criterion = torch.nn.L1Loss(size_average=False)
        optimizer = CaffeSGD(params,
                             lr=1e-8,
                             momentum=self.hparams.momentum,
                             weight_decay=self.hparams.wd)

        lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
        trainloader = self.train_dataloader()
        lr_finder.range_test(trainloader,
                             start_lr=1e-7,
                             end_lr=1,
                             num_iter=500)
        lr_finder.plot()

    def __set_seed(self, SEED):
        ''' set all the seeds for reproducibility '''
        logger.info('Settings seed = {}'.format(SEED))
        torch.manual_seed(SEED)
        np.random.seed(SEED)
        random.seed(SEED)
        cudnn.deterministic = True

    @staticmethod
    def add_model_specific_args(parent_parser):
        ''' These are specific parameters for the sample generator '''
        ap = argparse.ArgumentParser(parents=[parent_parser])

        ap.add_argument('--min_scale',
                        type=float,
                        default=-0.4,
                        help='min scale')
        ap.add_argument('--max_scale',
                        type=float,
                        default=0.4,
                        help='max scale')
        ap.add_argument('--lamda_shift', type=float, default=5)
        ap.add_argument('--lamda_scale', type=int, default=15)
        return ap

    def configure_optimizers(self):
        """Configure optimizers"""
        logger.info(
            'Configuring optimizer: SGD with lr = {}, momentum = {}'.format(
                self.hparams.lr, self.hparams.momentum))
        params = self.__set_lr()
        optimizer = CaffeSGD(params,
                             lr=self.hparams.lr,
                             momentum=self.hparams.momentum,
                             weight_decay=self.hparams.wd)
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=self.hparams.lr_step,
            gamma=self.hparams.gamma)
        return [optimizer], [scheduler]

    @pl.data_loader
    def train_dataloader(self):
        """train dataloader"""
        logger.info('===' * 20)
        logger.info('Loading dataset for training, please wait...')
        logger.info('===' * 20)

        imagenet_path = self.hparams.imagenet_path
        alov_path = self.hparams.alov_path
        mean_file = None
        manager = Manager()
        objGoturn = GoturnDataloader(imagenet_path,
                                     alov_path,
                                     mean_file=mean_file,
                                     images_p=manager.list(),
                                     targets_p=manager.list(),
                                     bboxes_p=manager.list(),
                                     val_ratio=0.005,
                                     isTrain=True,
                                     dbg=False)
        train_loader = DataLoader(objGoturn,
                                  batch_size=self.hparams.batch_size,
                                  shuffle=True,
                                  num_workers=6,
                                  collate_fn=objGoturn.collate)

        return train_loader

    @pl.data_loader
    def val_dataloader(self):
        """validation dataloader"""
        logger.info('===' * 20)
        logger.info('Loading dataset for Validation, please wait...')
        logger.info('===' * 20)

        imagenet_path = self.hparams.imagenet_path
        alov_path = self.hparams.alov_path
        mean_file = None

        manager = Manager()
        objGoturn = GoturnDataloader(imagenet_path,
                                     alov_path,
                                     mean_file=mean_file,
                                     images_p=manager.list(),
                                     targets_p=manager.list(),
                                     bboxes_p=manager.list(),
                                     val_ratio=0.005,
                                     isTrain=False,
                                     dbg=False)
        val_loader = DataLoader(objGoturn,
                                batch_size=self.hparams.batch_size,
                                shuffle=True,
                                num_workers=6,
                                collate_fn=objGoturn.collate)
        return val_loader

    def forward(self, prev, curr):
        """forward function
        """
        pred_bb = self._model(prev.float(), curr.float())
        return pred_bb

    def vis_images(self, prev, curr, gt_bb, pred_bb, prefix='train'):
        def unnormalize(image, mean):
            image = np.transpose(image, (1, 2, 0)) + mean
            image = image.astype(np.float32)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            return image

        for i in range(0, prev.shape[0]):
            # _mean = np.load(self.hparams.mean_file)
            _mean = np.array([104, 117, 123])
            prev_img = prev[i].cpu().detach().numpy()
            curr_img = curr[i].cpu().detach().numpy()

            prev_img = unnormalize(prev_img, _mean)
            curr_img = unnormalize(curr_img, _mean)

            gt_bb_i = BoundingBox(*gt_bb[i].cpu().detach().numpy().tolist())
            gt_bb_i.unscale(curr_img)
            curr_img = draw.bbox(curr_img, gt_bb_i, color=(255, 0, 255))

            pred_bb_i = BoundingBox(
                *pred_bb[i].cpu().detach().numpy().tolist())
            pred_bb_i.unscale(curr_img)
            curr_img = draw.bbox(curr_img, pred_bb_i)

            out = np.concatenate(
                (prev_img[np.newaxis, ...], curr_img[np.newaxis, ...]), axis=0)
            out = np.transpose(out, [0, 3, 1, 2])

            self._viz.plot_images_np(out,
                                     title='sample_{}'.format(i),
                                     env='goturn_{}'.format(prefix))

    def training_step(self, batch, batch_idx):
        """Training step
        @batch: current batch data
        @batch_idx: current batch index
        """
        curr, prev, gt_bb = batch
        pred_bb = self.forward(prev, curr)
        loss = torch.nn.L1Loss(size_average=False)(pred_bb.float(),
                                                   gt_bb.float())

        if self.trainer.use_dp:
            loss = loss.unsqueeze(0)

        if self._dbg:
            if batch_idx % 1000 == 0:
                d = {'loss': loss.item()}
                iters = (self.trainer.num_training_batches -
                         1) * self.current_epoch + batch_idx
                self._viz.plot_curves(d,
                                      iters,
                                      title='Train',
                                      ylabel='train_loss')
            if batch_idx % 1000 == 0:
                self.vis_images(prev, curr, gt_bb, pred_bb)

        tqdm_dict = {'batch_loss': loss}
        output = OrderedDict({
            'loss': loss,
            'progress_bar': tqdm_dict,
            'log': tqdm_dict
        })
        return output

    def validation_step(self, batch, batch_idx):
        """validation step
        @batch: current batch data
        @batch_idx: current batch index
        """
        curr, prev, gt_bb = batch
        pred_bb = self.forward(prev, curr)
        loss = torch.nn.L1Loss(size_average=False)(pred_bb, gt_bb.float())

        if self.trainer.use_dp:
            loss = loss.unsqueeze(0)

        if self._dbg:
            if batch_idx % 100 == 0:
                d = {'loss': loss.item()}
                iters = (self.trainer.num_val_batches -
                         1) * self.current_epoch + batch_idx
                self._viz.plot_curves(d,
                                      iters,
                                      title='Validation',
                                      ylabel='val_loss')

            if batch_idx % 1000 == 0:
                self.vis_images(prev, curr, gt_bb, pred_bb, prefix='val')

        tqdm_dict = {'val_loss': loss}
        output = OrderedDict({
            'val_loss': loss,
            'progress_bar': tqdm_dict,
            'log': tqdm_dict
        })
        return output

    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': avg_loss}
예제 #9
0
class sample_generator:
    """Generate samples from single frame"""
    def __init__(self,
                 lamda_shift,
                 lamda_scale,
                 min_scale,
                 max_scale,
                 dbg=False,
                 env='sample_generator'):
        """set parameters """

        self._lamda_shift = lamda_shift
        self._lamda_scale = lamda_scale
        self._min_scale = min_scale
        self._max_scale = max_scale
        self._kSamplesPerImage = 10  # number of synthetic samples per image

        self._viz = None
        if dbg:
            self._env = env
            self._viz = Visualizer(env=self._env)

        self._dbg = dbg

    def make_true_sample(self):
        """Generate true target:search_region pair"""

        curr_prior_tight = self.bbox_prev_gt_
        target_pad = self.target_pad_
        # To find out the region in which we need to search in the
        # current frame, we use the previous frame bbox to get the
        # region in which we can make the search
        output = cropPadImage(curr_prior_tight, self.img_curr_, self._dbg,
                              self._viz)
        curr_search_region, curr_search_location, edge_spacing_x, edge_spacing_y = output

        bbox_curr_gt = self.bbox_curr_gt_
        bbox_curr_gt_recentered = BoundingBox(0, 0, 0, 0)
        bbox_curr_gt_recentered = bbox_curr_gt.recenter(
            curr_search_location, edge_spacing_x, edge_spacing_y,
            bbox_curr_gt_recentered)

        if self._dbg:
            env = self._env + '_make_true_sample'
            search_dbg = draw.bbox(self.img_curr_, curr_search_location)
            search_dbg = draw.bbox(search_dbg,
                                   bbox_curr_gt,
                                   color=(255, 255, 0))
            self._viz.plot_image_opencv(search_dbg, 'search_region', env=env)

            recentered_img = draw.bbox(curr_search_region,
                                       bbox_curr_gt_recentered,
                                       color=(255, 255, 0))
            self._viz.plot_image_opencv(recentered_img,
                                        'cropped_search_region',
                                        env=env)
            del recentered_img
            del search_dbg

        bbox_curr_gt_recentered.scale(curr_search_region)

        return curr_search_region, target_pad, bbox_curr_gt_recentered

    def make_training_samples(self, num_samples, images, targets,
                              bbox_gt_scales):
        """
        @num_samples: number of samples
        @images: set of num_samples appended to images list
        @target: set of num_samples targets appended to targets list
        @bbox_gt_scales: bounding box to be regressed (scaled version)
        """
        for i in range(num_samples):
            image_rand_focus, target_pad, bbox_gt_scaled = self.make_training_sample_BBShift(
            )
            images.append(image_rand_focus)
            targets.append(target_pad)
            bbox_gt_scales.append(bbox_gt_scaled)
            if self._dbg:
                self.visualize(image_rand_focus, target_pad, bbox_gt_scaled, i)

        return images, targets, bbox_gt_scales

    def visualize(self, image, target, bbox, idx):
        """
        sample generator prepares image and the respective targets (with
        bounding box). This function helps you to visualize it.

        The visualization is based on the Visdom server, please
        initialize the visdom server by running the command
        $ python -m visdom.server
        open http://localhost:8097 in your browser to visualize the
        images
        """

        if image_io._is_pil_image(image):
            image = np.asarray(image)

        if image_io._is_pil_image(target):
            target = np.asarray(target)

        target = cv2.resize(target, (227, 227))
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (227, 227))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        bbox.unscale(image)
        bbox.x1, bbox.x2, bbox.y1, bbox.y2 = int(bbox.x1), int(bbox.x2), int(
            bbox.y1), int(bbox.y2)

        image_bb = draw.bbox(image, bbox)
        out = np.concatenate(
            (target[np.newaxis, ...], image_bb[np.newaxis, ...]), axis=0)
        out = np.transpose(out, [0, 3, 1, 2])
        self._viz.plot_images_np(out,
                                 title='sample_{}'.format(idx),
                                 env=self._env + '_train')

    def get_default_bb_params(self):
        """default bb parameters"""
        default_params = bbParams(self._lamda_shift, self._lamda_scale,
                                  self._min_scale, self._max_scale)
        return default_params

    def make_training_sample_BBShift_(self, bbParams, dbg=False):
        """generate training samples based on bbparams"""

        bbox_curr_gt = self.bbox_curr_gt_
        bbox_curr_shift = BoundingBox(0, 0, 0, 0)
        bbox_curr_shift = bbox_curr_gt.shift(
            self.img_curr_, bbParams.lamda_scale, bbParams.lamda_shift,
            bbParams.min_scale, bbParams.max_scale, True, bbox_curr_shift)
        rand_search_region, rand_search_location, edge_spacing_x, edge_spacing_y = cropPadImage(
            bbox_curr_shift, self.img_curr_, dbg=self._dbg, viz=self._viz)

        bbox_curr_gt = self.bbox_curr_gt_
        bbox_gt_recentered = BoundingBox(0, 0, 0, 0)
        bbox_gt_recentered = bbox_curr_gt.recenter(rand_search_location,
                                                   edge_spacing_x,
                                                   edge_spacing_y,
                                                   bbox_gt_recentered)

        if dbg:
            env = self._env + '_make_training_sample_bbshift'
            viz = self._viz
            curr_img_bbox = draw.bbox(self.img_curr_, bbox_curr_gt)
            recentered_img = draw.bbox(rand_search_region, bbox_gt_recentered)

            viz.plot_image_opencv(curr_img_bbox, 'curr shifted bbox', env=env)
            viz.plot_image_opencv(recentered_img,
                                  'recentered shifted bbox',
                                  env=env)

        bbox_gt_recentered.scale(rand_search_region)
        bbox_gt_scaled = bbox_gt_recentered

        return rand_search_region, self.target_pad_, bbox_gt_scaled

    def make_training_sample_BBShift(self):
        """
        bb_params consists of shift, scale, min-max scale for shifting
        the current bounding box
        """
        default_bb_params = self.get_default_bb_params()
        image_rand_focus, target_pad, bbox_gt_scaled = self.make_training_sample_BBShift_(
            default_bb_params, self._dbg)

        return image_rand_focus, target_pad, bbox_gt_scaled

    def reset(self, bbox_curr, bbox_prev, img_curr, img_prev):
        """This prepares the target image with enough context (search
        region)
        @bbox_curr: current frame bounding box
        @bbox_prev: previous frame bounding box
        @img_curr: current frame
        @img_prev: previous frame
        """

        target_pad, pad_image_location, _, _ = cropPadImage(bbox_prev,
                                                            img_prev,
                                                            dbg=self._dbg,
                                                            viz=self._viz)
        self.img_curr_ = img_curr
        self.bbox_curr_gt_ = bbox_curr
        self.bbox_prev_gt_ = bbox_prev
        self.target_pad_ = target_pad  # crop kContextFactor * bbox_curr copied

        if self._dbg:
            env = self._env + '_targetpad'
            search_dbg = draw.bbox(img_prev, bbox_prev, color=(0, 0, 255))
            search_dbg = draw.bbox(search_dbg, pad_image_location)
            self._viz.plot_image_opencv(search_dbg, 'target_region', env=env)
            self._viz.plot_image_opencv(target_pad,
                                        'cropped_target_region',
                                        env=env)
            del search_dbg