Ejemplo n.º 1
0
    def on_train_epoch_end(self, trainer: 'IEGTrainer', func, params: NoisyParams, meter: Meter, *args, **kwargs):

        with torch.no_grad():
            from sklearn import metrics
            f_mean = self.false_pred_mem[:, max(params.eidx - params.gmm_burnin, 0):params.eidx].mean(
                dim=1).cpu().numpy()
            f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy()
            feature = np.stack([f_mean, f_cur], axis=1)

            model = tricks.group_fit(feature)
            noisy_cls = model.predict_proba(feature)[:, 0]  # type:np.ndarray

            if params.eidx == params.gmm_burnin:
                self.noisy_cls_mem = torch.tensor(noisy_cls, device=self.device)

            if params.eidx > params.gmm_burnin:
                self.noisy_cls_mem = torch.tensor(noisy_cls, device=self.device) * 0.1 + self.noisy_cls_mem * 0.9
                self.noisy_cls = self.noisy_cls_mem.clone()
                self.noisy_cls[self.noisy_cls < 0.5] = 0

                # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许)
                self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_(params.gmm_w_sche(params.eidx))

                true_ncls = (self.true_pred_mem == self.false_pred_mem[:, params.eidx - 1]).all(dim=1).cpu().numpy()
                self.logger.info('gmm accuracy',
                                 metrics.confusion_matrix(true_ncls, self.noisy_cls.cpu().numpy() == 0,
                                                          labels=None, sample_weight=None))

                error_idx = set(np.where(true_ncls == 0)[0])
                self.logger.info('err set', len(set(np.where(noisy_cls != 0)[0]) & error_idx) / len(error_idx))
Ejemplo n.º 2
0
    def on_train_epoch_end(self, trainer: 'NoisyTrainer', func,
                           params: NoisyParams, meter: Meter, *args, **kwargs):
        true_f = os.path.join(self.experiment.test_dir, 'true.pth')
        false_f = os.path.join(self.experiment.test_dir, 'false.pth')
        loss_f = os.path.join(self.experiment.test_dir, 'loss.pth')
        cls_f = os.path.join(self.experiment.test_dir, 'cls.pth')
        if params.eidx % 10 == 0:
            torch.save(self.true_pred_mem, true_f)
            torch.save(self.false_pred_mem, false_f)
            torch.save(self.loss_mem, loss_f)
            torch.save(self.cls_mem, cls_f)

        with torch.no_grad():
            f_mean = self.false_pred_mem[:,
                                         max(params.eidx -
                                             params.gmm_burnin, 0):params.
                                         eidx].mean(dim=1).cpu().numpy()
            f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy()
            feature = self.create_feature(f_mean, f_cur)

            noisy_cls = self.gmm_predict(feature)

            true_cls = (self.true_pred_mem == self.false_pred_mem).all(
                dim=1).cpu().numpy()
            m = self.acc_mixture_(true_cls, noisy_cls)
            meter.update(m)
            self.logger.info(m)

            if params.eidx > params.gmm_burnin:
                self.noisy_cls_mem = torch.tensor(
                    noisy_cls,
                    device=self.device) * 0.1 + self.noisy_cls_mem * 0.9
                self.noisy_cls = self.noisy_cls_mem.clone()
                self.noisy_cls[self.noisy_cls < 0.5] = 0

                # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许)
                self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_(
                    params.gmm_w_sche(params.eidx))
Ejemplo n.º 3
0
        # meter.nres = residual[n_mask].mean()
        # meter.nsum = self.sum_mem[ids][n_mask].mean()
        # meter.tres = residual[n_mask.logical_not()].mean()
        # meter.tsum = self.sum_mem[ids][n_mask.logical_not()].mean()
        self.acc_precise_(self.pred_mem[ids].argmax(dim=1),
                          ys,
                          meter,
                          name='sacc')

        return meter

    def to_logits(self, xs) -> torch.Tensor:
        return self.model(xs)


