Пример #1
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: MnistImblanceParams, device: torch.device):
        meter = Meter()

        (xs, ys), (vxs, vys) = batch_data  # type:torch.Tensor

        w = self.meta_optimizer(xs, ys, vxs, vys, meter=meter)

        logits = self.model(xs).squeeze()
        _net_costs = F.binary_cross_entropy_with_logits(logits,
                                                        ys,
                                                        reduction='none')
        l_f = torch.sum(_net_costs * w.detach())

        self.optim.zero_grad()
        l_f.backward()
        self.optim.step()

        meter.l_f = l_f
        if (ys == 1).any():
            meter.w1 = w[ys == 1].mean()  # minority will have larger weight
        if (ys == 0).any():
            meter.w0 = w[ys == 0].mean()

        return meter
Пример #2
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: MetaSSLParams, device: torch.device):
        meter = Meter()

        sup, unsup = batch_data
        xs, ys = sup
        ids, un_xs, un_axs, un_ys = unsup

        mid_logits = self.to_mid(un_xs)

        self.meta_optimizer(un_xs, xs, ys, meter=meter)

        left, right = tricks.cartesian_product(mid_logits, self.cls_center)

        dist_ = F.pairwise_distance(left, right).reshape(mid_logits.shape[0], -1)
        dist_targets = torch.softmax(dist_, dim=-1)
        logits = self.to_logits(un_axs)

        meter.Lall = meter.Lall + self.loss_ce_with_targets_(logits, dist_targets, meter=meter, name='Ldce')  # Lw*

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), un_ys, meter, name='un_acc')
        self.acc_precise_(logits.argmax(dim=1), un_ys, meter, name='un_acc')

        return meter
Пример #3
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: IEGParams, device: torch.device):
        meter = Meter()
        train_data, (
            vxs, vys
        ) = batch_data  # type:List[torch.Tensor],(torch.Tensor,torch.Tensor)

        ys, nys = train_data[-2:]
        xs = train_data[1]
        axs = torch.cat(train_data[2:2 + params.K])

        logits = self.to_logits(xs)
        aug_logits = self.to_logits(axs)
        n_targets = tricks.onehot(nys, params.n_classes)

        guess_targets = self.unsupervised_loss(xs,
                                               axs,
                                               vxs,
                                               vys,
                                               logits,
                                               aug_logits,
                                               meter=meter)

        weight, eps_k = self.meta_optimizer(xs,
                                            guess_targets,
                                            n_targets,
                                            vxs,
                                            vys,
                                            meter=meter)

        meter.tw = eps_k[ys == nys].mean()
        meter.fw = eps_k[ys != nys].mean()

        mixed_targets = eps_k * n_targets + (1 - eps_k) * guess_targets

        # 表面上是 和 targets 做交叉熵,实际上 targets 都是 onehot 形式的

        # loss with initial eps
        init_eps = torch.ones([guess_targets.shape[0]],
                              dtype=torch.float,
                              device=self.device) * params.grad_eps_init
        init_mixed_labels = tricks.elementwise_mul(
            init_eps, n_targets) + tricks.elementwise_mul(
                1 - init_eps, guess_targets)
        # loss with initial weight

        meter.Lws = self.loss_ce_with_targets_(logits, mixed_targets)  # Lw*
        meter.Llamda = -torch.mean(
            torch.sum(F.log_softmax(logits, dim=1) * init_mixed_labels, dim=1)
            * weight)  # Lλ*
        meter.Lall = meter.Lall + (meter.Lws + meter.Llamda) / 2

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        self.acc_precise_(logits.argmax(dim=1), nys, meter, name='noisy_acc')

        return meter
Пример #4
0
    def test_eval_logic(self, dataloader, param: Params):
        from thexp.calculate import accuracy as acc
        from sklearn import metrics
        import numpy as np
        with torch.no_grad():
            meter = Meter()
            y_trues = []
            y_preds = []
            for xs, y_true in dataloader:
                preds = self.predict(xs)  # type:torch.Tensor

                y_pred = preds.argmax(dim=-1).cpu().numpy()
                y_true = y_true.cpu().numpy()

                y_preds.extend(y_pred)
                y_trues.extend(y_true)

            meter.nmi = metrics.normalized_mutual_info_score(
                np.array(y_trues),
                np.array(y_preds),
                average_method="arithmetic")

            meter.ari = metrics.adjusted_rand_score(np.array(y_trues),
                                                    np.array(y_preds))

        return meter
