def forward(self, meta_parameter: TensorList, num_iter=None, *args, **kwargs): input_is_list = True if not isinstance(meta_parameter, TensorList): meta_parameter = TensorList([meta_parameter]) input_is_list = False # Make sure grad is enabled torch_grad_enabled = torch.is_grad_enabled() torch.set_grad_enabled(True) num_iter = self.num_iter if num_iter is None else num_iter meta_parameter_iterates = [] def _add_iterate(meta_par): if input_is_list: meta_parameter_iterates.append(meta_par) else: meta_parameter_iterates.append(meta_par[0]) _add_iterate(meta_parameter) losses = [] for i in range(num_iter): if i > 0 and i % self.detach_length == 0: meta_parameter = meta_parameter.detach() meta_parameter.requires_grad_(True) # Compute residual vector r = self.residual_module(meta_parameter, **kwargs) if self.compute_losses: losses.append(self._compute_loss(r)) # Compute gradient of loss u = r.clone() g = TensorList( torch.autograd.grad(r, meta_parameter, u, create_graph=True)) # Multiply gradient with Jacobian h = TensorList(torch.autograd.grad(g, u, g, create_graph=True)) # Compute squared norms ip_gg = self._sqr_norm(g, batch_dim=self._parameter_batch_dim) ip_hh = self._sqr_norm(h, batch_dim=self._residual_batch_dim) # Compute step length alpha = ip_gg / (ip_hh + self.steplength_reg * ip_gg).clamp(1e-8) # Compute optimization step step = g.apply(lambda e: alpha.reshape([ -1 if d == self._parameter_batch_dim else 1 for d in range(e.dim()) ]) * e) # Add step to parameter meta_parameter = meta_parameter - step _add_iterate(meta_parameter) if self.compute_losses: losses.append( self._compute_loss( self.residual_module(meta_parameter, **kwargs))) # Reset the grad enabled flag torch.set_grad_enabled(torch_grad_enabled) if not torch_grad_enabled: meta_parameter.detach_() for w in meta_parameter_iterates: w.detach_() for l in losses: l.detach_() if not input_is_list: meta_parameter = meta_parameter[0] return meta_parameter, meta_parameter_iterates, losses
def forward(self, meta_parameter: TensorList, num_iter=None, **kwargs): if not isinstance(meta_parameter, TensorList): meta_parameter = TensorList([meta_parameter]) _residual_batch_dim = 1 # Make sure grad is enabled torch_grad_enabled = torch.is_grad_enabled() torch.set_grad_enabled(True) num_iter = self.num_iter if num_iter is None else num_iter step_length_factor = torch.exp(self.log_step_length) label_density, sample_weight, reg_weight = self.score_predictor.init_data( meta_parameter, **kwargs) exp_reg = 0 if self.softmax_reg is None else math.exp(self.softmax_reg) def _compute_loss(scores, weights): num_sequences = scores.shape[_residual_batch_dim] return torch.sum(sample_weight.reshape(sample_weight.shape[0], -1) * (torch.log(scores.exp().sum(dim=(-2, -1)) + exp_reg) - (label_density * scores).sum(dim=(-2, -1)))) / num_sequences + \ reg_weight * sum((weights * weights).sum()) / num_sequences meta_parameter_iterates = [meta_parameter] losses = [] for i in range(num_iter): if i > 0 and i % self.detach_length == 0: meta_parameter = meta_parameter.detach() meta_parameter.requires_grad_(True) # Compute residual vector scores = self.score_predictor(meta_parameter, **kwargs) if self.compute_losses: losses.append(_compute_loss(scores, meta_parameter)) scores_softmax = activation.softmax_reg( scores.reshape(*scores.shape[:2], -1), dim=2, reg=self.softmax_reg).reshape(scores.shape) dLds = sample_weight * (scores_softmax - label_density) # Compute gradient of loss weights_grad = TensorList(torch.autograd.grad(scores, meta_parameter, dLds, create_graph=True)) + \ meta_parameter * reg_weight # Multiply gradient with Jacobian scores_grad = torch.autograd.grad(weights_grad, dLds, weights_grad, create_graph=True)[0] sm_scores_grad = scores_softmax * scores_grad hes_scores_grad = sm_scores_grad - scores_softmax * torch.sum(sm_scores_grad, dim=(-2, -1), keepdim=True) + \ self.hessian_reg * scores_grad grad_hes_grad = (scores_grad * hes_scores_grad).reshape( *scores.shape[:2], -1).sum(dim=2).clamp(min=0) grad_hes_grad = ( sample_weight.reshape(sample_weight.shape[0], -1) * grad_hes_grad).sum(dim=0) # Compute optimal step length gg = (weights_grad * weights_grad).reshape(scores.shape[1], -1).sum(dim=1) alpha_num = sum(gg) alpha_den = (grad_hes_grad + sum(gg * reg_weight) + self.steplength_reg * alpha_num).clamp(1e-8) alpha = step_length_factor * (alpha_num / alpha_den) # Compute optimization step step = weights_grad.apply(lambda e: alpha.reshape([ -1 if d == self._parameter_batch_dim else 1 for d in range(e.dim()) ]) * e) # Add step to parameter meta_parameter = meta_parameter - step meta_parameter_iterates.append(meta_parameter) if self.compute_losses: losses.append( _compute_loss(self.score_predictor(meta_parameter, **kwargs), meta_parameter)) # Reset the grad enabled flag torch.set_grad_enabled(torch_grad_enabled) if not torch_grad_enabled: meta_parameter.detach_() for w in meta_parameter_iterates: w.detach_() for l in losses: l.detach_() return meta_parameter, meta_parameter_iterates, losses