예제 #1
0
 def A(self, x):
     dfdx_x = torch.autograd.grad(self.dfdxt_g,
                                  self.g,
                                  x,
                                  retain_graph=True)
     return TensorList(
         torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True))
예제 #2
0
    def extract_transformed(self, im, pos, scale, image_sz, transforms):
        """Extract features from a set of transformed image samples.
        args:
            im: Image.
            pos: Center position for extraction.
            scale: Image scale to extract features from.
            image_sz: Size to resize the image samples to before extraction.
            transforms: A set of image transforms to apply.
        """

        # Get image patche
        im_patch = sample_patch(im, pos, scale * image_sz, image_sz)

        # Apply transforms
        im_patches = torch.cat([T(im_patch) for T in transforms])

        # Compute features
        # debug
        # for f in self.features:
        #     f.get_feature(im_patches)

        feature_map = TensorList(
            [f.get_feature(im_patches) for f in self.features]).unroll()

        return feature_map
예제 #3
0
    def evaluate_CG_iteration(self, delta_x):
        if self.analyze_convergence:
            x = (self.x + delta_x).detach()
            x.requires_grad_(True)

            # compute loss and gradient
            f = self.problem(x)
            loss = self.problem.ip_output(f, f)
            grad = TensorList(torch.autograd.grad(loss, x))

            # store in the vectors
            self.losses = torch.cat(
                (self.losses, loss.detach().cpu().view(-1)))
            self.gradient_mags = torch.cat(
                (self.gradient_mags,
                 sum(grad.view(-1)
                     @ grad.view(-1)).cpu().sqrt().detach().view(-1)))
예제 #4
0
    def init_classifier(self, init_backbone_feat):
        # Get classification features
        x = self.get_classification_features(init_backbone_feat)

        # Add the dropout augmentation here, since it requires extraction of the classification features
        if 'dropout' in self.p.augmentation and getattr(
                self.p, 'use_augmentation', True):
            num, prob = self.p.augmentation['dropout']
            self.transforms.extend(self.transforms[:1] * num)
            x = torch.cat([
                x,
                F.dropout2d(x[0:1, ...].expand(num, -1, -1, -1),
                            p=prob,
                            training=True)
            ])

        # Set feature size and other related sizes
        self.feature_sz = torch.Tensor(list(x.shape[-2:]))
        ksz = self.net.classifier.filter_size
        self.kernel_size = torch.Tensor(
            [ksz, ksz] if isinstance(ksz, (int, float)) else ksz)
        self.output_sz = self.feature_sz + (self.kernel_size + 1) % 2

        # Construct output window
        self.output_window = None
        if getattr(self.p, 'window_output', False):
            if getattr(self.p, 'use_clipped_window', False):
                self.output_window = dcf.hann2d_clipped(
                    self.output_sz.long(),
                    self.output_sz.long() * self.p.effective_search_area /
                    self.p.search_area_scale,
                    centered=False).to(self.p.device)
            else:
                self.output_window = dcf.hann2d(self.output_sz.long(),
                                                centered=True).to(
                                                    self.p.device)
            self.output_window = self.output_window.squeeze(0)

        # Get target boxes for the different augmentations
        target_boxes = self.init_target_boxes()

        # Set number of iterations
        plot_loss = self.p.debug > 0
        num_iter = getattr(self.p, 'net_opt_iter', None)

        # Get target filter by running the discriminative model prediction module
        with torch.no_grad():
            self.target_filter, _, losses = self.net.classifier.get_filter(
                x, target_boxes, num_iter=num_iter, compute_losses=plot_loss)

        # Init memory
        if getattr(self.p, 'update_classifier', True):
            self.init_memory(TensorList([x]))

        if plot_loss:
            if isinstance(losses, dict):
                losses = losses['train']
            self.losses = torch.stack(losses)
예제 #5
0
 def get_fparams(self, name: str = None):
     if name is None:
         return [
             f.fparams for f in self.features if self._return_feature(f)
         ]
     return TensorList([
         getattr(f.fparams, name) for f in self.features
         if self._return_feature(f)
     ]).unroll()