if __name__ == '__main__':
    params = NoisyParams()
    params.optim.args.lr = 0.06
    params.epoch = 400
    params.device = 'cuda:0'
    params.filter_ema = 0.999

    params.mixup = True
    params.ideal_mixup = True
    params.worst_mixup = False
    params.noisy_ratio = 0.8
    params.from_args()
    params.initial()
    trainer = MixupEvalTrainer(params)
    trainer.train()
Ejemplo n.º 4
0
    def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor
        _right_mask = (ys == nys)
        _error_mask = _right_mask.logical_not()

        logits = self.to_logits(axs)

        w_logits = self.to_logits(xs)

        preds = torch.softmax(logits, dim=1).detach()
        label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze()
        # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0)
        if params.local_filter:
            weight_mask = self.count_mem[ids] < 0
            # weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes
            # weight_mask = weight < 0
            meter.tl = weight_mask[_right_mask].float().mean()
            meter.fl = weight_mask[_error_mask].float().mean()

        fweight = torch.ones(w_logits.shape[0], dtype=torch.float, device=device)
        if eidx >= params.burnin:
            fweight -= self.noisy_cls[ids]
            if params.local_filter:
                fweight[weight_mask] -= params.gmm_w_sche(eidx)
            fweight = torch.relu(fweight)

        if params.right_n > 0:
            right_mask = ids < params.right_n
            if right_mask.any():
                fweight[right_mask] = 1

        raw_targets = torch.softmax(w_logits, dim=1)

        with torch.no_grad():
            targets = self.plabel_mem[ids] * params.targets_ema + raw_targets * (1 - params.targets_ema)
            self.plabel_mem[ids] = targets
        values, p_labels = targets.max(dim=1)

        mask = values > params.pred_thresh
        if mask.any():
            self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc')
        mask = mask.float()
        # uda_mask = label_pred > params.uda_sche(eidx)
        # meter.uda = uda_mask.float().mean()

        meter.pm = mask.mean()

        meter.Lall = meter.Lall + self.loss_ce_with_masked_(logits, nys, fweight, meter=meter)

        meter.Lall = meter.Lall + self.loss_ce_with_masked_(logits, p_labels,
                                                            (1 - fweight) * mask,
                                                            meter=meter,
                                                            name='Lpce') * params.plabel_sche(eidx)

        meter.tw = fweight[_right_mask].mean()
        meter.fw = fweight[_error_mask].mean()

        if params.local_filter:
            with torch.no_grad():
                ids_mask = weight_mask.logical_not()
                alpha = params.filter_ema
                # if eidx == 1:
                #     self.target_mem[ids] = label_pred
                # else:
                if eidx < params.burnin:
                    alpha = 0.99
                self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * \
                                                 (1 - alpha)

                # 将没有参与的逐渐回归到
                # self.target_mem[ids[weight_mask]] = (self.target_mem[ids[weight_mask]] * 0.9 +
                #                                      (1 / params.n_classes) * 0.1)
                # self.count_mem[ids] += ((label_pred - self.target_mem[ids]) > 0).long() * 3 - 2
                self.count_mem[ids] += (label_pred - self.target_mem[ids])

        if 'Lall' in meter:
            self.optim.zero_grad()
            meter.Lall.backward()
            self.optim.step()
        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='nacc')

        false_pred = targets.gather(1, nys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        true_pred = targets.gather(1, ys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        with torch.no_grad():
            self.true_pred_mem[ids, eidx - 1] = true_pred
            self.false_pred_mem[ids, eidx - 1] = false_pred
            self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits, nys, reduction='none')
            # mem_mask = ids < self.pred_mem_size - 1
            # self.pred_mem[ids[mem_mask], :, eidx - 1] = targets[ids[mem_mask]]

        if eidx == 1:
            self.cls_mem[ids, 0] = ys
        elif eidx == 2:
            self.cls_mem[ids, 1] = nys
        else:
            self.cls_mem[ids, eidx - 1] = p_labels

        return meter
Ejemplo n.º 5
0
                    self.noisy_cls = torch.tensor(noisy_cls, device=self.device)
                # self.noisy_cls[self.noisy_cls < 0.5] = 0

                # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许)
                # self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_(params.gmm_w_sche(params.eidx))

            m2 = self.acc_mixture_(true_cls, (self.count_mem >= 0).cpu().numpy(), pre='con')
            meter.update(m)
            self.logger.info(m2)


if __name__ == '__main__':
    # for retry,params in enumerate(params.grid_range(5)):
    # for retry in range(5):
    #     retry = retry + 1
    params = NoisyParams()
    params.right_n = 10
    params.use_right_label = False
    params.optim.args.lr = 0.06
    params.epoch = 500
    params.device = 'cuda:2'
    params.filter_ema = 0.99
    params.burnin = 20
    params.mix_burnin = 20
    params.targets_ema = 0.3
    params.pred_thresh = 0.9
    params.feature_mean = False
    params.local_filter = True  # 局部的筛选方法
    params.mixt_ema = True  # 是否对 BMM 的预测结果用 EMA 做平滑
    params.from_args()
    params.initial()
Ejemplo n.º 6
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor
        mid, logits = self.to_logits(axs, with_mid=True)

        w_mid, w_logits = self.to_logits(xs, with_mid=True)

        preds = torch.softmax(logits, dim=1).detach()
        label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze()
        # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0)
        if params.local_filter:
            weight = label_pred - self.target_mem[ids]
            weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes
            weight_mask = weight < 0
            meter.tl = weight_mask[ys == nys].float().mean()
            meter.fl = weight_mask[ys != nys].float().mean()

        fweight = torch.ones(w_logits.shape[0],
                             dtype=torch.float,
                             device=device)
        if eidx >= params.burnin:
            if params.local_filter:
                fweight[weight_mask] -= params.gmm_w_sche(eidx)
            fweight -= self.noisy_cls[ids]
            fweight = torch.relu(fweight)

        remain = ((fweight > 0.6) | (fweight < 0.4)).float()

        meter.rem = remain.mean()

        raw_targets = torch.softmax(w_logits, dim=1)

        with torch.no_grad():
            targets = self.plabel_mem[
                ids] * params.targets_ema + raw_targets * (1 -
                                                           params.targets_ema)
            self.plabel_mem[ids] = targets

        top_values, top_indices = targets.topk(2, dim=-1)
        p_labels = top_indices[:, 0]
        values = top_values[:, 0]

        ratio = params.smooth_ratio_sche(eidx)

        mask = values > params.pred_thresh
        if mask.any():
            self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc')

        n_targets = tricks.onehot(nys, params.n_classes)
        p_targets = tricks.onehot(p_labels, params.n_classes)
        n_targets = n_targets * (1 - ratio) + p_targets * ratio

        p_targets[mask.logical_not()] = p_targets.scatter(
            -1, top_indices[:, 1:2], ratio)[mask.logical_not()]

        mask = mask.float()
        meter.pm = mask.mean()

        meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(
            logits, n_targets, fweight, meter=meter)
        meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(
            logits,
            p_targets, (1 - fweight) * values,
            meter=meter,
            name='Lpce') * params.plabel_sche(eidx)

        meter.tw = fweight[ys == nys].mean()
        meter.fw = fweight[ys != nys].mean()

        if params.local_filter:
            with torch.no_grad():
                ids_mask = weight_mask.logical_not()
                alpha = params.filter_ema
                if eidx < params.burnin:
                    alpha = 0.99
                self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * \
                                                 (1 - alpha)

                # 将没有参与的逐渐回归到
                # self.target_mem[ids[weight_mask]] = self.target_mem[ids[weight_mask]] * alpha + (1 / params.n_classes) * \
                #                                  (1 - alpha)

        if 'Lall' in meter:
            self.optim.zero_grad()
            meter.Lall.backward()
            self.optim.step()
        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask],
                              nys[n_mask],
                              meter,
                              name='nacc')

        false_pred = targets.gather(
            1, nys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        true_pred = targets.gather(
            1, ys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        with torch.no_grad():
            self.true_pred_mem[ids, eidx - 1] = true_pred
            self.false_pred_mem[ids, eidx - 1] = false_pred
            self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits,
                                                           nys,
                                                           reduction='none')
            # mem_mask = ids < self.pred_mem_size - 1
            # self.pred_mem[ids[mem_mask], :, eidx - 1] = targets[ids[mem_mask]]

        if eidx == 1:
            self.cls_mem[ids, 0] = ys
        elif eidx == 2:
            self.cls_mem[ids, 1] = nys
        else:
            self.cls_mem[ids, eidx - 1] = p_labels

        return meter
Ejemplo n.º 7
0
                if params.mixt_ema:
                    self.noisy_cls = self.noisy_cls_mem.clone()
                else:
                    self.noisy_cls = torch.tensor(noisy_cls,
                                                  device=self.device)
                # self.noisy_cls[self.noisy_cls < 0.5] = 0

                # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许)
                # self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_(params.gmm_w_sche(params.eidx))


if __name__ == '__main__':
    # for retry,params in enumerate(params.grid_range(5)):
    # for retry in range(5):
    #     retry = retry + 1
    params = NoisyParams()
    params.large_model = False
    params.right_n = 10
    params.use_right_label = False
    params.optim.args.lr = 0.03
    params.epoch = 500
    params.device = 'cuda:2'
    params.filter_ema = 0.999
    params.burnin = 2
    params.mix_burnin = 20
    params.with_fc = False
    params.smooth_ratio_sche = params.SCHE.Exp(0.1, 0, right=200)
    params.targets_ema = 0.3
    params.pred_thresh = 0.85
    params.feature_mean = False
    params.local_filter = True  # 局部的筛选方法
Ejemplo n.º 8
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor
        logits = self.to_logits(axs)

        w_logits = self.to_logits(xs)

        preds = torch.softmax(logits, dim=1).detach()
        label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze()
        # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0)
        fweight = (nys == ys).float()
        fweight[nys != ys] = params.ideal_nyw_sche(eidx)

        raw_targets = torch.softmax(w_logits, dim=1)

        with torch.no_grad():
            targets = self.plabel_mem[
                ids] * params.targets_ema + raw_targets * (1 -
                                                           params.targets_ema)
            self.plabel_mem[ids] = targets

        top_values, top_indices = targets.topk(2, dim=-1)
        p_labels = top_indices[:, 0]
        values = top_values[:, 0]

        p_targets = (torch.zeros_like(targets).scatter(
            -1, top_indices[:, 0:1], 0.9).scatter(-1, top_indices[:, 1:2],
                                                  0.1))

        mask = values > params.pred_thresh
        if mask.any():
            self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc')
        mask = mask.float()
        meter.pm = mask.mean()

        n_targets = tricks.onehot(nys, params.n_classes)
        n_targets = n_targets * 0.9 + p_targets * 0.1

        meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(
            logits, n_targets, fweight, meter=meter)
        meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(
            logits,
            p_targets, (1 - fweight) * values * mask,
            meter=meter,
            name='Lpce') * params.plabel_sche(eidx)

        meter.tw = fweight[ys == nys].mean()
        meter.fw = fweight[ys != nys].mean()

        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask],
                              nys[n_mask],
                              meter,
                              name='nacc')

        false_pred = targets.gather(
            1, nys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        true_pred = targets.gather(
            1, ys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        with torch.no_grad():
            self.true_pred_mem[ids, eidx - 1] = true_pred
            self.false_pred_mem[ids, eidx - 1] = false_pred
            self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits,
                                                           nys,
                                                           reduction='none')
            # mem_mask = ids < self.pred_mem_size - 1
            # self.pred_mem[ids[mem_mask], :, eidx - 1] = targets[ids[mem_mask]]

        if eidx == 1:
            self.cls_mem[ids, 0] = ys
        elif eidx == 2:
            self.cls_mem[ids, 1] = nys
        else:
            self.cls_mem[ids, eidx - 1] = p_labels

        return meter
Ejemplo n.º 9
0
    def datasets(self, params: NoisyParams):
        self.rnd.mark('kk')
        params.noisy_type = params.default('symmetric', True)
        params.noisy_ratio = params.default(0.2, True)

        from data.constant import norm_val
        mean, std = norm_val.get(params.dataset, [None, None])
        from data.transforms import ToNormTensor
        toTensor = ToNormTensor(mean, std)
        from data.transforms import Weak
        weak = Weak(mean, std)
        from data.transforms import Strong

        dataset_fn = datasets.datasets[params.dataset]
        train_x, train_y = dataset_fn(True)
        train_y = np.array(train_y)
        from thexp import DatasetBuilder

        from data.noisy import symmetric_noisy
        noisy_y = symmetric_noisy(train_y,
                                  params.noisy_ratio,
                                  n_classes=params.n_classes)
        clean_mask = (train_y == noisy_y)

        noisy_mask = np.logical_not(clean_mask)
        noisy_mask = np.where(noisy_mask)[0]

        nmask_a = noisy_mask[:len(noisy_mask) // 2]
        nmask_b = noisy_mask[len(noisy_mask) // 2:]

        clean_x, clean_y = train_x[clean_mask], noisy_y[clean_mask]
        clean_true_y = train_y[clean_mask]

        raw_x, raw_true_y = train_x[nmask_a], train_y[nmask_a]
        raw_y = noisy_y[nmask_a]

        change_x, change_true_y, change_y = train_x[nmask_b], train_y[
            nmask_b], noisy_y[nmask_b]

        first_x, first_y, first_true_y = (
            clean_x + raw_x,
            np.concatenate([clean_y, raw_y]),
            np.concatenate([clean_true_y, raw_true_y]),
        )

        second_x, second_y, second_true_y = (
            clean_x + change_x,
            np.concatenate([clean_y, change_y]),
            np.concatenate([clean_true_y, change_true_y]),
        )

        first_set = (DatasetBuilder(first_x, first_true_y).add_labels(
            first_y, 'noisy_y').toggle_id().add_x(
                transform=weak).add_y().add_y(source='noisy_y'))
        second_set = (DatasetBuilder(second_x, second_true_y).add_labels(
            second_y, 'noisy_y').toggle_id().add_x(
                transform=weak).add_y().add_y(source='noisy_y'))

        self.first_dataloader = first_set.DataLoader(
            batch_size=params.batch_size,
            num_workers=params.num_workers,
            drop_last=True,
            shuffle=True)

        self.second_dataloader = second_set.DataLoader(
            batch_size=params.batch_size,
            num_workers=params.num_workers,
            drop_last=True,
            shuffle=True)

        self.second_dataloader = second_set.DataLoader(
            batch_size=params.batch_size,
            num_workers=params.num_workers,
            drop_last=True,
            shuffle=True)
        self.second = False
        self.regist_databundler(train=self.first_dataloader)
        self.cur_set = 0
        self.to(self.device)
Ejemplo n.º 10
0
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask],
                              nys[n_mask],
                              meter,
                              name='noisy_acc')

        return meter

    def to_logits(self, xs) -> torch.Tensor:
        return self.model(xs)


if __name__ == '__main__':
    params = NoisyParams()
    params.optim.args.lr = 0.06
    params.epoch = 400
    params.device = 'cuda:0'
    params.filter_ema = 0.999
    params.echange = False  # 每一个 epoch 都换
    params.change = True
    params.noisy_ratio = 0.8
    params.from_args()
    params.initial()
    trainer = MixupEvalTrainer(params)
    trainer.train()
Ejemplo n.º 11
0
        self.optim.zero_grad()
        meter.Lall.backward()
        self.optim.step()

        self.acc_precise_(logits.argmax(dim=1),
                          ys,
                          meter=meter,
                          name='true_acc')
        self.acc_precise_(logits.argmax(dim=1),
                          nys,
                          meter=meter,
                          name='noisy_acc')

        return meter


if __name__ == '__main__':
    params = NoisyParams()
    params.ema = True  # l2r have no ema for model
    params.epoch = 120
    params.batch_size = 100
    params.device = 'cuda:3'
    params.optim.args.lr = 0.1
    params.meta_optim = {
        'lr': 0.1,
        'momentum': 0.9,
    }
    params.from_args()
    trainer = L2RTrainer(params)
    trainer.train()
Ejemplo n.º 12
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor
        logits = self.to_logits(axs)

        w_logits = self.to_logits(xs)

        preds = torch.softmax(logits, dim=1).detach()
        label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze()
        # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0)
        weight = label_pred - self.target_mem[ids]
        weight = weight + label_pred * 0.5 / params.n_classes - 0.375 / params.n_classes

        fweight = torch.ones_like(weight)
        if eidx >= params.burnin:
            fweight -= self.noisy_cls[ids]
            fweight[weight <= 0] -= params.gmm_w_sche(eidx)
            fweight = torch.relu(fweight)

        if params.right_n > 0:
            right_mask = ids < params.right_n
            if right_mask.any():
                fweight[right_mask] = 1

        raw_targets = torch.softmax(w_logits, dim=1)

        with torch.no_grad():
            targets = self.plabel_mem[
                ids] * params.targets_ema + raw_targets * (1 -
                                                           params.targets_ema)
            self.plabel_mem[ids] = targets
        values, p_labels = targets.max(dim=1)

        mask = values > params.pred_thresh
        if mask.any():
            self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc')
        mask = mask.float()
        # uda_mask = label_pred > params.uda_sche(eidx)
        # meter.uda = uda_mask.float().mean()

        meter.m0 = (fweight == 0).float().mean()
        meter.m1 = (fweight == 1).float().mean()
        meter.pm = mask.mean()

        meter.Lall = meter.Lall + self.loss_ce_with_masked_(
            logits, nys, fweight, meter=meter)

        meter.Lall = meter.Lall + self.loss_ce_with_masked_(
            logits, p_labels, (1 - fweight) * mask, meter=meter,
            name='Lpce') * params.plabel_sche(eidx)

        meter.tw = fweight[ys == nys].mean()
        meter.fw = fweight[ys != nys].mean()

        with torch.no_grad():
            ids_mask = weight.bool()
            alpha = params.filter_ema
            if eidx < params.burnin:
                alpha = 0.99
            self.target_mem[ids[ids_mask]] = self.target_mem[
                ids[ids_mask]] * alpha + label_pred[ids_mask] * (1 - alpha)

        if 'Lall' in meter:
            self.optim.zero_grad()
            meter.Lall.backward()
            self.optim.step()
        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask],
                              nys[n_mask],
                              meter,
                              name='noisy_acc')

        false_pred = targets.gather(
            1, nys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        true_pred = targets.gather(
            1, ys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        with torch.no_grad():
            self.true_pred_mem[ids, eidx - 1] = true_pred
            self.false_pred_mem[ids, eidx - 1] = false_pred
            self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits,
                                                           nys,
                                                           reduction='none')

        if eidx == 1:
            self.cls_mem[ids, 0] = ys
        elif eidx == 2:
            self.cls_mem[ids, 1] = nys
        else:
            self.cls_mem[ids, eidx - 1] = p_labels

        return meter
Ejemplo n.º 13
0
                self.noisy_cls_mem = torch.tensor(
                    noisy_cls,
                    device=self.device) * 0.1 + self.noisy_cls_mem * 0.9
                self.noisy_cls = self.noisy_cls_mem.clone()
                self.noisy_cls[self.noisy_cls < 0.5] = 0

                # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许)
                self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_(
                    params.gmm_w_sche(params.eidx))


if __name__ == '__main__':
    # for retry,params in enumerate(params.grid_range(5)):
    # for retry in range(5):
    #     retry = retry + 1
    params = NoisyParams()
    params.right_n = 10
    params.use_right_label = False
    params.optim.args.lr = 0.06
    params.epoch = 300
    params.device = 'cuda:2'
    params.filter_ema = 0.999
    params.burnin = 2
    params.gmm_burnin = 20
    params.targets_ema = 0.3
    # params.tolerance_type = 'exp'
    params.pred_thresh = 0.9
    # params.widen_factor = 10
    params.from_args()
    params.initial()
    trainer = MultiHeadTrainer(params)
Ejemplo n.º 14
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor
        logits = self.to_logits(axs)
        w_logits = self.to_logits(xs)

        preds = torch.softmax(logits, dim=1).detach()
        label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze()
        # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0)
        if params.local_filter:
            weight = label_pred - self.target_mem[ids]
            weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes
            weight_mask = weight < 0
            meter.tl = weight_mask[ys == nys].float().mean()
            meter.fl = weight_mask[ys != nys].float().mean()

        fweight = torch.ones(w_logits.shape[0],
                             dtype=torch.float,
                             device=device)
        if eidx >= params.burnin:
            if params.local_filter:
                fweight[weight_mask] -= params.gmm_w_sche(eidx)
            fweight -= self.noisy_cls[ids]
            fweight = torch.relu(fweight)

        if params.right_n > 0:
            right_mask = ids < params.right_n
            if right_mask.any():
                fweight[right_mask] = 1

        raw_targets = torch.softmax(w_logits, dim=1)

        with torch.no_grad():
            targets = self.plabel_mem[
                ids] * params.targets_ema + raw_targets * (1 -
                                                           params.targets_ema)
            self.plabel_mem[ids] = targets

        # targets = self.sharpen_(targets)
        # targets = raw_targets
        n_targets = tricks.onehot(nys, params.n_classes)

        t_weight = fweight.unsqueeze(dim=1)
        # cur_targets = n_targets * t_weight + targets * (1 - t_weight)
        cur_ys = targets.argmax(dim=-1)
        # cur_targets = tricks.onehot(cur_ys, params.n_classes)

        id_dict = defaultdict(list)
        for i, ny in enumerate(nys):
            id_dict[int(ny)].append(i)  # 保证添加进去的都是噪音标签

        for i in range(params.n_classes):
            id_dict[i] = cycle(id_dict[i])

        re_id = []
        for i, cur_y in enumerate(cur_ys):
            cur_y = int(cur_y)
            try:
                # 如果原本标签就是干净标签,那么无所谓混合什么标签(因为干净标签占主导)
                # 否则,混合进去噪音标签
                re_id.append(next(id_dict.get(cur_y)))
            except:
                re_id.append(np.random.randint(0, len(nys)))

        mixed_input, mixed_target = self.mixup_(xs, n_targets, reids=re_id)

        mixed_logits = self.to_logits(mixed_input)
        meter.Lall = meter.Lall + self.loss_ce_with_targets_(
            mixed_logits, mixed_target, meter=meter)

        meter.tw = fweight[ys == nys].mean()
        meter.fw = fweight[ys != nys].mean()

        if params.local_filter:
            with torch.no_grad():
                ids_mask = weight_mask.logical_not()
                alpha = params.filter_ema
                if eidx < params.burnin:
                    alpha = 0.99
                self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * \
                                                 (1 - alpha)

                # 将没有参与的逐渐回归到
                # self.target_mem[ids[weight_mask]] = self.target_mem[ids[weight_mask]] * alpha + (1 / params.n_classes) * \
                #                                     (1 - alpha)

        if 'Lall' in meter:
            self.optim.zero_grad()
            meter.Lall.backward()
            self.optim.step()
        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask],
                              nys[n_mask],
                              meter,
                              name='nacc')

        false_pred = targets.gather(
            1, nys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        true_pred = targets.gather(
            1, ys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        with torch.no_grad():
            self.true_pred_mem[ids, eidx - 1] = true_pred
            self.false_pred_mem[ids, eidx - 1] = false_pred
            self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits,
                                                           nys,
                                                           reduction='none')

        if eidx == 1:
            self.cls_mem[ids, 0] = ys
        elif eidx == 2:
            self.cls_mem[ids, 1] = nys
        else:
            self.cls_mem[ids, eidx - 1] = cur_ys

        return meter
Ejemplo n.º 15
0
    def train_batch(self, eidx, idx, global_step, batch_data,
                    params: NoisyParams, device: torch.device):
        meter = Meter()
        ids, xs, axs, ys, nys = batch_data  # type:torch.Tensor
        logits = self.to_logits(axs)

        w_logits = self.to_logits(xs)

        preds = torch.softmax(logits, dim=1).detach()
        label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze()

        weight = label_pred - self.target_mem[ids]
        if params.tolerance_type == 'linear':
            weight = weight + label_pred * 0.5 / params.n_classes - 0.375 / params.n_classes

        elif params.tolerance_type == 'exp':
            exp_ratio = (torch.exp((self.target_mem[ids] - 1) * 5) -
                         np.exp(-5) * (1 - self.target_mem[ids]))
            weight = weight + (params.tol_start *
                               (1 - exp_ratio) + params.tol_end * exp_ratio)

        elif params.tolerance_type == 'log':
            log_ratio = 1 - torch.exp(
                -self.target_mem[ids] * 5) + np.exp(-5) * self.target_mem[ids]
            weight = weight + (params.tol_start *
                               (1 - log_ratio) + params.tol_end * log_ratio)

        # weight[self.noisy_cls[ids] == 0] -= params.gmm_sche(eidx)
        weight[self.noisy_cls[ids] == 0] = 0

        raw_targets = torch.softmax(w_logits, dim=1)

        with torch.no_grad():
            targets = self.plabel_mem[
                ids] * params.targets_ema + raw_targets * (1 -
                                                           params.targets_ema)
            self.plabel_mem[ids] = targets
        values, p_labels = targets.max(dim=1)

        mask = values > params.pred_thresh
        if mask.any():
            self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc')

        weight[weight > 0] = 1

        if eidx < params.burnin:
            weight = torch.ones_like(weight)

        meter.m0 = (weight == 0).float().mean()
        meter.pm = mask.float().mean()

        weight_mask = weight.bool()
        pweight_mask = weight_mask.logical_not() & mask
        if weight_mask.any():
            meter.Lall = meter.Lall + self.loss_ce_(
                logits[weight_mask], nys[weight_mask], meter=meter)

        if pweight_mask.any():
            meter.Lall = meter.Lall + self.loss_ce_(
                logits[pweight_mask],
                p_labels[pweight_mask],
                meter=meter,
                name='Lpce') * params.plabel_sche(eidx)

        meter.tw = weight[ys == nys].mean()
        meter.fw = weight[ys != nys].mean()

        with torch.no_grad():
            ids_mask = weight.bool()
            alpha = params.filter_ema
            if eidx < params.burnin:
                alpha = 0.99
            self.target_mem[ids[ids_mask]] = self.target_mem[
                ids[ids_mask]] * alpha + label_pred[ids_mask] * (1 - alpha)

        if 'Lall' in meter:
            self.optim.zero_grad()
            meter.Lall.backward()
            self.optim.step()
        self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc')
        n_mask = nys != ys
        if n_mask.any():
            self.acc_precise_(logits.argmax(dim=1)[n_mask],
                              nys[n_mask],
                              meter,
                              name='noisy_acc')

        false_pred = targets.gather(
            1, nys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        true_pred = targets.gather(
            1, ys.unsqueeze(dim=1)).squeeze()  # [ys != nys]
        with torch.no_grad():
            self.true_pred_mem[ids] = true_pred
            self.false_pred_mem[ids, eidx - 1] = false_pred
            self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits,
                                                           nys,
                                                           reduction='none')

        if eidx == 1:
            self.cls_mem[ids, 0] = ys
        elif eidx == 2:
            self.cls_mem[ids, 1] = nys
        else:
            self.cls_mem[ids, eidx - 1] = p_labels

        return meter
Ejemplo n.º 16
0
            m = self.acc_mixture_(true_cls, noisy_cls)
            meter.update(m)
            self.logger.info(m)

            if params.eidx > params.gmm_burnin:
                self.noisy_cls = torch.tensor(noisy_cls, device=self.device)

    def to_logits(self, xs) -> torch.Tensor:
        return self.model(xs)


if __name__ == '__main__':
    # for retry,params in enumerate(params.grid_range(5)):
    # for retry in range(5):
    #     retry = retry + 1
    params = NoisyParams()
    params.right_n = 10
    params.optim.args.lr = 0.06
    params.epoch = 400
    params.device = 'cuda:0'
    params.filter_ema = 0.999
    params.burnin = 2
    params.gmm_burnin = 10
    params.targets_ema = 0.3

    params.pred_thresh = 0.9
    params.from_args()
    params.initial()
    trainer = MultiHeadTrainer(params)

    if params.ss_pretrain: