def train_model(self, batch_data, model, head, optim, probs, pys, meter: Meter):
        (ids, xs, axs, ys, nys) = batch_data  # type:torch.Tensor
        optim.zero_grad()
        pys = pys[ids]
        features = model(xs)  # type:torch.Tensor

        n_mask = (ys != nys)

        outs = head(features)[:params.head]

        meter.Lall = meter.Lall + self.loss_ce_(outs[0], nys, meter=meter, name='Lce')

        n_targets = tricks.onehot(nys, params.n_classes)
        p_targets = tricks.onehot(pys, params.n_classes)

        mixed_xs1, mixed_targets1 = self.mixup_(xs, n_targets, target_b=p_targets)
        mixed_xs2, mixed_targets2 = self.mixup_(xs, p_targets, target_b=n_targets)



        for i in range(1, params.head):
            meter.Lall = meter.Lall + self.loss_ce_(outs[i], pys[:, i - 1], meter=meter, name='Lfce{}'.format(i))

        # EMA
        meter.Lall.backward()
        optim.step()

        with torch.no_grad():
            # EMA
            probs[ids] = probs[ids] * 0.3 + torch.softmax(outs[0], dim=-1) * 0.7
        #
        # n_mask = (nys != ys)
        self.acc_precise_(outs[0].argmax(dim=-1), ys, meter=meter, name='acc')
        self.acc_precise_(outs[0][n_mask].argmax(dim=-1), nys[n_mask], meter=meter, name='nacc')
        self.acc_precise_(outs[1].argmax(dim=-1), ys, meter=meter, name='pacc')
示例#2
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        # 注意,ys 为训练集的真实label,只用于计算准确率,不用于训练过程
        (ids, xs, axs, ys, nys), (vxs, vys) = batch_data

        w_logits = self.model(xs)
        logits = self.model(axs)  # type:torch.Tensor

        p_targets = self.sharpen_(torch.softmax(w_logits, dim=1))

        weight = self.meta_optimizer(xs, nys, vxs, vys, meter=meter)

        meter.Lall = meter.Lall + self.loss_ce_with_masked_(
            logits, nys, weight, meter=meter, name='Lwce')
        meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(
            logits, p_targets, (1 - weight), meter=meter, name='Lwce')

        meter.tw = weight[ys == nys].mean()
        meter.fw = weight[ys != nys].mean()

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1),
                          ys,
                          meter=meter,
                          name='true_acc')
        self.acc_precise_(logits.argmax(dim=1),
                          nys,
                          meter=meter,
                          name='noisy_acc')

        return meter
示例#3
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: SupervisedParams, device: torch.device):
        super().train_batch(eidx, idx, global_step, batch_data, params, device)
        meter = Meter()
        ids, xs, axs, ys = batch_data  # type:torch.Tensor
        logits = self.to_logits(xs)

        if eidx < 10:
            # mask = (ids % 1000) < eidx
            targets = torch.ones(xs.shape[0], params.n_classes,
                                 device=device) / params.n_classes
            # targets[mask] = tricks.onehot(ys[mask], params.n_classes)

            meter.Lall = meter.Lall + self.loss_ce_with_targets_(
                logits, targets, meter=meter)
        else:
            meter.Lall = meter.Lall + self.loss_ce_(logits, ys, meter=meter)

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='acc')

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        return meter
示例#4
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: PencilParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        logits = self.to_logits(xs)
        if eidx < params.stage1:
            # lc is classification loss
            Lce = self.loss_ce_(logits, nys, meter=meter, name='Lce')

            # init y_tilde, let softmax(y_tilde) is noisy labels
            noisy_targets = tricks.onehot(nys, params.n_classes)
            self.target_mem[ids] = noisy_targets
        else:
            yy = self.target_mem[ids]
            yy = torch.autograd.Variable(yy, requires_grad=True)

            # obtain label distributions (y_hat)
            last_y_var = torch.softmax(yy, dim=1)

            Lce = self.loss_ce_with_lc_targets_(logits,
                                                last_y_var,
                                                meter=meter,
                                                name='Lce')
            Lco = self.loss_ce_(last_y_var, nys, meter=meter, name='Lco')

        # le is entropy loss
        Lent = self.loss_ent_(logits, meter=meter, name='Lent')

        if eidx < params.stage1:
            meter.Lall = Lce
        elif eidx < params.stage2:
            meter.Lall = Lce + params.alpha * Lco + params.beta * Lent
        else:
            meter.Lall = Lce

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        self.acc_precise_(logits.argmax(dim=1), nys, meter, name='noisy_acc')

        if eidx >= params.stage1 and eidx < params.stage2:
            self.target_mem[
                ids] = self.target_mem[ids] - params.lambda1 * yy.grad.data
            self.acc_precise_(self.target_mem[ids].argmax(dim=1),
                              ys,
                              meter,
                              name='check_acc')

        return meter
示例#5
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        logits = self.to_logits(xs)
        if params.smooth:
            if params.smooth_ratio != -1:
                ns_targets = tricks.onehot(nys, params.n_classes)

                if params.smooth_argmax:
                    targets = tricks.onehot(logits.argmax(dim=-1), params.n_classes)
                else:
                    rys = ys.clone()
                    rids = torch.randperm(len(rys))
                    if params.smooth_ratio > 0:
                        rids = rids[:int((len(rids) * params.smooth_ratio))]
                        rys[rids] = torch.randint(0, params.n_classes, [len(rids)], device=device)
                    targets = tricks.onehot(rys, params.n_classes)

                if params.smooth_mixup:
                    l = np.random.beta(0.75, 0.75, size=targets.shape[0])
                    l = np.max([l, 1 - l], axis=0)
                    l = torch.tensor(l, device=device, dtype=torch.float).unsqueeze(1)
                else:
                    l = 0.9

                ns_targets = ns_targets * l + targets * (1 - l)


            else:
                ns_targets = tricks.label_smoothing(tricks.onehot(nys, params.n_classes))
            meter.Lall = meter.Lall + self.loss_ce_with_targets_(logits, ns_targets, meter=meter)
        else:
            meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        with torch.no_grad():
            nlogits = self.to_logits(xs)
            res = torch.softmax(nlogits, dim=-1) - torch.softmax(logits, dim=-1)
            meter.nres = res.gather(1, nys.unsqueeze(1)).mean() * 10
            meter.tres = res.gather(1, ys.unsqueeze(1)).mean() * 100
            meter.rres = res.gather(1, torch.randint_like(ys, 0, params.n_classes).unsqueeze(1)).mean() * 100
        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='noisy_acc')

        return meter
示例#6
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: FixMatchParams, device: torch.device):
        super().train_batch(eidx, idx, global_step, batch_data, params, device)
        meter = Meter()

        sup, unsup = batch_data
        xs, ys = sup
        _, un_xs, un_axs, un_ys = unsup

        logits_list = self.to_logits(torch.cat([
            xs, un_xs, un_axs
        ])).split_with_sizes([xs.shape[0], un_xs.shape[0], un_xs.shape[0]])
        logits, un_w_logits, un_s_logits = logits_list  # type:torch.Tensor

        pseudo_targets = torch.softmax(un_w_logits, dim=-1)
        max_probs, un_pseudo_labels = torch.max(pseudo_targets, dim=-1)
        mask = (max_probs > params.pred_thresh)

        if mask.any():
            self.acc_precise_(un_w_logits.argmax(dim=1)[mask],
                              un_ys[mask],
                              meter,
                              name='umacc')

        meter.Lall = meter.Lall + self.loss_ce_(
            logits, ys, meter=meter, name='Lx')

        meter.Lall = meter.Lall + self.loss_ce_with_masked_(
            un_s_logits,
            un_pseudo_labels,
            mask.float(),
            meter=meter,
            name='Lu') * params.lambda_u

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        meter.masked = mask.float().mean()
        self.acc_precise_(logits.argmax(dim=1), ys, meter)
        self.acc_precise_(un_w_logits.argmax(dim=1),
                          un_ys,
                          meter,
                          name='uwacc')
        self.acc_precise_(un_s_logits.argmax(dim=1),
                          un_ys,
                          meter,
                          name='usacc')

        return meter
示例#7
0
    def warmup_model(self, batch_data, model, optim, meter: Meter):
        (ids, xs, axs, ys, nys) = batch_data  # type:torch.Tensor
        optim.zero_grad()

        logits = model(xs)  # type:torch.Tensor
        meter.Lall = meter.Lall + self.loss_ce_(
            logits, nys, meter=meter, name='Lce')
        if params.noisy_type == 'asymmetric':  # penalize confident prediction for asymmetric noise
            meter.Lall = meter.Lall + self.loss_minent_(
                logits, meter=meter, name='Lpen')

        self.acc_precise_(logits.argmax(dim=-1), ys, meter=meter, name='acc')
        meter.Lall.backward()
        optim.step()
示例#8
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: MetaSSLParams, device: torch.device):
        meter = Meter()

        sup, unsup = batch_data
        xs, ys = sup
        ids, un_xs, un_axs, un_ys = unsup

        mid_logits = self.to_mid(un_xs)

        self.meta_optimizer(un_xs, xs, ys, meter=meter)

        left, right = tricks.cartesian_product(mid_logits, self.cls_center)

        dist_ = F.pairwise_distance(left, right).reshape(mid_logits.shape[0], -1)
        dist_targets = torch.softmax(dist_, dim=-1)
        logits = self.to_logits(un_axs)

        meter.Lall = meter.Lall + self.loss_ce_with_targets_(logits, dist_targets, meter=meter, name='Ldce')  # Lw*

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), un_ys, meter, name='un_acc')
        self.acc_precise_(logits.argmax(dim=1), un_ys, meter, name='un_acc')

        return meter
示例#9
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: IEGParams, device: torch.device):
        meter = Meter()
        train_data, (
            vxs, vys
        ) = batch_data  # type:List[torch.Tensor],(torch.Tensor,torch.Tensor)

        ys, nys = train_data[-2:]
        xs = train_data[1]
        axs = torch.cat(train_data[2:2 + params.K])

        logits = self.to_logits(xs)
        aug_logits = self.to_logits(axs)
        n_targets = tricks.onehot(nys, params.n_classes)

        guess_targets = self.unsupervised_loss(xs,
                                               axs,
                                               vxs,
                                               vys,
                                               logits,
                                               aug_logits,
                                               meter=meter)

        weight, eps_k = self.meta_optimizer(xs,
                                            guess_targets,
                                            n_targets,
                                            vxs,
                                            vys,
                                            meter=meter)

        meter.tw = eps_k[ys == nys].mean()
        meter.fw = eps_k[ys != nys].mean()

        mixed_targets = eps_k * n_targets + (1 - eps_k) * guess_targets

        # 表面上是 和 targets 做交叉熵,实际上 targets 都是 onehot 形式的

        # loss with initial eps
        init_eps = torch.ones([guess_targets.shape[0]],
                              dtype=torch.float,
                              device=self.device) * params.grad_eps_init
        init_mixed_labels = tricks.elementwise_mul(
            init_eps, n_targets) + tricks.elementwise_mul(
                1 - init_eps, guess_targets)
        # loss with initial weight

        meter.Lws = self.loss_ce_with_targets_(logits, mixed_targets)  # Lw*
        meter.Llamda = -torch.mean(
            torch.sum(F.log_softmax(logits, dim=1) * init_mixed_labels, dim=1)
            * weight)  # Lλ*
        meter.Lall = meter.Lall + (meter.Lws + meter.Llamda) / 2

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        self.acc_precise_(logits.argmax(dim=1), nys, meter, name='noisy_acc')

        return meter
示例#10
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: MentorNetParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        # Basic parameter for mentornet
        _epoch_step = params.epoch_step(eidx)
        _zero_labels = torch.zeros_like(nys)
        _loss_p_percentile = torch.ones(
            100, dtype=torch.float) * params.loss_p_percentile
        _dropout_rates = self.parse_dropout_rate_list_()

        logits = self.to_logits(xs)
        _basic_losses = F.cross_entropy(logits, ys, reduction='none').detach()

        weight = torch.rand_like(nys).detach()  # TODO replace to mentornet

        meter.Lall = torch.mean(_basic_losses * weight)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        self.acc_precise_(logits.argmax(dim=1), nys, meter, name='noisy_acc')

        return meter
示例#11
0
    def ssl(self, eidx, batch_data, meter: Meter):
        """"""
        sup, unsup = batch_data
        _ids, xs, _ys, nys = sup
        _uids, uxs1, uxs2, _uys, unys, unprob = unsup

        targets = tricks.onehot(nys, params.n_classes)

        logits = self.to_logits(xs).detach()
        un_logits = self.to_logits(torch.cat([uxs1, uxs2]))
        un_targets = self.label_guesses_(*un_logits.chunk(2))
        un_targets = self.sharpen_(un_targets, params.T)

        mixed_input, mixed_target = self.mixmatch_up_(xs, [uxs1, uxs2], targets, un_targets)

        sup_mixed_target, unsup_mixed_target = mixed_target.detach().split_with_sizes(
            [xs.shape[0], mixed_input.shape[0] - xs.shape[0]])

        sup_mixed_logits, unsup_mixed_logits = self.to_logits(mixed_input).split_with_sizes(
            [xs.shape[0], mixed_input.shape[0] - xs.shape[0]])

        meter.Lall = meter.Lall + self.loss_ce_with_targets_(sup_mixed_logits, sup_mixed_target,
                                                             meter=meter, name='Lx')
        meter.Lall = meter.Lall + self.loss_ce_with_targets_(unsup_mixed_logits, unsup_mixed_target,
                                                             meter=meter, name='Lu') * params.w_sche(eidx)
        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        preds = torch.softmax(torch.cat([logits, un_logits.chunk(2)[0]]), dim=1).detach()
        label_pred = preds.gather(1, torch.cat([nys, unys]).unsqueeze(dim=1)).squeeze()
        ids = torch.cat([_ids, _uids])
        if params.local_filter:
            weight = label_pred - self.target_mem[ids]
            weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes
            weight_mask = weight < 0

        fweight = torch.ones(ids.shape[0], dtype=torch.float, device=self.device)
        if eidx >= params.burnin:
            if params.local_filter:
                fweight[weight_mask] -= params.gmm_w_sche(eidx)
            fweight -= self.noisy_cls[ids]
            fweight = torch.relu(fweight)
        self.filter_mem[ids] = fweight

        return meter
示例#12
0
    def unsupervised_loss(self, xs: torch.Tensor, axs: torch.Tensor,
                          vxs: torch.Tensor, vys: torch.Tensor,
                          logits: torch.Tensor, aug_logits: torch.Tensor,
                          meter: Meter):
        '''create Lub, Lpb, Lkl'''
        re_ids = torch.randperm(vxs.shape[0])
        re_vxs = vxs[re_ids]
        re_vys = vys[re_ids]

        p_target = self.label_guesses_(self.logit_norm_(logits),
                                       self.logit_norm_(aug_logits))
        p_target = self.sharpen_(p_target, params.T)

        re_v_targets = tricks.onehot(re_vys, params.n_classes)
        mixed_input, mixed_target = self.mixmatch_up_(re_vxs, [xs, axs],
                                                      re_v_targets,
                                                      p_target,
                                                      beta=params.mix_beta)

        mixed_logits = self.to_logits(mixed_input)
        mixed_logits_lis = mixed_logits.split_with_sizes(
            [re_vxs.shape[0], xs.shape[0], axs.shape[0]])
        (mixed_v_logits, mixed_n_logits,
         mixed_an_logits) = [self.logit_norm_(l)
                             for l in mixed_logits_lis]  # type:torch.Tensor

        mixed_nn_logits = torch.cat([mixed_n_logits, mixed_an_logits], dim=0)
        mixed_v_targets, mixed_nn_targets = mixed_target.split_with_sizes(
            [mixed_v_logits.shape[0], mixed_nn_logits.shape[0]])

        # Lpβ,验证集作为半监督中的有标签数据集
        meter.Lall = meter.Lall + self.loss_ce_with_targets_(
            mixed_v_logits, mixed_v_targets, meter=meter, name='Lpb')
        # p * Luβ,训练集作为半监督中的无标签数据集
        meter.Lall = meter.Lall + self.loss_ce_with_targets_(
            mixed_nn_logits, mixed_nn_targets, meter=meter, name='Lub')

        # Lkl,对多次增广的一致性损失
        meter.Lall = meter.Lall + self.loss_kl_ieg_(
            logits,
            aug_logits,
            n_classes=params.n_classes,
            consistency_factor=params.consistency_factor,
            meter=meter)
        return p_target
示例#13
0
    def warmup_model(self, batch_data, model, optim, meter: Meter):
        (ids, xs, nys) = batch_data  # type:torch.Tensor
        optim.zero_grad()

        logits = model(xs)  # type:torch.Tensor
        meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter, name='Lce')

        self.acc_precise_(logits.argmax(dim=-1), nys, meter=meter, name='acc')
        meter.Lall.backward()
        optim.step()
