Пример #1
0
 def init_target_boxes(self):
     """Get the target bounding boxes for the initial augmented samples."""
     self.classifier_target_box = self.get_iounet_box(
         self.pos, self.target_sz, self.init_sample_pos,
         self.init_sample_scale)
     init_target_boxes = TensorList()
     for T in self.transforms:
         init_target_boxes.append(
             self.classifier_target_box +
             torch.Tensor([T.shift[1], T.shift[0], 0, 0]))
     init_target_boxes = torch.cat(init_target_boxes.view(1, 4),
                                   0).to(self.params.device)
     self.target_boxes = init_target_boxes.new_zeros(
         self.params.sample_memory_size, 4)
     self.target_boxes[:init_target_boxes.shape[0], :] = init_target_boxes
     return init_target_boxes
Пример #2
0
    def init_iou_net(self):
        # Setup IoU net
        self.iou_predictor = self.params.features.get_unique_attribute(
            'iou_predictor')
        for p in self.iou_predictor.parameters():
            p.requires_grad = False

        # Get target boxes for the different augmentations
        self.iou_target_box = self.get_iounet_box(self.pos, self.target_sz,
                                                  self.pos.round(),
                                                  self.target_scale)
        target_boxes = TensorList()
        if self.params.iounet_augmentation:
            for T in self.transforms:
                if not isinstance(
                        T, (augmentation.Identity, augmentation.Translation,
                            augmentation.FlipHorizontal,
                            augmentation.FlipVertical, augmentation.Blur)):
                    break
                target_boxes.append(
                    self.iou_target_box +
                    torch.Tensor([T.shift[1], T.shift[0], 0, 0]))
        else:
            target_boxes.append(self.iou_target_box.clone())
        target_boxes = torch.cat(target_boxes.view(1, 4),
                                 0).to(self.params.device)

        # Get iou features
        iou_backbone_features = self.get_iou_backbone_features()

        # Remove other augmentations such as rotation
        iou_backbone_features = TensorList(
            [x[:target_boxes.shape[0], ...] for x in iou_backbone_features])

        # Extract target feat
        with torch.no_grad():
            target_feat = self.iou_predictor.get_modulation(
                iou_backbone_features, target_boxes)
        self.target_feat = TensorList(
            [x.detach().mean(0) for x in target_feat])

        if getattr(self.params, 'iounet_not_use_reference', False):
            self.target_feat = TensorList([
                torch.full_like(tf,
                                tf.norm() / tf.numel())
                for tf in self.target_feat
            ])
Пример #3
0
    def init_iou_net(self, backbone_feat):
        # Setup IoU net and objective
        for p in self.net.bb_regressor.parameters():
            p.requires_grad = False

        # Get target boxes for the different augmentations
        self.classifier_target_box = self.get_iounet_box(
            self.pos, self.target_sz, self.init_sample_pos,
            self.init_sample_scale)
        target_boxes = TensorList()
        if self.params.iounet_augmentation:
            for T in self.transforms:
                if not isinstance(
                        T, (augmentation.Identity, augmentation.Translation,
                            augmentation.FlipHorizontal,
                            augmentation.FlipVertical, augmentation.Blur)):
                    break
                target_boxes.append(
                    self.classifier_target_box +
                    torch.Tensor([T.shift[1], T.shift[0], 0, 0]))
        else:
            target_boxes.append(self.classifier_target_box + torch.Tensor([
                self.transforms[0].shift[1], self.transforms[0].shift[0], 0, 0
            ]))
        target_boxes = torch.cat(target_boxes.view(1, 4),
                                 0).to(self.params.device)

        # Get iou features
        iou_backbone_feat = self.get_iou_backbone_features(backbone_feat)

        # Remove other augmentations such as rotation
        iou_backbone_feat = TensorList(
            [x[:target_boxes.shape[0], ...] for x in iou_backbone_feat])

        # Get modulation vector
        self.iou_modulation = self.get_iou_modulation(iou_backbone_feat,
                                                      target_boxes)
        if torch.is_tensor(self.iou_modulation[0]):
            self.iou_modulation = TensorList(
                [x.detach().mean(0) for x in self.iou_modulation])
Пример #4
0
    def init_dr_net(self):
        # Setup IoU net
        self.box_predictor = self.params.features.get_unique_attribute(
            'iou_predictor')
        for p in self.box_predictor.parameters():
            p.requires_grad = False

        # Get target boxes for the different augmentations
        self.iou_target_box = self.get_iounet_box(self.pos, self.target_sz,
                                                  self.pos.round(),
                                                  self.target_scale)
        target_boxes = TensorList()
        target_boxes.append(self.iou_target_box.clone())
        target_boxes = torch.cat(target_boxes.view(1, 4),
                                 0).to(self.params.device)

        # Get iou features
        iou_backbone_features = self.get_iou_backbone_features()

        # Remove other augmentations such as rotation
        iou_backbone_features = TensorList(
            [x[:target_boxes.shape[0], ...] for x in iou_backbone_features])

        # Extract target feat
        with torch.no_grad():
            target_feat = self.box_predictor.get_filter(
                iou_backbone_features, target_boxes)
        self.target_feat = TensorList(
            [x.detach().mean(0) for x in target_feat])

        if getattr(self.params, 'iounet_not_use_reference', False):
            self.target_feat = TensorList([
                torch.full_like(tf,
                                tf.norm() / tf.numel())
                for tf in self.target_feat
            ])
