def meta_optimizer(self, xs: torch.Tensor, vxs: torch.Tensor,
                       vys: torch.Tensor, meter: Meter):
        """
        先使用训练集数据和初始化权重更新一次,将权重封在 MetaNet 中,随后计算验证集梯度,然后求参数对初始化权重的梯度
        :param xs:
        :param guess_targets:
        :param n_targets:
        :param vxs:
        :param vys:
        :param meter:
        :return:
        """

        metanet, metasgd = self.create_metanet()
        metanet.zero_grad()

        logits = metanet(xs)

        weight_0 = torch.ones(xs.shape[0],
                              dtype=torch.float,
                              device=self.device) * params.init_eps_val
        weight_0 = autograd.Variable(weight_0, requires_grad=True)

        dist_loss = self.loss_ce_with_masked_(logits, logits.argmax(dim=1),
                                              weight_0)

        var_grads = autograd.grad(dist_loss,
                                  metanet.params(),
                                  create_graph=True)

        # metanet.update_params(0.1, var_grads)
        metasgd.meta_step(var_grads)

        m_v_logits = metanet(vxs)  # type:torch.Tensor
        meta_loss = self.loss_ce_(m_v_logits, vys)

        # method A
        # grad_meta_vars = autograd.grad(meta_loss, metanet.params(), create_graph=True)
        # grad_target, grad_eps = autograd.grad(
        #     metanet.params(), [cls_center, eps_k], grad_outputs=grad_meta_vars)

        # method B
        grad_target, = autograd.grad(meta_loss, [weight_0])

        raw_weight = weight_0 - grad_target
        raw_weight = raw_weight - params.init_eps_val
        unorm_weight = raw_weight.clamp_min(0)
        norm_c = unorm_weight.sum()
        weight = torch.div(unorm_weight, norm_c + 0.00001).detach()

        self.acc_precise_(m_v_logits.argmax(dim=-1),
                          vys,
                          meter=meter,
                          name='Macc')

        meter.LMce = meta_loss.detach()

        return weight
    def meta_optimizer(self, xs: torch.Tensor, vxs: torch.Tensor,
                       vys: torch.Tensor, meter: Meter):
        """
        先使用训练集数据和初始化权重更新一次,将权重封在 MetaNet 中,随后计算验证集梯度,然后求参数对初始化权重的梯度
        :param xs:
        :param guess_targets:
        :param n_targets:
        :param vxs:
        :param vys:
        :param meter:
        :return:
        """

        metanet, metasgd = self.create_metanet()
        metanet.zero_grad()

        mid_logits = metanet(xs)

        cls_center = autograd.Variable(self.cls_center, requires_grad=True)

        left, right = tricks.cartesian_product(mid_logits, cls_center)
        dist_ = F.pairwise_distance(left,
                                    right).reshape(mid_logits.shape[0], -1)
        dist_targets = torch.softmax(dist_, dim=-1)

        dist_loss = self.loss_ce_with_targets_(metanet.fc(mid_logits),
                                               dist_targets)

        var_grads = autograd.grad(dist_loss,
                                  metanet.params(),
                                  create_graph=True)

        # metanet.update_params(0.1, var_grads)
        metasgd.meta_step(var_grads)

        m_v_logits = metanet.fc(metanet(vxs))  # type:torch.Tensor
        meta_loss = self.loss_ce_(m_v_logits, vys)

        # method A
        # grad_meta_vars = autograd.grad(meta_loss, metanet.params(), create_graph=True)
        # grad_target, grad_eps = autograd.grad(
        #     metanet.params(), [cls_center, eps_k], grad_outputs=grad_meta_vars)

        # method B
        grad_target, = autograd.grad(meta_loss, [cls_center])

        with torch.no_grad():
            self.cls_center = self.cls_center - grad_target * 0.4

        self.acc_precise_(m_v_logits.argmax(dim=-1),
                          vys,
                          meter=meter,
                          name='Macc')

        meter.LMce = meta_loss.detach()
    def meta_optimizer(self, xs: torch.Tensor, guess_targets: torch.Tensor,
                       n_targets: torch.Tensor, vxs: torch.Tensor,
                       vys: torch.Tensor, meter: Meter):
        """
        先使用训练集数据和初始化权重更新一次,将权重封在 MetaNet 中,随后计算验证集梯度,然后求参数对初始化权重的梯度
        :param xs:
        :param guess_targets:
        :param n_targets:
        :param vxs:
        :param vys:
        :param meter:
        :return:
        """
        ## 0. create a metanet which can hold hyperparameter gradients.
        metanet, metasgd = self.create_metanet()

        ## 1. calculate loss of train sample with hyperparameters
        # the hyperparameter used to reweight, requires_grad must be True.
        weight_0 = torch.ones(guess_targets.shape[0],
                              dtype=torch.float,
                              device=self.device) * params.init_eps_val
        weight_0 = autograd.Variable(weight_0, requires_grad=True)

        eps = torch.ones([guess_targets.shape[0]],
                         dtype=torch.float,
                         device=self.device) * params.grad_eps_init
        eps = autograd.Variable(eps, requires_grad=True)
        eps_k = eps.unsqueeze(dim=-1)

        logits = metanet(xs)
        mixed_labels = eps_k * n_targets + (1 - eps_k) * guess_targets
        # ce loss with targets and none reduction
        _net_cost = -torch.mean(
            mixed_labels * torch.log_softmax(logits, dim=1), dim=1)
        lookahead_loss = torch.mul(weight_0, _net_cost).mean()

        ## 2. update gradient of train samples
        var_grads = autograd.grad(lookahead_loss,
                                  metanet.params(),
                                  create_graph=True)
        metanet.update_params(0.1, var_grads)
        # or metasgd.meta_step(var_grads)

        ## 3. calculate gradient of meta validate sample
        m_v_logits = metanet(vxs)  # type:torch.Tensor
        v_targets = tricks.onehot(vys, params.n_classes)
        meta_loss = self.loss_ce_with_targets_(m_v_logits, v_targets)

        # method A
        # grad_meta_vars = autograd.grad(meta_loss, metanet.params(), create_graph=True)
        # grad_target, grad_eps = autograd.grad(metanet.params(), [weight_0, eps_k],
        #                                       grad_outputs=grad_meta_vars)

        # equal method B
        grad_target, grad_eps = autograd.grad(meta_loss, [weight_0, eps_k])

        ## 4. build weight by meta vlidate gradient
        raw_weight = weight_0 - grad_target
        raw_weight = raw_weight - params.init_eps_val
        unorm_weight = raw_weight.clamp_min(0)
        norm_c = unorm_weight.sum()
        weight = torch.div(unorm_weight, norm_c + 0.00001).detach()
        new_eps = (grad_eps < 0).float().unsqueeze(dim=-1).detach()

        self.acc_precise_(m_v_logits.argmax(dim=-1),
                          vys,
                          meter=meter,
                          name='Macc')

        meter.LMce = meta_loss.detach()
        return weight, new_eps