Exemple #1
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: _trainer_name_Param, device: torch.device):
        super().train_batch(eidx, idx, global_step, batch_data, params, device)

        meter = Meter()
        model = self.model
        optim = self.optim

        xs, ys = batch_data
        xs, ys = xs.to(device), ys.to(device)

        logits = model(xs)

        meter.loss = self.lossf(logits, ys)

        optim.zero_grad()
        meter.loss.backward()
        optim.step()

        return meter
    def meta_optimizer(self, xs: torch.Tensor, vxs: torch.Tensor,
                       vys: torch.Tensor, meter: Meter):
        """
        先使用训练集数据和初始化权重更新一次,将权重封在 MetaNet 中,随后计算验证集梯度,然后求参数对初始化权重的梯度
        :param xs:
        :param guess_targets:
        :param n_targets:
        :param vxs:
        :param vys:
        :param meter:
        :return:
        """

        metanet, metasgd = self.create_metanet()
        metanet.zero_grad()

        mid_logits = metanet(xs)

        cls_center = autograd.Variable(self.cls_center, requires_grad=True)

        left, right = tricks.cartesian_product(mid_logits, cls_center)
        dist_ = F.pairwise_distance(left,
                                    right).reshape(mid_logits.shape[0], -1)
        dist_targets = torch.softmax(dist_, dim=-1)

        dist_loss = self.loss_ce_with_targets_(metanet.fc(mid_logits),
                                               dist_targets)

        var_grads = autograd.grad(dist_loss,
                                  metanet.params(),
                                  create_graph=True)

        # metanet.update_params(0.1, var_grads)
        metasgd.meta_step(var_grads)

        m_v_logits = metanet.fc(metanet(vxs))  # type:torch.Tensor
        meta_loss = self.loss_ce_(m_v_logits, vys)

        # method A
        # grad_meta_vars = autograd.grad(meta_loss, metanet.params(), create_graph=True)
        # grad_target, grad_eps = autograd.grad(
        #     metanet.params(), [cls_center, eps_k], grad_outputs=grad_meta_vars)

        # method B
        grad_target, = autograd.grad(meta_loss, [cls_center])

        with torch.no_grad():
            self.cls_center = self.cls_center - grad_target * 0.4

        self.acc_precise_(m_v_logits.argmax(dim=-1),
                          vys,
                          meter=meter,
                          name='Macc')

        meter.LMce = meta_loss.detach()
    def on_train_epoch_end(self, trainer: 'NoisyTrainer', func,
                           params: GmaParams, meter: Meter, *args, **kwargs):
        filter_mem = os.path.join(self.experiment.test_dir,
                                  'filter_mem_{}.pth'.format(params.eidx))
        local_noisy_cls = os.path.join(
            self.experiment.test_dir,
            'local_noisy_cls_{}.pth'.format(params.eidx))
        noisy_cls = os.path.join(self.experiment.test_dir,
                                 'noisy_cls_{}.pth'.format(params.eidx))
        false_pred_mem = os.path.join(self.experiment.test_dir,
                                      'false_pred_mem.pth')

        torch.save(self.filter_mem, filter_mem)
        torch.save(self.local_noisy_cls, local_noisy_cls)
        torch.save(self.noisy_cls, noisy_cls)
        torch.save(self.false_pred_mem, false_pred_mem)

        with torch.no_grad():
            id_ = 0  # max(0, params.eidx - 5)
            self.logger.info('f_mean left id', id_)
            f_mean = self.false_pred_mem[:, id_:params.eidx].mean(
                dim=1).cpu().numpy()
            f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy()
            feature = self.create_feature(f_mean, f_mean)

            noisy_cls = self.bmm_predict(feature,
                                         mean=params.feature_mean,
                                         offset=params.offset_sche(
                                             params.eidx))

            # noisy_cls = self.gmm_predict(feature)
            self.noisy_cls_mem = torch.tensor(
                noisy_cls, device=self.device) * 0.2 + self.noisy_cls_mem * 0.8

            if params.eidx <= 40 and False:
                self.noisy_cls = torch.tensor(noisy_cls, device=self.device)
                self.noisy_cls_mem = torch.tensor(noisy_cls,
                                                  device=self.device)
            else:
                self.noisy_cls = self.noisy_cls_mem.clone()

            meter.mm = (self.noisy_cls > 0.5).float().mean()
            self.logger.info('mm raw ratio', (noisy_cls > 0.5).mean(),
                             'mm ratio', meter.mm)

            if params.eidx > params.burnin:
                # self.filter_mem = self.local_noisy_cls - self.noisy_cls
                self.filter_mem = torch.ones(
                    self.train_size, dtype=torch.float,
                    device=self.device) - self.noisy_cls
                self.logger.info('filter',
                                 (self.filter_mem > 0.5).float().mean())
                values, _ = self.plabel_mem.max(dim=-1)
                self.logger.info('plabel filter',
                                 (values > params.pred_thresh).float().mean())