Пример #5
0
    def meter(self):
        from thexp import Meter

        meter = Meter()
        for key, offset in self.times.items():
            meter[key] = offset
        return meter
Пример #6
0
    def test_eval_logic(self, dataloader, param: FixMatchParams):
        from thexp.calculate import accuracy as acc

        param.topk = param.default([1, 5])

        with torch.no_grad():
            noisy_mem = torch.zeros(50000,
                                    device=self.device,
                                    dtype=torch.long)
            count_dict = Meter()
            for batch_data in dataloader:
                ids, xs, labels = batch_data
                preds = self.predict(xs)
                noisy_ys = preds.argmax(dim=1)
                noisy_mem[ids] = noisy_ys
                total, topk_res = acc.classify(preds, labels, topk=param.topk)
                count_dict["total"] += total
                for i, topi_res in zip(param.topk, topk_res):
                    count_dict["top{}".format(i)] += topi_res

        import numpy as np
        noisy_mem = noisy_mem.detach().cpu().numpy()
        np.save('noisy_{}.npy'.format(count_dict['top1']), noisy_mem)
        self.logger.info()
        return count_dict
Пример #7
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        # 注意,ys 为训练集的真实label,只用于计算准确率,不用于训练过程
        (ids, xs, axs, ys, nys), (vxs, vys) = batch_data

        w_logits = self.model(xs)
        logits = self.model(axs)  # type:torch.Tensor

        p_targets = self.sharpen_(torch.softmax(w_logits, dim=1))

        weight = self.meta_optimizer(xs, nys, vxs, vys, meter=meter)

        meter.Lall = meter.Lall + self.loss_ce_with_masked_(
            logits, nys, weight, meter=meter, name='Lwce')
        meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(
            logits, p_targets, (1 - weight), meter=meter, name='Lwce')

        meter.tw = weight[ys == nys].mean()
        meter.fw = weight[ys != nys].mean()

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1),
                          ys,
                          meter=meter,
                          name='true_acc')
        self.acc_precise_(logits.argmax(dim=1),
                          nys,
                          meter=meter,
                          name='noisy_acc')

        return meter
Пример #8
0
 def train_batch(self, eidx, idx, global_step, batch_data, params: DivideMixParams, device: torch.device):
     meter = Meter()
     if eidx <= params.warm_up:
         self.warmup_model(batch_data, self.model, self.optim, meter)
     else:
         self.train_model(self.model, self.optim, batch_data, params, meter)
     return meter
Пример #9
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: MentorNetParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        # Basic parameter for mentornet
        _epoch_step = params.epoch_step(eidx)
        _zero_labels = torch.zeros_like(nys)
        _loss_p_percentile = torch.ones(
            100, dtype=torch.float) * params.loss_p_percentile
        _dropout_rates = self.parse_dropout_rate_list_()

        logits = self.to_logits(xs)
        _basic_losses = F.cross_entropy(logits, ys, reduction='none').detach()

        weight = torch.rand_like(nys).detach()  # TODO replace to mentornet

        meter.Lall = torch.mean(_basic_losses * weight)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        self.acc_precise_(logits.argmax(dim=1), nys, meter, name='noisy_acc')

        return meter
Пример #10
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: O2UParams, device: torch.device):
        meter = Meter()
        (ids, xs, axs, ys, nys) = batch_data  # type:torch.Tensor
        # ids, xs, nys = batch_data  # type:(torch.Tensor,torch.Tensor,torch.Tensor)
        if eidx > params.acc_stage[1] and eidx < params.acc_stage[2]:
            mask = self.mask[ids] > 0.5
            if not mask.any():
                return meter
            else:
                xs = xs[mask]
                nys = nys[mask]

            meter.tw = mask[nys == ys].float().mean()
            meter.fw = mask[nys != ys].float().mean()

        logits = self.to_logits(xs)
        Lces = F.cross_entropy(logits, nys, reduction='none')

        self.example_loss[ids] = Lces

        self.acc_precise_(logits.argmax(dim=1), ys, meter=meter, name='true_acc')
        self.acc_precise_(logits.argmax(dim=1), nys, meter=meter, name='noisy_acc')
        self.globals_loss = self.globals_loss + Lces.sum().cpu().data.item()
        meter.Lce = Lces.mean()
        self.optim.zero_grad()
        meter.Lce.backward()
        self.optim.step()
        return meter