예제 #6
0
파일: models.py 프로젝트: swan2015/TracKit
    def extract(self, im: torch.Tensor):
        # im = im/255       # remove this for siam_net
        # im -= self.mean
        # im /= self.std
        im = im.cuda()

        with torch.no_grad():
            output_features = self.net.extract_for_online(im)

        return TensorList([output_features])
예제 #7
0
    def run(self, num_iter, dummy=None):

        if num_iter == 0:
            return

        lossvec = None
        if self.debug:
            lossvec = torch.zeros(num_iter + 1)
            grad_mags = torch.zeros(num_iter + 1)

        for i in range(num_iter):
            self.x.requires_grad_(True)

            # Evaluate function at current estimate
            self.f0 = self.problem(self.x)

            # Compute loss
            loss = self.problem.ip_output(self.f0, self.f0)

            # Compute grad
            grad = TensorList(torch.autograd.grad(loss, self.x))

            # Update direction
            if self.dir is None:
                self.dir = grad
            else:
                self.dir = grad + self.momentum * self.dir

            self.x.detach_()
            self.x -= self.step_legnth * self.dir

            if self.debug:
                lossvec[i] = loss.item()
                grad_mags[i] = sum(grad.view(-1) @ grad.view(-1)).sqrt().item()

        if self.debug:
            self.x.requires_grad_(True)
            self.f0 = self.problem(self.x)
            loss = self.problem.ip_output(self.f0, self.f0)
            grad = TensorList(torch.autograd.grad(loss, self.x))
            lossvec[-1] = self.problem.ip_output(self.f0, self.f0).item()
            grad_mags[-1] = sum(
                grad.view(-1) @ grad.view(-1)).cpu().sqrt().item()
            self.losses = torch.cat((self.losses, lossvec))
            self.gradient_mags = torch.cat((self.gradient_mags, grad_mags))
            if self.plotting:
                plot_graph(self.losses, self.fig_num[0], title='Loss')
                plot_graph(self.gradient_mags,
                           self.fig_num[1],
                           title='Gradient magnitude')

        self.x.detach_()
        self.clear_temp()
예제 #8
0
    def init_memory(self, train_x: TensorList):
        # Initialize first-frame spatial training samples
        self.num_init_samples = train_x.size(0)
        init_sample_weights = TensorList(
            [x.new_ones(1) / x.shape[0] for x in train_x])

        # Sample counters and weights for spatial
        self.num_stored_samples = self.num_init_samples.copy()
        self.previous_replace_ind = [None] * len(self.num_stored_samples)
        self.sample_weights = TensorList(
            [x.new_zeros(self.p.sample_memory_size) for x in train_x])
        for sw, init_sw, num in zip(self.sample_weights, init_sample_weights,
                                    self.num_init_samples):
            sw[:num] = init_sw

        # Initialize memory
        self.training_samples = TensorList([
            x.new_zeros(self.p.sample_memory_size, x.shape[1], x.shape[2],
                        x.shape[3]) for x in train_x
        ])

        for ts, x in zip(self.training_samples, train_x):
            ts[:x.shape[0], ...] = x
예제 #9
0
    def extract(self, im, pos, scales, image_sz):
        if isinstance(scales, (int, float)):
            scales = [scales]

        # Get image patches
        im_patches = torch.cat(
            [sample_patch(im, pos, s * image_sz, image_sz) for s in scales])

        # Compute features
        feature_map = torch.cat(TensorList(
            [f.get_feature(im_patches) for f in self.features]).unroll(),
                                dim=1)

        return feature_map