Exemple #4
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: SupConParams, device: torch.device):
        super().train_batch(eidx, idx, global_step, batch_data, params, device)
        meter = Meter()
        _, xs, axs, ys = batch_data  # type:torch.Tensor

        b, c, h, w = xs.size()
        input_ = torch.cat([xs.unsqueeze(1), axs.unsqueeze(1)], dim=1)
        input_ = input_.view(-1, c, h, w)

        features = self.to_logits(input_).view(b, 2, -1)

        meter.Lall = meter.Lall + self.loss_supcon_(features, ys,
                                                    temperature=params.temperature,
                                                    meter=meter)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        return meter
Exemple #5
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: GlobalParams, device: torch.device):
        super().train_batch(eidx, idx, global_step, batch_data, params, device)
        meter = Meter()
        xs, ys = batch_data

        logits = self.to_logits(xs)

        meter.Lall = meter.Lall + self.loss_ce_(
            logits, ys, meter=meter, name='Lce')

        self.any_()

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='acc')

        return meter
Exemple #6
0
    def warmup_model(self, batch_data, model, optim, meter: Meter):
        (ids, xs, nys) = batch_data  # type:torch.Tensor
        optim.zero_grad()

        logits = model(xs)  # type:torch.Tensor
        meter.Lall = meter.Lall + self.loss_ce_(
            logits, nys, meter=meter, name='Lce')

        self.acc_precise_(logits.argmax(dim=-1), nys, meter=meter, name='acc')
        meter.Lall.backward()
        optim.step()
Exemple #7
0
 def train_batch(self, eidx, idx, global_step, batch_data, params: GmaParams, device: torch.device):
     meter = Meter()
     if eidx <= 10 or self.ssl_dataloader is None:
         self.warn_up(eidx, batch_data, meter)
     else:
         try:
             batch_data = next(self.ssl_loaderiter)
         except:
             self.ssl_loaderiter = iter(self.ssl_dataloader)
             batch_data = next(self.ssl_loaderiter)
         self.ssl(eidx, batch_data, meter)
     return meter
Exemple #8
0
 def test_eval_logic(self, dataloader, param: Params):
     from thexp.calculate import accuracy as acc
     param.topk = param.default([1, 5])
     with torch.no_grad():
         count_dict = Meter()
         for xs, labels in dataloader:
             preds = self.predict(xs)
             total, topk_res = acc.classify(preds, labels, topk=param.topk)
             count_dict["total"] += total
             for i, topi_res in zip(param.topk, topk_res):
                 count_dict["top{}".format(i)] += topi_res
     return count_dict
Exemple #9
0
    def on_train_epoch_end(self, trainer: 'NoisyTrainer', func, params: NoisyParams, meter: Meter, *args, **kwargs):

        with torch.no_grad():
            f_mean = self.false_pred_mem[:, :params.eidx].mean(
                dim=1).cpu().numpy()
            f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy()
            feature = self.create_feature(f_mean, f_cur)

            noisy_cls = self.bmm_predict(feature, mean=params.feature_mean)

            true_cls = (self.true_pred_mem == self.false_pred_mem).all(dim=1).cpu().numpy()
            m = self.acc_mixture_(true_cls, noisy_cls)
            meter.update(m)
            self.logger.info(m)

            if params.eidx > params.mix_burnin:
                if params.mixt_ema:
                    self.noisy_cls_mem = torch.tensor(noisy_cls, device=self.device) * 0.1 + self.noisy_cls_mem * 0.9
                    self.noisy_cls = self.noisy_cls_mem.clone()
                else:
                    self.noisy_cls = torch.tensor(noisy_cls, device=self.device)