Пример #11
0
 def train_batch(self, eidx, idx, global_step, batch_data, params: CoRandomParams, device: torch.device):
     meter = Meter()
     if eidx <= params.warm_up:
         self.warmup_model(batch_data, self.model, self.head, self.optim, self.probs, meter)
     else:
         self.train_model(batch_data, self.model, self.head, self.optim, self.probs, self.pyss, meter)
     return meter
Пример #12
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: SupervisedParams, device: torch.device):
        super().train_batch(eidx, idx, global_step, batch_data, params, device)
        meter = Meter()
        ids, xs, axs, ys = batch_data  # type:torch.Tensor
        logits = self.to_logits(xs)

        if eidx < 10:
            # mask = (ids % 1000) < eidx
            targets = torch.ones(xs.shape[0], params.n_classes,
                                 device=device) / params.n_classes
            # targets[mask] = tricks.onehot(ys[mask], params.n_classes)

            meter.Lall = meter.Lall + self.loss_ce_with_targets_(
                logits, targets, meter=meter)
        else:
            meter.Lall = meter.Lall + self.loss_ce_(logits, ys, meter=meter)

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='acc')

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        return meter
Пример #13
0
    def on_train_epoch_end(self, trainer: 'IEGTrainer', func,
                           params: IEGParams, meter: Meter, *args, **kwargs):
        with torch.no_grad():
            from sklearn import metrics
            f_mean = self.false_pred_mem[:,
                                         max(params.eidx -
                                             params.gmm_burnin, 0):params.
                                         eidx].mean(dim=1).cpu().numpy()
            f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy()
            feature = np.stack([f_mean, f_cur], axis=1)

            model = tricks.group_fit(feature)
            noisy_cls = model.predict_proba(feature)[:, 0]  # type:np.ndarray

            if params.eidx == params.gmm_burnin:
                self.noisy_cls_mem = torch.tensor(noisy_cls,
                                                  device=self.device)

            # true_ncls = (self.true_pred_mem == self.false_pred_mem[:, params.eidx - 1]).cpu().numpy()
            true_ncls = (
                self.true_pred_mem == self.false_pred_mem[:, params.eidx - 1])
            m = Meter()
            if params.eidx > params.gmm_burnin:
                self.noisy_cls_mem = torch.tensor(
                    noisy_cls,
                    device=self.device) * 0.1 + self.noisy_cls_mem * 0.9
                # self.noisy_cls = 1 / (1 + torch.exp(-(10 * self.noisy_cls_mem.clone() - 5)))
                self.noisy_cls = self.noisy_cls_mem.clone()

                # self.noisy_cls[self.noisy_cls] = 0

                # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许)
                # self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_(1)
                m.gmm_t = self.noisy_cls[true_ncls].float().mean()
                m.percent(m.gmm_t_)
                m.gmm_f = self.noisy_cls[
                    true_ncls.logical_not()].float().mean()
                m.percent(m.gmm_f_)

            if params.eidx > params.burnin:
                x = self.false_pred_mem[:, 1:params.eidx - 1].cpu()
                y = torch.arange(x.shape[-1]).repeat([x.shape[0], 1]).float()
                corrcoefs = tricks.bcorrcoef(x, y).to(self.device)
                corrcoefs[torch.isnan(corrcoefs)] = 0
                mask = corrcoefs > 0
                corrcoefs[mask] = corrcoefs[mask] / corrcoefs[mask].max()
                mask.logical_not_()
                corrcoefs[mask] = corrcoefs[mask] / -corrcoefs[mask].min()

                self.corrcoefs = 1 - (1 / (1 + torch.exp(
                    -(params.corr_sigmoid_sche(params.eidx) * corrcoefs))))

                m.cor_t = self.corrcoefs[true_ncls].mean()
                m.percent(m.cor_t_)
                m.cor_f = self.corrcoefs[true_ncls.logical_not()].mean()
                m.percent(m.cor_f_)

            meter.update(m)
            self.logger.info(m)