예제 #10
0
    def update_classifier(self,
                          train_x,
                          target_box,
                          learning_rate=None,
                          scores=None):
        # Set flags and learning rate
        hard_negative_flag = learning_rate is not None
        if learning_rate is None:
            learning_rate = self.p.learning_rate

        # Update the tracker memory
        self.update_memory(TensorList([train_x]), target_box, learning_rate)

        # Decide the number of iterations to run
        num_iter = 0
        low_score_th = getattr(self.p, 'low_score_opt_threshold', None)
        if hard_negative_flag:
            num_iter = getattr(self.p, 'net_opt_hn_iter', None)
        elif low_score_th is not None and low_score_th > scores.max().item():
            num_iter = getattr(self.p, 'net_opt_low_iter', None)
        elif (self.frame_num - 1) % self.p.train_skipping == 0:
            num_iter = getattr(self.p, 'net_opt_update_iter', None)

        plot_loss = self.p.debug > 0

        if num_iter > 0:
            # Get inputs for the ONLINE filter optimizer module
            samples = self.training_samples[0][:self.num_stored_samples[0],
                                               ...]
            target_boxes = self.target_boxes[:self.
                                             num_stored_samples[0], :].clone()
            sample_weights = self.sample_weights[0][:self.
                                                    num_stored_samples[0]]

            # Run the filter optimizer module
            with torch.no_grad():
                self.target_filter, _, losses = self.net.classifier.filter_optimizer(
                    self.target_filter,
                    samples,
                    target_boxes,
                    sample_weight=sample_weights,
                    num_iter=num_iter,
                    compute_losses=plot_loss)

            if plot_loss:
                if isinstance(losses, dict):
                    losses = losses['train']
                self.losses = torch.cat((self.losses, torch.stack(losses)))
예제 #11
0
    def init_iou_net(self, backbone_feat):
        # Setup IoU net and objective
        for p in self.net.bb_regressor.parameters():
            p.requires_grad = False

        # Get target boxes for the different augmentations
        self.classifier_target_box = self.get_iounet_box(
            self.pos, self.target_sz, self.init_sample_pos,
            self.init_sample_scale)
        target_boxes = TensorList()
        if self.p.iounet_augmentation:
            for T in self.transforms:
                if not isinstance(
                        T, (augmentation.Identity, augmentation.Translation,
                            augmentation.FlipHorizontal,
                            augmentation.FlipVertical, augmentation.Blur)):
                    break
                target_boxes.append(
                    self.classifier_target_box +
                    torch.Tensor([T.shift[1], T.shift[0], 0, 0]))
        else:
            target_boxes.append(self.classifier_target_box + torch.Tensor([
                self.transforms[0].shift[1], self.transforms[0].shift[0], 0, 0
            ]))
        target_boxes = torch.cat(target_boxes.view(1, 4), 0).to(self.p.device)

        # Get iou features
        iou_backbone_feat = self.get_iou_backbone_features(backbone_feat)

        # Remove other augmentations such as rotation
        iou_backbone_feat = TensorList(
            [x[:target_boxes.shape[0], ...] for x in iou_backbone_feat])

        # Get modulation vector
        self.iou_modulation = self.get_iou_modulation(iou_backbone_feat,
                                                      target_boxes)
        self.iou_modulation = TensorList(
            [x.detach().mean(0) for x in self.iou_modulation])
예제 #12
0
    def extract(self, im, pos, scales, image_sz):
        """Extract features.
        args:
            im: Image.
            pos: Center position for extraction.
            scales: Image scales to extract features from.
            image_sz: Size to resize the image samples to before extraction.
        """
        if isinstance(scales, (int, float)):
            scales = [scales]

        # Get image patches
        im_patches = torch.cat(
            [sample_patch(im, pos, s * image_sz, image_sz) for s in scales])

        # Compute features
        feature_map = TensorList(
            [f.get_feature(im_patches) for f in self.features]).unroll()

        return feature_map