示例#14
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        logits = self.to_logits(xs)
        meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter)

        self.acc_precise_(logits.argmax(dim=1), nys, meter, name='acc')
        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc')

        # _err = (ys != nys)
        # ids, xs, axs, ys, nys = ids[_err], xs[_err], axs[_err], ys[_err], nys[_err]

        for i in range(params.n_classes):
            _cls_mask = (ys == 0)
            _fcls_mask = (nys == 1)
            _tloss = F.cross_entropy(logits[_cls_mask],
                                     ys[_cls_mask],
                                     reduction='none')
            _floss = F.cross_entropy(logits[_fcls_mask],
                                     nys[_fcls_mask],
                                     reduction='none')

            tcls_grads = [
                i for i in autograd.grad(_tloss,
                                         self.model.parameters(),
                                         grad_outputs=torch.ones_like(_tloss),
                                         retain_graph=True,
                                         allow_unused=True) if i is not None
            ]

            fcls_grads = [
                i for i in autograd.grad(_floss,
                                         self.model.parameters(),
                                         grad_outputs=torch.ones_like(_floss),
                                         retain_graph=True,
                                         allow_unused=True) if i is not None
            ]
            res = 0
            for tgrad, fgrad in zip(tcls_grads, fcls_grads):
                # res += (tgrad - fgrad).abs().pow(2).sum()
                # res = max(tricks.cos_similarity(tgrad, fgrad).mean(), res)
                res += tricks.cos_similarity(tgrad, fgrad).mean()

            meter.res = res
            break

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        return meter
示例#15
0
    def unsupervised_loss(self, xs: torch.Tensor, axs: torch.Tensor,
                          vxs: torch.Tensor, vys: torch.Tensor,
                          logits_lis: List[torch.Tensor], meter: Meter):
        '''create Lub, Lpb, Lkl'''

        logits_lis = [self.logit_norm_(logits) for logits in logits_lis]

        p_target = self.label_guesses_(*logits_lis)
        p_target = self.sharpen_(p_target, params.T)

        re_v_targets = tricks.onehot(vys, params.n_classes)
        mixed_input, mixed_target = self.mixmatch_up_(vxs, [axs],
                                                      re_v_targets,
                                                      p_target,
                                                      beta=params.mix_beta)

        mixed_logits = self.to_logits(mixed_input)
        mixed_logits_lis = mixed_logits.split_with_sizes(
            [vxs.shape[0], axs.shape[0]])
        (mixed_v_logits,
         mixed_nn_logits) = [self.logit_norm_(l)
                             for l in mixed_logits_lis]  # type:torch.Tensor

        # mixed_nn_logits = torch.cat([mixed_n_logits, mixed_an_logits], dim=0)
        mixed_v_targets, mixed_nn_targets = mixed_target.split_with_sizes(
            [mixed_v_logits.shape[0], mixed_nn_logits.shape[0]])

        # Lpβ,验证集作为半监督中的有标签数据集
        meter.Lall = meter.Lall + self.loss_ce_with_targets_(
            mixed_v_logits, mixed_v_targets, meter=meter,
            name='Lpb') * params.semi_sche(params.eidx)
        # p * Luβ,训练集作为半监督中的无标签数据集
        if params.lub:
            meter.Lall = meter.Lall + self.loss_ce_with_targets_(
                mixed_nn_logits, mixed_nn_targets, meter=meter,
                name='Lub') * params.semi_sche(params.eidx)

        # Lkl,对多次增广的一致性损失

        return p_target
示例#16
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: GmaParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, nys = batch_data  # type:torch.Tensor
        logits = self.to_logits(axs)
        w_logits = self.to_logits(xs)


        meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter)
        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()
        self.acc_precise_(w_logits.argmax(dim=1), nys, meter, name='acc')

        return meter
示例#17
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: MetaSSLParams, device: torch.device):
        meter = Meter()

        sup, unsup = batch_data
        xs, ys = sup
        ids, un_xs, un_axs, un_ys = unsup

        if eidx < 2:
            logits = self.to_logits(xs)
            meter.Lall = meter.Lall + self.loss_ce_(
                logits, ys, meter=meter, name='Lce')
            self.acc_precise_(logits.argmax(dim=1), ys, meter, name='acc')
        else:
            un_logits = self.to_logits(un_xs)
            weight = self.meta_optimizer(un_xs, xs, ys, meter=meter)

            # meter.Lall = meter.Lall + self.loss_ce_(self.to_logits(xs), ys, meter=meter, name='Lce')
            meter.Lall = meter.Lall + self.loss_ce_with_masked_(
                un_logits,
                un_logits.argmax(dim=1),
                weight,
                meter=meter,
                name='Ldce')  # Lw*
            self.acc_precise_(un_logits.argmax(dim=1),
                              un_ys,
                              meter,
                              name='un_acc')

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        # self.acc_precise_(dist_targets.argmax(dim=1), un_ys, meter=meter, name='dist_acc')

        return meter
示例#18
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: ICTParams, device: torch.device):
        super().train_batch(eidx, idx, global_step, batch_data, params, device)
        meter = Meter()

        sup, unsup = batch_data
        xs, ys = sup
        _, un_xs, _, un_ys = unsup

        logits_list = self.to_logits(torch.cat([xs, un_xs])).split_with_sizes(
            [xs.shape[0], un_xs.shape[0]])
        logits, un_logits = logits_list  # type:torch.Tensor

        mixed_xs, ys_a, ys_b, lam_sup = self.ict_mixup_(xs, ys)

        logits = self.to_logits(xs)
        mixed_logits = self.to_logits(mixed_xs)

        ema_un_logits = self.predict(un_xs)

        mixed_un_xs, ema_mixed_un_logits, _ = self.mixup_unsup_(un_xs, ema_un_logits, mix=True)

        mixed_un_logits = self.to_logits(mixed_un_xs)

        meter.Lall = meter.Lall + self.loss_mixup_sup_ce_(mixed_logits, ys_a, ys_b, lam_sup, meter=meter)
        meter.Lall = meter.Lall + self.loss_mixup_unsup_mse_(mixed_un_logits, ema_un_logits,
                                                             decay=params.mixup_consistency_sche(eidx),
                                                             meter=meter)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter)
        self.acc_precise_(un_logits.argmax(dim=1), un_ys, meter, name='unacc')

        return meter
示例#19
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor
        if eidx == 1:
            self.smooth_cls_mem[ids] = self.smooth_cls_mem[ids].scatter(
                1, nys.unsqueeze(1), 0.8) + 0.1
            # self.smooth_cls_mem[ids] += 0.1

        logits = self.to_logits(xs)
        ns_targets = self.smooth_cls_mem[ids]
        meter.Lall = meter.Lall + self.loss_ce_with_targets_(
            logits, ns_targets, meter=meter)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        with torch.no_grad():
            nlogits = self.to_logits(xs)
            res = (torch.softmax(nlogits, dim=-1) -
                   torch.softmax(logits, dim=-1)) / ns_targets
            # res = res.scatter(1, res.argsort(dim=-1, descending=True)[:, 2:], 0)

            ns_targets = torch.clamp(ns_targets + res * 0.1, 0, 1)
            self.smooth_cls_mem[ids] = ns_targets

            # ns_max, _ = ns_targets.max(dim=-1, keepdim=True)
            # ns_min, _ = ns_targets.min(dim=-1, keepdim=True)
            # ns_targets = (ns_targets - ns_min)
            # ns_targets = ns_targets / ns_targets.sum(dim=1, keepdims=True)
            targets = tricks.onehot(ys, params.n_classes)

            meter.cm = ((targets * res) > 0).any(dim=-1).float().mean()
            self.smooth_cls_mem[ids] = ns_targets

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc')
        self.acc_precise_(ns_targets.argmax(dim=1), ys, meter, name='cacc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask],
                              nys[n_mask],
                              meter,
                              name='nacc')

        return meter
示例#20
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: GlobalParams, device: torch.device):
        super().train_batch(eidx, idx, global_step, batch_data, params, device)
        meter = Meter()
        xs, ys = batch_data

        logits = self.to_logits(xs)

        meter.Lall = meter.Lall + self.loss_ce_(
            logits, ys, meter=meter, name='Lce')

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='acc')

        return meter
示例#21
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: SimCLRParams, device: torch.device):
        super().train_batch(eidx, idx, global_step, batch_data, params, device)
        meter = Meter()
        _, xs, axs, _ = batch_data  # type:torch.Tensor

        b, c, h, w = xs.size()
        input_ = torch.cat([xs.unsqueeze(1), axs.unsqueeze(1)], dim=1)
        input_ = input_.view(-1, c, h, w)

        output = self.to_logits(input_).view(b, 2, -1)
        meter.Lall = meter.Lall + self.loss_sim_(output, params.temperature,
                                                 device=device, meter=meter)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        return meter
示例#22
0
    def train_model(self, batch_data, model, optim, pys, probs, meter: Meter):
        (ids, xs, axs, ys, _) = batch_data  # type:torch.Tensor
        nys = pys[ids]

        optim.zero_grad()

        logits = model(xs)  # type:torch.Tensor
        meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter, name='Lce')
        meter.Lall.backward()
        optim.step()

        with torch.no_grad():
            # EMA
            logits = model(xs)
            probs[ids] = probs[ids] * 0.3 + torch.softmax(logits, dim=-1) * 0.7

        n_mask = (nys != ys)
        self.acc_precise_(logits.argmax(dim=-1), ys, meter=meter, name='acc')
        self.acc_precise_(logits[n_mask].argmax(dim=-1), nys[n_mask], meter=meter, name='nacc')
        self.acc_precise_(probs[ids].argmax(dim=-1), ys, meter=meter, name='pacc')
示例#23
0
    def warmup_model(self, batch_data, model, head, optim, probs, meter: Meter):
        """最初的测试,观察模型能够从 ys 和 nys 中学到正确的部份"""
        (ids, xs, axs, ys, nys) = batch_data  # type:torch.Tensor
        optim.zero_grad()

        features = model(xs)  # type:torch.Tensor

        n_mask = (ys != nys)

        outs = head(features)[:3]

        meter.Lall = meter.Lall + self.loss_ce_(outs[0], nys, meter=meter, name='Lce')
        meter.Lall.backward()
        optim.step()

        with torch.no_grad():
            # EMA
            probs[ids] = probs[ids] * 0.3 + torch.softmax(outs[0], dim=-1) * 0.7
        #
        # n_mask = (nys != ys)
        self.acc_precise_(outs[0].argmax(dim=-1), ys, meter=meter, name='acc')
        self.acc_precise_(outs[0][n_mask].argmax(dim=-1), nys[n_mask], meter=meter, name='nacc')
示例#24
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: SupervisedParams, device: torch.device):
        super().train_batch(eidx, idx, global_step, batch_data, params, device)
        meter = Meter()
        ids, xs, axs, ys = batch_data  # type:torch.Tensor

        targets = tricks.onehot(ys, params.n_classes)
        mixed_xs, mixed_targets = self.mixup_(xs, targets)

        mixed_logits = self.to_logits(mixed_xs)
        meter.Lall = meter.Lall + self.loss_ce_with_targets_(
            mixed_logits, mixed_targets, meter=meter)

        with torch.no_grad():
            self.acc_precise_(self.to_logits(xs).argmax(dim=1),
                              ys,
                              meter=meter,
                              name='acc')

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        return meter
示例#25
0
    def train_first(self, eidx, model, optim, batch_data, meter: Meter):
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor
        logits = model(axs)

        w_logits = model(xs)

        preds = torch.softmax(logits, dim=1).detach()
        label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze()
        # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0)
        if params.local_filter:
            weight = label_pred - self.target_mem[ids]
            weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes
            weight_mask = weight < 0
            meter.tl = weight_mask[ys == nys].float().mean()
            meter.fl = weight_mask[ys != nys].float().mean()

        fweight = torch.ones(w_logits.shape[0],
                             dtype=torch.float,
                             device=self.device)
        if eidx >= params.mix_burnin:
            if params.local_filter:
                fweight[weight_mask] -= params.gmm_w_sche(eidx)
            fweight -= self.noisy_cls[ids]
            fweight = torch.relu(fweight)

        self.filter_mem[ids] = fweight

        raw_targets = torch.softmax(w_logits, dim=1)

        with torch.no_grad():
            targets = self.plabel_mem[
                ids] * params.targets_ema + raw_targets * (1 -
                                                           params.targets_ema)
            self.plabel_mem[ids] = targets

        top_values, top_indices = targets.topk(2, dim=-1)
        p_labels = top_indices[:, 0]
        values = top_values[:, 0]

        # ratio = params.smooth_ratio_sche(eidx)

        mask = values > params.pred_thresh
        if mask.any():
            self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc')

        n_targets = tricks.onehot(nys, params.n_classes)
        p_targets = tricks.onehot(p_labels, params.n_classes)

        mask = mask.float()
        meter.pm = mask.mean()

        meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(
            logits, n_targets, fweight, meter=meter)
        meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(
            logits,
            p_targets, (1 - fweight) * values * mask,
            meter=meter,
            name='Lpce') * params.plabel_sche(eidx)

        meter.tw = fweight[ys == nys].mean()
        meter.fw = fweight[ys != nys].mean()

        if params.local_filter:
            with torch.no_grad():
                ids_mask = weight_mask.logical_not()
                alpha = params.filter_ema

                self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * \
                                                 (1 - alpha)

                # 将没有参与的逐渐回归到
                # self.target_mem[ids[weight_mask]] = self.target_mem[ids[weight_mask]] * alpha + (1 / params.n_classes) * \
                #                                  (1 - alpha)

        if 'Lall' in meter:
            optim.zero_grad()
            meter.Lall.backward()
            optim.step()
        self.acc_precise_(w_logits.argmax(dim=1), ys, meter, name='tacc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask],
                              nys[n_mask],
                              meter,
                              name='nacc')

        false_pred = targets.gather(
            1, nys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        true_pred = targets.gather(
            1, ys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        with torch.no_grad():
            self.true_pred_mem[ids, eidx - 1] = true_pred
            self.false_pred_mem[ids, eidx - 1] = false_pred
            self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits,
                                                           nys,
                                                           reduction='none')
            # mem_mask = ids < self.pred_mem_size - 1
            # self.pred_mem[ids[mem_mask], :, eidx - 1] = targets[ids[mem_mask]]

        if eidx == 1:
            self.cls_mem[ids, 0] = ys
        elif eidx == 2:
            self.cls_mem[ids, 1] = nys
        else:
            self.cls_mem[ids, eidx - 1] = p_labels
示例#26
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: IEGParams, device: torch.device):
        meter = Meter()
        train_data, (vxs, vys) = batch_data  # type:List[torch.Tensor],(torch.Tensor,torch.Tensor)

        ids = train_data[0]
        axs = train_data[1]
        xs = torch.cat(train_data[2:2 + params.K])
        ys, nys = train_data[-2:]  # type:torch.Tensor

        w_logits = self.to_logits(xs)
        aug_logits = self.to_logits(axs)

        logits = w_logits.chunk(params.K)[0]
        # logits = aug_logits  # .chunk(params.K)[0]

        w_targets = torch.softmax(w_logits.chunk(params.K)[0], dim=1).detach()
        guess_targets = self.unsupervised_loss(xs, axs, vxs, vys,
                                               logits_lis=[*w_logits.detach().chunk(params.K)],
                                               meter=meter)
        # guess_targets = self.sharpen_(torch.softmax(logits, dim=1))

        label_pred = guess_targets.gather(1, nys.unsqueeze(dim=1)).squeeze()
        # label_pred = .gather(1, nys.unsqueeze(dim=1)).squeeze()
        weight = label_pred - self.target_mem[ids]
        weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes

        fweight = torch.ones_like(weight)
        if eidx >= params.burnin:
            fweight -= self.noisy_cls[ids]
            fweight[weight <= 0] -= params.gmm_w_sche(eidx)
            fweight = torch.relu(fweight)

        raw_targets = w_targets  # guess_targets  # torch.softmax(w_logits, dim=1)

        with torch.no_grad():
            targets = self.plabel_mem[ids] * params.targets_ema + raw_targets * (1 - params.targets_ema)
            self.plabel_mem[ids] = targets
        values, p_labels = targets.max(dim=1)

        mask = values > params.pred_thresh
        if mask.any():
            self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc')
        mask = mask.float()

        meter.pm = mask.mean()

        meter.Lall = meter.Lall + self.loss_ce_with_masked_(logits, nys, fweight,
                                                            meter=meter)

        meter.Lall = meter.Lall + self.loss_ce_with_masked_(logits, p_labels,
                                                            (1 - fweight) * mask,
                                                            meter=meter,
                                                            name='Lpce') * params.plabel_sche(eidx)
        if params.lkl:
            meter.Lall = meter.Lall + self.loss_kl_ieg_(logits, aug_logits,
                                                        n_classes=params.n_classes,
                                                        consistency_factor=params.consistency_factor,
                                                        meter=meter)

        meter.tw = fweight[ys == nys].mean()
        meter.fw = fweight[ys != nys].mean()

        with torch.no_grad():
            ids_mask = weight.bool()
            alpha = params.filter_ema
            if eidx < params.burnin:
                alpha = 0.99
            self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * (1 - alpha)
            self.weight_mem[ids] = weight < 0
        if 'Lall' in meter:
            self.optim.zero_grad()
            meter.Lall.backward()
            self.optim.step()

        self.acc_precise_(w_targets.argmax(dim=1), ys, meter, name='true_acc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(w_targets.argmax(dim=1)[n_mask], nys[n_mask], meter, name='noisy_acc')

        with torch.no_grad():
            false_pred = targets.gather(1, nys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
            true_pred = targets.gather(1, ys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
            self.true_pred_mem[ids] = true_pred
            self.false_pred_mem[ids, eidx - 1] = false_pred

        return meter
示例#27
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor
        _right_mask = (ys == nys)
        _error_mask = _right_mask.logical_not()

        logits = self.to_logits(axs)

        w_logits = self.to_logits(xs)

        preds = torch.softmax(logits, dim=1).detach()
        label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze()
        # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0)
        if params.local_filter:
            weight_mask = self.count_mem[ids] < 0
            # weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes
            # weight_mask = weight < 0
            meter.tl = weight_mask[_right_mask].float().mean()
            meter.fl = weight_mask[_error_mask].float().mean()

        fweight = torch.ones(w_logits.shape[0], dtype=torch.float, device=device)
        if eidx >= params.burnin:
            fweight -= self.noisy_cls[ids]
            if params.local_filter:
                fweight[weight_mask] -= params.gmm_w_sche(eidx)
            fweight = torch.relu(fweight)

        if params.right_n > 0:
            right_mask = ids < params.right_n
            if right_mask.any():
                fweight[right_mask] = 1

        raw_targets = torch.softmax(w_logits, dim=1)

        with torch.no_grad():
            targets = self.plabel_mem[ids] * params.targets_ema + raw_targets * (1 - params.targets_ema)
            self.plabel_mem[ids] = targets
        values, p_labels = targets.max(dim=1)

        mask = values > params.pred_thresh
        if mask.any():
            self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc')
        mask = mask.float()
        # uda_mask = label_pred > params.uda_sche(eidx)
        # meter.uda = uda_mask.float().mean()

        meter.pm = mask.mean()

        meter.Lall = meter.Lall + self.loss_ce_with_masked_(logits, nys, fweight, meter=meter)

        meter.Lall = meter.Lall + self.loss_ce_with_masked_(logits, p_labels,
                                                            (1 - fweight) * mask,
                                                            meter=meter,
                                                            name='Lpce') * params.plabel_sche(eidx)

        meter.tw = fweight[_right_mask].mean()
        meter.fw = fweight[_error_mask].mean()

        if params.local_filter:
            with torch.no_grad():
                ids_mask = weight_mask.logical_not()
                alpha = params.filter_ema
                # if eidx == 1:
                #     self.target_mem[ids] = label_pred
                # else:
                if eidx < params.burnin:
                    alpha = 0.99
                self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * \
                                                 (1 - alpha)

                # 将没有参与的逐渐回归到
                # self.target_mem[ids[weight_mask]] = (self.target_mem[ids[weight_mask]] * 0.9 +
                #                                      (1 / params.n_classes) * 0.1)
                # self.count_mem[ids] += ((label_pred - self.target_mem[ids]) > 0).long() * 3 - 2
                self.count_mem[ids] += (label_pred - self.target_mem[ids])

        if 'Lall' in meter:
            self.optim.zero_grad()
            meter.Lall.backward()
            self.optim.step()
        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='nacc')

        false_pred = targets.gather(1, nys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        true_pred = targets.gather(1, ys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        with torch.no_grad():
            self.true_pred_mem[ids, eidx - 1] = true_pred
            self.false_pred_mem[ids, eidx - 1] = false_pred
            self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits, nys, reduction='none')
            # mem_mask = ids < self.pred_mem_size - 1
            # self.pred_mem[ids[mem_mask], :, eidx - 1] = targets[ids[mem_mask]]

        if eidx == 1:
            self.cls_mem[ids, 0] = ys
        elif eidx == 2:
            self.cls_mem[ids, 1] = nys
        else:
            self.cls_mem[ids, eidx - 1] = p_labels

        return meter
示例#28
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor

        if params.mixup:
            id_dict = defaultdict(list)
            for i, (y, ny) in enumerate(zip(ys, nys)):
                y, ny = int(y), int(ny)
                if y != ny:
                    pass
                    id_dict[int(ny)].append(i)  # 保证添加进去的都是噪音标签
                else:
                    pass
                    # id_dict[int(ny)].append(i)  # 保证添加进去的都是干净标签
            for i in range(params.n_classes):
                id_dict[i] = cycle(id_dict[i])

            if params.ideal_mixup:  # 理想情况下
                re_id = []
                for i, (y, ny) in enumerate(zip(ys, nys)):
                    y, ny = int(y), int(ny)
                    try:
                        # 如果原本标签就是干净标签,那么无所谓混合什么标签(因为干净标签占主导)
                        # 否则,混合进去噪音标签
                        re_id.append(next(id_dict.get(y)))
                    except:
                        re_id.append(np.random.randint(0, len(nys)))

            elif params.worst_mixup:
                re_id = []
                for i, (y, ny) in enumerate(zip(ys, nys)):
                    y, ny = int(y), int(ny)
                    try:
                        # 混合进去非该类的标签,也即 mixup 最差的情况,如果最差情况下依然要优于非 mixup
                        # 那就验证了模型可以从噪音中学习
                        # 否则就说明 mixup 的有效性源自 label smoothing 中的正确部份的标签。
                        neg_cls = np.random.choice(self.neg_dict[y])
                        re_id.append(next(id_dict.get(neg_cls)))
                    except:
                        re_id.append(np.random.randint(0, len(nys)))
            else:
                # re_id = torch.randperm(len(nys)) # 随机

                # 安排固定的 50% 可能
                re_id = []
                rand = np.random.rand(len(nys))
                for i, (y, ny, rand) in enumerate(zip(ys, nys, rand)):
                    y, ny = int(y), int(ny)
                    try:
                        if rand < 0.1:
                            # 混合进去非该类的标签,也即 mixup 最差的情况,如果最差情况下依然要优于非 mixup
                            # 那就验证了模型可以从噪音中学习
                            # 否则就说明 mixup 的有效性源自 label smoothing 中的正确部份的标签。
                            neg_cls = np.random.choice(self.neg_dict[y])
                            re_id.append(next(id_dict.get(neg_cls)))
                        else:
                            re_id.append(next(id_dict.get(y)))
                    except:
                        re_id.append(np.random.randint(0, len(nys)))

            ntargets = tricks.onehot(nys, params.n_classes)
            mixed_input, mixed_target = self.mixup_(xs, ntargets, reids=re_id)

            mixed_logits = self.to_logits(mixed_input)
            meter.Lall = meter.Lall + self.loss_ce_with_targets_(
                mixed_logits, mixed_target, meter=meter)
            logits = self.predict(xs)
        else:
            logits = self.to_logits(xs)
            meter.Lall = meter.Lall = self.loss_ce_(logits, nys, meter=meter)

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask],
                              nys[n_mask],
                              meter,
                              name='noisy_acc')

        return meter
示例#29
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: GmaParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, nys = batch_data  # type:torch.Tensor
        n_targets = tricks.onehot(nys, params.n_classes)
        mixed_xs, mixed_ntargets = self.mixup_(xs, n_targets)

        mixed_logits = self.to_logits(mixed_xs)
        # logits = self.to_logits(xs)
        w_logits = self.to_logits(xs).detach()
        # w_logits = logits.detach()

        fweight = self.filter_mem[ids]
        # fweight -= self.noisy_cls[ids]
        fweight = torch.relu(fweight)

        raw_targets = torch.softmax(w_logits, dim=1)

        if eidx == 1:
            targets = raw_targets
        else:
            targets = self.plabel_mem[ids]

        values, p_labels = targets.max(dim=-1)
        mask = values > params.pred_thresh
        if mask.any():
            self.acc_precise_(p_labels[mask], nys[mask], meter, name='pacc')

        p_targets = tricks.onehot(p_labels, params.n_classes)

        mask = mask.float()
        meter.pm = mask.mean()
        meter.nm = self.noisy_cls[ids].mean()
        meter.fm = fweight.float().mean()

        p_targets[mask.logical_not()] = n_targets[mask.logical_not()]

        # mixed_input, mixed_target = self.mixup_(axs, n_targets, beta=0.75, target_b=p_targets)
        # mixed_logits = self.to_logits(mixed_input)

        # meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(mixed_logits, mixed_target,
        #                                                             fweight,
        #                                                             meter=meter)

        fmask = fweight > 0.5
        # if eidx > 5:
        #     meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(logits, p_targets,
        #                                                                 (fmask.logical_not().float()),
        #                                                                 meter=meter,
        #                                                                 name='Lpce') * params.plabel_sche(eidx)

        if fmask.any():
            # meter.Lall = meter.Lall + self.loss_ce_with_targets_(logits[fmask], n_targets[fmask],
            # fweight[fweight > 0.5].float(),
            # meter=meter)
            meter.Lall = meter.Lall + self.loss_ce_with_targets_(
                mixed_logits[fmask],
                mixed_ntargets[fmask],
                # fweight[fweight > 0.5].float(),
                meter=meter)

            # if fmask.any():
            self.acc_precise_(w_logits.argmax(dim=-1)[fmask],
                              nys[fmask],
                              meter,
                              name='tacc')

        else:
            self.acc_precise_(w_logits.argmax(dim=-1)[fmask.logical_not()],
                              nys[fmask.logical_not()],
                              meter,
                              name='facc')

        # prior = torch.ones(params.n_classes, device=self.device) / params.n_classes
        # pred_mean = torch.softmax(logits, dim=1).mean(0)
        # meter.Lpen = torch.sum(prior * torch.log(prior / pred_mean))  # penalty
        # meter.Lall = meter.Lall + meter.Lpen

        self.optim.zero_grad()
        if 'Lall' in meter:
            meter.Lall.backward()
            self.optim.step()

        with torch.no_grad():
            preds = torch.softmax(w_logits, dim=1).detach()
            label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze()

            weight = label_pred - self.target_mem[ids]
            weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes
            weight_mask = weight < 0
            meter.wm = weight_mask.float().mean()

            ids_mask = weight_mask.logical_not()
            alpha = params.filter_ema_sche(eidx)
            if eidx == 1:
                alpha = 0.2
            self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * \
                                             (1 - alpha)

            self.acc_precise_(w_logits.argmax(dim=1), nys, meter, name='acc')

            fweight = torch.ones(nys.shape[0],
                                 dtype=torch.float,
                                 device=device)
            # fweight[weight_mask] -= params.gmm_w_sche(eidx)

            self.local_noisy_cls[ids] = fweight

            raw_targets = torch.softmax(w_logits, dim=1)

            if eidx == 1:
                targets = raw_targets
            else:
                targets = self.plabel_mem[
                    ids] * params.targets_ema + raw_targets * (
                        1 - params.targets_ema)

            self.plabel_mem[ids] = targets

            false_pred = raw_targets.gather(
                1, nys.unsqueeze(dim=1)).squeeze()  # [nys != nys]
            self.false_pred_mem[ids, eidx - 1] = false_pred

        return meter
示例#30
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor
        logits = self.to_logits(axs)

        w_logits = self.to_logits(xs)

        preds = torch.softmax(logits, dim=1).detach()
        label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze()
        # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0)
        fweight = (nys == ys).float()
        fweight[nys != ys] = params.ideal_nyw_sche(eidx)

        raw_targets = torch.softmax(w_logits, dim=1)

        with torch.no_grad():
            targets = self.plabel_mem[
                ids] * params.targets_ema + raw_targets * (1 -
                                                           params.targets_ema)
            self.plabel_mem[ids] = targets

        top_values, top_indices = targets.topk(2, dim=-1)
        p_labels = top_indices[:, 0]
        values = top_values[:, 0]

        p_targets = (torch.zeros_like(targets).scatter(
            -1, top_indices[:, 0:1], 0.9).scatter(-1, top_indices[:, 1:2],
                                                  0.1))

        mask = values > params.pred_thresh
        if mask.any():
            self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc')
        mask = mask.float()
        meter.pm = mask.mean()

        n_targets = tricks.onehot(nys, params.n_classes)
        n_targets = n_targets * 0.9 + p_targets * 0.1

        meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(
            logits, n_targets, fweight, meter=meter)
        meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(
            logits,
            p_targets, (1 - fweight) * values * mask,
            meter=meter,
            name='Lpce') * params.plabel_sche(eidx)

        meter.tw = fweight[ys == nys].mean()
        meter.fw = fweight[ys != nys].mean()

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask],
                              nys[n_mask],
                              meter,
                              name='nacc')

        false_pred = targets.gather(
            1, nys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        true_pred = targets.gather(
            1, ys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        with torch.no_grad():
            self.true_pred_mem[ids, eidx - 1] = true_pred
            self.false_pred_mem[ids, eidx - 1] = false_pred
            self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits,
                                                           nys,
                                                           reduction='none')
            # mem_mask = ids < self.pred_mem_size - 1
            # self.pred_mem[ids[mem_mask], :, eidx - 1] = targets[ids[mem_mask]]

        if eidx == 1:
            self.cls_mem[ids, 0] = ys
        elif eidx == 2:
            self.cls_mem[ids, 1] = nys
        else:
            self.cls_mem[ids, eidx - 1] = p_labels

        return meter