Пример #14
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: FewshotParams, device: torch.device):
        super().train_batch(eidx, idx, global_step, batch_data, params, device)
        meter = Meter()
        losses_q = [0 for _ in range(params.update_step + 1)]
        corrects = [0 for _ in range(params.update_step + 1)]
        x_spt, y_spt, x_qry, y_qry = batch_data
        for i in range(params.task_num):
            metanet = self.create_metanet()
            logits = metanet(x_spt[i])
            loss = self.loss_ce_(logits, y_spt[i])

            grad = torch.autograd.grad(loss, metanet.params())
            metanet.update_params(params.meta_lr, grad)

            # this is the loss and accuracy before first update
            with torch.no_grad():
                # [setsz, nway]
                logits_q = self.model(x_qry[i])
                loss_q = self.loss_ce_(logits_q, y_qry[i])
                losses_q[0] += loss_q

                pred_q = logits_q.argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[0] = corrects[0] + correct

            # this is the loss and accuracy after the first update
            with torch.no_grad():
                # [setsz, nway]
                logits_q = metanet(x_qry[i])
                loss_q = self.loss_ce_(logits_q, y_qry[i])
                losses_q[1] += loss_q
                # [setsz]
                pred_q = logits_q.argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[1] = corrects[1] + correct

            for k in range(1, params.update_step):
                # 1. run the i-th task and compute loss for k=1~K-1
                logits = metanet(x_spt[i])
                loss = self.loss_ce_(logits, y_spt[i])
                # 2. compute grad on theta_pi
                grad = torch.autograd.grad(loss, metanet.params())
                # 3. theta_pi = theta_pi - train_lr * grad
                metanet.update_params(params.meta_lr, grad)

                logits_q = metanet(x_qry[i])
                # loss_q will be overwritten and just keep the loss_q on last update step.
                loss_q = self.loss_ce_(logits_q, y_qry[i])
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    pred_q = logits_q.argmax(dim=1)
                    correct = torch.eq(pred_q, y_qry[i]).sum().item()  # convert to numpy
                    corrects[k + 1] = corrects[k + 1] + correct

        return meter
Пример #15
0
    def acc_mixture_(self,
                     true_cls: np.ndarray,
                     noisy_cls: np.ndarray,
                     pre='mix'):
        meter = Meter()
        t_n, f_n = '{}t'.format(pre), '{}f'.format(pre)
        meter[t_n] = noisy_cls[true_cls].mean()
        meter[f_n] = noisy_cls[np.logical_not(true_cls)].mean()

        meter.percent(t_n)
        meter.percent(f_n)
        return meter
Пример #16
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: ImblanceParams, device: torch.device):
        meter = Meter()
        (images, labels), (_, _) = batch_data

        logits = self.model(images).squeeze()
        meter.ce_loss = F.binary_cross_entropy_with_logits(logits, labels)

        self.optim.zero_grad()
        meter.ce_loss.backward()
        self.optim.step()
        return meter
Пример #17
0
 def test_eval_logic(self, dataloader, param: Params):
     from thexp.calculate import accuracy as acc
     param.topk = param.default([1, 5])
     with torch.no_grad():
         count_dict = Meter()
         for xs, labels in dataloader:
             preds = self.predict(xs)
             total, topk_res = acc.classify(preds, labels, topk=param.topk)
             count_dict["total"] += total
             for i, topi_res in zip(param.topk, topk_res):
                 count_dict["top{}".format(i)] += topi_res
     return count_dict
Пример #18
0
 def train_batch(self, eidx, idx, global_step, batch_data, params: GmaParams, device: torch.device):
     meter = Meter()
     if eidx <= 10 or self.ssl_dataloader is None:
         self.warn_up(eidx, batch_data, meter)
     else:
         try:
             batch_data = next(self.ssl_loaderiter)
         except:
             self.ssl_loaderiter = iter(self.ssl_dataloader)
             batch_data = next(self.ssl_loaderiter)
         self.ssl(eidx, batch_data, meter)
     return meter
