def on_train_epoch_end(self, trainer: 'IEGTrainer', func, params: NoisyParams, meter: Meter, *args, **kwargs): with torch.no_grad(): from sklearn import metrics f_mean = self.false_pred_mem[:, max(params.eidx - params.gmm_burnin, 0):params.eidx].mean( dim=1).cpu().numpy() f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy() feature = np.stack([f_mean, f_cur], axis=1) model = tricks.group_fit(feature) noisy_cls = model.predict_proba(feature)[:, 0] # type:np.ndarray if params.eidx == params.gmm_burnin: self.noisy_cls_mem = torch.tensor(noisy_cls, device=self.device) if params.eidx > params.gmm_burnin: self.noisy_cls_mem = torch.tensor(noisy_cls, device=self.device) * 0.1 + self.noisy_cls_mem * 0.9 self.noisy_cls = self.noisy_cls_mem.clone() self.noisy_cls[self.noisy_cls < 0.5] = 0 # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许) self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_(params.gmm_w_sche(params.eidx)) true_ncls = (self.true_pred_mem == self.false_pred_mem[:, params.eidx - 1]).all(dim=1).cpu().numpy() self.logger.info('gmm accuracy', metrics.confusion_matrix(true_ncls, self.noisy_cls.cpu().numpy() == 0, labels=None, sample_weight=None)) error_idx = set(np.where(true_ncls == 0)[0]) self.logger.info('err set', len(set(np.where(noisy_cls != 0)[0]) & error_idx) / len(error_idx))
def on_train_epoch_end(self, trainer: 'NoisyTrainer', func, params: NoisyParams, meter: Meter, *args, **kwargs): true_f = os.path.join(self.experiment.test_dir, 'true.pth') false_f = os.path.join(self.experiment.test_dir, 'false.pth') loss_f = os.path.join(self.experiment.test_dir, 'loss.pth') cls_f = os.path.join(self.experiment.test_dir, 'cls.pth') if params.eidx % 10 == 0: torch.save(self.true_pred_mem, true_f) torch.save(self.false_pred_mem, false_f) torch.save(self.loss_mem, loss_f) torch.save(self.cls_mem, cls_f) with torch.no_grad(): f_mean = self.false_pred_mem[:, max(params.eidx - params.gmm_burnin, 0):params. eidx].mean(dim=1).cpu().numpy() f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy() feature = self.create_feature(f_mean, f_cur) noisy_cls = self.gmm_predict(feature) true_cls = (self.true_pred_mem == self.false_pred_mem).all( dim=1).cpu().numpy() m = self.acc_mixture_(true_cls, noisy_cls) meter.update(m) self.logger.info(m) if params.eidx > params.gmm_burnin: self.noisy_cls_mem = torch.tensor( noisy_cls, device=self.device) * 0.1 + self.noisy_cls_mem * 0.9 self.noisy_cls = self.noisy_cls_mem.clone() self.noisy_cls[self.noisy_cls < 0.5] = 0 # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许) self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_( params.gmm_w_sche(params.eidx))
# meter.nres = residual[n_mask].mean() # meter.nsum = self.sum_mem[ids][n_mask].mean() # meter.tres = residual[n_mask.logical_not()].mean() # meter.tsum = self.sum_mem[ids][n_mask.logical_not()].mean() self.acc_precise_(self.pred_mem[ids].argmax(dim=1), ys, meter, name='sacc') return meter def to_logits(self, xs) -> torch.Tensor: return self.model(xs) if __name__ == '__main__': params = NoisyParams() params.optim.args.lr = 0.06 params.epoch = 400 params.device = 'cuda:0' params.filter_ema = 0.999 params.mixup = True params.ideal_mixup = True params.worst_mixup = False params.noisy_ratio = 0.8 params.from_args() params.initial() trainer = MixupEvalTrainer(params) trainer.train()
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor _right_mask = (ys == nys) _error_mask = _right_mask.logical_not() logits = self.to_logits(axs) w_logits = self.to_logits(xs) preds = torch.softmax(logits, dim=1).detach() label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze() # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0) if params.local_filter: weight_mask = self.count_mem[ids] < 0 # weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes # weight_mask = weight < 0 meter.tl = weight_mask[_right_mask].float().mean() meter.fl = weight_mask[_error_mask].float().mean() fweight = torch.ones(w_logits.shape[0], dtype=torch.float, device=device) if eidx >= params.burnin: fweight -= self.noisy_cls[ids] if params.local_filter: fweight[weight_mask] -= params.gmm_w_sche(eidx) fweight = torch.relu(fweight) if params.right_n > 0: right_mask = ids < params.right_n if right_mask.any(): fweight[right_mask] = 1 raw_targets = torch.softmax(w_logits, dim=1) with torch.no_grad(): targets = self.plabel_mem[ids] * params.targets_ema + raw_targets * (1 - params.targets_ema) self.plabel_mem[ids] = targets values, p_labels = targets.max(dim=1) mask = values > params.pred_thresh if mask.any(): self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc') mask = mask.float() # uda_mask = label_pred > params.uda_sche(eidx) # meter.uda = uda_mask.float().mean() meter.pm = mask.mean() meter.Lall = meter.Lall + self.loss_ce_with_masked_(logits, nys, fweight, meter=meter) meter.Lall = meter.Lall + self.loss_ce_with_masked_(logits, p_labels, (1 - fweight) * mask, meter=meter, name='Lpce') * params.plabel_sche(eidx) meter.tw = fweight[_right_mask].mean() meter.fw = fweight[_error_mask].mean() if params.local_filter: with torch.no_grad(): ids_mask = weight_mask.logical_not() alpha = params.filter_ema # if eidx == 1: # self.target_mem[ids] = label_pred # else: if eidx < params.burnin: alpha = 0.99 self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * \ (1 - alpha) # 将没有参与的逐渐回归到 # self.target_mem[ids[weight_mask]] = (self.target_mem[ids[weight_mask]] * 0.9 + # (1 / params.n_classes) * 0.1) # self.count_mem[ids] += ((label_pred - self.target_mem[ids]) > 0).long() * 3 - 2 self.count_mem[ids] += (label_pred - self.target_mem[ids]) if 'Lall' in meter: self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='nacc') false_pred = targets.gather(1, nys.unsqueeze(dim=1)).squeeze() # [ys != nys] true_pred = targets.gather(1, ys.unsqueeze(dim=1)).squeeze() # [ys != nys] with torch.no_grad(): self.true_pred_mem[ids, eidx - 1] = true_pred self.false_pred_mem[ids, eidx - 1] = false_pred self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits, nys, reduction='none') # mem_mask = ids < self.pred_mem_size - 1 # self.pred_mem[ids[mem_mask], :, eidx - 1] = targets[ids[mem_mask]] if eidx == 1: self.cls_mem[ids, 0] = ys elif eidx == 2: self.cls_mem[ids, 1] = nys else: self.cls_mem[ids, eidx - 1] = p_labels return meter
self.noisy_cls = torch.tensor(noisy_cls, device=self.device) # self.noisy_cls[self.noisy_cls < 0.5] = 0 # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许) # self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_(params.gmm_w_sche(params.eidx)) m2 = self.acc_mixture_(true_cls, (self.count_mem >= 0).cpu().numpy(), pre='con') meter.update(m) self.logger.info(m2) if __name__ == '__main__': # for retry,params in enumerate(params.grid_range(5)): # for retry in range(5): # retry = retry + 1 params = NoisyParams() params.right_n = 10 params.use_right_label = False params.optim.args.lr = 0.06 params.epoch = 500 params.device = 'cuda:2' params.filter_ema = 0.99 params.burnin = 20 params.mix_burnin = 20 params.targets_ema = 0.3 params.pred_thresh = 0.9 params.feature_mean = False params.local_filter = True # 局部的筛选方法 params.mixt_ema = True # 是否对 BMM 的预测结果用 EMA 做平滑 params.from_args() params.initial()
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor mid, logits = self.to_logits(axs, with_mid=True) w_mid, w_logits = self.to_logits(xs, with_mid=True) preds = torch.softmax(logits, dim=1).detach() label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze() # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0) if params.local_filter: weight = label_pred - self.target_mem[ids] weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes weight_mask = weight < 0 meter.tl = weight_mask[ys == nys].float().mean() meter.fl = weight_mask[ys != nys].float().mean() fweight = torch.ones(w_logits.shape[0], dtype=torch.float, device=device) if eidx >= params.burnin: if params.local_filter: fweight[weight_mask] -= params.gmm_w_sche(eidx) fweight -= self.noisy_cls[ids] fweight = torch.relu(fweight) remain = ((fweight > 0.6) | (fweight < 0.4)).float() meter.rem = remain.mean() raw_targets = torch.softmax(w_logits, dim=1) with torch.no_grad(): targets = self.plabel_mem[ ids] * params.targets_ema + raw_targets * (1 - params.targets_ema) self.plabel_mem[ids] = targets top_values, top_indices = targets.topk(2, dim=-1) p_labels = top_indices[:, 0] values = top_values[:, 0] ratio = params.smooth_ratio_sche(eidx) mask = values > params.pred_thresh if mask.any(): self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc') n_targets = tricks.onehot(nys, params.n_classes) p_targets = tricks.onehot(p_labels, params.n_classes) n_targets = n_targets * (1 - ratio) + p_targets * ratio p_targets[mask.logical_not()] = p_targets.scatter( -1, top_indices[:, 1:2], ratio)[mask.logical_not()] mask = mask.float() meter.pm = mask.mean() meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_( logits, n_targets, fweight, meter=meter) meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_( logits, p_targets, (1 - fweight) * values, meter=meter, name='Lpce') * params.plabel_sche(eidx) meter.tw = fweight[ys == nys].mean() meter.fw = fweight[ys != nys].mean() if params.local_filter: with torch.no_grad(): ids_mask = weight_mask.logical_not() alpha = params.filter_ema if eidx < params.burnin: alpha = 0.99 self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * \ (1 - alpha) # 将没有参与的逐渐回归到 # self.target_mem[ids[weight_mask]] = self.target_mem[ids[weight_mask]] * alpha + (1 / params.n_classes) * \ # (1 - alpha) if 'Lall' in meter: self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='nacc') false_pred = targets.gather( 1, nys.unsqueeze(dim=1)).squeeze() # [ys != nys] true_pred = targets.gather( 1, ys.unsqueeze(dim=1)).squeeze() # [ys != nys] with torch.no_grad(): self.true_pred_mem[ids, eidx - 1] = true_pred self.false_pred_mem[ids, eidx - 1] = false_pred self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits, nys, reduction='none') # mem_mask = ids < self.pred_mem_size - 1 # self.pred_mem[ids[mem_mask], :, eidx - 1] = targets[ids[mem_mask]] if eidx == 1: self.cls_mem[ids, 0] = ys elif eidx == 2: self.cls_mem[ids, 1] = nys else: self.cls_mem[ids, eidx - 1] = p_labels return meter
if params.mixt_ema: self.noisy_cls = self.noisy_cls_mem.clone() else: self.noisy_cls = torch.tensor(noisy_cls, device=self.device) # self.noisy_cls[self.noisy_cls < 0.5] = 0 # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许) # self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_(params.gmm_w_sche(params.eidx)) if __name__ == '__main__': # for retry,params in enumerate(params.grid_range(5)): # for retry in range(5): # retry = retry + 1 params = NoisyParams() params.large_model = False params.right_n = 10 params.use_right_label = False params.optim.args.lr = 0.03 params.epoch = 500 params.device = 'cuda:2' params.filter_ema = 0.999 params.burnin = 2 params.mix_burnin = 20 params.with_fc = False params.smooth_ratio_sche = params.SCHE.Exp(0.1, 0, right=200) params.targets_ema = 0.3 params.pred_thresh = 0.85 params.feature_mean = False params.local_filter = True # 局部的筛选方法
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor logits = self.to_logits(axs) w_logits = self.to_logits(xs) preds = torch.softmax(logits, dim=1).detach() label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze() # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0) fweight = (nys == ys).float() fweight[nys != ys] = params.ideal_nyw_sche(eidx) raw_targets = torch.softmax(w_logits, dim=1) with torch.no_grad(): targets = self.plabel_mem[ ids] * params.targets_ema + raw_targets * (1 - params.targets_ema) self.plabel_mem[ids] = targets top_values, top_indices = targets.topk(2, dim=-1) p_labels = top_indices[:, 0] values = top_values[:, 0] p_targets = (torch.zeros_like(targets).scatter( -1, top_indices[:, 0:1], 0.9).scatter(-1, top_indices[:, 1:2], 0.1)) mask = values > params.pred_thresh if mask.any(): self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc') mask = mask.float() meter.pm = mask.mean() n_targets = tricks.onehot(nys, params.n_classes) n_targets = n_targets * 0.9 + p_targets * 0.1 meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_( logits, n_targets, fweight, meter=meter) meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_( logits, p_targets, (1 - fweight) * values * mask, meter=meter, name='Lpce') * params.plabel_sche(eidx) meter.tw = fweight[ys == nys].mean() meter.fw = fweight[ys != nys].mean() self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='nacc') false_pred = targets.gather( 1, nys.unsqueeze(dim=1)).squeeze() # [ys != nys] true_pred = targets.gather( 1, ys.unsqueeze(dim=1)).squeeze() # [ys != nys] with torch.no_grad(): self.true_pred_mem[ids, eidx - 1] = true_pred self.false_pred_mem[ids, eidx - 1] = false_pred self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits, nys, reduction='none') # mem_mask = ids < self.pred_mem_size - 1 # self.pred_mem[ids[mem_mask], :, eidx - 1] = targets[ids[mem_mask]] if eidx == 1: self.cls_mem[ids, 0] = ys elif eidx == 2: self.cls_mem[ids, 1] = nys else: self.cls_mem[ids, eidx - 1] = p_labels return meter
def datasets(self, params: NoisyParams): self.rnd.mark('kk') params.noisy_type = params.default('symmetric', True) params.noisy_ratio = params.default(0.2, True) from data.constant import norm_val mean, std = norm_val.get(params.dataset, [None, None]) from data.transforms import ToNormTensor toTensor = ToNormTensor(mean, std) from data.transforms import Weak weak = Weak(mean, std) from data.transforms import Strong dataset_fn = datasets.datasets[params.dataset] train_x, train_y = dataset_fn(True) train_y = np.array(train_y) from thexp import DatasetBuilder from data.noisy import symmetric_noisy noisy_y = symmetric_noisy(train_y, params.noisy_ratio, n_classes=params.n_classes) clean_mask = (train_y == noisy_y) noisy_mask = np.logical_not(clean_mask) noisy_mask = np.where(noisy_mask)[0] nmask_a = noisy_mask[:len(noisy_mask) // 2] nmask_b = noisy_mask[len(noisy_mask) // 2:] clean_x, clean_y = train_x[clean_mask], noisy_y[clean_mask] clean_true_y = train_y[clean_mask] raw_x, raw_true_y = train_x[nmask_a], train_y[nmask_a] raw_y = noisy_y[nmask_a] change_x, change_true_y, change_y = train_x[nmask_b], train_y[ nmask_b], noisy_y[nmask_b] first_x, first_y, first_true_y = ( clean_x + raw_x, np.concatenate([clean_y, raw_y]), np.concatenate([clean_true_y, raw_true_y]), ) second_x, second_y, second_true_y = ( clean_x + change_x, np.concatenate([clean_y, change_y]), np.concatenate([clean_true_y, change_true_y]), ) first_set = (DatasetBuilder(first_x, first_true_y).add_labels( first_y, 'noisy_y').toggle_id().add_x( transform=weak).add_y().add_y(source='noisy_y')) second_set = (DatasetBuilder(second_x, second_true_y).add_labels( second_y, 'noisy_y').toggle_id().add_x( transform=weak).add_y().add_y(source='noisy_y')) self.first_dataloader = first_set.DataLoader( batch_size=params.batch_size, num_workers=params.num_workers, drop_last=True, shuffle=True) self.second_dataloader = second_set.DataLoader( batch_size=params.batch_size, num_workers=params.num_workers, drop_last=True, shuffle=True) self.second_dataloader = second_set.DataLoader( batch_size=params.batch_size, num_workers=params.num_workers, drop_last=True, shuffle=True) self.second = False self.regist_databundler(train=self.first_dataloader) self.cur_set = 0 self.to(self.device)
meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='noisy_acc') return meter def to_logits(self, xs) -> torch.Tensor: return self.model(xs) if __name__ == '__main__': params = NoisyParams() params.optim.args.lr = 0.06 params.epoch = 400 params.device = 'cuda:0' params.filter_ema = 0.999 params.echange = False # 每一个 epoch 都换 params.change = True params.noisy_ratio = 0.8 params.from_args() params.initial() trainer = MixupEvalTrainer(params) trainer.train()
self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter=meter, name='true_acc') self.acc_precise_(logits.argmax(dim=1), nys, meter=meter, name='noisy_acc') return meter if __name__ == '__main__': params = NoisyParams() params.ema = True # l2r have no ema for model params.epoch = 120 params.batch_size = 100 params.device = 'cuda:3' params.optim.args.lr = 0.1 params.meta_optim = { 'lr': 0.1, 'momentum': 0.9, } params.from_args() trainer = L2RTrainer(params) trainer.train()
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor logits = self.to_logits(axs) w_logits = self.to_logits(xs) preds = torch.softmax(logits, dim=1).detach() label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze() # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0) weight = label_pred - self.target_mem[ids] weight = weight + label_pred * 0.5 / params.n_classes - 0.375 / params.n_classes fweight = torch.ones_like(weight) if eidx >= params.burnin: fweight -= self.noisy_cls[ids] fweight[weight <= 0] -= params.gmm_w_sche(eidx) fweight = torch.relu(fweight) if params.right_n > 0: right_mask = ids < params.right_n if right_mask.any(): fweight[right_mask] = 1 raw_targets = torch.softmax(w_logits, dim=1) with torch.no_grad(): targets = self.plabel_mem[ ids] * params.targets_ema + raw_targets * (1 - params.targets_ema) self.plabel_mem[ids] = targets values, p_labels = targets.max(dim=1) mask = values > params.pred_thresh if mask.any(): self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc') mask = mask.float() # uda_mask = label_pred > params.uda_sche(eidx) # meter.uda = uda_mask.float().mean() meter.m0 = (fweight == 0).float().mean() meter.m1 = (fweight == 1).float().mean() meter.pm = mask.mean() meter.Lall = meter.Lall + self.loss_ce_with_masked_( logits, nys, fweight, meter=meter) meter.Lall = meter.Lall + self.loss_ce_with_masked_( logits, p_labels, (1 - fweight) * mask, meter=meter, name='Lpce') * params.plabel_sche(eidx) meter.tw = fweight[ys == nys].mean() meter.fw = fweight[ys != nys].mean() with torch.no_grad(): ids_mask = weight.bool() alpha = params.filter_ema if eidx < params.burnin: alpha = 0.99 self.target_mem[ids[ids_mask]] = self.target_mem[ ids[ids_mask]] * alpha + label_pred[ids_mask] * (1 - alpha) if 'Lall' in meter: self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='noisy_acc') false_pred = targets.gather( 1, nys.unsqueeze(dim=1)).squeeze() # [ys != nys] true_pred = targets.gather( 1, ys.unsqueeze(dim=1)).squeeze() # [ys != nys] with torch.no_grad(): self.true_pred_mem[ids, eidx - 1] = true_pred self.false_pred_mem[ids, eidx - 1] = false_pred self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits, nys, reduction='none') if eidx == 1: self.cls_mem[ids, 0] = ys elif eidx == 2: self.cls_mem[ids, 1] = nys else: self.cls_mem[ids, eidx - 1] = p_labels return meter
self.noisy_cls_mem = torch.tensor( noisy_cls, device=self.device) * 0.1 + self.noisy_cls_mem * 0.9 self.noisy_cls = self.noisy_cls_mem.clone() self.noisy_cls[self.noisy_cls < 0.5] = 0 # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许) self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_( params.gmm_w_sche(params.eidx)) if __name__ == '__main__': # for retry,params in enumerate(params.grid_range(5)): # for retry in range(5): # retry = retry + 1 params = NoisyParams() params.right_n = 10 params.use_right_label = False params.optim.args.lr = 0.06 params.epoch = 300 params.device = 'cuda:2' params.filter_ema = 0.999 params.burnin = 2 params.gmm_burnin = 20 params.targets_ema = 0.3 # params.tolerance_type = 'exp' params.pred_thresh = 0.9 # params.widen_factor = 10 params.from_args() params.initial() trainer = MultiHeadTrainer(params)
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor logits = self.to_logits(axs) w_logits = self.to_logits(xs) preds = torch.softmax(logits, dim=1).detach() label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze() # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0) if params.local_filter: weight = label_pred - self.target_mem[ids] weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes weight_mask = weight < 0 meter.tl = weight_mask[ys == nys].float().mean() meter.fl = weight_mask[ys != nys].float().mean() fweight = torch.ones(w_logits.shape[0], dtype=torch.float, device=device) if eidx >= params.burnin: if params.local_filter: fweight[weight_mask] -= params.gmm_w_sche(eidx) fweight -= self.noisy_cls[ids] fweight = torch.relu(fweight) if params.right_n > 0: right_mask = ids < params.right_n if right_mask.any(): fweight[right_mask] = 1 raw_targets = torch.softmax(w_logits, dim=1) with torch.no_grad(): targets = self.plabel_mem[ ids] * params.targets_ema + raw_targets * (1 - params.targets_ema) self.plabel_mem[ids] = targets # targets = self.sharpen_(targets) # targets = raw_targets n_targets = tricks.onehot(nys, params.n_classes) t_weight = fweight.unsqueeze(dim=1) # cur_targets = n_targets * t_weight + targets * (1 - t_weight) cur_ys = targets.argmax(dim=-1) # cur_targets = tricks.onehot(cur_ys, params.n_classes) id_dict = defaultdict(list) for i, ny in enumerate(nys): id_dict[int(ny)].append(i) # 保证添加进去的都是噪音标签 for i in range(params.n_classes): id_dict[i] = cycle(id_dict[i]) re_id = [] for i, cur_y in enumerate(cur_ys): cur_y = int(cur_y) try: # 如果原本标签就是干净标签,那么无所谓混合什么标签(因为干净标签占主导) # 否则,混合进去噪音标签 re_id.append(next(id_dict.get(cur_y))) except: re_id.append(np.random.randint(0, len(nys))) mixed_input, mixed_target = self.mixup_(xs, n_targets, reids=re_id) mixed_logits = self.to_logits(mixed_input) meter.Lall = meter.Lall + self.loss_ce_with_targets_( mixed_logits, mixed_target, meter=meter) meter.tw = fweight[ys == nys].mean() meter.fw = fweight[ys != nys].mean() if params.local_filter: with torch.no_grad(): ids_mask = weight_mask.logical_not() alpha = params.filter_ema if eidx < params.burnin: alpha = 0.99 self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * \ (1 - alpha) # 将没有参与的逐渐回归到 # self.target_mem[ids[weight_mask]] = self.target_mem[ids[weight_mask]] * alpha + (1 / params.n_classes) * \ # (1 - alpha) if 'Lall' in meter: self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='nacc') false_pred = targets.gather( 1, nys.unsqueeze(dim=1)).squeeze() # [ys != nys] true_pred = targets.gather( 1, ys.unsqueeze(dim=1)).squeeze() # [ys != nys] with torch.no_grad(): self.true_pred_mem[ids, eidx - 1] = true_pred self.false_pred_mem[ids, eidx - 1] = false_pred self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits, nys, reduction='none') if eidx == 1: self.cls_mem[ids, 0] = ys elif eidx == 2: self.cls_mem[ids, 1] = nys else: self.cls_mem[ids, eidx - 1] = cur_ys return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor logits = self.to_logits(axs) w_logits = self.to_logits(xs) preds = torch.softmax(logits, dim=1).detach() label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze() weight = label_pred - self.target_mem[ids] if params.tolerance_type == 'linear': weight = weight + label_pred * 0.5 / params.n_classes - 0.375 / params.n_classes elif params.tolerance_type == 'exp': exp_ratio = (torch.exp((self.target_mem[ids] - 1) * 5) - np.exp(-5) * (1 - self.target_mem[ids])) weight = weight + (params.tol_start * (1 - exp_ratio) + params.tol_end * exp_ratio) elif params.tolerance_type == 'log': log_ratio = 1 - torch.exp( -self.target_mem[ids] * 5) + np.exp(-5) * self.target_mem[ids] weight = weight + (params.tol_start * (1 - log_ratio) + params.tol_end * log_ratio) # weight[self.noisy_cls[ids] == 0] -= params.gmm_sche(eidx) weight[self.noisy_cls[ids] == 0] = 0 raw_targets = torch.softmax(w_logits, dim=1) with torch.no_grad(): targets = self.plabel_mem[ ids] * params.targets_ema + raw_targets * (1 - params.targets_ema) self.plabel_mem[ids] = targets values, p_labels = targets.max(dim=1) mask = values > params.pred_thresh if mask.any(): self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc') weight[weight > 0] = 1 if eidx < params.burnin: weight = torch.ones_like(weight) meter.m0 = (weight == 0).float().mean() meter.pm = mask.float().mean() weight_mask = weight.bool() pweight_mask = weight_mask.logical_not() & mask if weight_mask.any(): meter.Lall = meter.Lall + self.loss_ce_( logits[weight_mask], nys[weight_mask], meter=meter) if pweight_mask.any(): meter.Lall = meter.Lall + self.loss_ce_( logits[pweight_mask], p_labels[pweight_mask], meter=meter, name='Lpce') * params.plabel_sche(eidx) meter.tw = weight[ys == nys].mean() meter.fw = weight[ys != nys].mean() with torch.no_grad(): ids_mask = weight.bool() alpha = params.filter_ema if eidx < params.burnin: alpha = 0.99 self.target_mem[ids[ids_mask]] = self.target_mem[ ids[ids_mask]] * alpha + label_pred[ids_mask] * (1 - alpha) if 'Lall' in meter: self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='noisy_acc') false_pred = targets.gather( 1, nys.unsqueeze(dim=1)).squeeze() # [ys != nys] true_pred = targets.gather( 1, ys.unsqueeze(dim=1)).squeeze() # [ys != nys] with torch.no_grad(): self.true_pred_mem[ids] = true_pred self.false_pred_mem[ids, eidx - 1] = false_pred self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits, nys, reduction='none') if eidx == 1: self.cls_mem[ids, 0] = ys elif eidx == 2: self.cls_mem[ids, 1] = nys else: self.cls_mem[ids, eidx - 1] = p_labels return meter
m = self.acc_mixture_(true_cls, noisy_cls) meter.update(m) self.logger.info(m) if params.eidx > params.gmm_burnin: self.noisy_cls = torch.tensor(noisy_cls, device=self.device) def to_logits(self, xs) -> torch.Tensor: return self.model(xs) if __name__ == '__main__': # for retry,params in enumerate(params.grid_range(5)): # for retry in range(5): # retry = retry + 1 params = NoisyParams() params.right_n = 10 params.optim.args.lr = 0.06 params.epoch = 400 params.device = 'cuda:0' params.filter_ema = 0.999 params.burnin = 2 params.gmm_burnin = 10 params.targets_ema = 0.3 params.pred_thresh = 0.9 params.from_args() params.initial() trainer = MultiHeadTrainer(params) if params.ss_pretrain: