def run(self, num_iter, new_xf: TensorList = None): if num_iter == 0: return if new_xf is not None: new_sample_energy = complex.abs_sqr(new_xf) if self.sample_energy is None: self.sample_energy = new_sample_energy else: self.sample_energy = ( 1 - self.params.precond_learning_rate ) * self.sample_energy + self.params.precond_learning_rate * new_sample_energy # Compute right hand side self.b = complex.mtimes(self.sample_weights.view(1, 1, 1, -1), self.training_samples).permute(2, 3, 0, 1, 4) self.b = complex.mult_conj(self.yf, self.b) self.diag_M = (1 - self.params.precond_reg_param) * ( self.params.precond_data_param * self.sample_energy + (1 - self.params.precond_data_param) * self.sample_energy.mean(1, keepdim=True) ) + self.params.precond_reg_param * self.reg_energy _, res = self.run_CG(num_iter, self.filter) if self.debug: self.residuals = torch.cat((self.residuals, res)) plot_graph(self.residuals, 9)
def update_classifier(self, train_x, target_box, learning_rate=None, scores=None): # Set flags and learning rate hard_negative_flag = learning_rate is not None if learning_rate is None: learning_rate = self.params.learning_rate # Update the tracker memory if hard_negative_flag or self.frame_num % self.params.get( 'train_sample_interval', 1) == 0: self.update_memory(TensorList([train_x]), target_box, learning_rate) # Decide the number of iterations to run num_iter = 0 low_score_th = self.params.get('low_score_opt_threshold', None) if hard_negative_flag: num_iter = self.params.get('net_opt_hn_iter', None) elif low_score_th is not None and low_score_th > scores.max().item(): num_iter = self.params.get('net_opt_low_iter', None) elif (self.frame_num - 1) % self.params.train_skipping == 0: num_iter = self.params.get('net_opt_update_iter', None) plot_loss = self.params.debug > 0 if num_iter > 0: # Get inputs for the DiMP filter optimizer module samples = self.training_samples[0][:self.num_stored_samples[0], ...] target_boxes = self.target_boxes[:self. num_stored_samples[0], :].clone() sample_weights = self.sample_weights[0][:self. num_stored_samples[0]] # Run the filter optimizer module with torch.no_grad(): self.target_filter, _, losses = self.net.classifier.filter_optimizer( self.target_filter, num_iter=num_iter, feat=samples, bb=target_boxes, sample_weight=sample_weights, compute_losses=plot_loss) if plot_loss: if isinstance(losses, dict): losses = losses['train'] self.losses = torch.cat((self.losses, torch.cat(losses))) if self.visdom is not None: self.visdom.register( (self.losses, torch.arange(self.losses.numel())), 'lineplot', 3, 'Training Loss' + self.id_str) elif self.params.debug >= 3: plot_graph(self.losses, 10, title='Training Loss' + self.id_str)
def init_classifier(self, init_backbone_feat_rgb, init_backbone_feat_d): # Get classification features x_rgb, x_d = self.get_classification_features(init_backbone_feat_rgb, init_backbone_feat_d) # Overwrite some parameters in the classifier. (These are not generally changed) self._overwrite_classifier_params(feature_dim=x_rgb.shape[-3]) # Add the dropout augmentation here, since it requires extraction of the classification features if 'dropout' in self.params.augmentation and self.params.get('use_augmentation', True): num, prob = self.params.augmentation['dropout'] self.transforms.extend(self.transforms[:1]*num) x_rgb = torch.cat([x_rgb, F.dropout2d(x_rgb[0:1,...].expand(num,-1,-1,-1), p=prob, training=True)]) x_d = torch.cat([x_d, F.dropout2d(x_d[0:1,...].expand(num,-1,-1,-1), p=prob, training=True)]) # Set feature size and other related sizes self.feature_sz = torch.Tensor(list(x_rgb.shape[-2:])) ksz = self.net_rgb.classifier.filter_size self.kernel_size = torch.Tensor([ksz, ksz] if isinstance(ksz, (int, float)) else ksz) self.output_sz = self.feature_sz + (self.kernel_size + 1)%2 # Construct output window self.output_window = None if self.params.get('window_output', False): if self.params.get('use_clipped_window', False): self.output_window = dcf.hann2d_clipped(self.output_sz.long(), (self.output_sz*self.params.effective_search_area / self.params.search_area_scale).long(), centered=True).to(self.params.device) else: self.output_window = dcf.hann2d(self.output_sz.long(), centered=True).to(self.params.device) self.output_window = self.output_window.squeeze(0) # Get target boxes for the different augmentations target_boxes = self.init_target_boxes() # Set number of iterations plot_loss = self.params.debug > 0 num_iter = self.params.get('net_opt_iter', None) # Get target filter by running the discriminative model prediction module with torch.no_grad(): self.target_filter_rgb, _, losses_rgb = self.net_rgb.classifier.get_filter(x_rgb, target_boxes, num_iter=num_iter, compute_losses=plot_loss) self.target_filter_d, _, losses_d = self.net_d.classifier.get_filter(x_d, target_boxes, num_iter=num_iter, compute_losses=plot_loss) # Init memory if self.params.get('update_classifier', True): self.init_memory(TensorList([x_rgb]), TensorList([x_d])) if plot_loss: if isinstance(losses_rgb, dict): losses_rgb = losses_rgb['train'] losses_d = losses_d['train'] self.losses_rgb = torch.cat(losses_rgb) self.losses_d = torch.cat(losses_d) if self.visdom is not None: self.visdom.register((self.losses_rgb, torch.arange(self.losses_rgb.numel())), 'lineplot', 3, 'Training Loss_RGB' + self.id_str) self.visdom.register((self.losses_d, torch.arange(self.losses_d.numel())), 'lineplot', 3, 'Training Loss_D' + self.id_str) elif self.params.debug >= 3: plot_graph(self.losses_rgb, 10, title='Training Loss_RGB' + self.id_str) plot_graph(self.losses_d, 10, title='Training Loss_D' + self.id_str)
def run(self, num_iter, dummy=None): if num_iter == 0: return lossvec = None if self.debug: lossvec = torch.zeros(num_iter + 1) grad_mags = torch.zeros(num_iter + 1) for i in range(num_iter): self.x.requires_grad_(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) # Compute loss loss = self.problem.ip_output(self.f0, self.f0) # Compute grad grad = TensorList(torch.autograd.grad(loss, self.x)) # Update direction if self.dir is None: self.dir = grad else: self.dir = grad + self.momentum * self.dir self.x.detach_() self.x -= self.step_legnth * self.dir if self.debug: lossvec[i] = loss.item() grad_mags[i] = sum(grad.view(-1) @ grad.view(-1)).sqrt().item() if self.debug: self.x.requires_grad_(True) self.f0 = self.problem(self.x) loss = self.problem.ip_output(self.f0, self.f0) grad = TensorList(torch.autograd.grad(loss, self.x)) lossvec[-1] = self.problem.ip_output(self.f0, self.f0).item() grad_mags[-1] = sum( grad.view(-1) @ grad.view(-1)).cpu().sqrt().item() self.losses = torch.cat((self.losses, lossvec)) self.gradient_mags = torch.cat((self.gradient_mags, grad_mags)) if self.visdom is not None: self.visdom.register(self.losses, 'lineplot', 3, 'Loss') self.visdom.register(self.gradient_mags, 'lineplot', 4, 'Gradient magnitude') elif self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.gradient_mags, self.fig_num[1], title='Gradient magnitude') self.x.detach_() self.clear_temp()
def run(self, num_cg_iter, joint_var): """Run the oprimizer with the provided number of iterations.""" if num_cg_iter == 0: return self.x = joint_var lossvec = None if self.debug: lossvec = torch.zeros(2) self.x.requires_grad_(True) # Evaluate function at current estimate self.f0 = self.problem(self.x) # Create copy with graph detached self.g = self.f0.detach() if self.debug: lossvec[0] = self.problem.ip_output(self.g, self.g) self.g.requires_grad_(True) # Get df/dx^t @ f0 self.dfdxt_g = TensorList( torch.autograd.grad(self.f0, self.x, self.g, create_graph=True)) # Get the right hand side self.b = -self.dfdxt_g.detach() # Run CG delta_x, res = self.run_CG(num_cg_iter, eps=self.cg_eps) self.x.detach_() self.x += delta_x if self.debug: self.f0 = self.problem(self.x) lossvec[-1] = self.problem.ip_output(self.f0, self.f0) self.residuals = torch.cat((self.residuals, res)) self.losses = torch.cat((self.losses, lossvec)) # print('Loss:', self.losses) if self.visdom is not None: self.visdom.register(self.losses, 'lineplot', 3, 'Loss') self.visdom.register(self.residuals, 'lineplot', 3, 'CG residuals') elif self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.residuals, self.fig_num[1], title='CG residuals') self.x.detach_() self.clear_temp()
def init_classifier(self, init_backbone_feat): # Get classification features x = self.get_classification_features(init_backbone_feat) # Add the dropout augmentation here, since it requires extraction of the classification features if 'dropout' in self.params.augmentation and getattr(self.params, 'use_augmentation', True): num, prob = self.params.augmentation['dropout'] self.transforms.extend(self.transforms[:1]*num) x = torch.cat([x, F.dropout2d(x[0:1,...].expand(num,-1,-1,-1), p=prob, training=True)]) # Set feature size and other related sizes #18,18 self.feature_sz = torch.Tensor(list(x.shape[-2:])) ksz = self.net.classifier.filter_size self.kernel_size = torch.Tensor([ksz, ksz] if isinstance(ksz, (int, float)) else ksz) self.output_sz = self.feature_sz + (self.kernel_size + 1)%2 #print(['output_sz', self.output_sz]) # Construct output window self.output_window = None if getattr(self.params, 'window_output', False): if getattr(self.params, 'use_clipped_window', False): self.output_window = dcf.hann2d_clipped(self.output_sz.long(), self.output_sz.long()*self.params.effective_search_area / self.params.search_area_scale, centered=False).to(self.params.device) else: self.output_window = dcf.hann2d(self.output_sz.long(), centered=True).to(self.params.device) self.output_window = self.output_window.squeeze(0) # Get target boxes for the different augmentations target_boxes = self.init_target_boxes() # Set number of iterations plot_loss = self.params.debug > 0 num_iter = getattr(self.params, 'net_opt_iter', None) # Get target filter by running the discriminative model prediction module with torch.no_grad(): self.target_filter, _, losses = self.net.classifier.get_filter(x, target_boxes, num_iter=num_iter, compute_losses=plot_loss) # Init memory if getattr(self.params, 'update_classifier', True): self.init_memory(TensorList([x])) if plot_loss: if isinstance(losses, dict): losses = losses['train'] self.losses = torch.stack(losses) if self.visdom is not None: self.visdom.register((self.losses, torch.arange(self.losses.numel())), 'lineplot', 3, 'Training Loss') elif self.params.debug >= 3: plot_graph(self.losses, 10, title='Training loss')
def run(self, num_cg_iter, num_newton_iter=None): if isinstance(num_cg_iter, int): if num_cg_iter == 0: return if num_newton_iter is None: num_newton_iter = 1 num_cg_iter = [num_cg_iter] * num_newton_iter num_newton_iter = len(num_cg_iter) if num_newton_iter == 0: return if self.analyze_convergence: self.evaluate_CG_iteration(0) for cg_iter in num_cg_iter: self.run_newton_iter(cg_iter) self.hessian_reg *= self.hessian_reg_factor if self.debug: if not self.analyze_convergence: loss = self.problem(self.x) self.losses = torch.cat((self.losses, loss.detach().cpu().view(-1))) if self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.residuals, self.fig_num[1], title='CG residuals') if self.analyze_convergence: plot_graph(self.gradient_mags, self.fig_num[2], 'Gradient magnitude') self.x.detach_() self.clear_temp() return self.losses, self.residuals
def run(self, num_cg_iter, num_gn_iter=None): """Run the optimizer. args: num_cg_iter: Number of CG iterations per GN iter. If list, then each entry specifies number of CG iterations and number of GN iterations is given by the length of the list. num_gn_iter: Number of GN iterations. Shall only be given if num_cg_iter is an integer. """ if isinstance(num_cg_iter, int): if num_gn_iter is None: raise ValueError( 'Must specify number of GN iter if CG iter is constant') num_cg_iter = [num_cg_iter] * num_gn_iter num_gn_iter = len(num_cg_iter) if num_gn_iter == 0: return if self.analyze_convergence: self.evaluate_CG_iteration(0) # Outer loop for running the GN iterations. for cg_iter in num_cg_iter: self.run_GN_iter(cg_iter) if self.debug: if not self.analyze_convergence: self.f0 = self.problem(self.x) loss = self.problem.ip_output(self.f0, self.f0) self.losses = torch.cat( (self.losses, loss.detach().cpu().view(-1))) if self.visdom is not None: self.visdom.register(self.losses, 'lineplot', 3, 'Loss') self.visdom.register(self.residuals, 'lineplot', 3, 'CG residuals') if self.analyze_convergence: self.visdom.register(self.gradient_mags, 'lineplot', 4, 'Gradient magnitude') elif self.plotting: plot_graph(self.losses, self.fig_num[0], title='Loss') plot_graph(self.residuals, self.fig_num[1], title='CG residuals') if self.analyze_convergence: plot_graph(self.gradient_mags, self.fig_num[2], 'Gradient magnitude') self.x.detach_() self.clear_temp() return self.losses, self.residuals