Пример #19
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        logits = self.to_logits(xs)
        meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter)

        self.acc_precise_(logits.argmax(dim=1), nys, meter, name='acc')
        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc')

        # _err = (ys != nys)
        # ids, xs, axs, ys, nys = ids[_err], xs[_err], axs[_err], ys[_err], nys[_err]

        for i in range(params.n_classes):
            _cls_mask = (ys == 0)
            _fcls_mask = (nys == 1)
            _tloss = F.cross_entropy(logits[_cls_mask],
                                     ys[_cls_mask],
                                     reduction='none')
            _floss = F.cross_entropy(logits[_fcls_mask],
                                     nys[_fcls_mask],
                                     reduction='none')

            tcls_grads = [
                i for i in autograd.grad(_tloss,
                                         self.model.parameters(),
                                         grad_outputs=torch.ones_like(_tloss),
                                         retain_graph=True,
                                         allow_unused=True) if i is not None
            ]

            fcls_grads = [
                i for i in autograd.grad(_floss,
                                         self.model.parameters(),
                                         grad_outputs=torch.ones_like(_floss),
                                         retain_graph=True,
                                         allow_unused=True) if i is not None
            ]
            res = 0
            for tgrad, fgrad in zip(tcls_grads, fcls_grads):
                # res += (tgrad - fgrad).abs().pow(2).sum()
                # res = max(tricks.cos_similarity(tgrad, fgrad).mean(), res)
                res += tricks.cos_similarity(tgrad, fgrad).mean()

            meter.res = res
            break

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        return meter
Пример #20
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: ImblanceParams, device: torch.device):
        meter = Meter()

        (images, labels), (val_images, val_labels) = batch_data

        metanet = MetaLeNet(1).to(device)
        metanet.load_state_dict(self.model.state_dict())
        y_f_hat = metanet(images).squeeze()
        cost = F.binary_cross_entropy_with_logits(y_f_hat,
                                                  labels,
                                                  reduction='none')
        eps = torch.zeros_like(labels, device=device, requires_grad=True)
        l_f_meta = torch.sum(cost * eps)
        metanet.zero_grad()

        grads = torch.autograd.grad(l_f_meta, (metanet.params()),
                                    create_graph=True)
        metanet.update_params(params.optim.lr, grads=grads)

        y_g_hat = metanet(val_images).squeeze()
        v_meta_loss = F.binary_cross_entropy_with_logits(y_g_hat, val_labels)
        grad_eps = torch.autograd.grad(v_meta_loss, eps, only_inputs=True)[0]
        w_tilde = torch.clamp(-grad_eps, min=0)
        norm_c = torch.sum(w_tilde)

        if norm_c != 0:
            w = w_tilde / norm_c
        else:
            w = w_tilde

        y_f_hat = self.model(images).squeeze()
        cost = F.binary_cross_entropy_with_logits(y_f_hat,
                                                  labels,
                                                  reduction='none')
        l_f = torch.sum(cost * w)

        self.optim.zero_grad()
        l_f.backward()
        self.optim.step()

        meter.l_f = l_f
        meter.meta_l = v_meta_loss
        if (labels == 0).sum() > 0:
            meter.grad_0 = grad_eps[labels == 0].mean() * 1e5
            meter.grad_0_max = grad_eps[labels == 0].max() * 1e5
            meter.grad_0_min = grad_eps[labels == 0].min() * 1e5
        meter.grad_1 = grad_eps[labels == 1].mean() * 1e5
        meter.grad_1_max = grad_eps[labels == 1].max() * 1e5
        meter.grad_1_min = grad_eps[labels == 1].min() * 1e5

        return meter
Пример #21
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: PencilParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        logits = self.to_logits(xs)
        if eidx < params.stage1:
            # lc is classification loss
            Lce = self.loss_ce_(logits, nys, meter=meter, name='Lce')

            # init y_tilde, let softmax(y_tilde) is noisy labels
            noisy_targets = tricks.onehot(nys, params.n_classes)
            self.target_mem[ids] = noisy_targets
        else:
            yy = self.target_mem[ids]
            yy = torch.autograd.Variable(yy, requires_grad=True)

            # obtain label distributions (y_hat)
            last_y_var = torch.softmax(yy, dim=1)

            Lce = self.loss_ce_with_lc_targets_(logits,
                                                last_y_var,
                                                meter=meter,
                                                name='Lce')
            Lco = self.loss_ce_(last_y_var, nys, meter=meter, name='Lco')

        # le is entropy loss
        Lent = self.loss_ent_(logits, meter=meter, name='Lent')

        if eidx < params.stage1:
            meter.Lall = Lce
        elif eidx < params.stage2:
            meter.Lall = Lce + params.alpha * Lco + params.beta * Lent
        else:
            meter.Lall = Lce

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        self.acc_precise_(logits.argmax(dim=1), nys, meter, name='noisy_acc')

        if eidx >= params.stage1 and eidx < params.stage2:
            self.target_mem[
                ids] = self.target_mem[ids] - params.lambda1 * yy.grad.data
            self.acc_precise_(self.target_mem[ids].argmax(dim=1),
                              ys,
                              meter,
                              name='check_acc')

        return meter