Exemple #10
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        # 注意,ys 为训练集的真实label,只用于计算准确率,不用于训练过程
        (ids, xs, axs, ys, nys), (vxs, vys) = batch_data

        w_logits = self.model(xs)
        logits = self.model(axs)  # type:torch.Tensor

        p_targets = self.sharpen_(torch.softmax(w_logits, dim=1))

        weight = self.meta_optimizer(xs, nys, vxs, vys, meter=meter)

        meter.Lall = meter.Lall + self.loss_ce_with_masked_(
            logits, nys, weight, meter=meter, name='Lwce')
        meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(
            logits, p_targets, (1 - weight), meter=meter, name='Lwce')

        meter.tw = weight[ys == nys].mean()
        meter.fw = weight[ys != nys].mean()

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1),
                          ys,
                          meter=meter,
                          name='true_acc')
        self.acc_precise_(logits.argmax(dim=1),
                          nys,
                          meter=meter,
                          name='noisy_acc')

        return meter
Exemple #11
0
    def unsupervised_loss(self, xs: torch.Tensor, axs: torch.Tensor,
                          vxs: torch.Tensor, vys: torch.Tensor,
                          logits_lis: List[torch.Tensor], meter: Meter):
        '''create Lub, Lpb, Lkl'''

        logits_lis = [self.logit_norm_(logits) for logits in logits_lis]

        p_target = self.label_guesses_(*logits_lis)
        p_target = self.sharpen_(p_target, params.T)

        re_v_targets = tricks.onehot(vys, params.n_classes)
        mixed_input, mixed_target = self.mixmatch_up_(vxs, [axs],
                                                      re_v_targets,
                                                      p_target,
                                                      beta=params.mix_beta)

        mixed_logits = self.to_logits(mixed_input)
        mixed_logits_lis = mixed_logits.split_with_sizes(
            [vxs.shape[0], axs.shape[0]])
        (mixed_v_logits,
         mixed_nn_logits) = [self.logit_norm_(l)
                             for l in mixed_logits_lis]  # type:torch.Tensor

        # mixed_nn_logits = torch.cat([mixed_n_logits, mixed_an_logits], dim=0)
        mixed_v_targets, mixed_nn_targets = mixed_target.split_with_sizes(
            [mixed_v_logits.shape[0], mixed_nn_logits.shape[0]])

        # Lpβ,验证集作为半监督中的有标签数据集
        meter.Lall = meter.Lall + self.loss_ce_with_targets_(
            mixed_v_logits, mixed_v_targets, meter=meter,
            name='Lpb') * params.semi_sche(params.eidx)
        # p * Luβ,训练集作为半监督中的无标签数据集
        if params.lub:
            meter.Lall = meter.Lall + self.loss_ce_with_targets_(
                mixed_nn_logits, mixed_nn_targets, meter=meter,
                name='Lub') * params.semi_sche(params.eidx)

        # Lkl,对多次增广的一致性损失

        return p_target
Exemple #12
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        logits = self.to_logits(xs)
        meter.Lall = meter.Lall = self.loss_ce_(logits, nys, meter=meter)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        preds_ = logits.detach().clone()  # torch.softmax(logits, dim=-1)
        # old_preds = self.pred_mem[ids]
        # residual = preds - old_preds

        _mask = torch.arange(params.n_classes).unsqueeze(0).repeat(
            [nys.shape[0], 1]).to(device)
        mask = _mask[_mask == nys.unsqueeze(1)].view(nys.shape[0], -1)

        # residual = residual.gather(1, mask)
        # old_preds = torch.softmax(old_preds.gather(1, mask), dim=-1)
        # preds = torch.softmax(preds_.gather(1, mask), dim=-1)
        # residual = preds - old_preds

        # residual = torch.pow(residual, 2)

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask],
                              nys[n_mask],
                              meter,
                              name='nacc')

        self.pred_mem[ids] += preds_.detach()

        # self.sum_mem[ids] += residual.detach()
        # if eidx == 1:
        # else:
        #     self.pred_mem[ids] = self.pred_mem[ids] * 0.9 + logits.detach() * 0.1

        # meter.nres = residual[n_mask].mean()
        # meter.nsum = self.sum_mem[ids][n_mask].mean()
        # meter.tres = residual[n_mask.logical_not()].mean()
        # meter.tsum = self.sum_mem[ids][n_mask.logical_not()].mean()
        self.acc_precise_(self.pred_mem[ids].argmax(dim=1),
                          ys,
                          meter,
                          name='sacc')

        return meter