예제 #13
0
 def init_target_boxes(self):
     """Get the target bounding boxes for the initial augmented samples."""
     self.classifier_target_box = self.get_iounet_box(
         self.pos, self.target_sz, self.init_sample_pos,
         self.init_sample_scale)
     init_target_boxes = TensorList()
     for T in self.transforms:
         init_target_boxes.append(
             self.classifier_target_box +
             torch.Tensor([T.shift[1], T.shift[0], 0, 0]))
     init_target_boxes = torch.cat(init_target_boxes.view(1, 4),
                                   0).to(self.p.device)
     self.target_boxes = init_target_boxes.new_zeros(
         self.p.sample_memory_size, 4)
     self.target_boxes[:init_target_boxes.shape[0], :] = init_target_boxes
     return init_target_boxes
예제 #14
0
 def A(self, x):
     return TensorList(
         torch.autograd.grad(self.g, self.x, x,
                             retain_graph=True)) + self.hessian_reg * x
예제 #15
0
class NewtonCG(ConjugateGradientBase):
    """Newton with Conjugate Gradient. Handels general minimization problems."""
    def __init__(self,
                 problem: MinimizationProblem,
                 variable: TensorList,
                 init_hessian_reg=0.0,
                 hessian_reg_factor=1.0,
                 cg_eps=0.0,
                 fletcher_reeves=True,
                 standard_alpha=True,
                 direction_forget_factor=0,
                 debug=False,
                 analyze=False,
                 plotting=False,
                 fig_num=(10, 11, 12)):
        super().__init__(fletcher_reeves, standard_alpha,
                         direction_forget_factor, debug or analyze or plotting)

        self.problem = problem
        self.x = variable

        self.analyze_convergence = analyze
        self.plotting = plotting
        self.fig_num = fig_num

        self.hessian_reg = init_hessian_reg
        self.hessian_reg_factor = hessian_reg_factor
        self.cg_eps = cg_eps
        self.f0 = None
        self.g = None

        self.residuals = torch.zeros(0)
        self.losses = torch.zeros(0)
        self.gradient_mags = torch.zeros(0)

    def clear_temp(self):
        self.f0 = None
        self.g = None

    def run(self, num_cg_iter, num_newton_iter=None):

        if isinstance(num_cg_iter, int):
            if num_cg_iter == 0:
                return
            if num_newton_iter is None:
                num_newton_iter = 1
            num_cg_iter = [num_cg_iter] * num_newton_iter

        num_newton_iter = len(num_cg_iter)
        if num_newton_iter == 0:
            return

        if self.analyze_convergence:
            self.evaluate_CG_iteration(0)

        for cg_iter in num_cg_iter:
            self.run_newton_iter(cg_iter)
            self.hessian_reg *= self.hessian_reg_factor

        if self.debug:
            if not self.analyze_convergence:
                loss = self.problem(self.x)
                self.losses = torch.cat(
                    (self.losses, loss.detach().cpu().view(-1)))

            if self.plotting:
                plot_graph(self.losses, self.fig_num[0], title='Loss')
                plot_graph(self.residuals,
                           self.fig_num[1],
                           title='CG residuals')
                if self.analyze_convergence:
                    plot_graph(self.gradient_mags, self.fig_num[2],
                               'Gradient magnitude')

        self.x.detach_()
        self.clear_temp()

        return self.losses, self.residuals

    def run_newton_iter(self, num_cg_iter):

        self.x.requires_grad_(True)

        # Evaluate function at current estimate
        self.f0 = self.problem(self.x)

        if self.debug and not self.analyze_convergence:
            self.losses = torch.cat(
                (self.losses, self.f0.detach().cpu().view(-1)))

        # Gradient of loss
        self.g = TensorList(
            torch.autograd.grad(self.f0, self.x, create_graph=True))

        # Get the right hand side
        self.b = -self.g.detach()

        # Run CG
        delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps)

        self.x.detach_()
        self.x += delta_x

        if self.debug:
            self.residuals = torch.cat((self.residuals, res))

    def A(self, x):
        return TensorList(
            torch.autograd.grad(self.g, self.x, x,
                                retain_graph=True)) + self.hessian_reg * x

    def ip(self, a, b):
        # Implements the inner product
        return self.problem.ip_input(a, b)

    def M1(self, x):
        return self.problem.M1(x)

    def M2(self, x):
        return self.problem.M2(x)

    def evaluate_CG_iteration(self, delta_x):
        if self.analyze_convergence:
            x = (self.x + delta_x).detach()
            x.requires_grad_(True)

            # compute loss and gradient
            loss = self.problem(x)
            grad = TensorList(torch.autograd.grad(loss, x))

            # store in the vectors
            self.losses = torch.cat(
                (self.losses, loss.detach().cpu().view(-1)))
            self.gradient_mags = torch.cat(
                (self.gradient_mags,
                 sum(grad.view(-1)
                     @ grad.view(-1)).cpu().sqrt().detach().view(-1)))