Пример #22
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        logits = self.to_logits(xs)
        meter.Lall = meter.Lall = self.loss_ce_(logits, nys, meter=meter)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        preds_ = logits.detach().clone()  # torch.softmax(logits, dim=-1)
        # old_preds = self.pred_mem[ids]
        # residual = preds - old_preds

        _mask = torch.arange(params.n_classes).unsqueeze(0).repeat(
            [nys.shape[0], 1]).to(device)
        mask = _mask[_mask == nys.unsqueeze(1)].view(nys.shape[0], -1)

        # residual = residual.gather(1, mask)
        # old_preds = torch.softmax(old_preds.gather(1, mask), dim=-1)
        # preds = torch.softmax(preds_.gather(1, mask), dim=-1)
        # residual = preds - old_preds

        # residual = torch.pow(residual, 2)

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask],
                              nys[n_mask],
                              meter,
                              name='nacc')

        self.pred_mem[ids] += preds_.detach()

        # self.sum_mem[ids] += residual.detach()
        # if eidx == 1:
        # else:
        #     self.pred_mem[ids] = self.pred_mem[ids] * 0.9 + logits.detach() * 0.1

        # meter.nres = residual[n_mask].mean()
        # meter.nsum = self.sum_mem[ids][n_mask].mean()
        # meter.tres = residual[n_mask.logical_not()].mean()
        # meter.tsum = self.sum_mem[ids][n_mask.logical_not()].mean()
        self.acc_precise_(self.pred_mem[ids].argmax(dim=1),
                          ys,
                          meter,
                          name='sacc')

        return meter
Пример #23
0
    def meter(self, ratio=True):
        from thexp import Meter

        meter = Meter()
        for key, offset in self.times.items():
            if ratio:
                if isinstance(offset, str):
                    continue
                meter[key] = offset / self.times['use']
            else:
                meter[key] = offset

        return meter
Пример #24
0
 def train_batch(self, eidx, idx, global_step, batch_data, params: CoRandomParams, device: torch.device):
     meter = Meter()
     if eidx < params.warm_up:
         if eidx % 2 == 0:
             self.warmup_model(batch_data, self.model, self.optim, self.probs1, meter)
         else:
             self.warmup_model(batch_data, self.model2, self.optim2, self.probs2, meter)
     else:
         if eidx % 2 == 0:
             self.train_model(batch_data, self.model, self.optim, self.pys2, self.probs1, meter)
         else:
             self.train_model(batch_data, self.model2, self.optim2, self.pys1, self.probs2, meter)
     return meter
Пример #25
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: OffsetParams, device: torch.device):
        meter = Meter()
        if params.epoch - eidx > params.offset_epoch:
            self.train_first(eidx, self.model, self.optim, batch_data, meter)
        if eidx - params.offset_epoch > 0:
            if params.epoch - eidx > params.offset_epoch:
                self.train_second(eidx, self.model2, batch_data, meter)
            else:
                self.train_first(eidx, self.model2, self.optim2, batch_data,
                                 meter)

        return meter
Пример #26
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        logits = self.to_logits(xs)
        if params.smooth:
            if params.smooth_ratio != -1:
                ns_targets = tricks.onehot(nys, params.n_classes)

                if params.smooth_argmax:
                    targets = tricks.onehot(logits.argmax(dim=-1), params.n_classes)
                else:
                    rys = ys.clone()
                    rids = torch.randperm(len(rys))
                    if params.smooth_ratio > 0:
                        rids = rids[:int((len(rids) * params.smooth_ratio))]
                        rys[rids] = torch.randint(0, params.n_classes, [len(rids)], device=device)
                    targets = tricks.onehot(rys, params.n_classes)

                if params.smooth_mixup:
                    l = np.random.beta(0.75, 0.75, size=targets.shape[0])
                    l = np.max([l, 1 - l], axis=0)
                    l = torch.tensor(l, device=device, dtype=torch.float).unsqueeze(1)
                else:
                    l = 0.9

                ns_targets = ns_targets * l + targets * (1 - l)


            else:
                ns_targets = tricks.label_smoothing(tricks.onehot(nys, params.n_classes))
            meter.Lall = meter.Lall + self.loss_ce_with_targets_(logits, ns_targets, meter=meter)
        else:
            meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        with torch.no_grad():
            nlogits = self.to_logits(xs)
            res = torch.softmax(nlogits, dim=-1) - torch.softmax(logits, dim=-1)
            meter.nres = res.gather(1, nys.unsqueeze(1)).mean() * 10
            meter.tres = res.gather(1, ys.unsqueeze(1)).mean() * 100
            meter.rres = res.gather(1, torch.randint_like(ys, 0, params.n_classes).unsqueeze(1)).mean() * 100
        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='noisy_acc')

        return meter