Exemple #13
0
 def train_batch(self, eidx, idx, global_step, batch_data, params: CoRandomParams, device: torch.device):
     meter = Meter()
     if eidx < params.warm_up:
         if eidx % 2 == 0:
             self.warmup_model(batch_data, self.model, self.optim, self.probs1, meter)
         else:
             self.warmup_model(batch_data, self.model2, self.optim2, self.probs2, meter)
     else:
         if eidx % 2 == 0:
             self.train_model(batch_data, self.model, self.optim, self.pys2, self.probs1, meter)
         else:
             self.train_model(batch_data, self.model2, self.optim2, self.pys1, self.probs2, meter)
     return meter
Exemple #14
0
    def meter(self, ratio=True):
        from thexp import Meter

        meter = Meter()
        for key, offset in self.times.items():
            if ratio:
                if isinstance(offset, str):
                    continue
                meter[key] = offset / self.times['use']
            else:
                meter[key] = offset

        return meter
Exemple #15
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: OffsetParams, device: torch.device):
        meter = Meter()
        if params.epoch - eidx > params.offset_epoch:
            self.train_first(eidx, self.model, self.optim, batch_data, meter)
        if eidx - params.offset_epoch > 0:
            if params.epoch - eidx > params.offset_epoch:
                self.train_second(eidx, self.model2, batch_data, meter)
            else:
                self.train_first(eidx, self.model2, self.optim2, batch_data,
                                 meter)

        return meter
