def meta_optimizer(self, xs: torch.Tensor, vxs: torch.Tensor, vys: torch.Tensor, meter: Meter): """ 先使用训练集数据和初始化权重更新一次,将权重封在 MetaNet 中,随后计算验证集梯度,然后求参数对初始化权重的梯度 :param xs: :param guess_targets: :param n_targets: :param vxs: :param vys: :param meter: :return: """ metanet, metasgd = self.create_metanet() metanet.zero_grad() logits = metanet(xs) weight_0 = torch.ones(xs.shape[0], dtype=torch.float, device=self.device) * params.init_eps_val weight_0 = autograd.Variable(weight_0, requires_grad=True) dist_loss = self.loss_ce_with_masked_(logits, logits.argmax(dim=1), weight_0) var_grads = autograd.grad(dist_loss, metanet.params(), create_graph=True) # metanet.update_params(0.1, var_grads) metasgd.meta_step(var_grads) m_v_logits = metanet(vxs) # type:torch.Tensor meta_loss = self.loss_ce_(m_v_logits, vys) # method A # grad_meta_vars = autograd.grad(meta_loss, metanet.params(), create_graph=True) # grad_target, grad_eps = autograd.grad( # metanet.params(), [cls_center, eps_k], grad_outputs=grad_meta_vars) # method B grad_target, = autograd.grad(meta_loss, [weight_0]) raw_weight = weight_0 - grad_target raw_weight = raw_weight - params.init_eps_val unorm_weight = raw_weight.clamp_min(0) norm_c = unorm_weight.sum() weight = torch.div(unorm_weight, norm_c + 0.00001).detach() self.acc_precise_(m_v_logits.argmax(dim=-1), vys, meter=meter, name='Macc') meter.LMce = meta_loss.detach() return weight
def meta_optimizer(self, xs: torch.Tensor, vxs: torch.Tensor, vys: torch.Tensor, meter: Meter): """ 先使用训练集数据和初始化权重更新一次,将权重封在 MetaNet 中,随后计算验证集梯度,然后求参数对初始化权重的梯度 :param xs: :param guess_targets: :param n_targets: :param vxs: :param vys: :param meter: :return: """ metanet, metasgd = self.create_metanet() metanet.zero_grad() mid_logits = metanet(xs) cls_center = autograd.Variable(self.cls_center, requires_grad=True) left, right = tricks.cartesian_product(mid_logits, cls_center) dist_ = F.pairwise_distance(left, right).reshape(mid_logits.shape[0], -1) dist_targets = torch.softmax(dist_, dim=-1) dist_loss = self.loss_ce_with_targets_(metanet.fc(mid_logits), dist_targets) var_grads = autograd.grad(dist_loss, metanet.params(), create_graph=True) # metanet.update_params(0.1, var_grads) metasgd.meta_step(var_grads) m_v_logits = metanet.fc(metanet(vxs)) # type:torch.Tensor meta_loss = self.loss_ce_(m_v_logits, vys) # method A # grad_meta_vars = autograd.grad(meta_loss, metanet.params(), create_graph=True) # grad_target, grad_eps = autograd.grad( # metanet.params(), [cls_center, eps_k], grad_outputs=grad_meta_vars) # method B grad_target, = autograd.grad(meta_loss, [cls_center]) with torch.no_grad(): self.cls_center = self.cls_center - grad_target * 0.4 self.acc_precise_(m_v_logits.argmax(dim=-1), vys, meter=meter, name='Macc') meter.LMce = meta_loss.detach()
def meta_optimizer(self, xs: torch.Tensor, guess_targets: torch.Tensor, n_targets: torch.Tensor, vxs: torch.Tensor, vys: torch.Tensor, meter: Meter): """ 先使用训练集数据和初始化权重更新一次,将权重封在 MetaNet 中,随后计算验证集梯度,然后求参数对初始化权重的梯度 :param xs: :param guess_targets: :param n_targets: :param vxs: :param vys: :param meter: :return: """ ## 0. create a metanet which can hold hyperparameter gradients. metanet, metasgd = self.create_metanet() ## 1. calculate loss of train sample with hyperparameters # the hyperparameter used to reweight, requires_grad must be True. weight_0 = torch.ones(guess_targets.shape[0], dtype=torch.float, device=self.device) * params.init_eps_val weight_0 = autograd.Variable(weight_0, requires_grad=True) eps = torch.ones([guess_targets.shape[0]], dtype=torch.float, device=self.device) * params.grad_eps_init eps = autograd.Variable(eps, requires_grad=True) eps_k = eps.unsqueeze(dim=-1) logits = metanet(xs) mixed_labels = eps_k * n_targets + (1 - eps_k) * guess_targets # ce loss with targets and none reduction _net_cost = -torch.mean( mixed_labels * torch.log_softmax(logits, dim=1), dim=1) lookahead_loss = torch.mul(weight_0, _net_cost).mean() ## 2. update gradient of train samples var_grads = autograd.grad(lookahead_loss, metanet.params(), create_graph=True) metanet.update_params(0.1, var_grads) # or metasgd.meta_step(var_grads) ## 3. calculate gradient of meta validate sample m_v_logits = metanet(vxs) # type:torch.Tensor v_targets = tricks.onehot(vys, params.n_classes) meta_loss = self.loss_ce_with_targets_(m_v_logits, v_targets) # method A # grad_meta_vars = autograd.grad(meta_loss, metanet.params(), create_graph=True) # grad_target, grad_eps = autograd.grad(metanet.params(), [weight_0, eps_k], # grad_outputs=grad_meta_vars) # equal method B grad_target, grad_eps = autograd.grad(meta_loss, [weight_0, eps_k]) ## 4. build weight by meta vlidate gradient raw_weight = weight_0 - grad_target raw_weight = raw_weight - params.init_eps_val unorm_weight = raw_weight.clamp_min(0) norm_c = unorm_weight.sum() weight = torch.div(unorm_weight, norm_c + 0.00001).detach() new_eps = (grad_eps < 0).float().unsqueeze(dim=-1).detach() self.acc_precise_(m_v_logits.argmax(dim=-1), vys, meter=meter, name='Macc') meter.LMce = meta_loss.detach() return weight, new_eps