def train_batch(self, eidx, idx, global_step, batch_data, params: _trainer_name_Param, device: torch.device): super().train_batch(eidx, idx, global_step, batch_data, params, device) meter = Meter() model = self.model optim = self.optim xs, ys = batch_data xs, ys = xs.to(device), ys.to(device) logits = model(xs) meter.loss = self.lossf(logits, ys) optim.zero_grad() meter.loss.backward() optim.step() return meter
def meta_optimizer(self, xs: torch.Tensor, vxs: torch.Tensor, vys: torch.Tensor, meter: Meter): """ 先使用训练集数据和初始化权重更新一次,将权重封在 MetaNet 中,随后计算验证集梯度,然后求参数对初始化权重的梯度 :param xs: :param guess_targets: :param n_targets: :param vxs: :param vys: :param meter: :return: """ metanet, metasgd = self.create_metanet() metanet.zero_grad() mid_logits = metanet(xs) cls_center = autograd.Variable(self.cls_center, requires_grad=True) left, right = tricks.cartesian_product(mid_logits, cls_center) dist_ = F.pairwise_distance(left, right).reshape(mid_logits.shape[0], -1) dist_targets = torch.softmax(dist_, dim=-1) dist_loss = self.loss_ce_with_targets_(metanet.fc(mid_logits), dist_targets) var_grads = autograd.grad(dist_loss, metanet.params(), create_graph=True) # metanet.update_params(0.1, var_grads) metasgd.meta_step(var_grads) m_v_logits = metanet.fc(metanet(vxs)) # type:torch.Tensor meta_loss = self.loss_ce_(m_v_logits, vys) # method A # grad_meta_vars = autograd.grad(meta_loss, metanet.params(), create_graph=True) # grad_target, grad_eps = autograd.grad( # metanet.params(), [cls_center, eps_k], grad_outputs=grad_meta_vars) # method B grad_target, = autograd.grad(meta_loss, [cls_center]) with torch.no_grad(): self.cls_center = self.cls_center - grad_target * 0.4 self.acc_precise_(m_v_logits.argmax(dim=-1), vys, meter=meter, name='Macc') meter.LMce = meta_loss.detach()
def on_train_epoch_end(self, trainer: 'NoisyTrainer', func, params: GmaParams, meter: Meter, *args, **kwargs): filter_mem = os.path.join(self.experiment.test_dir, 'filter_mem_{}.pth'.format(params.eidx)) local_noisy_cls = os.path.join( self.experiment.test_dir, 'local_noisy_cls_{}.pth'.format(params.eidx)) noisy_cls = os.path.join(self.experiment.test_dir, 'noisy_cls_{}.pth'.format(params.eidx)) false_pred_mem = os.path.join(self.experiment.test_dir, 'false_pred_mem.pth') torch.save(self.filter_mem, filter_mem) torch.save(self.local_noisy_cls, local_noisy_cls) torch.save(self.noisy_cls, noisy_cls) torch.save(self.false_pred_mem, false_pred_mem) with torch.no_grad(): id_ = 0 # max(0, params.eidx - 5) self.logger.info('f_mean left id', id_) f_mean = self.false_pred_mem[:, id_:params.eidx].mean( dim=1).cpu().numpy() f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy() feature = self.create_feature(f_mean, f_mean) noisy_cls = self.bmm_predict(feature, mean=params.feature_mean, offset=params.offset_sche( params.eidx)) # noisy_cls = self.gmm_predict(feature) self.noisy_cls_mem = torch.tensor( noisy_cls, device=self.device) * 0.2 + self.noisy_cls_mem * 0.8 if params.eidx <= 40 and False: self.noisy_cls = torch.tensor(noisy_cls, device=self.device) self.noisy_cls_mem = torch.tensor(noisy_cls, device=self.device) else: self.noisy_cls = self.noisy_cls_mem.clone() meter.mm = (self.noisy_cls > 0.5).float().mean() self.logger.info('mm raw ratio', (noisy_cls > 0.5).mean(), 'mm ratio', meter.mm) if params.eidx > params.burnin: # self.filter_mem = self.local_noisy_cls - self.noisy_cls self.filter_mem = torch.ones( self.train_size, dtype=torch.float, device=self.device) - self.noisy_cls self.logger.info('filter', (self.filter_mem > 0.5).float().mean()) values, _ = self.plabel_mem.max(dim=-1) self.logger.info('plabel filter', (values > params.pred_thresh).float().mean())
def train_batch(self, eidx, idx, global_step, batch_data, params: SupConParams, device: torch.device): super().train_batch(eidx, idx, global_step, batch_data, params, device) meter = Meter() _, xs, axs, ys = batch_data # type:torch.Tensor b, c, h, w = xs.size() input_ = torch.cat([xs.unsqueeze(1), axs.unsqueeze(1)], dim=1) input_ = input_.view(-1, c, h, w) features = self.to_logits(input_).view(b, 2, -1) meter.Lall = meter.Lall + self.loss_supcon_(features, ys, temperature=params.temperature, meter=meter) self.optim.zero_grad() meter.Lall.backward() self.optim.step() return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: GlobalParams, device: torch.device): super().train_batch(eidx, idx, global_step, batch_data, params, device) meter = Meter() xs, ys = batch_data logits = self.to_logits(xs) meter.Lall = meter.Lall + self.loss_ce_( logits, ys, meter=meter, name='Lce') self.any_() self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='acc') return meter
def warmup_model(self, batch_data, model, optim, meter: Meter): (ids, xs, nys) = batch_data # type:torch.Tensor optim.zero_grad() logits = model(xs) # type:torch.Tensor meter.Lall = meter.Lall + self.loss_ce_( logits, nys, meter=meter, name='Lce') self.acc_precise_(logits.argmax(dim=-1), nys, meter=meter, name='acc') meter.Lall.backward() optim.step()
def train_batch(self, eidx, idx, global_step, batch_data, params: GmaParams, device: torch.device): meter = Meter() if eidx <= 10 or self.ssl_dataloader is None: self.warn_up(eidx, batch_data, meter) else: try: batch_data = next(self.ssl_loaderiter) except: self.ssl_loaderiter = iter(self.ssl_dataloader) batch_data = next(self.ssl_loaderiter) self.ssl(eidx, batch_data, meter) return meter
def test_eval_logic(self, dataloader, param: Params): from thexp.calculate import accuracy as acc param.topk = param.default([1, 5]) with torch.no_grad(): count_dict = Meter() for xs, labels in dataloader: preds = self.predict(xs) total, topk_res = acc.classify(preds, labels, topk=param.topk) count_dict["total"] += total for i, topi_res in zip(param.topk, topk_res): count_dict["top{}".format(i)] += topi_res return count_dict
def on_train_epoch_end(self, trainer: 'NoisyTrainer', func, params: NoisyParams, meter: Meter, *args, **kwargs): with torch.no_grad(): f_mean = self.false_pred_mem[:, :params.eidx].mean( dim=1).cpu().numpy() f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy() feature = self.create_feature(f_mean, f_cur) noisy_cls = self.bmm_predict(feature, mean=params.feature_mean) true_cls = (self.true_pred_mem == self.false_pred_mem).all(dim=1).cpu().numpy() m = self.acc_mixture_(true_cls, noisy_cls) meter.update(m) self.logger.info(m) if params.eidx > params.mix_burnin: if params.mixt_ema: self.noisy_cls_mem = torch.tensor(noisy_cls, device=self.device) * 0.1 + self.noisy_cls_mem * 0.9 self.noisy_cls = self.noisy_cls_mem.clone() else: self.noisy_cls = torch.tensor(noisy_cls, device=self.device)
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() # 注意,ys 为训练集的真实label,只用于计算准确率,不用于训练过程 (ids, xs, axs, ys, nys), (vxs, vys) = batch_data w_logits = self.model(xs) logits = self.model(axs) # type:torch.Tensor p_targets = self.sharpen_(torch.softmax(w_logits, dim=1)) weight = self.meta_optimizer(xs, nys, vxs, vys, meter=meter) meter.Lall = meter.Lall + self.loss_ce_with_masked_( logits, nys, weight, meter=meter, name='Lwce') meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_( logits, p_targets, (1 - weight), meter=meter, name='Lwce') meter.tw = weight[ys == nys].mean() meter.fw = weight[ys != nys].mean() self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter=meter, name='true_acc') self.acc_precise_(logits.argmax(dim=1), nys, meter=meter, name='noisy_acc') return meter
def unsupervised_loss(self, xs: torch.Tensor, axs: torch.Tensor, vxs: torch.Tensor, vys: torch.Tensor, logits_lis: List[torch.Tensor], meter: Meter): '''create Lub, Lpb, Lkl''' logits_lis = [self.logit_norm_(logits) for logits in logits_lis] p_target = self.label_guesses_(*logits_lis) p_target = self.sharpen_(p_target, params.T) re_v_targets = tricks.onehot(vys, params.n_classes) mixed_input, mixed_target = self.mixmatch_up_(vxs, [axs], re_v_targets, p_target, beta=params.mix_beta) mixed_logits = self.to_logits(mixed_input) mixed_logits_lis = mixed_logits.split_with_sizes( [vxs.shape[0], axs.shape[0]]) (mixed_v_logits, mixed_nn_logits) = [self.logit_norm_(l) for l in mixed_logits_lis] # type:torch.Tensor # mixed_nn_logits = torch.cat([mixed_n_logits, mixed_an_logits], dim=0) mixed_v_targets, mixed_nn_targets = mixed_target.split_with_sizes( [mixed_v_logits.shape[0], mixed_nn_logits.shape[0]]) # Lpβ,验证集作为半监督中的有标签数据集 meter.Lall = meter.Lall + self.loss_ce_with_targets_( mixed_v_logits, mixed_v_targets, meter=meter, name='Lpb') * params.semi_sche(params.eidx) # p * Luβ,训练集作为半监督中的无标签数据集 if params.lub: meter.Lall = meter.Lall + self.loss_ce_with_targets_( mixed_nn_logits, mixed_nn_targets, meter=meter, name='Lub') * params.semi_sche(params.eidx) # Lkl,对多次增广的一致性损失 return p_target
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor logits = self.to_logits(xs) meter.Lall = meter.Lall = self.loss_ce_(logits, nys, meter=meter) self.optim.zero_grad() meter.Lall.backward() self.optim.step() preds_ = logits.detach().clone() # torch.softmax(logits, dim=-1) # old_preds = self.pred_mem[ids] # residual = preds - old_preds _mask = torch.arange(params.n_classes).unsqueeze(0).repeat( [nys.shape[0], 1]).to(device) mask = _mask[_mask == nys.unsqueeze(1)].view(nys.shape[0], -1) # residual = residual.gather(1, mask) # old_preds = torch.softmax(old_preds.gather(1, mask), dim=-1) # preds = torch.softmax(preds_.gather(1, mask), dim=-1) # residual = preds - old_preds # residual = torch.pow(residual, 2) self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='nacc') self.pred_mem[ids] += preds_.detach() # self.sum_mem[ids] += residual.detach() # if eidx == 1: # else: # self.pred_mem[ids] = self.pred_mem[ids] * 0.9 + logits.detach() * 0.1 # meter.nres = residual[n_mask].mean() # meter.nsum = self.sum_mem[ids][n_mask].mean() # meter.tres = residual[n_mask.logical_not()].mean() # meter.tsum = self.sum_mem[ids][n_mask.logical_not()].mean() self.acc_precise_(self.pred_mem[ids].argmax(dim=1), ys, meter, name='sacc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: CoRandomParams, device: torch.device): meter = Meter() if eidx < params.warm_up: if eidx % 2 == 0: self.warmup_model(batch_data, self.model, self.optim, self.probs1, meter) else: self.warmup_model(batch_data, self.model2, self.optim2, self.probs2, meter) else: if eidx % 2 == 0: self.train_model(batch_data, self.model, self.optim, self.pys2, self.probs1, meter) else: self.train_model(batch_data, self.model2, self.optim2, self.pys1, self.probs2, meter) return meter
def meter(self, ratio=True): from thexp import Meter meter = Meter() for key, offset in self.times.items(): if ratio: if isinstance(offset, str): continue meter[key] = offset / self.times['use'] else: meter[key] = offset return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: OffsetParams, device: torch.device): meter = Meter() if params.epoch - eidx > params.offset_epoch: self.train_first(eidx, self.model, self.optim, batch_data, meter) if eidx - params.offset_epoch > 0: if params.epoch - eidx > params.offset_epoch: self.train_second(eidx, self.model2, batch_data, meter) else: self.train_first(eidx, self.model2, self.optim2, batch_data, meter) return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: MentorNetParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor # Basic parameter for mentornet _epoch_step = params.epoch_step(eidx) _zero_labels = torch.zeros_like(nys) _loss_p_percentile = torch.ones( 100, dtype=torch.float) * params.loss_p_percentile _dropout_rates = self.parse_dropout_rate_list_() logits = self.to_logits(xs) _basic_losses = F.cross_entropy(logits, ys, reduction='none').detach() weight = torch.rand_like(nys).detach() # TODO replace to mentornet ntargets = tricks.onehot(nys, params.n_classes) mixed_xs, mixed_targets = self.mentor_mixup_( xs, ntargets, weight.detach().cpu().numpy()) mixed_logits = self.to_logits(mixed_xs) _mixed_losses = torch.sum(F.log_softmax(mixed_logits, dim=1) * mixed_targets, dim=1) mix_weight = torch.rand_like(nys).detach() # TODO replace to mentornet meter.Lall = torch.mean(_mixed_losses * mix_weight) self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc') self.acc_precise_(logits.argmax(dim=1), nys, meter, name='noisy_acc') return meter
def on_train_epoch_end(self, trainer: 'NoisyTrainer', func, params: NoisyParams, meter: Meter, *args, **kwargs): true_f = os.path.join(self.experiment.test_dir, 'true.pth') false_f = os.path.join(self.experiment.test_dir, 'false.pth') loss_f = os.path.join(self.experiment.test_dir, 'loss.pth') cls_f = os.path.join(self.experiment.test_dir, 'cls.pth') if params.eidx % 10 == 0: torch.save(self.true_pred_mem, true_f) torch.save(self.false_pred_mem, false_f) torch.save(self.loss_mem, loss_f) torch.save(self.cls_mem, cls_f) with torch.no_grad(): f_mean = self.false_pred_mem[:, max(params.eidx - params.gmm_burnin, 0):params. eidx].mean(dim=1).cpu().numpy() f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy() feature = self.create_feature(f_mean, f_cur) noisy_cls = self.gmm_predict(feature) true_cls = (self.true_pred_mem == self.false_pred_mem).all( dim=1).cpu().numpy() m = self.acc_mixture_(true_cls, noisy_cls) meter.update(m) self.logger.info(m) if params.eidx > params.gmm_burnin: self.noisy_cls_mem = torch.tensor( noisy_cls, device=self.device) * 0.1 + self.noisy_cls_mem * 0.9 self.noisy_cls = self.noisy_cls_mem.clone() self.noisy_cls[self.noisy_cls < 0.5] = 0 # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许) self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_( params.gmm_w_sche(params.eidx))
def on_train_epoch_end(self, trainer: 'NoisyTrainer', func, params: GmaParams, meter: Meter, *args, **kwargs): true_f = os.path.join(self.experiment.test_dir, 'true.pth') false_f = os.path.join(self.experiment.test_dir, 'false.pth') loss_f = os.path.join(self.experiment.test_dir, 'loss.pth') if params.eidx % 10 == 0: torch.save(self.true_pred_mem, true_f) torch.save(self.false_pred_mem, false_f) torch.save(self.loss_mem, loss_f) with torch.no_grad(): f_mean = self.false_pred_mem[:, :params.eidx].mean( dim=1).cpu().numpy() f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy() feature = self.create_feature(f_mean, f_cur) noisy_cls = self.bmm_predict(feature, mean=params.feature_mean, offset=params.offset_sche( params.eidx)) if params.eidx > 1: self.noisy_cls_mem = torch.tensor( noisy_cls, device=self.device) * 0.1 + self.noisy_cls_mem * 0.9 true_cls = (self.true_pred_mem == self.false_pred_mem).all( dim=1).cpu().numpy() m = self.acc_mixture_(true_cls, self.noisy_cls_mem.cpu().numpy()) meter.update(m) self.logger.info(m) if params.eidx > params.mix_burnin: if params.mixt_ema: self.noisy_cls = self.noisy_cls_mem.clone() else: self.noisy_cls = torch.tensor(noisy_cls, device=self.device)
def on_train_epoch_end(self, trainer: 'NoisyTrainer', func, params: NoisyParams, meter: Meter, *args, **kwargs): with torch.no_grad(): f_mean = self.false_pred_mem[:, :params.eidx].mean( dim=1).cpu().numpy() f_cur = self.false_pred_mem[:, params.eidx - 1].cpu().numpy() feature = self.create_feature(f_mean, f_cur) noisy_cls = self.bmm_predict(feature, mean=params.feature_mean) true_cls = (self.true_pred_mem == self.false_pred_mem).all( dim=1).cpu().numpy() m = self.acc_mixture_(true_cls, noisy_cls) meter.update(m) self.logger.info(m) if params.eidx > params.mix_burnin: if params.mixt_ema: self.noisy_cls_mem = torch.tensor( noisy_cls, device=self.device) * 0.1 + self.noisy_cls_mem * 0.9 self.noisy_cls = self.noisy_cls_mem.clone() else: self.noisy_cls = torch.tensor(noisy_cls, device=self.device) # self.noisy_cls[self.noisy_cls < 0.5] = 0 # 随时间推移,越难以区分的样本越应该直接挂掉,而不是模糊来模糊去的加权(或许) # self.noisy_cls[self.noisy_cls >= 0.5].clamp_min_(params.gmm_w_sche(params.eidx)) m2 = self.acc_mixture_(true_cls, (self.count_mem >= 0).cpu().numpy(), pre='con') meter.update(m) self.logger.info(m2)
def train_batch(self, eidx, idx, global_step, batch_data, params: ICTParams, device: torch.device): super().train_batch(eidx, idx, global_step, batch_data, params, device) meter = Meter() sup, unsup = batch_data xs, ys = sup _, un_xs, _, un_ys = unsup logits_list = self.to_logits(torch.cat([xs, un_xs])).split_with_sizes( [xs.shape[0], un_xs.shape[0]]) logits, un_logits = logits_list # type:torch.Tensor mixed_xs, ys_a, ys_b, lam_sup = self.ict_mixup_(xs, ys) logits = self.to_logits(xs) mixed_logits = self.to_logits(mixed_xs) ema_un_logits = self.predict(un_xs) mixed_un_xs, ema_mixed_un_logits, _ = self.mixup_unsup_(un_xs, ema_un_logits, mix=True) mixed_un_logits = self.to_logits(mixed_un_xs) meter.Lall = meter.Lall + self.loss_mixup_sup_ce_(mixed_logits, ys_a, ys_b, lam_sup, meter=meter) meter.Lall = meter.Lall + self.loss_mixup_unsup_mse_(mixed_un_logits, ema_un_logits, decay=params.mixup_consistency_sche(eidx), meter=meter) self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter) self.acc_precise_(un_logits.argmax(dim=1), un_ys, meter, name='unacc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: MetaSSLParams, device: torch.device): meter = Meter() sup, unsup = batch_data xs, ys = sup ids, un_xs, un_axs, un_ys = unsup if eidx < 2: logits = self.to_logits(xs) meter.Lall = meter.Lall + self.loss_ce_( logits, ys, meter=meter, name='Lce') self.acc_precise_(logits.argmax(dim=1), ys, meter, name='acc') else: un_logits = self.to_logits(un_xs) weight = self.meta_optimizer(un_xs, xs, ys, meter=meter) # meter.Lall = meter.Lall + self.loss_ce_(self.to_logits(xs), ys, meter=meter, name='Lce') meter.Lall = meter.Lall + self.loss_ce_with_masked_( un_logits, un_logits.argmax(dim=1), weight, meter=meter, name='Ldce') # Lw* self.acc_precise_(un_logits.argmax(dim=1), un_ys, meter, name='un_acc') self.optim.zero_grad() meter.Lall.backward() self.optim.step() # self.acc_precise_(dist_targets.argmax(dim=1), un_ys, meter=meter, name='dist_acc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: MetaSSLParams, device: torch.device): meter = Meter() sup, unsup = batch_data xs, ys = sup ids, un_xs, un_axs, un_ys = unsup mid_logits = self.to_mid(un_xs) self.meta_optimizer(un_xs, xs, ys, meter=meter) left, right = tricks.cartesian_product(mid_logits, self.cls_center) dist_ = F.pairwise_distance(left, right).reshape(mid_logits.shape[0], -1) dist_targets = torch.softmax(dist_, dim=-1) self.sharpen_(dist_targets) logits = self.to_logits(un_axs) # meter.Lall = meter.Lall + self.loss_ce_(self.to_logits(xs), ys, meter=meter, name='Lce') meter.Lall = meter.Lall + self.loss_ce_with_targets_( logits, dist_targets, meter=meter, name='Ldce') # Lw* self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(dist_targets.argmax(dim=1), un_ys, meter=meter, name='dist_acc') self.acc_precise_(logits.argmax(dim=1), un_ys, meter, name='un_acc') self.acc_precise_(logits.argmax(dim=1), un_ys, meter, name='un_acc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor logits = self.to_logits(xs) if params.smooth: if params.smooth_ratio != -1: ns_targets = tricks.onehot(nys, params.n_classes) if params.smooth_argmax: targets = tricks.onehot(logits.argmax(dim=-1), params.n_classes) else: rys = ys.clone() rids = torch.randperm(len(rys)) if params.smooth_ratio > 0: rids = rids[:int((len(rids) * params.smooth_ratio))] rys[rids] = torch.randint(0, params.n_classes, [len(rids)], device=device) targets = tricks.onehot(rys, params.n_classes) if params.smooth_mixup: l = np.random.beta(0.75, 0.75, size=targets.shape[0]) l = np.max([l, 1 - l], axis=0) l = torch.tensor(l, device=device, dtype=torch.float).unsqueeze(1) else: l = 0.9 ns_targets = ns_targets * l + targets * (1 - l) else: ns_targets = tricks.label_smoothing(tricks.onehot(nys, params.n_classes)) meter.Lall = meter.Lall + self.loss_ce_with_targets_(logits, ns_targets, meter=meter) else: meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter) self.optim.zero_grad() meter.Lall.backward() self.optim.step() with torch.no_grad(): nlogits = self.to_logits(xs) res = torch.softmax(nlogits, dim=-1) - torch.softmax(logits, dim=-1) meter.nres = res.gather(1, nys.unsqueeze(1)).mean() * 10 meter.tres = res.gather(1, ys.unsqueeze(1)).mean() * 100 meter.rres = res.gather(1, torch.randint_like(ys, 0, params.n_classes).unsqueeze(1)).mean() * 100 self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='noisy_acc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: SupervisedParams, device: torch.device): super().train_batch(eidx, idx, global_step, batch_data, params, device) meter = Meter() ids, xs, axs, ys = batch_data # type:torch.Tensor targets = tricks.onehot(ys, params.n_classes) mixed_xs, mixed_targets = self.mixup_(xs, targets) mixed_logits = self.to_logits(mixed_xs) meter.Lall = meter.Lall + self.loss_ce_with_targets_( mixed_logits, mixed_targets, meter=meter) with torch.no_grad(): self.acc_precise_(self.to_logits(xs).argmax(dim=1), ys, meter=meter, name='acc') self.optim.zero_grad() meter.Lall.backward() self.optim.step() return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor logits = self.to_logits(axs) self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='noisy_acc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: DivideMixParams, device: torch.device): meter = Meter() if eidx < params.warm_up: if eidx % 2 == 0: self.warmup_model(batch_data, self.model, self.optim, meter) else: self.warmup_model(batch_data, self.model2, self.optim2, meter) else: if eidx % 2 == 0: self.train_model(self.model, self.model2, self.optim, batch_data, params, meter) else: self.train_model(self.model2, self.model, self.optim2, batch_data, params, meter) return meter
def train_model(self, batch_data, model, optim, pys, probs, meter: Meter): (ids, xs, axs, ys, _) = batch_data # type:torch.Tensor nys = pys[ids] optim.zero_grad() logits = model(xs) # type:torch.Tensor meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter, name='Lce') meter.Lall.backward() optim.step() with torch.no_grad(): # EMA logits = model(xs) probs[ids] = probs[ids] * 0.3 + torch.softmax(logits, dim=-1) * 0.7 n_mask = (nys != ys) self.acc_precise_(logits.argmax(dim=-1), ys, meter=meter, name='acc') self.acc_precise_(logits[n_mask].argmax(dim=-1), nys[n_mask], meter=meter, name='nacc') self.acc_precise_(probs[ids].argmax(dim=-1), ys, meter=meter, name='pacc')
def warmup_model(self, batch_data, model, head, optim, probs, meter: Meter): """最初的测试,观察模型能够从 ys 和 nys 中学到正确的部份""" (ids, xs, axs, ys, nys) = batch_data # type:torch.Tensor optim.zero_grad() features = model(xs) # type:torch.Tensor n_mask = (ys != nys) outs = head(features)[:3] meter.Lall = meter.Lall + self.loss_ce_(outs[0], nys, meter=meter, name='Lce') meter.Lall.backward() optim.step() with torch.no_grad(): # EMA probs[ids] = probs[ids] * 0.3 + torch.softmax(outs[0], dim=-1) * 0.7 # # n_mask = (nys != ys) self.acc_precise_(outs[0].argmax(dim=-1), ys, meter=meter, name='acc') self.acc_precise_(outs[0][n_mask].argmax(dim=-1), nys[n_mask], meter=meter, name='nacc')
def train_first(self, eidx, model, optim, batch_data, meter: Meter): ids, xs, axs, ys, nys = batch_data # type:torch.Tensor logits = model(axs) w_logits = model(xs) preds = torch.softmax(logits, dim=1).detach() label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze() # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0) if params.local_filter: weight = label_pred - self.target_mem[ids] weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes weight_mask = weight < 0 meter.tl = weight_mask[ys == nys].float().mean() meter.fl = weight_mask[ys != nys].float().mean() fweight = torch.ones(w_logits.shape[0], dtype=torch.float, device=self.device) if eidx >= params.mix_burnin: if params.local_filter: fweight[weight_mask] -= params.gmm_w_sche(eidx) fweight -= self.noisy_cls[ids] fweight = torch.relu(fweight) self.filter_mem[ids] = fweight raw_targets = torch.softmax(w_logits, dim=1) with torch.no_grad(): targets = self.plabel_mem[ ids] * params.targets_ema + raw_targets * (1 - params.targets_ema) self.plabel_mem[ids] = targets top_values, top_indices = targets.topk(2, dim=-1) p_labels = top_indices[:, 0] values = top_values[:, 0] # ratio = params.smooth_ratio_sche(eidx) mask = values > params.pred_thresh if mask.any(): self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc') n_targets = tricks.onehot(nys, params.n_classes) p_targets = tricks.onehot(p_labels, params.n_classes) mask = mask.float() meter.pm = mask.mean() meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_( logits, n_targets, fweight, meter=meter) meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_( logits, p_targets, (1 - fweight) * values * mask, meter=meter, name='Lpce') * params.plabel_sche(eidx) meter.tw = fweight[ys == nys].mean() meter.fw = fweight[ys != nys].mean() if params.local_filter: with torch.no_grad(): ids_mask = weight_mask.logical_not() alpha = params.filter_ema self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * \ (1 - alpha) # 将没有参与的逐渐回归到 # self.target_mem[ids[weight_mask]] = self.target_mem[ids[weight_mask]] * alpha + (1 / params.n_classes) * \ # (1 - alpha) if 'Lall' in meter: optim.zero_grad() meter.Lall.backward() optim.step() self.acc_precise_(w_logits.argmax(dim=1), ys, meter, name='tacc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='nacc') false_pred = targets.gather( 1, nys.unsqueeze(dim=1)).squeeze() # [ys != nys] true_pred = targets.gather( 1, ys.unsqueeze(dim=1)).squeeze() # [ys != nys] with torch.no_grad(): self.true_pred_mem[ids, eidx - 1] = true_pred self.false_pred_mem[ids, eidx - 1] = false_pred self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits, nys, reduction='none') # mem_mask = ids < self.pred_mem_size - 1 # self.pred_mem[ids[mem_mask], :, eidx - 1] = targets[ids[mem_mask]] if eidx == 1: self.cls_mem[ids, 0] = ys elif eidx == 2: self.cls_mem[ids, 1] = nys else: self.cls_mem[ids, eidx - 1] = p_labels
distributed under the GNU General Public License, please contact [email protected] to purchase a commercial license. """ import sys sys.path.insert(0,"../") from thexp import __VERSION__ print(__VERSION__) from thexp import Meter,AvgMeter import torch m = Meter() m.a = 1 m.b = "2" m.c = torch.rand(1)[0] m.c1 = torch.rand(1) m.c2 = torch.rand(2) m.c3 = torch.rand(4, 4) print(m) m = Meter() m.a = 0.236 m.b = 3.236 m.c = 0.23612312 m.percent(m.a_) m.int(m.b_)