Exemple #16
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: MentorNetParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        # Basic parameter for mentornet
        _epoch_step = params.epoch_step(eidx)
        _zero_labels = torch.zeros_like(nys)
        _loss_p_percentile = torch.ones(
            100, dtype=torch.float) * params.loss_p_percentile
        _dropout_rates = self.parse_dropout_rate_list_()

        logits = self.to_logits(xs)
        _basic_losses = F.cross_entropy(logits, ys, reduction='none').detach()

        weight = torch.rand_like(nys).detach()  # TODO replace to mentornet

        ntargets = tricks.onehot(nys, params.n_classes)
        mixed_xs, mixed_targets = self.mentor_mixup_(
            xs, ntargets,
            weight.detach().cpu().numpy())

        mixed_logits = self.to_logits(mixed_xs)
        _mixed_losses = torch.sum(F.log_softmax(mixed_logits, dim=1) *
                                  mixed_targets,
                                  dim=1)

        mix_weight = torch.rand_like(nys).detach()  # TODO replace to mentornet

        meter.Lall = torch.mean(_mixed_losses * mix_weight)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        self.acc_precise_(logits.argmax(dim=1), nys, meter, name='noisy_acc')

        return meter
    def on_train_epoch_end(self, trainer: 'NoisyTrainer', func,
                           params: NoisyParams, meter: Meter, *args, **kwargs):
        true_f = os.path.join(self.experiment.test_dir, 'true.pth')
        false_f = os.path.join(self.experiment.test_dir, 'false.pth')
        loss_f = os.path.join(self.experiment.test_dir, 'loss.pth')
        cls_f = os.path.join(self.experiment.test_dir, 'cls.pth')
        if params.eidx % 10 == 0:
            torch.save(self.true_pred_mem, true_f)
            torch.save(self.false_pred_mem, false_f)
            torch.save(self.loss_mem, loss_f)
            torch.save(self.cls_mem, cls_f)

        with torch.no_grad():
            f_mean = self.false_pred_mem[:,
                                         max(params.eidx -
                                             params.gmm_burnin, 0):params.
                                         eidx].mean(dim=1).cpu().numpy()
            f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy()
            feature = self.create_feature(f_mean, f_cur)

            noisy_cls = self.gmm_predict(feature)

            true_cls = (self.true_pred_mem == self.false_pred_mem).all(
                dim=1).cpu().numpy()
            m = self.acc_mixture_(true_cls, noisy_cls)
            meter.update(m)
            self.logger.info(m)

            if params.eidx > params.gmm_burnin:
                self.noisy_cls_mem = torch.tensor(
                    noisy_cls,
                    device=self.device) * 0.1 + self.noisy_cls_mem * 0.9
                self.noisy_cls = self.noisy_cls_mem.clone()
                self.noisy_cls[self.noisy_cls < 0.5] = 0

                # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许)
                self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_(
                    params.gmm_w_sche(params.eidx))
    def on_train_epoch_end(self, trainer: 'NoisyTrainer', func,
                           params: GmaParams, meter: Meter, *args, **kwargs):
        true_f = os.path.join(self.experiment.test_dir, 'true.pth')
        false_f = os.path.join(self.experiment.test_dir, 'false.pth')
        loss_f = os.path.join(self.experiment.test_dir, 'loss.pth')
        if params.eidx % 10 == 0:
            torch.save(self.true_pred_mem, true_f)
            torch.save(self.false_pred_mem, false_f)
            torch.save(self.loss_mem, loss_f)

        with torch.no_grad():
            f_mean = self.false_pred_mem[:, :params.eidx].mean(
                dim=1).cpu().numpy()
            f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy()
            feature = self.create_feature(f_mean, f_cur)

            noisy_cls = self.bmm_predict(feature,
                                         mean=params.feature_mean,
                                         offset=params.offset_sche(
                                             params.eidx))
            if params.eidx > 1:
                self.noisy_cls_mem = torch.tensor(
                    noisy_cls,
                    device=self.device) * 0.1 + self.noisy_cls_mem * 0.9
                true_cls = (self.true_pred_mem == self.false_pred_mem).all(
                    dim=1).cpu().numpy()
                m = self.acc_mixture_(true_cls,
                                      self.noisy_cls_mem.cpu().numpy())
                meter.update(m)
                self.logger.info(m)

                if params.eidx > params.mix_burnin:
                    if params.mixt_ema:
                        self.noisy_cls = self.noisy_cls_mem.clone()
                    else:
                        self.noisy_cls = torch.tensor(noisy_cls,
                                                      device=self.device)
    def on_train_epoch_end(self, trainer: 'NoisyTrainer', func,
                           params: NoisyParams, meter: Meter, *args, **kwargs):

        with torch.no_grad():
            f_mean = self.false_pred_mem[:, :params.eidx].mean(
                dim=1).cpu().numpy()
            f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy()
            feature = self.create_feature(f_mean, f_cur)

            noisy_cls = self.bmm_predict(feature, mean=params.feature_mean)

            true_cls = (self.true_pred_mem == self.false_pred_mem).all(
                dim=1).cpu().numpy()
            m = self.acc_mixture_(true_cls, noisy_cls)
            meter.update(m)
            self.logger.info(m)

            if params.eidx > params.mix_burnin:
                if params.mixt_ema:
                    self.noisy_cls_mem = torch.tensor(
                        noisy_cls,
                        device=self.device) * 0.1 + self.noisy_cls_mem * 0.9
                    self.noisy_cls = self.noisy_cls_mem.clone()
                else:
                    self.noisy_cls = torch.tensor(noisy_cls,
                                                  device=self.device)
                # self.noisy_cls[self.noisy_cls < 0.5] = 0

                # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许)
                # self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_(params.gmm_w_sche(params.eidx))

            m2 = self.acc_mixture_(true_cls,
                                   (self.count_mem >= 0).cpu().numpy(),
                                   pre='con')
            meter.update(m)
            self.logger.info(m2)