Пример #5
0
class ECO(BaseTracker):
    def initialize_features(self):
        if not getattr(self, 'features_initialized', False):
            self.params.features.initialize()
        self.features_initialized = True

    def initialize(self, image, info: dict) -> dict:

        initSeed = 1
        torch.manual_seed(initSeed)
        torch.cuda.manual_seed(initSeed)
        torch.cuda.manual_seed_all(initSeed)
        np.random.seed(initSeed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        os.environ['PYTHONHASHSEED'] = str(initSeed)
        state = info['init_bbox']

        # Initialize some stuff
        self.frame_num = 1
        if not hasattr(self.params, 'device'):
            self.params.device = 'cuda' if self.params.use_gpu else 'cpu'

        # Initialize features
        self.initialize_features()

        # metricnet
        self.metric_model = model_load(self.params.metric_model_path)
        # warmup start
        with torch.no_grad():
            tmp = np.random.rand(5, 3, 107, 107)
            tmp = torch.Tensor(tmp)
            tmp = (Variable(tmp)).type(torch.FloatTensor).cuda()
            tmp = self.metric_model(tmp)
            # warmup end
            self.target_metric_feature = get_target_feature(
                self.metric_model, np.array(state), np.array(image))
        pos_generator = SampleGenerator(
            'gaussian', np.array([image.shape[1], image.shape[0]]), 0.1, 1.3)
        gt_pos_examples = pos_generator(
            np.array(state).astype(np.int), 20, [0.7, 1])
        gt_iou = 0.7
        while gt_pos_examples.shape[0] == 0:
            gt_iou = gt_iou - 0.1
            gt_pos_examples = pos_generator(
                np.array(state).astype(np.int), 20, [gt_iou, 1])
        # print('gt-iou:', gt_iou)
        # self.gt_pos_features = get_anchor_feature(self.metric_model, np.array(image), gt_pos_examples).cpu().detach().numpy()
        with torch.no_grad():
            gt_pos_features0 = get_anchor_feature(self.metric_model,
                                                  np.array(image),
                                                  gt_pos_examples)
            gt_pos_features = gt_pos_features0.cpu().detach().numpy()
            target_metric_feature = self.target_metric_feature.repeat(
                gt_pos_features.shape[0], 1)
            pos_all = torch.norm(gt_pos_features0 - target_metric_feature,
                                 2,
                                 dim=1).view(-1)
            self.similar = pos_all.mean() * self.params.sim_rate
            print('similarThresh', self.similar)
        self.target_features_all = []
        self.target_features_all.append(self.target_metric_feature)
        self.clf = lof_fit(gt_pos_features, k=5)

        # Chack if image is color
        self.params.features.set_is_color(image.shape[2] == 3)

        # Get feature specific params
        self.fparams = self.params.features.get_fparams('feature_params')

        # Get position and size
        self.pos = torch.Tensor(
            [state[1] + (state[3] - 1) / 2, state[0] + (state[2] - 1) / 2])
        self.target_sz = torch.Tensor([state[3], state[2]])

        # Set search area
        self.target_scale = 1.0
        search_area = torch.prod(self.target_sz *
                                 self.params.search_area_scale).item()
        if search_area > self.params.max_image_sample_size:
            self.target_scale = math.sqrt(search_area /
                                          self.params.max_image_sample_size)
        elif search_area < self.params.min_image_sample_size:
            self.target_scale = math.sqrt(search_area /
                                          self.params.min_image_sample_size)

        # Target size in base scale
        self.base_target_sz = self.target_sz / self.target_scale

        # Use odd square search area and set sizes
        feat_max_stride = max(self.params.features.stride())
        self.img_sample_sz = torch.round(
            torch.sqrt(
                torch.prod(self.base_target_sz *
                           self.params.search_area_scale))) * torch.ones(2)
        self.img_sample_sz += feat_max_stride - self.img_sample_sz % (
            2 * feat_max_stride)

        # Set other sizes (corresponds to ECO code)
        self.img_support_sz = self.img_sample_sz
        self.feature_sz = self.params.features.size(self.img_sample_sz)
        self.filter_sz = self.feature_sz + (self.feature_sz + 1) % 2
        self.output_sz = self.params.score_upsample_factor * self.img_support_sz  # Interpolated size of the output
        self.compressed_dim = self.fparams.attribute('compressed_dim')

        # Number of filters
        self.num_filters = len(self.filter_sz)

        # Get window function
        self.window = TensorList(
            [dcf.hann2d(sz).to(self.params.device) for sz in self.feature_sz])

        # Get interpolation function
        self.interp_fs = TensorList([
            dcf.get_interp_fourier(sz, self.params.interpolation_method,
                                   self.params.interpolation_bicubic_a,
                                   self.params.interpolation_centering,
                                   self.params.interpolation_windowing,
                                   self.params.device) for sz in self.filter_sz
        ])

        # Get regularization filter
        self.reg_filter = TensorList([
            dcf.get_reg_filter(self.img_support_sz, self.base_target_sz,
                               fparams).to(self.params.device)
            for fparams in self.fparams
        ])
        self.reg_energy = self.reg_filter.view(-1) @ self.reg_filter.view(-1)

        # Get label function
        output_sigma_factor = self.fparams.attribute('output_sigma_factor')
        sigma = (self.filter_sz / self.img_support_sz) * torch.sqrt(
            self.base_target_sz.prod()) * output_sigma_factor
        self.yf = TensorList([
            dcf.label_function(sz, sig).to(self.params.device)
            for sz, sig in zip(self.filter_sz, sigma)
        ])

        # Optimization options
        self.params.precond_learning_rate = self.fparams.attribute(
            'learning_rate')
        if self.params.CG_forgetting_rate is None or max(
                self.params.precond_learning_rate) >= 1:
            self.params.direction_forget_factor = 0
        else:
            self.params.direction_forget_factor = (
                1 - max(self.params.precond_learning_rate)
            )**self.params.CG_forgetting_rate

        # Convert image
        im = numpy_to_torch(image)

        # Setup bounds
        self.image_sz = torch.Tensor([im.shape[2], im.shape[3]])
        self.min_scale_factor = torch.max(10 / self.base_target_sz)
        self.max_scale_factor = torch.min(self.image_sz / self.base_target_sz)

        # Extract and transform sample
        x = self.generate_init_samples(im)

        # Initialize projection matrix
        x_mat = TensorList(
            [e.permute(1, 0, 2, 3).reshape(e.shape[1], -1).clone() for e in x])
        x_mat -= x_mat.mean(dim=1, keepdim=True)
        cov_x = x_mat @ x_mat.t()
        self.projection_matrix = TensorList([
            torch.svd(C)[0][:, :cdim].clone()
            for C, cdim in zip(cov_x, self.compressed_dim)
        ])

        # Transform to get the training sample
        train_xf = self.preprocess_sample(x)

        # Shift the samples back
        if 'shift' in self.params.augmentation:
            for xf in train_xf:
                if xf.shape[0] == 1:
                    continue
                for i, shift in enumerate(self.params.augmentation['shift']):
                    shift_samp = 2 * math.pi * torch.Tensor(
                        shift) / self.img_support_sz
                    xf[1 + i:2 + i, ...] = fourier.shift_fs(xf[1 + i:2 + i,
                                                               ...],
                                                            shift=shift_samp)

        # Shift sample
        shift_samp = 2 * math.pi * (self.pos - self.pos.round()) / (
            self.target_scale * self.img_support_sz)
        train_xf = fourier.shift_fs(train_xf, shift=shift_samp)

        # Initialize first-frame training samples
        num_init_samples = train_xf.size(0)
        self.init_sample_weights = TensorList(
            [xf.new_ones(1) / xf.shape[0] for xf in train_xf])
        self.init_training_samples = train_xf.permute(2, 3, 0, 1, 4)

        # Sample counters and weights
        self.num_stored_samples = num_init_samples
        self.previous_replace_ind = [None] * len(self.num_stored_samples)
        self.sample_weights = TensorList(
            [xf.new_zeros(self.params.sample_memory_size) for xf in train_xf])
        for sw, init_sw, num in zip(self.sample_weights,
                                    self.init_sample_weights,
                                    num_init_samples):
            sw[:num] = init_sw

        # Initialize memory
        self.training_samples = TensorList([
            xf.new_zeros(xf.shape[2], xf.shape[3],
                         self.params.sample_memory_size, cdim, 2)
            for xf, cdim in zip(train_xf, self.compressed_dim)
        ])

        # Initialize filter
        self.filter = TensorList([
            xf.new_zeros(1, cdim, xf.shape[2], xf.shape[3], 2)
            for xf, cdim in zip(train_xf, self.compressed_dim)
        ])

        # Do joint optimization
        self.joint_problem = FactorizedConvProblem(self.init_training_samples,
                                                   self.yf, self.reg_filter,
                                                   self.projection_matrix,
                                                   self.params,
                                                   self.init_sample_weights)
        joint_var = self.filter.concat(self.projection_matrix)
        self.joint_optimizer = GaussNewtonCG(self.joint_problem,
                                             joint_var,
                                             debug=(self.params.debug >= 1),
                                             visdom=self.visdom)

        if self.params.update_projection_matrix:
            self.joint_optimizer.run(
                self.params.init_CG_iter // self.params.init_GN_iter,
                self.params.init_GN_iter)

        # Re-project samples with the new projection matrix
        compressed_samples = complex.mtimes(self.init_training_samples,
                                            self.projection_matrix)
        for train_samp, init_samp in zip(self.training_samples,
                                         compressed_samples):
            train_samp[:, :, :init_samp.shape[2], :, :] = init_samp

        # Initialize optimizer
        self.filter_optimizer = FilterOptim(self.params, self.reg_energy)
        self.filter_optimizer.register(self.filter, self.training_samples,
                                       self.yf, self.sample_weights,
                                       self.reg_filter)
        self.filter_optimizer.sample_energy = self.joint_problem.sample_energy
        self.filter_optimizer.residuals = self.joint_optimizer.residuals.clone(
        )

        if not self.params.update_projection_matrix:
            self.filter_optimizer.run(self.params.init_CG_iter)

        # Post optimization
        self.filter_optimizer.run(self.params.post_init_CG_iter)

        self.symmetrize_filter()

        # metricnet_lof
        self.current_target_metric_feature = []
        self.train_xf = []
        # self.iou=[]
        # self.lof_thresh=3.5

        self.lof_thresh = self.params.lof_rate

    def track(self, image) -> dict:

        self.debug_info = {}

        self.frame_num += 1
        self.debug_info['frame_num'] = self.frame_num

        # Convert image
        im = numpy_to_torch(image)

        # ------- LOCALIZATION ------- #

        # Get sample
        sample_pos = self.pos.round()
        sample_scales = self.target_scale * self.params.scale_factors
        test_xf = self.extract_fourier_sample(im, self.pos, sample_scales,
                                              self.img_sample_sz)

        # Compute scores
        sf = self.apply_filter(test_xf)
        translation_vec, scale_ind, s = self.localize_target(sf)
        scale_change_factor = self.params.scale_factors[scale_ind]

        # Update position and scale
        self.update_state(sample_pos + translation_vec,
                          self.target_scale * scale_change_factor)

        score_map = s[scale_ind, ...]
        max_score = torch.max(score_map).item()
        self.debug_info['max_score'] = max_score

        if self.visdom is not None:
            self.visdom.register(score_map, 'heatmap', 2, 'Score Map')
            self.visdom.register(self.debug_info, 'info_dict', 1, 'Status')
        elif self.params.debug >= 2:
            show_tensor(score_map,
                        5,
                        title='Max score = {:.2f}'.format(max_score))

        # if self.params.debug >= 3:
        #     for i, hf in enumerate(self.filter):
        #         show_tensor(fourier.sample_fs(hf).abs().mean(1), 6+i)

        # metric
        state_tmp = torch.cat(
            (self.pos[[1, 0]] - (self.target_sz[[1, 0]] - 1) / 2,
             self.target_sz[[1, 0]]))
        state_tmp = state_tmp.numpy()
        with torch.no_grad():
            self.current_target_metric_feature.append(
                get_target_feature(self.metric_model, state_tmp,
                                   np.array(image)).cpu().detach().numpy())
        # self.iou.append(overlap_ratio(state_tmp,self.ground_truth_rect[self.frame_num-1]))
        # success, target_dist = judge_success_no_class(self.metric_model, current_target_metric_feature,self.target_metric_feature, self.params)
        # lof_predict,success = lof(self.gt_pos_features, current_target_metric_feature.cpu().detach().numpy().reshape((1,1024)), k=5,thresh=5)
        # print(self.frame_num,':    lof:',lof_predict[0],'  ',success[0])
        # ------- UPDATE ------- #

        # Get train sample
        train_xf = TensorList(
            [xf[scale_ind:scale_ind + 1, ...] for xf in test_xf])

        # Shift the sample
        shift_samp = 2 * math.pi * (self.pos - sample_pos) / (
            sample_scales[scale_ind] * self.img_support_sz)
        train_xf = fourier.shift_fs(train_xf, shift=shift_samp)

        self.train_xf.append(train_xf)

        if self.frame_num == 1:
            # Update memory
            self.update_memory(train_xf)  # metricnet
            self.filter_optimizer.run(self.params.CG_iter, train_xf)
            self.symmetrize_filter()
        elif self.frame_num % self.params.train_skipping == 1:
            current_target_metric_feature = np.array(
                self.current_target_metric_feature).squeeze()
            current_target_metric_feature0 = torch.from_numpy(
                current_target_metric_feature).cuda()
            # lof_predict, success = lof(np.concatenate([self.gt_pos_features,current_target_metric_feature],axis=0), k=20,thresh=self.lof_thresh)
            lof_predict, success = lof(current_target_metric_feature,
                                       self.clf,
                                       k=5,
                                       thresh=self.lof_thresh)
            last_id = -1
            if self.frame_num <= self.params.train_skipping + 1:
                self.lof_thresh = lof_predict.mean() * self.params.lof_rate
                print('lof_thresh:', self.lof_thresh)
            for ii in range(len(self.train_xf)):
                # print('lof:',lof_predict[ii],'   iou:',self.iou[ii],success[ii])
                if self.frame_num > self.params.train_skipping + 1 and success[
                        ii]:
                    for kk in range(len(self.target_features_all) - 1, -1, -1):
                        dist = torch.norm(
                            self.target_features_all[kk] -
                            current_target_metric_feature0[ii].reshape(
                                [1, 1024]),
                            2,
                            dim=1).view(-1)
                        if dist < self.similar:
                            success[ii] = 0
                            continue
                if self.frame_num <= self.params.train_skipping + 1 or success[
                        ii]:
                    self.target_features_all.append(
                        current_target_metric_feature0[ii].reshape([1, 1024]))
                    last_id = ii
                    self.update_memory(self.train_xf[ii])
            if last_id > -1:
                self.filter_optimizer.run(self.params.CG_iter,
                                          self.train_xf[last_id])
                self.symmetrize_filter()
            self.current_target_metric_feature = []
            self.train_xf = []
            # self.iou=[]
        # # Train filter
        # if self.frame_num % self.params.train_skipping == 1:
        #     self.filter_optimizer.run(self.params.CG_iter, train_xf)
        #     self.symmetrize_filter()

        # Return new state
        new_state = torch.cat(
            (self.pos[[1, 0]] - (self.target_sz[[1, 0]] - 1) / 2,
             self.target_sz[[1, 0]]))

        out = {'target_bbox': new_state.tolist()}
        return out

    def apply_filter(self, sample_xf: TensorList) -> torch.Tensor:
        return complex.mult(self.filter, sample_xf).sum(1, keepdim=True)

    def localize_target(self, sf: TensorList):
        if self.params.score_fusion_strategy == 'sum':
            scores = fourier.sample_fs(fourier.sum_fs(sf), self.output_sz)
        elif self.params.score_fusion_strategy == 'weightedsum':
            weight = self.fparams.attribute('translation_weight')
            scores = fourier.sample_fs(fourier.sum_fs(weight * sf),
                                       self.output_sz)
        elif self.params.score_fusion_strategy == 'transcale':
            alpha = self.fparams.attribute('scale_weight')
            beta = self.fparams.attribute('translation_weight')
            sample_sz = torch.round(
                self.output_sz.view(1, -1) *
                self.params.scale_factors.view(-1, 1))
            scores = 0
            for sfe, a, b in zip(sf, alpha, beta):
                sfe = fourier.shift_fs(sfe, math.pi * torch.ones(2))
                scores_scales = []
                for sind, sz in enumerate(sample_sz):
                    pd = (self.output_sz - sz) / 2
                    scores_scales.append(
                        F.pad(fourier.sample_fs(sfe[sind:sind + 1, ...], sz),
                              (math.floor(pd[1].item()), math.ceil(
                                  pd[1].item()), math.floor(
                                      pd[0].item()), math.ceil(pd[0].item()))))
                scores_cat = torch.cat(scores_scales)
                scores = scores + (b - a) * scores_cat.mean(
                    dim=0, keepdim=True) + a * scores_cat
        else:
            raise ValueError('Unknown score fusion strategy.')

        # Get maximum
        max_score, max_disp = dcf.max2d(scores)
        _, scale_ind = torch.max(max_score, dim=0)
        max_disp = max_disp.float().cpu()

        # Convert to displacements in the base scale
        if self.params.score_fusion_strategy in ['sum', 'weightedsum']:
            disp = (max_disp +
                    self.output_sz / 2) % self.output_sz - self.output_sz / 2
        elif self.params.score_fusion_strategy == 'transcale':
            disp = max_disp - self.output_sz / 2

        # Compute translation vector and scale change factor
        translation_vec = disp[scale_ind, ...].view(-1) * (
            self.img_support_sz / self.output_sz) * self.target_scale
        if self.params.score_fusion_strategy in ['sum', 'weightedsum']:
            translation_vec *= self.params.scale_factors[scale_ind]

        return translation_vec, scale_ind, scores

    def extract_sample(self, im: torch.Tensor, pos: torch.Tensor, scales,
                       sz: torch.Tensor):
        return self.params.features.extract(im, pos, scales, sz)[0]

    def extract_fourier_sample(self, im: torch.Tensor, pos: torch.Tensor,
                               scales, sz: torch.Tensor) -> TensorList:
        x = self.extract_sample(im, pos, scales, sz)
        return self.preprocess_sample(self.project_sample(x))

    def preprocess_sample(self, x: TensorList) -> TensorList:
        x *= self.window
        sample_xf = fourier.cfft2(x)
        return TensorList([
            dcf.interpolate_dft(xf, bf)
            for xf, bf in zip(sample_xf, self.interp_fs)
        ])

    def project_sample(self, x: TensorList):
        @tensor_operation
        def _project_sample(x: torch.Tensor, P: torch.Tensor):
            if P is None:
                return x
            return torch.matmul(x.permute(2, 3, 0, 1), P).permute(2, 3, 0, 1)

        return _project_sample(x, self.projection_matrix)

    def generate_init_samples(self, im: torch.Tensor) -> TensorList:
        # Do data augmentation
        transforms = [augmentation.Identity()]
        if 'shift' in self.params.augmentation:
            transforms.extend([
                augmentation.Translation(shift)
                for shift in self.params.augmentation['shift']
            ])
        if 'fliplr' in self.params.augmentation and self.params.augmentation[
                'fliplr']:
            transforms.append(augmentation.FlipHorizontal())
        if 'rotate' in self.params.augmentation:
            transforms.extend([
                augmentation.Rotate(angle)
                for angle in self.params.augmentation['rotate']
            ])
        if 'blur' in self.params.augmentation:
            transforms.extend([
                augmentation.Blur(sigma)
                for sigma in self.params.augmentation['blur']
            ])

        init_samples = self.params.features.extract_transformed(
            im, self.pos, self.target_scale, self.img_sample_sz, transforms)

        # Remove augmented samples for those that shall not have
        for i, use_aug in enumerate(
                self.fparams.attribute('use_augmentation')):
            if not use_aug:
                init_samples[i] = init_samples[i][0:1, ...]

        if 'dropout' in self.params.augmentation:
            num, prob = self.params.augmentation['dropout']
            for i, use_aug in enumerate(
                    self.fparams.attribute('use_augmentation')):
                if use_aug:
                    init_samples[i] = torch.cat([
                        init_samples[i],
                        F.dropout2d(init_samples[i][0:1, ...].expand(
                            num, -1, -1, -1),
                                    p=prob,
                                    training=True)
                    ])

        return init_samples

    def update_memory(self, sample_xf: TensorList):

        # Update weights and get index to replace
        replace_ind = self.update_sample_weights()
        for train_samp, xf, ind in zip(self.training_samples, sample_xf,
                                       replace_ind):
            train_samp[:, :, ind:ind + 1, :, :] = xf.permute(2, 3, 0, 1, 4)

    def update_sample_weights(self):

        replace_ind = []
        for sw, prev_ind, num_samp, fparams in zip(self.sample_weights,
                                                   self.previous_replace_ind,
                                                   self.num_stored_samples,
                                                   self.fparams):
            if num_samp == 0 or fparams.learning_rate == 1:
                sw[:] = 0
                sw[0] = 1
                r_ind = 0
            else:
                # Get index to replace
                _, r_ind = torch.min(sw, 0)
                r_ind = r_ind.item()

                # Update weights
                if prev_ind is None:
                    sw /= 1 - fparams.learning_rate
                    sw[r_ind] = fparams.learning_rate
                else:
                    sw[r_ind] = sw[prev_ind] / (1 - fparams.learning_rate)

            sw /= sw.sum()
            replace_ind.append(r_ind)

        self.previous_replace_ind = replace_ind.copy()
        self.num_stored_samples += 1
        return replace_ind

    def update_state(self, new_pos, new_scale):
        # Update scale
        self.target_scale = new_scale.clamp(self.min_scale_factor,
                                            self.max_scale_factor)
        self.target_sz = self.base_target_sz * self.target_scale

        # Update pos
        inside_ratio = 0.2
        inside_offset = (inside_ratio - 0.5) * self.target_sz
        self.pos = torch.max(torch.min(new_pos, self.image_sz - inside_offset),
                             inside_offset)

    def symmetrize_filter(self):
        for hf in self.filter:
            hf[:, :, :, 0, :] /= 2
            hf[:, :, :, 0, :] += complex.conj(hf[:, :, :, 0, :].flip((2, )))
Пример #6
0
class ECO(BaseTracker):

    def initialize_features(self):
        if not getattr(self, 'features_initialized', False):
            self.params.features.initialize()
        self.features_initialized = True


    def initialize(self, image, state, *args, **kwargs):

        # Initialize some stuff
        self.frame_num = 1
        if not hasattr(self.params, 'device'):
            self.params.device = 'cuda' if self.params.use_gpu else 'cpu'

        # Initialize features
        self.initialize_features()

        # Chack if image is color
        self.params.features.set_is_color(image.shape[2] == 3)

        # Get feature specific params
        self.fparams = self.params.features.get_fparams('feature_params')

        # Get position and size
        self.pos = torch.Tensor([state[1] + (state[3] - 1)/2, state[0] + (state[2] - 1)/2])
        self.target_sz = torch.Tensor([state[3], state[2]])

        # Set search area
        self.target_scale = 1.0
        search_area = torch.prod(self.target_sz * self.params.search_area_scale).item()
        if search_area > self.params.max_image_sample_size:
            self.target_scale =  math.sqrt(search_area / self.params.max_image_sample_size)
        elif search_area < self.params.min_image_sample_size:
            self.target_scale =  math.sqrt(search_area / self.params.min_image_sample_size)

        # Target size in base scale
        self.base_target_sz = self.target_sz / self.target_scale

        # Use odd square search area and set sizes
        feat_max_stride = max(self.params.features.stride())
        self.img_sample_sz = torch.round(torch.sqrt(torch.prod(self.base_target_sz * self.params.search_area_scale))) * torch.ones(2)
        self.img_sample_sz += feat_max_stride - self.img_sample_sz % (2 * feat_max_stride)

        # Set other sizes (corresponds to ECO code)
        self.img_support_sz = self.img_sample_sz
        self.feature_sz = self.params.features.size(self.img_sample_sz)
        self.filter_sz = self.feature_sz + (self.feature_sz + 1) % 2
        self.output_sz = self.params.score_upsample_factor * self.img_support_sz    # Interpolated size of the output
        self.compressed_dim = self.fparams.attribute('compressed_dim')

        # Number of filters
        self.num_filters = len(self.filter_sz)

        # Get window function
        self.window = TensorList([dcf.hann2d(sz).to(self.params.device) for sz in self.feature_sz])

        # Get interpolation function
        self.interp_fs = TensorList([dcf.get_interp_fourier(sz, self.params.interpolation_method,
                                                self.params.interpolation_bicubic_a, self.params.interpolation_centering,
                                                self.params.interpolation_windowing, self.params.device) for sz in self.filter_sz])

        # Get regularization filter
        self.reg_filter = TensorList([dcf.get_reg_filter(self.img_support_sz, self.base_target_sz, fparams).to(self.params.device)
                                      for fparams in self.fparams])
        self.reg_energy = self.reg_filter.view(-1) @ self.reg_filter.view(-1)

        # Get label function
        output_sigma_factor = self.fparams.attribute('output_sigma_factor')
        sigma = (self.filter_sz / self.img_support_sz) * torch.sqrt(self.base_target_sz.prod()) * output_sigma_factor
        self.yf = TensorList([dcf.label_function(sz, sig).to(self.params.device) for sz, sig in zip(self.filter_sz, sigma)])

        # Optimization options
        self.params.precond_learning_rate = self.fparams.attribute('learning_rate')
        if self.params.CG_forgetting_rate is None or max(self.params.precond_learning_rate) >= 1:
            self.params.direction_forget_factor = 0
        else:
            self.params.direction_forget_factor = (1 - max(self.params.precond_learning_rate))**self.params.CG_forgetting_rate


        # Convert image
        im = numpy_to_torch(image)

        # Setup bounds
        self.image_sz = torch.Tensor([im.shape[2], im.shape[3]])
        self.min_scale_factor = torch.max(10 / self.base_target_sz)
        self.max_scale_factor = torch.min(self.image_sz / self.base_target_sz)

        # Extract and transform sample
        x = self.generate_init_samples(im)

        # Initialize projection matrix
        x_mat = TensorList([e.permute(1,0,2,3).reshape(e.shape[1], -1).clone() for e in x])
        x_mat -= x_mat.mean(dim=1, keepdim=True)
        cov_x = x_mat @ x_mat.t()
        self.projection_matrix = TensorList([torch.svd(C)[0][:,:cdim].clone() for C, cdim in zip(cov_x, self.compressed_dim)])

        # Transform to get the training sample
        train_xf = self.preprocess_sample(x)

        # Shift the samples back
        if 'shift' in self.params.augmentation:
            for xf in train_xf:
                if xf.shape[0] == 1:
                    continue
                for i, shift in enumerate(self.params.augmentation['shift']):
                    shift_samp = 2 * math.pi * torch.Tensor(shift) / self.img_support_sz
                    xf[1+i:2+i,...] = fourier.shift_fs(xf[1+i:2+i,...], shift=shift_samp)

        # Shift sample
        shift_samp = 2*math.pi * (self.pos - self.pos.round()) / (self.target_scale * self.img_support_sz)
        train_xf = fourier.shift_fs(train_xf, shift=shift_samp)

        # Initialize first-frame training samples
        num_init_samples = train_xf.size(0)
        self.init_sample_weights = TensorList([xf.new_ones(1) / xf.shape[0] for xf in train_xf])
        self.init_training_samples = train_xf.permute(2, 3, 0, 1, 4)


        # Sample counters and weights
        self.num_stored_samples = num_init_samples
        self.previous_replace_ind = [None]*len(self.num_stored_samples)
        self.sample_weights = TensorList([xf.new_zeros(self.params.sample_memory_size) for xf in train_xf])
        for sw, init_sw, num in zip(self.sample_weights, self.init_sample_weights, num_init_samples):
            sw[:num] = init_sw

        # Initialize memory
        self.training_samples = TensorList(
            [xf.new_zeros(xf.shape[2], xf.shape[3], self.params.sample_memory_size, cdim, 2) for xf, cdim in zip(train_xf, self.compressed_dim)])

        # Initialize filter
        self.filter = TensorList(
            [xf.new_zeros(1, cdim, xf.shape[2], xf.shape[3], 2) for xf, cdim in zip(train_xf, self.compressed_dim)])

        # Do joint optimization
        self.joint_problem = FactorizedConvProblem(self.init_training_samples, self.yf, self.reg_filter, self.projection_matrix, self.params, self.init_sample_weights)
        joint_var = self.filter.concat(self.projection_matrix)
        self.joint_optimizer = GaussNewtonCG(self.joint_problem, joint_var, debug=(self.params.debug>=3))

        if self.params.update_projection_matrix:
            self.joint_optimizer.run(self.params.init_CG_iter // self.params.init_GN_iter, self.params.init_GN_iter)

        # Re-project samples with the new projection matrix
        compressed_samples = complex.mtimes(self.init_training_samples, self.projection_matrix)
        for train_samp, init_samp in zip(self.training_samples, compressed_samples):
            train_samp[:,:,:init_samp.shape[2],:,:] = init_samp

        # Initialize optimizer
        self.filter_optimizer = FilterOptim(self.params, self.reg_energy)
        self.filter_optimizer.register(self.filter, self.training_samples, self.yf, self.sample_weights, self.reg_filter)
        self.filter_optimizer.sample_energy = self.joint_problem.sample_energy
        self.filter_optimizer.residuals = self.joint_optimizer.residuals.clone()

        if not self.params.update_projection_matrix:
            self.filter_optimizer.run(self.params.init_CG_iter)

        # Post optimization
        self.filter_optimizer.run(self.params.post_init_CG_iter)

        self.symmetrize_filter()



    def track(self, image):

        self.frame_num += 1

        # Convert image
        im = numpy_to_torch(image)

        # ------- LOCALIZATION ------- #

        # Get sample
        sample_pos = self.pos.round()
        sample_scales = self.target_scale * self.params.scale_factors
        test_xf = self.extract_fourier_sample(im, sample_pos, sample_scales, self.img_sample_sz)

        # Compute scores
        sf = self.apply_filter(test_xf)
        translation_vec, scale_ind, s = self.localize_target(sf)
        scale_change_factor = self.params.scale_factors[scale_ind]

        # Update position and scale
        self.update_state(sample_pos + translation_vec, self.target_scale * scale_change_factor)

        if self.params.debug >= 2:
            show_tensor(s[scale_ind,...], 5)
        if self.params.debug >= 3:
            for i, hf in enumerate(self.filter):
                show_tensor(fourier.sample_fs(hf).abs().mean(1), 6+i)


        # ------- UPDATE ------- #

        # Get train sample
        train_xf = TensorList([xf[scale_ind:scale_ind+1, ...] for xf in test_xf])

        # Shift the sample
        shift_samp = 2*math.pi * (self.pos - sample_pos) / (sample_scales[scale_ind] * self.img_support_sz)
        train_xf = fourier.shift_fs(train_xf, shift=shift_samp)

        # Update memory
        self.update_memory(train_xf)

        # Train filter
        if self.frame_num % self.params.train_skipping == 1:
            self.filter_optimizer.run(self.params.CG_iter, train_xf)
            self.symmetrize_filter()

        # Return new state
        new_state = torch.cat((self.pos[[1,0]] - (self.target_sz[[1,0]]-1)/2, self.target_sz[[1,0]]))

        return new_state.tolist()


    def apply_filter(self, sample_xf: TensorList) -> torch.Tensor:
        return complex.mult(self.filter, sample_xf).sum(1, keepdim=True)

    def localize_target(self, sf: TensorList):
        if self.params.score_fusion_strategy == 'sum':
            scores = fourier.sample_fs(fourier.sum_fs(sf), self.output_sz)
        elif self.params.score_fusion_strategy == 'weightedsum':
            weight = self.fparams.attribute('translation_weight')
            scores = fourier.sample_fs(fourier.sum_fs(weight * sf), self.output_sz)
        elif self.params.score_fusion_strategy == 'transcale':
            alpha = self.fparams.attribute('scale_weight')
            beta = self.fparams.attribute('translation_weight')
            sample_sz = torch.round(self.output_sz.view(1,-1) * self.params.scale_factors.view(-1,1))
            scores = 0
            for sfe, a, b in zip(sf, alpha, beta):
                sfe = fourier.shift_fs(sfe, math.pi*torch.ones(2))
                scores_scales = []
                for sind, sz in enumerate(sample_sz):
                    pd = (self.output_sz-sz)/2
                    scores_scales.append(F.pad(fourier.sample_fs(sfe[sind:sind+1,...], sz),
                                        (math.floor(pd[1].item()), math.ceil(pd[1].item()),
                                         math.floor(pd[0].item()), math.ceil(pd[0].item()))))
                scores_cat = torch.cat(scores_scales)
                scores = scores + (b - a) * scores_cat.mean(dim=0, keepdim=True) + a * scores_cat
        else:
            raise ValueError('Unknown score fusion strategy.')

        # Get maximum
        max_score, max_disp = dcf.max2d(scores)
        _, scale_ind = torch.max(max_score, dim=0)
        max_disp = max_disp.float().cpu()

        # Convert to displacements in the base scale
        if self.params.score_fusion_strategy in ['sum', 'weightedsum']:
            disp = (max_disp + self.output_sz / 2) % self.output_sz - self.output_sz / 2
        elif self.params.score_fusion_strategy == 'transcale':
            disp = max_disp - self.output_sz / 2

        # Compute translation vector and scale change factor
        translation_vec = disp[scale_ind, ...].view(-1) * (self.img_support_sz / self.output_sz) * self.target_scale
        if self.params.score_fusion_strategy in ['sum', 'weightedsum']:
            translation_vec *= self.params.scale_factors[scale_ind]

        return translation_vec, scale_ind, scores


    def extract_sample(self, im: torch.Tensor, pos: torch.Tensor, scales, sz: torch.Tensor):
        return self.params.features.extract(im, pos, scales, sz)

    def extract_fourier_sample(self, im: torch.Tensor, pos: torch.Tensor, scales, sz: torch.Tensor) -> TensorList:
        x = self.extract_sample(im, pos, scales, sz)
        return self.preprocess_sample(self.project_sample(x))

    def preprocess_sample(self, x: TensorList) -> TensorList:
        x *= self.window
        sample_xf = fourier.cfft2(x)
        return TensorList([dcf.interpolate_dft(xf, bf) for xf, bf in zip(sample_xf, self.interp_fs)])

    def project_sample(self, x: TensorList):
        @tensor_operation
        def _project_sample(x: torch.Tensor, P: torch.Tensor):
            if P is None:
                return x
            return torch.matmul(x.permute(2, 3, 0, 1), P).permute(2, 3, 0, 1)

        return _project_sample(x, self.projection_matrix)

    def generate_init_samples(self, im: torch.Tensor) -> TensorList:
        # Do data augmentation
        transforms = [augmentation.Identity()]
        if 'shift' in self.params.augmentation:
            transforms.extend([augmentation.Translation(shift) for shift in self.params.augmentation['shift']])
        if 'fliplr' in self.params.augmentation and self.params.augmentation['fliplr']:
            transforms.append(augmentation.FlipHorizontal())
        if 'rotate' in self.params.augmentation:
            transforms.extend([augmentation.Rotate(angle) for angle in self.params.augmentation['rotate']])
        if 'blur' in self.params.augmentation:
            transforms.extend([augmentation.Blur(sigma) for sigma in self.params.augmentation['blur']])

        init_samples = self.params.features.extract_transformed(im, self.pos.round(), self.target_scale, self.img_sample_sz, transforms)

        # Remove augmented samples for those that shall not have
        for i, use_aug in enumerate(self.fparams.attribute('use_augmentation')):
            if not use_aug:
                init_samples[i] = init_samples[i][0:1, ...]

        if 'dropout' in self.params.augmentation:
            num, prob = self.params.augmentation['dropout']
            for i, use_aug in enumerate(self.fparams.attribute('use_augmentation')):
                if use_aug:
                    init_samples[i] = torch.cat([init_samples[i], F.dropout2d(init_samples[i][0:1,...].expand(num,-1,-1,-1), p=prob, training=True)])

        return init_samples


    def update_memory(self, sample_xf: TensorList):
        # Update weights and get index to replace
        replace_ind = self.update_sample_weights()
        for train_samp, xf, ind in zip(self.training_samples, sample_xf, replace_ind):
            train_samp[:,:,ind:ind+1,:,:] = xf.permute(2, 3, 0, 1, 4)


    def update_sample_weights(self):
        replace_ind = []
        for sw, prev_ind, num_samp, fparams in zip(self.sample_weights, self.previous_replace_ind, self.num_stored_samples, self.fparams):
            if num_samp == 0 or fparams.learning_rate == 1:
                sw[:] = 0
                sw[0] = 1
                r_ind = 0
            else:
                # Get index to replace
                _, r_ind = torch.min(sw, 0)
                r_ind = r_ind.item()

                # Update weights
                if prev_ind is None:
                    sw /= 1 - fparams.learning_rate
                    sw[r_ind] = fparams.learning_rate
                else:
                    sw[r_ind] = sw[prev_ind] / (1 - fparams.learning_rate)

            sw /= sw.sum()
            replace_ind.append(r_ind)

        self.previous_replace_ind = replace_ind.copy()
        self.num_stored_samples += 1
        return replace_ind

    def update_state(self, new_pos, new_scale):
        # Update scale
        self.target_scale = new_scale.clamp(self.min_scale_factor, self.max_scale_factor)
        self.target_sz = self.base_target_sz * self.target_scale

        # Update pos
        inside_ratio = 0.2
        inside_offset = (inside_ratio - 0.5) * self.target_sz
        self.pos = torch.max(torch.min(new_pos, self.image_sz - inside_offset), inside_offset)

    def symmetrize_filter(self):
        for hf in self.filter:
            hf[:,:,:,0,:] /= 2
            hf[:,:,:,0,:] += complex.conj(hf[:,:,:,0,:].flip((2,)))
Пример #7
0
class ECO(BaseTracker):
    def initialize_features(self):
        if not getattr(self, 'features_initialized', False):
            self.params.features.initialize()
        self.features_initialized = True

    def initialize(self, image, state, *args, **kwargs):

        # Initialize some stuff
        self.frame_num = 1
        if not hasattr(self.params, 'device'):
            self.params.device = 'cuda' if self.params.use_gpu else 'cpu'

        # Initialize features
        self.initialize_features()

        # Chack if image is color
        self.params.features.set_is_color(image.shape[2] == 3)

        # Get feature specific params
        self.fparams = self.params.features.get_fparams('feature_params')

        # Get position and size
        self.pos = torch.Tensor(
            [state[1] + (state[3] - 1) / 2, state[0] + (state[2] - 1) / 2])
        self.target_sz = torch.Tensor([state[3], state[2]])

        # Set search area
        self.target_scale = 1.0
        search_area = torch.prod(self.target_sz *
                                 self.params.search_area_scale).item()
        if search_area > self.params.max_image_sample_size:
            self.target_scale = math.sqrt(search_area /
                                          self.params.max_image_sample_size)
        elif search_area < self.params.min_image_sample_size:
            self.target_scale = math.sqrt(search_area /
                                          self.params.min_image_sample_size)

        # Target size in base scale
        self.base_target_sz = self.target_sz / self.target_scale
        self.use_iou_net = True
        # Use odd square search area and set sizes
        feat_max_stride = max(self.params.features.stride())
        self.img_sample_sz = torch.round(
            torch.sqrt(
                torch.prod(self.base_target_sz *
                           self.params.search_area_scale))) * torch.ones(2)
        self.img_sample_sz += feat_max_stride - self.img_sample_sz % (
            2 * feat_max_stride)

        # Set other sizes (corresponds to ECO code)
        self.img_support_sz = self.img_sample_sz
        self.feature_sz = self.params.features.size(self.img_sample_sz)
        self.filter_sz = self.feature_sz + (self.feature_sz + 1) % 2
        self.output_sz = self.params.score_upsample_factor * self.img_support_sz  # Interpolated size of the output
        self.compressed_dim = self.fparams.attribute('compressed_dim')

        # Number of filters
        self.num_filters = len(self.filter_sz)

        # Get window function
        self.window = TensorList(
            [dcf.hann2d(sz).to(self.params.device) for sz in self.feature_sz])

        # Get interpolation function
        self.interp_fs = TensorList([
            dcf.get_interp_fourier(sz, self.params.interpolation_method,
                                   self.params.interpolation_bicubic_a,
                                   self.params.interpolation_centering,
                                   self.params.interpolation_windowing,
                                   self.params.device) for sz in self.filter_sz
        ])

        # Get regularization filter
        self.reg_filter = TensorList([
            dcf.get_reg_filter(self.img_support_sz, self.base_target_sz,
                               fparams).to(self.params.device)
            for fparams in self.fparams
        ])
        self.reg_energy = self.reg_filter.view(-1) @ self.reg_filter.view(-1)

        # Get label function
        output_sigma_factor = self.fparams.attribute('output_sigma_factor')
        sigma = (self.filter_sz / self.img_support_sz) * torch.sqrt(
            self.base_target_sz.prod()) * output_sigma_factor
        self.yf = TensorList([
            dcf.label_function(sz, sig).to(self.params.device)
            for sz, sig in zip(self.filter_sz, sigma)
        ])

        # Optimization options
        self.params.precond_learning_rate = self.fparams.attribute(
            'learning_rate')
        if self.params.CG_forgetting_rate is None or max(
                self.params.precond_learning_rate) >= 1:
            self.params.direction_forget_factor = 0
        else:
            self.params.direction_forget_factor = (
                1 - max(self.params.precond_learning_rate)
            )**self.params.CG_forgetting_rate

        # Convert image
        im = numpy_to_torch(image)

        # Setup bounds
        self.image_sz = torch.Tensor([im.shape[2], im.shape[3]])
        self.min_scale_factor = torch.max(10 / self.base_target_sz)
        self.max_scale_factor = torch.min(self.image_sz / self.base_target_sz)

        # Extract and transform sample
        x = self.generate_init_samples(im)

        # Initialize projection matrix
        x_mat = TensorList(
            [e.permute(1, 0, 2, 3).reshape(e.shape[1], -1).clone() for e in x])
        x_mat -= x_mat.mean(dim=1, keepdim=True)
        cov_x = x_mat @ x_mat.t()
        self.projection_matrix = TensorList([
            torch.svd(C)[0][:, :cdim].clone()
            for C, cdim in zip(cov_x, self.compressed_dim)
        ])

        # Transform to get the training sample
        train_xf = self.preprocess_sample(x)

        # Shift the samples back
        if 'shift' in self.params.augmentation:
            for xf in train_xf:
                if xf.shape[0] == 1:
                    continue
                for i, shift in enumerate(self.params.augmentation['shift']):
                    shift_samp = 2 * math.pi * torch.Tensor(
                        shift) / self.img_support_sz
                    xf[1 + i:2 + i, ...] = fourier.shift_fs(xf[1 + i:2 + i,
                                                               ...],
                                                            shift=shift_samp)

        # Shift sample
        shift_samp = 2 * math.pi * (self.pos - self.pos.round()) / (
            self.target_scale * self.img_support_sz)
        train_xf = fourier.shift_fs(train_xf, shift=shift_samp)

        # Initialize first-frame training samples
        num_init_samples = train_xf.size(0)
        self.init_sample_weights = TensorList(
            [xf.new_ones(1) / xf.shape[0] for xf in train_xf])
        self.init_training_samples = train_xf.permute(2, 3, 0, 1, 4)

        ## 初始化回归分支
        self.iou_img_sample_sz = self.img_sample_sz
        self.init_dr_net()

        # Sample counters and weights
        self.num_stored_samples = num_init_samples
        self.previous_replace_ind = [None] * len(self.num_stored_samples)
        self.sample_weights = TensorList(
            [xf.new_zeros(self.params.sample_memory_size) for xf in train_xf])
        for sw, init_sw, num in zip(self.sample_weights,
                                    self.init_sample_weights,
                                    num_init_samples):
            sw[:num] = init_sw

        # Initialize memory
        self.training_samples = TensorList([
            xf.new_zeros(xf.shape[2], xf.shape[3],
                         self.params.sample_memory_size, cdim, 2)
            for xf, cdim in zip(train_xf, self.compressed_dim)
        ])

        # Initialize filter
        self.filter = TensorList([
            xf.new_zeros(1, cdim, xf.shape[2], xf.shape[3], 2)
            for xf, cdim in zip(train_xf, self.compressed_dim)
        ])

        # Do joint optimization
        self.joint_problem = FactorizedConvProblem(self.init_training_samples,
                                                   self.yf, self.reg_filter,
                                                   self.projection_matrix,
                                                   self.params,
                                                   self.init_sample_weights)
        joint_var = self.filter.concat(self.projection_matrix)
        self.joint_optimizer = GaussNewtonCG(self.joint_problem,
                                             joint_var,
                                             debug=(self.params.debug >= 3))

        if self.params.update_projection_matrix:
            self.joint_optimizer.run(
                self.params.init_CG_iter // self.params.init_GN_iter,
                self.params.init_GN_iter)

        # Re-project samples with the new projection matrix
        compressed_samples = complex.mtimes(self.init_training_samples,
                                            self.projection_matrix)
        for train_samp, init_samp in zip(self.training_samples,
                                         compressed_samples):
            train_samp[:, :, :init_samp.shape[2], :, :] = init_samp

        # Initialize optimizer
        self.filter_optimizer = FilterOptim(self.params, self.reg_energy)
        self.filter_optimizer.register(self.filter, self.training_samples,
                                       self.yf, self.sample_weights,
                                       self.reg_filter)
        self.filter_optimizer.sample_energy = self.joint_problem.sample_energy
        self.filter_optimizer.residuals = self.joint_optimizer.residuals.clone(
        )

        if not self.params.update_projection_matrix:
            self.filter_optimizer.run(self.params.init_CG_iter)

        # Post optimization
        self.filter_optimizer.run(self.params.post_init_CG_iter)

        self.symmetrize_filter()

    def get_iou_features(self):
        return self.params.features.get_unique_attribute('iounet_features')

    def get_iou_backbone_features(self):
        return self.params.features.get_unique_attribute(
            'iounet_backbone_features')

    def init_dr_net(self):
        # Setup IoU net
        self.iou_predictor = self.params.features.get_unique_attribute(
            'iou_predictor')
        for p in self.iou_predictor.parameters():
            p.requires_grad = False

        # Get target boxes for the different augmentations
        self.iou_target_box = self.get_iounet_box(self.pos, self.target_sz,
                                                  self.pos.round(),
                                                  self.target_scale)
        target_boxes = TensorList()
        self.params.iounet_augmentation = False
        if self.params.iounet_augmentation:
            for T in self.transforms:
                if not isinstance(
                        T, (augmentation.Identity, augmentation.Translation,
                            augmentation.FlipHorizontal,
                            augmentation.FlipVertical, augmentation.Blur)):
                    break
                target_boxes.append(
                    self.iou_target_box +
                    torch.Tensor([T.shift[1], T.shift[0], 0, 0]))
        else:
            target_boxes.append(self.iou_target_box.clone())
        target_boxes = torch.cat(target_boxes.view(1, 4),
                                 0).to(self.params.device)

        # Get iou features
        iou_backbone_features = self.get_iou_backbone_features()

        # Remove other augmentations such as rotation
        iou_backbone_features = TensorList(
            [x[:target_boxes.shape[0], ...] for x in iou_backbone_features])

        # Extract target feat
        with torch.no_grad():
            target_feat = self.iou_predictor.get_filter(
                iou_backbone_features, target_boxes)
        self.target_feat = TensorList(
            [x.detach().mean(0) for x in target_feat])

        if getattr(self.params, 'iounet_not_use_reference', False):
            self.target_feat = TensorList([
                torch.full_like(tf,
                                tf.norm() / tf.numel())
                for tf in self.target_feat
            ])

    def offset2box(self, init_box, offset):
        ctr_x = init_box[:, 0] + 0.5 * init_box[:, 2]
        ctr_y = init_box[:, 1] + 0.5 * init_box[:, 3]
        widths = init_box[:, 2]
        heights = init_box[:, 3]
        # ctr_x, ctr_y, widths, heights = init_box#xyxy2xywh(init_box)
        # print(ctr_x, ctr_y, widths, heights)

        wx, wy, ww, wh = 1, 1, 1, 1
        dx = offset[:, 0::4] / wx
        dy = offset[:, 1::4] / wy
        dw = offset[:, 2::4] / ww
        dh = offset[:, 3::4] / wh

        # Prevent sending too large values into np.exp()
        dw = torch.clamp(dw, max=np.log(1000. / 16.))
        dh = torch.clamp(dh, max=np.log(1000. / 16.))

        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
        pred_w = torch.exp(dw) * widths[:, None]
        pred_h = torch.exp(dh) * heights[:, None]

        pred_boxes = offset.new_zeros(offset.shape)
        # # x1
        # pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
        # # y1
        # pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
        # # x2
        # pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1
        # # y2
        # pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1
        pred_boxes[:, 0::4] = pred_ctr_x
        # y1
        pred_boxes[:, 1::4] = pred_ctr_y
        # x2
        pred_boxes[:, 2::4] = pred_w
        # y2
        pred_boxes[:, 3::4] = pred_h
        return pred_boxes

    def get_iounet_box(self, pos, sz, sample_pos, sample_scale):
        """All inputs in original image coordinates"""
        # print(self.iou_img_sample_sz,sample_scale)
        box_center = (pos - sample_pos) / sample_scale + (
            self.iou_img_sample_sz - 1) / 2
        box_sz = sz / sample_scale
        target_ul = box_center - (box_sz - 1) / 2
        # print(target_ul,box_sz)
        return torch.cat([target_ul.flip((0, )), box_sz.flip((0, ))])

    def predict_target_box(self,
                           sample_pos,
                           sample_scale,
                           scale_ind,
                           update_scale=True):
        # print(self.pos,sample_pos,self.target_sz)
        init_box = self.get_iounet_box(self.pos, self.target_sz, sample_pos,
                                       sample_scale)
        init_box = init_box.unsqueeze(0)
        init_box = init_box.unsqueeze(0)
        init_box = init_box.cuda()
        # print(init_box.shape)
        iou_features = self.get_iou_features()
        iou_features = TensorList(
            [x[scale_ind:scale_ind + 1, ...] for x in iou_features])
        #预测回归值
        reg = self.iou_predictor.predict_box(self.target_feat, iou_features,
                                             init_box)
        # print('reg',reg)
        init_box = init_box.view(-1, 4)
        reg = reg.view(-1, 4)

        predicted_box = self.offset2box(init_box, reg)

        # print(predicted_box.shape)
        predicted_box = predicted_box[0, :].cpu()
        # print(predicted_box.shape,self.iou_img_sample_sz.shape)
        new_pos = predicted_box[:2] - (self.iou_img_sample_sz - 1) / 2
        new_pos = new_pos.flip((0, )) * sample_scale + sample_pos
        new_target_sz = predicted_box[2:].flip((0, )) * sample_scale
        new_scale = torch.sqrt(new_target_sz.prod() /
                               self.base_target_sz.prod())

        # Update position
        # new_pos = predicted_box[:2] + predicted_box[2:]/2 - (self.iou_img_sample_sz - 1) / 2
        # new_pos = new_pos.flip((0,)) * sample_scale + sample_pos
        # new_target_sz = predicted_box[2:].flip((0,)) * sample_scale
        # new_scale = torch.sqrt(new_target_sz.prod() / self.base_target_sz.prod())

        self.pos_drnet = new_pos.clone()
        # print('pos',self.pos,new_pos)

        self.pos = new_pos.clone()
        # print('target_sz',self.target_sz,new_target_sz)
        self.target_sz = new_target_sz

        self.target_scale = new_scale

    def track(self, image):

        self.frame_num += 1

        # Convert image
        im = numpy_to_torch(image)

        # ------- LOCALIZATION ------- #

        # Get sample
        sample_pos = self.pos.round()
        sample_scales = self.target_scale * self.params.scale_factors
        test_xf = self.extract_fourier_sample(im, self.pos, sample_scales,
                                              self.img_sample_sz)

        # Compute scores
        sf = self.apply_filter(test_xf)
        translation_vec, scale_ind, s = self.localize_target(sf)

        scale_change_factor = self.params.scale_factors[scale_ind]

        # Update position and scale
        self.update_state(sample_pos + translation_vec,
                          self.target_scale * scale_change_factor)
        self.predict_target_box(sample_pos, sample_scales[scale_ind],
                                scale_ind)

        if self.params.debug >= 2:
            show_tensor(s[scale_ind, ...], 5)
        if self.params.debug >= 3:
            for i, hf in enumerate(self.filter):
                show_tensor(fourier.sample_fs(hf).abs().mean(1), 6 + i)

        # ------- UPDATE ------- #

        # Get train sample
        train_xf = TensorList(
            [xf[scale_ind:scale_ind + 1, ...] for xf in test_xf])

        # Shift the sample
        shift_samp = 2 * math.pi * (self.pos - sample_pos) / (
            sample_scales[scale_ind] * self.img_support_sz)
        train_xf = fourier.shift_fs(train_xf, shift=shift_samp)

        # Update memory
        self.update_memory(train_xf)

        # Train filter
        if self.frame_num % self.params.train_skipping == 1:
            self.filter_optimizer.run(self.params.CG_iter, train_xf)
            self.symmetrize_filter()

        # Return new state
        new_state = torch.cat(
            (self.pos[[1, 0]] - (self.target_sz[[1, 0]] - 1) / 2,
             self.target_sz[[1, 0]]))

        return new_state.tolist()

    def apply_filter(self, sample_xf: TensorList) -> torch.Tensor:
        return complex.mult(self.filter, sample_xf).sum(1, keepdim=True)

    def localize_target(self, sf: TensorList):
        if self.params.score_fusion_strategy == 'sum':
            scores = fourier.sample_fs(fourier.sum_fs(sf), self.output_sz)
        elif self.params.score_fusion_strategy == 'weightedsum':
            weight = self.fparams.attribute('translation_weight')
            scores = fourier.sample_fs(fourier.sum_fs(weight * sf),
                                       self.output_sz)
        elif self.params.score_fusion_strategy == 'transcale':
            alpha = self.fparams.attribute('scale_weight')
            beta = self.fparams.attribute('translation_weight')
            sample_sz = torch.round(
                self.output_sz.view(1, -1) *
                self.params.scale_factors.view(-1, 1))
            scores = 0
            for sfe, a, b in zip(sf, alpha, beta):
                sfe = fourier.shift_fs(sfe, math.pi * torch.ones(2))
                scores_scales = []
                for sind, sz in enumerate(sample_sz):
                    pd = (self.output_sz - sz) / 2
                    scores_scales.append(
                        F.pad(fourier.sample_fs(sfe[sind:sind + 1, ...], sz),
                              (math.floor(pd[1].item()), math.ceil(
                                  pd[1].item()), math.floor(
                                      pd[0].item()), math.ceil(pd[0].item()))))
                scores_cat = torch.cat(scores_scales)
                scores = scores + (b - a) * scores_cat.mean(
                    dim=0, keepdim=True) + a * scores_cat
        else:
            raise ValueError('Unknown score fusion strategy.')

        # Get maximum
        max_score, max_disp = dcf.max2d(scores)
        _, scale_ind = torch.max(max_score, dim=0)
        max_disp = max_disp.float().cpu()

        # Convert to displacements in the base scale
        if self.params.score_fusion_strategy in ['sum', 'weightedsum']:
            disp = (max_disp +
                    self.output_sz / 2) % self.output_sz - self.output_sz / 2
        elif self.params.score_fusion_strategy == 'transcale':
            disp = max_disp - self.output_sz / 2

        # Compute translation vector and scale change factor
        translation_vec = disp[scale_ind, ...].view(-1) * (
            self.img_support_sz / self.output_sz) * self.target_scale
        if self.params.score_fusion_strategy in ['sum', 'weightedsum']:
            translation_vec *= self.params.scale_factors[scale_ind]

        return translation_vec, scale_ind, scores

    def extract_sample(self, im: torch.Tensor, pos: torch.Tensor, scales,
                       sz: torch.Tensor):
        return self.params.features.extract(im, pos, scales, sz)

    def extract_fourier_sample(self, im: torch.Tensor, pos: torch.Tensor,
                               scales, sz: torch.Tensor) -> TensorList:
        x = self.extract_sample(im, pos, scales, sz)
        return self.preprocess_sample(self.project_sample(x))

    def preprocess_sample(self, x: TensorList) -> TensorList:
        x *= self.window
        sample_xf = fourier.cfft2(x)
        return TensorList([
            dcf.interpolate_dft(xf, bf)
            for xf, bf in zip(sample_xf, self.interp_fs)
        ])

    def project_sample(self, x: TensorList):
        @tensor_operation
        def _project_sample(x: torch.Tensor, P: torch.Tensor):
            if P is None:
                return x
            return torch.matmul(x.permute(2, 3, 0, 1), P).permute(2, 3, 0, 1)

        return _project_sample(x, self.projection_matrix)

    def generate_init_samples(self, im: torch.Tensor) -> TensorList:
        # Do data augmentation
        transforms = [augmentation.Identity()]
        if 'shift' in self.params.augmentation:
            transforms.extend([
                augmentation.Translation(shift)
                for shift in self.params.augmentation['shift']
            ])
        if 'fliplr' in self.params.augmentation and self.params.augmentation[
                'fliplr']:
            transforms.append(augmentation.FlipHorizontal())
        if 'rotate' in self.params.augmentation:
            transforms.extend([
                augmentation.Rotate(angle)
                for angle in self.params.augmentation['rotate']
            ])
        if 'blur' in self.params.augmentation:
            transforms.extend([
                augmentation.Blur(sigma)
                for sigma in self.params.augmentation['blur']
            ])

        init_samples = self.params.features.extract_transformed(
            im, self.pos, self.target_scale, self.img_sample_sz, transforms)

        # Remove augmented samples for those that shall not have
        for i, use_aug in enumerate(
                self.fparams.attribute('use_augmentation')):
            if not use_aug:
                init_samples[i] = init_samples[i][0:1, ...]

        if 'dropout' in self.params.augmentation:
            num, prob = self.params.augmentation['dropout']
            for i, use_aug in enumerate(
                    self.fparams.attribute('use_augmentation')):
                if use_aug:
                    init_samples[i] = torch.cat([
                        init_samples[i],
                        F.dropout2d(init_samples[i][0:1, ...].expand(
                            num, -1, -1, -1),
                                    p=prob,
                                    training=True)
                    ])

        return init_samples

    def update_memory(self, sample_xf: TensorList):
        # Update weights and get index to replace
        replace_ind = self.update_sample_weights()
        for train_samp, xf, ind in zip(self.training_samples, sample_xf,
                                       replace_ind):
            train_samp[:, :, ind:ind + 1, :, :] = xf.permute(2, 3, 0, 1, 4)

    def update_sample_weights(self):
        replace_ind = []
        for sw, prev_ind, num_samp, fparams in zip(self.sample_weights,
                                                   self.previous_replace_ind,
                                                   self.num_stored_samples,
                                                   self.fparams):
            if num_samp == 0 or fparams.learning_rate == 1:
                sw[:] = 0
                sw[0] = 1
                r_ind = 0
            else:
                # Get index to replace
                _, r_ind = torch.min(sw, 0)
                r_ind = r_ind.item()

                # Update weights
                if prev_ind is None:
                    sw /= 1 - fparams.learning_rate
                    sw[r_ind] = fparams.learning_rate
                else:
                    sw[r_ind] = sw[prev_ind] / (1 - fparams.learning_rate)

            sw /= sw.sum()
            replace_ind.append(r_ind)

        self.previous_replace_ind = replace_ind.copy()
        self.num_stored_samples += 1
        return replace_ind

    def update_state(self, new_pos, new_scale):
        # Update scale
        self.target_scale = new_scale.clamp(self.min_scale_factor,
                                            self.max_scale_factor)
        self.target_sz = self.base_target_sz * self.target_scale

        # Update pos
        inside_ratio = 0.2
        inside_offset = (inside_ratio - 0.5) * self.target_sz
        self.pos = torch.max(torch.min(new_pos, self.image_sz - inside_offset),
                             inside_offset)

    def symmetrize_filter(self):
        for hf in self.filter:
            hf[:, :, :, 0, :] /= 2
            hf[:, :, :, 0, :] += complex.conj(hf[:, :, :, 0, :].flip((2, )))
Пример #8
0
class CCOT(BaseTracker):
    def initialize_features(self, im):
        if not getattr(self, 'features_initialized', False):
            self.params.features.initialize(im)
        self.features_initialized = True

    def initialize(self, image, info: dict, gpu_device) -> dict:
        # Initialize some stuff
        self.frame_num = 1
        self.params.device = 'cuda:{0}'.format(
            gpu_device) if self.params.use_gpu else 'cpu'

        # Convert image
        im = numpy_to_torch(image)
        self.image_sz = torch.Tensor([im.shape[2], im.shape[3]])

        # Initialize features
        self.initialize_features(im)

        # Chack if image is color
        self.params.features.set_is_color(image.shape[2] == 3)

        # Get feature specific params
        self.fparams = self.params.features.get_fparams('feature_params')

        # Get position and size
        self.points = TensorList(
            [torch.Tensor([p[0], p[1]]) for p in info['points']])
        self.org_points = self.points.clone()
        self.target_sz = torch.Tensor(
            [info['target_sz'][0], info['target_sz'][1]])

        # Use odd square search area and set sizes
        feat_max_stride = max(self.params.features.stride())
        self.img_sample_sz = self.image_sz.clone()
        self.img_sample_sz += feat_max_stride - self.img_sample_sz % (
            2 * feat_max_stride)

        # Set other sizes (corresponds to ECO code)
        self.img_support_sz = self.img_sample_sz
        self.mid_point = self.img_support_sz // 2
        self.feature_sz = self.params.features.size(self.img_sample_sz)
        self.filter_sz = self.feature_sz + (self.feature_sz + 1) % 2
        self.output_sz = self.img_support_sz  # Interpolated size of the output

        # Number of filters
        self.num_filters = len(self.filter_sz)

        # Get window function
        #self.window = TensorList([dcf.hann2d(sz).to(self.params.device) for sz in self.feature_sz])
        self.window = TensorList([
            torch.ones((1, 1, int(sz[0].item()),
                        int(sz[1].item()))).to(self.params.device)
            for sz in self.feature_sz
        ])
        #self.window = TensorList([dcf.tukey2d(sz).to(self.params.device) for sz in self.feature_sz])

        # Get interpolation function
        self.interp_fs = TensorList([
            dcf.get_interp_fourier(sz, self.params.interpolation_method,
                                   self.params.interpolation_bicubic_a,
                                   self.params.interpolation_centering,
                                   self.params.interpolation_windowing,
                                   self.params.device) for sz in self.filter_sz
        ])

        # Get label function
        output_sigma_factor = self.fparams.attribute('output_sigma_factor')
        sigma = (self.filter_sz / self.img_support_sz) * torch.sqrt(
            self.target_sz.prod()) * output_sigma_factor
        yf_zero = TensorList([
            dcf.label_function(sz, sig).to(self.params.device)
            for sz, sig in zip(self.filter_sz, sigma)
        ])
        yf_zero = complex.complex(yf_zero)
        self.yf = TensorList()
        for p in self.points:
            shift_sample = 2 * math.pi * (self.mid_point -
                                          p) / self.img_support_sz
            self.yf.append(
                TensorList(
                    [fourier.shift_fs(yfs, shift_sample) for yfs in yf_zero]))

        # Optimization options
        self.params.precond_learning_rate = self.fparams.attribute(
            'learning_rate')
        if self.params.CG_forgetting_rate is None or max(
                self.params.precond_learning_rate) >= 1:
            self.params.direction_forget_factor = 0
        else:
            self.params.direction_forget_factor = (
                1 - max(self.params.precond_learning_rate)
            )**self.params.CG_forgetting_rate

        # Extract and transform sample
        x = self.generate_init_samples(im).to(self.params.device)
        self.x = x
        # Transform to get the training sample
        train_xf = self.preprocess_sample(x)

        # Shift the samples back
        if 'shift' in self.params.augmentation:
            for xf in train_xf:
                if xf.shape[0] == 1:
                    continue
                for i, shift in enumerate(self.params.augmentation['shift']):
                    shift_samp = 2 * math.pi * torch.Tensor(
                        shift) / self.img_support_sz
                    xf[1 + i:2 + i, ...] = fourier.shift_fs(xf[1 + i:2 + i,
                                                               ...],
                                                            shift=shift_samp)

        # Initialize first-frame training samples
        num_init_samples = train_xf.size(0)

        self.init_training_samples = train_xf.permute(2, 3, 0, 1, 4)

        # Initialize memory
        # Initialize filter
        self.training_samples = TensorList([
            xf.new_zeros(xf.shape[2], xf.shape[3],
                         self.params.sample_memory_size, xf.shape[1], 2)
            for xf in train_xf
        ])
        self.filters = TensorList([
            TensorList([
                xf.new_zeros(1, xf.shape[1], xf.shape[2], xf.shape[3], 2)
                for xf in train_xf
            ]) for i in range(len(self.points))
        ])

        self.init_sample_weights = TensorList(
            [xf.new_ones(1) / xf.shape[0] for xf in train_xf])
        self.sample_weights = TensorList(
            [xf.new_zeros(self.params.sample_memory_size) for xf in train_xf])
        for sw, init_sw, num in zip(self.sample_weights,
                                    self.init_sample_weights,
                                    num_init_samples):
            sw[:num] = init_sw

        # Get regularization filter
        self.reg_filter = TensorList([
            dcf.get_reg_filter(self.img_support_sz, self.target_sz,
                               fparams).to(self.params.device)
            for fparams in self.fparams
        ])
        self.reg_energy = self.reg_filter.view(-1) @ self.reg_filter.view(-1)

        # Sample counters and weights
        self.num_stored_samples = num_init_samples
        self.previous_replace_ind = [None] * len(self.num_stored_samples)

        for train_samp, init_samp in zip(self.training_samples,
                                         self.init_training_samples):
            train_samp[:, :, :init_samp.shape[2], :, :] = init_samp

        sample_energy = complex.abs_sqr(self.training_samples).mean(
            dim=2, keepdim=True).permute(2, 3, 0, 1)
        # Do joint optimization
        for i in range(len(self.points)):
            print('{0}'.format(i), end=', ')
            ts = self.training_samples.clone()
            yf = self.yf[i]
            filters = self.filters[i]
            i_sw = self.init_sample_weights.clone()
            re = self.reg_energy.clone()
            sw = self.sample_weights.clone()
            rf = self.reg_filter.clone()
            filter_optimizer = FilterOptim(self.params, re)
            filter_optimizer.register(filters, ts, yf, sw, rf)
            filter_optimizer.sample_energy = sample_energy.clone()

            filter_optimizer.run(self.params.init_CG_iter)

            # Post optimization
            filter_optimizer.run(self.params.post_init_CG_iter)
            self.filters[i] = filter_optimizer.filter
        self.symmetrize_filter()
        print()

    def track(self, image, update=False) -> dict:
        self.debug_info = {}

        self.frame_num += 1
        self.debug_info['frame_num'] = self.frame_num
        # Convert image
        im = numpy_to_torch(image)

        # ------- LOCALIZATION ------- #

        # Get sample
        test_xf = self.extract_fourier_sample(im)

        # Compute scores
        sfs = self.apply_filters(test_xf)
        out = TensorList([
            self.localize_and_update_target(sfs[i], i)
            for i in range(len(self.points))
        ])

        return out

    def apply_filters(self, sample_xf: TensorList) -> torch.Tensor:
        return TensorList([
            complex.mult(f, sample_xf).sum(1, keepdim=True)
            for f in self.filters
        ])

    def apply_filter(self, sample_xf: TensorList) -> torch.Tensor:
        return complex.mult(self.filter, sample_xf).sum(1, keepdim=True)

    def localize_and_update_target(self, sf: TensorList, i):
        if self.params.score_fusion_strategy == 'weightedsum':
            weight = self.fparams.attribute('translation_weight')
            sf = fourier.sum_fs(weight * sf)
            scores = fourier.sample_fs(sf, self.output_sz)
        else:
            raise ValueError('Unknown score fusion strategy.')

        # Get maximum
        max_score, max_disp = dcf.max2d(scores)
        max_disp = max_disp.float().cpu()

        # Convert to displacements in the base scale
        if self.params.score_fusion_strategy in ['sum', 'weightedsum']:
            disp = (max_disp +
                    self.output_sz / 2) % self.output_sz - self.output_sz / 2
        elif self.params.score_fusion_strategy == 'transcale':
            disp = max_disp - self.output_sz / 2

        # Compute translation vector and scale change factor
        translation_vec = disp.view(-1) * (self.img_support_sz /
                                           self.output_sz)

        # Update pos
        new_pos = self.mid_point.round() + translation_vec

        inside_ratio = 0.2
        inside_offset = (inside_ratio - 0.5) * self.target_sz
        self.points[i] = torch.max(
            torch.min(new_pos, self.image_sz - inside_offset), inside_offset)

        return self.points[i].round(), max_score, scores

    def extract_fourier_sample(self, im: torch.Tensor) -> TensorList:
        x = F.interpolate(im, self.output_sz.long().tolist(), mode='bilinear')
        x = TensorList([
            f.get_feature(x) for f in self.params.features.features
        ]).unroll().to(self.params.device)
        return self.preprocess_sample(x)

    def preprocess_sample(self, x: TensorList) -> TensorList:
        x *= self.window
        sample_xf = fourier.cfft2(x)
        return TensorList([
            dcf.interpolate_dft(xf, bf)
            for xf, bf in zip(sample_xf, self.interp_fs)
        ])

    def generate_init_samples(self, im: torch.Tensor) -> TensorList:
        # Do data augmentation
        transforms = [augmentation.Identity()]
        if 'shift' in self.params.augmentation:
            transforms.extend([
                augmentation.Translation(shift)
                for shift in self.params.augmentation['shift']
            ])
        if 'fliplr' in self.params.augmentation and self.params.augmentation[
                'fliplr']:
            transforms.append(augmentation.FlipHorizontal())
        if 'rotate' in self.params.augmentation:
            transforms.extend([
                augmentation.Rotate(angle)
                for angle in self.params.augmentation['rotate']
            ])
        if 'blur' in self.params.augmentation:
            transforms.extend([
                augmentation.Blur(sigma)
                for sigma in self.params.augmentation['blur']
            ])

        im_patch = F.interpolate(im,
                                 self.output_sz.long().tolist(),
                                 mode='bilinear')
        im_patches = torch.cat([T(im_patch) for T in transforms])
        init_samples = TensorList([
            f.get_feature(im_patches) for f in self.params.features.features
        ]).unroll()

        # Remove augmented samples for those that shall not have
        for i, use_aug in enumerate(
                self.fparams.attribute('use_augmentation')):
            if not use_aug:
                init_samples[i] = init_samples[i][0:1, ...]

        if 'dropout' in self.params.augmentation:
            num, prob = self.params.augmentation['dropout']
            for i, use_aug in enumerate(
                    self.fparams.attribute('use_augmentation')):
                if use_aug:
                    init_samples[i] = torch.cat([
                        init_samples[i],
                        F.dropout2d(init_samples[i][0:1, ...].expand(
                            num, -1, -1, -1),
                                    p=prob,
                                    training=True)
                    ])

        return init_samples

    def symmetrize_filter(self):
        for f in self.filters:
            for hf in f:
                hf[:, :, :, 0, :] /= 2
                hf[:, :, :, 0, :] += complex.conj(hf[:, :, :, 0, :].flip(
                    (2, )))