예제 #16
0
 def stride(self):
     return torch.Tensor(
         TensorList([
             f.stride() for f in self.features if self._return_feature(f)
         ]).unroll())
예제 #17
0
 def size(self, input_sz):
     return TensorList([
         f.size(input_sz) for f in self.features if self._return_feature(f)
     ]).unroll()
예제 #18
0
class GaussNewtonCG(ConjugateGradientBase):
    """Gauss-Newton with Conjugate Gradient optimizer."""
    def __init__(self,
                 problem: L2Problem,
                 variable: TensorList,
                 cg_eps=0.0,
                 fletcher_reeves=True,
                 standard_alpha=True,
                 direction_forget_factor=0,
                 debug=False,
                 analyze=False,
                 plotting=False,
                 fig_num=(10, 11, 12)):
        super().__init__(fletcher_reeves, standard_alpha,
                         direction_forget_factor, debug or analyze or plotting)

        self.problem = problem
        self.x = variable

        self.analyze_convergence = analyze
        self.plotting = plotting
        self.fig_num = fig_num

        self.cg_eps = cg_eps
        self.f0 = None
        self.g = None
        self.dfdxt_g = None

        self.residuals = torch.zeros(0)
        self.losses = torch.zeros(0)
        self.gradient_mags = torch.zeros(0)

    def clear_temp(self):
        self.f0 = None
        self.g = None
        self.dfdxt_g = None

    def run_GN(self, *args, **kwargs):
        return self.run(*args, **kwargs)

    def run(self, num_cg_iter, num_gn_iter=None):
        """Run the optimizer.
        args:
            num_cg_iter: Number of CG iterations per GN iter. If list, then each entry specifies number of CG iterations
                         and number of GN iterations is given by the length of the list.
            num_gn_iter: Number of GN iterations. Shall only be given if num_cg_iter is an integer.
        """

        if isinstance(num_cg_iter, int):
            if num_gn_iter is None:
                raise ValueError(
                    'Must specify number of GN iter if CG iter is constant')
            num_cg_iter = [num_cg_iter] * num_gn_iter

        num_gn_iter = len(num_cg_iter)
        if num_gn_iter == 0:
            return

        if self.analyze_convergence:
            self.evaluate_CG_iteration(0)

        # Outer loop for running the GN iterations.
        for cg_iter in num_cg_iter:
            self.run_GN_iter(cg_iter)

        if self.debug:
            if not self.analyze_convergence:
                self.f0 = self.problem(self.x)
                loss = self.problem.ip_output(self.f0, self.f0)
                self.losses = torch.cat(
                    (self.losses, loss.detach().cpu().view(-1)))

            if self.plotting:
                plot_graph(self.losses, self.fig_num[0], title='Loss')
                plot_graph(self.residuals,
                           self.fig_num[1],
                           title='CG residuals')
                if self.analyze_convergence:
                    plot_graph(self.gradient_mags, self.fig_num[2],
                               'Gradient magnitude')

        self.x.detach_()
        self.clear_temp()

        return self.losses, self.residuals

    def run_GN_iter(self, num_cg_iter):
        """Runs a single GN iteration."""

        self.x.requires_grad_(True)

        # Evaluate function at current estimate
        self.f0 = self.problem(self.x)

        # Create copy with graph detached
        self.g = self.f0.detach()

        if self.debug and not self.analyze_convergence:
            loss = self.problem.ip_output(self.g, self.g)
            self.losses = torch.cat(
                (self.losses, loss.detach().cpu().view(-1)))

        self.g.requires_grad_(True)

        # Get df/dx^t @ f0
        self.dfdxt_g = TensorList(
            torch.autograd.grad(self.f0, self.x, self.g, create_graph=True))

        # Get the right hand side
        self.b = -self.dfdxt_g.detach()

        # Run CG
        delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps)

        self.x.detach_()
        self.x += delta_x

        if self.debug:
            self.residuals = torch.cat((self.residuals, res))

    def A(self, x):
        dfdx_x = torch.autograd.grad(self.dfdxt_g,
                                     self.g,
                                     x,
                                     retain_graph=True)
        return TensorList(
            torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True))

    def ip(self, a, b):
        return self.problem.ip_input(a, b)

    def M1(self, x):
        return self.problem.M1(x)

    def M2(self, x):
        return self.problem.M2(x)

    def evaluate_CG_iteration(self, delta_x):
        if self.analyze_convergence:
            x = (self.x + delta_x).detach()
            x.requires_grad_(True)

            # compute loss and gradient
            f = self.problem(x)
            loss = self.problem.ip_output(f, f)
            grad = TensorList(torch.autograd.grad(loss, x))

            # store in the vectors
            self.losses = torch.cat(
                (self.losses, loss.detach().cpu().view(-1)))
            self.gradient_mags = torch.cat(
                (self.gradient_mags,
                 sum(grad.view(-1)
                     @ grad.view(-1)).cpu().sqrt().detach().view(-1)))