Exemple #20
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: ICTParams, device: torch.device):
        super().train_batch(eidx, idx, global_step, batch_data, params, device)
        meter = Meter()

        sup, unsup = batch_data
        xs, ys = sup
        _, un_xs, _, un_ys = unsup

        logits_list = self.to_logits(torch.cat([xs, un_xs])).split_with_sizes(
            [xs.shape[0], un_xs.shape[0]])
        logits, un_logits = logits_list  # type:torch.Tensor

        mixed_xs, ys_a, ys_b, lam_sup = self.ict_mixup_(xs, ys)

        logits = self.to_logits(xs)
        mixed_logits = self.to_logits(mixed_xs)

        ema_un_logits = self.predict(un_xs)

        mixed_un_xs, ema_mixed_un_logits, _ = self.mixup_unsup_(un_xs, ema_un_logits, mix=True)

        mixed_un_logits = self.to_logits(mixed_un_xs)

        meter.Lall = meter.Lall + self.loss_mixup_sup_ce_(mixed_logits, ys_a, ys_b, lam_sup, meter=meter)
        meter.Lall = meter.Lall + self.loss_mixup_unsup_mse_(mixed_un_logits, ema_un_logits,
                                                             decay=params.mixup_consistency_sche(eidx),
                                                             meter=meter)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter)
        self.acc_precise_(un_logits.argmax(dim=1), un_ys, meter, name='unacc')

        return meter
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: MetaSSLParams, device: torch.device):
        meter = Meter()

        sup, unsup = batch_data
        xs, ys = sup
        ids, un_xs, un_axs, un_ys = unsup

        if eidx < 2:
            logits = self.to_logits(xs)
            meter.Lall = meter.Lall + self.loss_ce_(
                logits, ys, meter=meter, name='Lce')
            self.acc_precise_(logits.argmax(dim=1), ys, meter, name='acc')
        else:
            un_logits = self.to_logits(un_xs)
            weight = self.meta_optimizer(un_xs, xs, ys, meter=meter)

            # meter.Lall = meter.Lall + self.loss_ce_(self.to_logits(xs), ys, meter=meter, name='Lce')
            meter.Lall = meter.Lall + self.loss_ce_with_masked_(
                un_logits,
                un_logits.argmax(dim=1),
                weight,
                meter=meter,
                name='Ldce')  # Lw*
            self.acc_precise_(un_logits.argmax(dim=1),
                              un_ys,
                              meter,
                              name='un_acc')

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        # self.acc_precise_(dist_targets.argmax(dim=1), un_ys, meter=meter, name='dist_acc')

        return meter
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: MetaSSLParams, device: torch.device):
        meter = Meter()

        sup, unsup = batch_data
        xs, ys = sup
        ids, un_xs, un_axs, un_ys = unsup

        mid_logits = self.to_mid(un_xs)

        self.meta_optimizer(un_xs, xs, ys, meter=meter)

        left, right = tricks.cartesian_product(mid_logits, self.cls_center)

        dist_ = F.pairwise_distance(left,
                                    right).reshape(mid_logits.shape[0], -1)
        dist_targets = torch.softmax(dist_, dim=-1)
        self.sharpen_(dist_targets)
        logits = self.to_logits(un_axs)

        # meter.Lall = meter.Lall + self.loss_ce_(self.to_logits(xs), ys, meter=meter, name='Lce')
        meter.Lall = meter.Lall + self.loss_ce_with_targets_(
            logits, dist_targets, meter=meter, name='Ldce')  # Lw*

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(dist_targets.argmax(dim=1),
                          un_ys,
                          meter=meter,
                          name='dist_acc')
        self.acc_precise_(logits.argmax(dim=1), un_ys, meter, name='un_acc')
        self.acc_precise_(logits.argmax(dim=1), un_ys, meter, name='un_acc')

        return meter
    def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        logits = self.to_logits(xs)
        if params.smooth:
            if params.smooth_ratio != -1:
                ns_targets = tricks.onehot(nys, params.n_classes)

                if params.smooth_argmax:
                    targets = tricks.onehot(logits.argmax(dim=-1), params.n_classes)
                else:
                    rys = ys.clone()
                    rids = torch.randperm(len(rys))
                    if params.smooth_ratio > 0:
                        rids = rids[:int((len(rids) * params.smooth_ratio))]
                        rys[rids] = torch.randint(0, params.n_classes, [len(rids)], device=device)
                    targets = tricks.onehot(rys, params.n_classes)

                if params.smooth_mixup:
                    l = np.random.beta(0.75, 0.75, size=targets.shape[0])
                    l = np.max([l, 1 - l], axis=0)
                    l = torch.tensor(l, device=device, dtype=torch.float).unsqueeze(1)
                else:
                    l = 0.9

                ns_targets = ns_targets * l + targets * (1 - l)


            else:
                ns_targets = tricks.label_smoothing(tricks.onehot(nys, params.n_classes))
            meter.Lall = meter.Lall + self.loss_ce_with_targets_(logits, ns_targets, meter=meter)
        else:
            meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        with torch.no_grad():
            nlogits = self.to_logits(xs)
            res = torch.softmax(nlogits, dim=-1) - torch.softmax(logits, dim=-1)
            meter.nres = res.gather(1, nys.unsqueeze(1)).mean() * 10
            meter.tres = res.gather(1, ys.unsqueeze(1)).mean() * 100
            meter.rres = res.gather(1, torch.randint_like(ys, 0, params.n_classes).unsqueeze(1)).mean() * 100
        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='noisy_acc')

        return meter