Пример #27
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: GmaParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, nys = batch_data  # type:torch.Tensor
        logits = self.to_logits(axs)
        w_logits = self.to_logits(xs)


        meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter)
        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()
        self.acc_precise_(w_logits.argmax(dim=1), nys, meter, name='acc')

        return meter
Пример #28
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: FixMatchParams, device: torch.device):
        super().train_batch(eidx, idx, global_step, batch_data, params, device)
        meter = Meter()

        sup, unsup = batch_data
        xs, ys = sup
        _, un_xs, un_axs, un_ys = unsup

        logits_list = self.to_logits(torch.cat([
            xs, un_xs, un_axs
        ])).split_with_sizes([xs.shape[0], un_xs.shape[0], un_xs.shape[0]])
        logits, un_w_logits, un_s_logits = logits_list  # type:torch.Tensor

        pseudo_targets = torch.softmax(un_w_logits, dim=-1)
        max_probs, un_pseudo_labels = torch.max(pseudo_targets, dim=-1)
        mask = (max_probs > params.pred_thresh)

        if mask.any():
            self.acc_precise_(un_w_logits.argmax(dim=1)[mask],
                              un_ys[mask],
                              meter,
                              name='umacc')

        meter.Lall = meter.Lall + self.loss_ce_(
            logits, ys, meter=meter, name='Lx')

        meter.Lall = meter.Lall + self.loss_ce_with_masked_(
            un_s_logits,
            un_pseudo_labels,
            mask.float(),
            meter=meter,
            name='Lu') * params.lambda_u

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        meter.masked = mask.float().mean()
        self.acc_precise_(logits.argmax(dim=1), ys, meter)
        self.acc_precise_(un_w_logits.argmax(dim=1),
                          un_ys,
                          meter,
                          name='uwacc')
        self.acc_precise_(un_s_logits.argmax(dim=1),
                          un_ys,
                          meter,
                          name='usacc')

        return meter
Пример #29
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        logits = self.to_logits(axs)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='noisy_acc')

        return meter
Пример #30
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor
        if eidx == 1:
            self.smooth_cls_mem[ids] = self.smooth_cls_mem[ids].scatter(
                1, nys.unsqueeze(1), 0.8) + 0.1
            # self.smooth_cls_mem[ids] += 0.1

        logits = self.to_logits(xs)
        ns_targets = self.smooth_cls_mem[ids]
        meter.Lall = meter.Lall + self.loss_ce_with_targets_(
            logits, ns_targets, meter=meter)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        with torch.no_grad():
            nlogits = self.to_logits(xs)
            res = (torch.softmax(nlogits, dim=-1) -
                   torch.softmax(logits, dim=-1)) / ns_targets
            # res = res.scatter(1, res.argsort(dim=-1, descending=True)[:, 2:], 0)

            ns_targets = torch.clamp(ns_targets + res * 0.1, 0, 1)
            self.smooth_cls_mem[ids] = ns_targets

            # ns_max, _ = ns_targets.max(dim=-1, keepdim=True)
            # ns_min, _ = ns_targets.min(dim=-1, keepdim=True)
            # ns_targets = (ns_targets - ns_min)
            # ns_targets = ns_targets / ns_targets.sum(dim=1, keepdims=True)
            targets = tricks.onehot(ys, params.n_classes)

            meter.cm = ((targets * res) > 0).any(dim=-1).float().mean()
            self.smooth_cls_mem[ids] = ns_targets

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc')
        self.acc_precise_(ns_targets.argmax(dim=1), ys, meter, name='cacc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask],
                              nys[n_mask],
                              meter,
                              name='nacc')

        return meter