예제 #19
0
 def dim(self):
     return TensorList([
         f.dim() for f in self.features if self._return_feature(f)
     ]).unroll()
예제 #20
0
class ConjugateGradient(ConjugateGradientBase):
    """Conjugate Gradient optimizer, performing single linearization of the residuals in the start."""
    def __init__(self,
                 problem: L2Problem,
                 variable: TensorList,
                 cg_eps=0.0,
                 fletcher_reeves=True,
                 standard_alpha=True,
                 direction_forget_factor=0,
                 debug=False,
                 plotting=False,
                 fig_num=(10, 11)):
        super().__init__(fletcher_reeves, standard_alpha,
                         direction_forget_factor, debug or plotting)

        self.problem = problem
        self.x = variable

        self.plotting = plotting
        self.fig_num = fig_num

        self.cg_eps = cg_eps
        self.f0 = None
        self.g = None
        self.dfdxt_g = None

        self.residuals = torch.zeros(0)
        self.losses = torch.zeros(0)

    def clear_temp(self):
        self.f0 = None
        self.g = None
        self.dfdxt_g = None

    def run(self, num_cg_iter):
        """Run the oprimizer with the provided number of iterations."""

        if num_cg_iter == 0:
            return

        lossvec = None
        if self.debug:
            lossvec = torch.zeros(2)

        self.x.requires_grad_(True)

        # Evaluate function at current estimate
        self.f0 = self.problem(self.x)

        # Create copy with graph detached
        self.g = self.f0.detach()

        if self.debug:
            lossvec[0] = self.problem.ip_output(self.g, self.g)

        self.g.requires_grad_(True)

        # Get df/dx^t @ f0
        self.dfdxt_g = TensorList(
            torch.autograd.grad(self.f0, self.x, self.g, create_graph=True))

        # Get the right hand side
        self.b = -self.dfdxt_g.detach()

        # Run CG
        delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps)

        self.x.detach_()
        self.x += delta_x

        if self.debug:
            self.f0 = self.problem(self.x)
            lossvec[-1] = self.problem.ip_output(self.f0, self.f0)
            self.residuals = torch.cat((self.residuals, res))
            self.losses = torch.cat((self.losses, lossvec))
            if self.plotting:
                plot_graph(self.losses, self.fig_num[0], title='Loss')
                plot_graph(self.residuals,
                           self.fig_num[1],
                           title='CG residuals')

        self.x.detach_()
        self.clear_temp()

    def A(self, x):
        dfdx_x = torch.autograd.grad(self.dfdxt_g,
                                     self.g,
                                     x,
                                     retain_graph=True)
        return TensorList(
            torch.autograd.grad(self.f0, self.x, dfdx_x, retain_graph=True))

    def ip(self, a, b):
        return self.problem.ip_input(a, b)

    def M1(self, x):
        return self.problem.M1(x)

    def M2(self, x):
        return self.problem.M2(x)
예제 #21
0
파일: models.py 프로젝트: swan2015/TracKit
 def stride(self):
     return TensorList([s * self.layer_stride[l] for l, s in zip(self.output_layers, self.pool_stride)])
예제 #22
0
파일: models.py 프로젝트: swan2015/TracKit
 def dim(self):
     return TensorList([self.layer_dim[l] for l in self.output_layers])
예제 #23
0
    def refine_target_box(self,
                          backbone_feat,
                          sample_pos,
                          sample_scale,
                          scale_ind,
                          update_scale=True):

        # Initial box for refinement
        init_box = self.get_iounet_box(self.pos, self.target_sz, sample_pos,
                                       sample_scale)

        # Extract features from the relevant scale
        iou_features = self.get_iou_features(backbone_feat)
        iou_features = TensorList(
            [x[scale_ind:scale_ind + 1, ...] for x in iou_features])

        # Generate random initial boxes
        init_boxes = init_box.view(1, 4).clone()
        if self.p.num_init_random_boxes > 0:
            square_box_sz = init_box[2:].prod().sqrt()
            rand_factor = square_box_sz * torch.cat([
                self.p.box_jitter_pos * torch.ones(2),
                self.p.box_jitter_sz * torch.ones(2)
            ])

            minimal_edge_size = init_box[2:].min() / 3
            rand_bb = (torch.rand(self.p.num_init_random_boxes, 4) -
                       0.5) * rand_factor
            new_sz = (init_box[2:] + rand_bb[:, 2:]).clamp(minimal_edge_size)
            new_center = (init_box[:2] + init_box[2:] / 2) + rand_bb[:, :2]
            init_boxes = torch.cat([new_center - new_sz / 2, new_sz], 1)
            init_boxes = torch.cat([init_box.view(1, 4), init_boxes])

        # Optimize the boxes
        output_boxes, output_iou = self.optimize_boxes(iou_features,
                                                       init_boxes)

        # Remove weird boxes
        output_boxes[:, 2:].clamp_(1)
        aspect_ratio = output_boxes[:, 2] / output_boxes[:, 3]
        keep_ind = (aspect_ratio < self.p.maximal_aspect_ratio) * (
            aspect_ratio > 1 / self.p.maximal_aspect_ratio)
        output_boxes = output_boxes[keep_ind, :]
        output_iou = output_iou[keep_ind]

        # If no box found
        if output_boxes.shape[0] == 0:
            return

        # Predict box
        k = getattr(self.p, 'iounet_k', 5)
        topk = min(k, output_boxes.shape[0])
        _, inds = torch.topk(output_iou, topk)
        predicted_box = output_boxes[inds, :].mean(0)
        predicted_iou = output_iou.view(-1, 1)[inds, :].mean(0)

        # Get new position and size
        new_pos = predicted_box[:2] + predicted_box[2:] / 2
        new_pos = (new_pos.flip(
            (0, )) - (self.img_sample_sz - 1) / 2) * sample_scale + sample_pos
        new_target_sz = predicted_box[2:].flip((0, )) * sample_scale
        new_scale = torch.sqrt(new_target_sz.prod() /
                               self.base_target_sz.prod())

        self.pos_iounet = new_pos.clone()

        if getattr(self.p, 'use_iounet_pos_for_learning', True):
            self.pos = new_pos.clone()

        self.target_sz = new_target_sz

        if update_scale:
            self.target_scale = new_scale