Exemple #24
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: SupervisedParams, device: torch.device):
        super().train_batch(eidx, idx, global_step, batch_data, params, device)
        meter = Meter()
        ids, xs, axs, ys = batch_data  # type:torch.Tensor

        targets = tricks.onehot(ys, params.n_classes)
        mixed_xs, mixed_targets = self.mixup_(xs, targets)

        mixed_logits = self.to_logits(mixed_xs)
        meter.Lall = meter.Lall + self.loss_ce_with_targets_(
            mixed_logits, mixed_targets, meter=meter)

        with torch.no_grad():
            self.acc_precise_(self.to_logits(xs).argmax(dim=1),
                              ys,
                              meter=meter,
                              name='acc')

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        return meter
Exemple #25
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        logits = self.to_logits(axs)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='noisy_acc')

        return meter
Exemple #26
0
 def train_batch(self, eidx, idx, global_step, batch_data,
                 params: DivideMixParams, device: torch.device):
     meter = Meter()
     if eidx < params.warm_up:
         if eidx % 2 == 0:
             self.warmup_model(batch_data, self.model, self.optim, meter)
         else:
             self.warmup_model(batch_data, self.model2, self.optim2, meter)
     else:
         if eidx % 2 == 0:
             self.train_model(self.model, self.model2, self.optim,
                              batch_data, params, meter)
         else:
             self.train_model(self.model2, self.model, self.optim2,
                              batch_data, params, meter)
     return meter
Exemple #27
0
    def train_model(self, batch_data, model, optim, pys, probs, meter: Meter):
        (ids, xs, axs, ys, _) = batch_data  # type:torch.Tensor
        nys = pys[ids]

        optim.zero_grad()

        logits = model(xs)  # type:torch.Tensor
        meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter, name='Lce')
        meter.Lall.backward()
        optim.step()

        with torch.no_grad():
            # EMA
            logits = model(xs)
            probs[ids] = probs[ids] * 0.3 + torch.softmax(logits, dim=-1) * 0.7

        n_mask = (nys != ys)
        self.acc_precise_(logits.argmax(dim=-1), ys, meter=meter, name='acc')
        self.acc_precise_(logits[n_mask].argmax(dim=-1), nys[n_mask], meter=meter, name='nacc')
        self.acc_precise_(probs[ids].argmax(dim=-1), ys, meter=meter, name='pacc')
    def warmup_model(self, batch_data, model, head, optim, probs, meter: Meter):
        """最初的测试,观察模型能够从 ys 和 nys 中学到正确的部份"""
        (ids, xs, axs, ys, nys) = batch_data  # type:torch.Tensor
        optim.zero_grad()

        features = model(xs)  # type:torch.Tensor

        n_mask = (ys != nys)

        outs = head(features)[:3]

        meter.Lall = meter.Lall + self.loss_ce_(outs[0], nys, meter=meter, name='Lce')
        meter.Lall.backward()
        optim.step()

        with torch.no_grad():
            # EMA
            probs[ids] = probs[ids] * 0.3 + torch.softmax(outs[0], dim=-1) * 0.7
        #
        # n_mask = (nys != ys)
        self.acc_precise_(outs[0].argmax(dim=-1), ys, meter=meter, name='acc')
        self.acc_precise_(outs[0][n_mask].argmax(dim=-1), nys[n_mask], meter=meter, name='nacc')
Exemple #29
0
    def train_first(self, eidx, model, optim, batch_data, meter: Meter):
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor
        logits = model(axs)

        w_logits = model(xs)

        preds = torch.softmax(logits, dim=1).detach()
        label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze()
        # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0)
        if params.local_filter:
            weight = label_pred - self.target_mem[ids]
            weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes
            weight_mask = weight < 0
            meter.tl = weight_mask[ys == nys].float().mean()
            meter.fl = weight_mask[ys != nys].float().mean()

        fweight = torch.ones(w_logits.shape[0],
                             dtype=torch.float,
                             device=self.device)
        if eidx >= params.mix_burnin:
            if params.local_filter:
                fweight[weight_mask] -= params.gmm_w_sche(eidx)
            fweight -= self.noisy_cls[ids]
            fweight = torch.relu(fweight)

        self.filter_mem[ids] = fweight

        raw_targets = torch.softmax(w_logits, dim=1)

        with torch.no_grad():
            targets = self.plabel_mem[
                ids] * params.targets_ema + raw_targets * (1 -
                                                           params.targets_ema)
            self.plabel_mem[ids] = targets

        top_values, top_indices = targets.topk(2, dim=-1)
        p_labels = top_indices[:, 0]
        values = top_values[:, 0]

        # ratio = params.smooth_ratio_sche(eidx)

        mask = values > params.pred_thresh
        if mask.any():
            self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc')

        n_targets = tricks.onehot(nys, params.n_classes)
        p_targets = tricks.onehot(p_labels, params.n_classes)

        mask = mask.float()
        meter.pm = mask.mean()

        meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(
            logits, n_targets, fweight, meter=meter)
        meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(
            logits,
            p_targets, (1 - fweight) * values * mask,
            meter=meter,
            name='Lpce') * params.plabel_sche(eidx)

        meter.tw = fweight[ys == nys].mean()
        meter.fw = fweight[ys != nys].mean()

        if params.local_filter:
            with torch.no_grad():
                ids_mask = weight_mask.logical_not()
                alpha = params.filter_ema

                self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * \
                                                 (1 - alpha)

                # 将没有参与的逐渐回归到
                # self.target_mem[ids[weight_mask]] = self.target_mem[ids[weight_mask]] * alpha + (1 / params.n_classes) * \
                #                                  (1 - alpha)

        if 'Lall' in meter:
            optim.zero_grad()
            meter.Lall.backward()
            optim.step()
        self.acc_precise_(w_logits.argmax(dim=1), ys, meter, name='tacc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask],
                              nys[n_mask],
                              meter,
                              name='nacc')

        false_pred = targets.gather(
            1, nys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        true_pred = targets.gather(
            1, ys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        with torch.no_grad():
            self.true_pred_mem[ids, eidx - 1] = true_pred
            self.false_pred_mem[ids, eidx - 1] = false_pred
            self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits,
                                                           nys,
                                                           reduction='none')
            # mem_mask = ids < self.pred_mem_size - 1
            # self.pred_mem[ids[mem_mask], :, eidx - 1] = targets[ids[mem_mask]]

        if eidx == 1:
            self.cls_mem[ids, 0] = ys
        elif eidx == 2:
            self.cls_mem[ids, 1] = nys
        else:
            self.cls_mem[ids, eidx - 1] = p_labels
Exemple #30
0
    distributed under the GNU General Public License, please contact 
            
            [email protected]
             
    to purchase a commercial license.
"""
import sys
sys.path.insert(0,"../")
from thexp import __VERSION__
print(__VERSION__)


from thexp import Meter,AvgMeter
import torch

m = Meter()
m.a = 1
m.b = "2"
m.c = torch.rand(1)[0]

m.c1 = torch.rand(1)
m.c2 = torch.rand(2)
m.c3 = torch.rand(4, 4)
print(m)

m = Meter()
m.a = 0.236
m.b = 3.236
m.c = 0.23612312
m.percent(m.a_)
m.int(m.b_)