def train_model(self, batch_data, model, head, optim, probs, pys, meter: Meter): (ids, xs, axs, ys, nys) = batch_data # type:torch.Tensor optim.zero_grad() pys = pys[ids] features = model(xs) # type:torch.Tensor n_mask = (ys != nys) outs = head(features)[:params.head] meter.Lall = meter.Lall + self.loss_ce_(outs[0], nys, meter=meter, name='Lce') n_targets = tricks.onehot(nys, params.n_classes) p_targets = tricks.onehot(pys, params.n_classes) mixed_xs1, mixed_targets1 = self.mixup_(xs, n_targets, target_b=p_targets) mixed_xs2, mixed_targets2 = self.mixup_(xs, p_targets, target_b=n_targets) for i in range(1, params.head): meter.Lall = meter.Lall + self.loss_ce_(outs[i], pys[:, i - 1], meter=meter, name='Lfce{}'.format(i)) # EMA meter.Lall.backward() optim.step() with torch.no_grad(): # EMA probs[ids] = probs[ids] * 0.3 + torch.softmax(outs[0], dim=-1) * 0.7 # # n_mask = (nys != ys) self.acc_precise_(outs[0].argmax(dim=-1), ys, meter=meter, name='acc') self.acc_precise_(outs[0][n_mask].argmax(dim=-1), nys[n_mask], meter=meter, name='nacc') self.acc_precise_(outs[1].argmax(dim=-1), ys, meter=meter, name='pacc')
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() # 注意,ys 为训练集的真实label,只用于计算准确率,不用于训练过程 (ids, xs, axs, ys, nys), (vxs, vys) = batch_data w_logits = self.model(xs) logits = self.model(axs) # type:torch.Tensor p_targets = self.sharpen_(torch.softmax(w_logits, dim=1)) weight = self.meta_optimizer(xs, nys, vxs, vys, meter=meter) meter.Lall = meter.Lall + self.loss_ce_with_masked_( logits, nys, weight, meter=meter, name='Lwce') meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_( logits, p_targets, (1 - weight), meter=meter, name='Lwce') meter.tw = weight[ys == nys].mean() meter.fw = weight[ys != nys].mean() self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter=meter, name='true_acc') self.acc_precise_(logits.argmax(dim=1), nys, meter=meter, name='noisy_acc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: SupervisedParams, device: torch.device): super().train_batch(eidx, idx, global_step, batch_data, params, device) meter = Meter() ids, xs, axs, ys = batch_data # type:torch.Tensor logits = self.to_logits(xs) if eidx < 10: # mask = (ids % 1000) < eidx targets = torch.ones(xs.shape[0], params.n_classes, device=device) / params.n_classes # targets[mask] = tricks.onehot(ys[mask], params.n_classes) meter.Lall = meter.Lall + self.loss_ce_with_targets_( logits, targets, meter=meter) else: meter.Lall = meter.Lall + self.loss_ce_(logits, ys, meter=meter) self.acc_precise_(logits.argmax(dim=1), ys, meter, name='acc') self.optim.zero_grad() meter.Lall.backward() self.optim.step() return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: PencilParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor logits = self.to_logits(xs) if eidx < params.stage1: # lc is classification loss Lce = self.loss_ce_(logits, nys, meter=meter, name='Lce') # init y_tilde, let softmax(y_tilde) is noisy labels noisy_targets = tricks.onehot(nys, params.n_classes) self.target_mem[ids] = noisy_targets else: yy = self.target_mem[ids] yy = torch.autograd.Variable(yy, requires_grad=True) # obtain label distributions (y_hat) last_y_var = torch.softmax(yy, dim=1) Lce = self.loss_ce_with_lc_targets_(logits, last_y_var, meter=meter, name='Lce') Lco = self.loss_ce_(last_y_var, nys, meter=meter, name='Lco') # le is entropy loss Lent = self.loss_ent_(logits, meter=meter, name='Lent') if eidx < params.stage1: meter.Lall = Lce elif eidx < params.stage2: meter.Lall = Lce + params.alpha * Lco + params.beta * Lent else: meter.Lall = Lce self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc') self.acc_precise_(logits.argmax(dim=1), nys, meter, name='noisy_acc') if eidx >= params.stage1 and eidx < params.stage2: self.target_mem[ ids] = self.target_mem[ids] - params.lambda1 * yy.grad.data self.acc_precise_(self.target_mem[ids].argmax(dim=1), ys, meter, name='check_acc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor logits = self.to_logits(xs) if params.smooth: if params.smooth_ratio != -1: ns_targets = tricks.onehot(nys, params.n_classes) if params.smooth_argmax: targets = tricks.onehot(logits.argmax(dim=-1), params.n_classes) else: rys = ys.clone() rids = torch.randperm(len(rys)) if params.smooth_ratio > 0: rids = rids[:int((len(rids) * params.smooth_ratio))] rys[rids] = torch.randint(0, params.n_classes, [len(rids)], device=device) targets = tricks.onehot(rys, params.n_classes) if params.smooth_mixup: l = np.random.beta(0.75, 0.75, size=targets.shape[0]) l = np.max([l, 1 - l], axis=0) l = torch.tensor(l, device=device, dtype=torch.float).unsqueeze(1) else: l = 0.9 ns_targets = ns_targets * l + targets * (1 - l) else: ns_targets = tricks.label_smoothing(tricks.onehot(nys, params.n_classes)) meter.Lall = meter.Lall + self.loss_ce_with_targets_(logits, ns_targets, meter=meter) else: meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter) self.optim.zero_grad() meter.Lall.backward() self.optim.step() with torch.no_grad(): nlogits = self.to_logits(xs) res = torch.softmax(nlogits, dim=-1) - torch.softmax(logits, dim=-1) meter.nres = res.gather(1, nys.unsqueeze(1)).mean() * 10 meter.tres = res.gather(1, ys.unsqueeze(1)).mean() * 100 meter.rres = res.gather(1, torch.randint_like(ys, 0, params.n_classes).unsqueeze(1)).mean() * 100 self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='noisy_acc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: FixMatchParams, device: torch.device): super().train_batch(eidx, idx, global_step, batch_data, params, device) meter = Meter() sup, unsup = batch_data xs, ys = sup _, un_xs, un_axs, un_ys = unsup logits_list = self.to_logits(torch.cat([ xs, un_xs, un_axs ])).split_with_sizes([xs.shape[0], un_xs.shape[0], un_xs.shape[0]]) logits, un_w_logits, un_s_logits = logits_list # type:torch.Tensor pseudo_targets = torch.softmax(un_w_logits, dim=-1) max_probs, un_pseudo_labels = torch.max(pseudo_targets, dim=-1) mask = (max_probs > params.pred_thresh) if mask.any(): self.acc_precise_(un_w_logits.argmax(dim=1)[mask], un_ys[mask], meter, name='umacc') meter.Lall = meter.Lall + self.loss_ce_( logits, ys, meter=meter, name='Lx') meter.Lall = meter.Lall + self.loss_ce_with_masked_( un_s_logits, un_pseudo_labels, mask.float(), meter=meter, name='Lu') * params.lambda_u self.optim.zero_grad() meter.Lall.backward() self.optim.step() meter.masked = mask.float().mean() self.acc_precise_(logits.argmax(dim=1), ys, meter) self.acc_precise_(un_w_logits.argmax(dim=1), un_ys, meter, name='uwacc') self.acc_precise_(un_s_logits.argmax(dim=1), un_ys, meter, name='usacc') return meter
def warmup_model(self, batch_data, model, optim, meter: Meter): (ids, xs, axs, ys, nys) = batch_data # type:torch.Tensor optim.zero_grad() logits = model(xs) # type:torch.Tensor meter.Lall = meter.Lall + self.loss_ce_( logits, nys, meter=meter, name='Lce') if params.noisy_type == 'asymmetric': # penalize confident prediction for asymmetric noise meter.Lall = meter.Lall + self.loss_minent_( logits, meter=meter, name='Lpen') self.acc_precise_(logits.argmax(dim=-1), ys, meter=meter, name='acc') meter.Lall.backward() optim.step()
def train_batch(self, eidx, idx, global_step, batch_data, params: MetaSSLParams, device: torch.device): meter = Meter() sup, unsup = batch_data xs, ys = sup ids, un_xs, un_axs, un_ys = unsup mid_logits = self.to_mid(un_xs) self.meta_optimizer(un_xs, xs, ys, meter=meter) left, right = tricks.cartesian_product(mid_logits, self.cls_center) dist_ = F.pairwise_distance(left, right).reshape(mid_logits.shape[0], -1) dist_targets = torch.softmax(dist_, dim=-1) logits = self.to_logits(un_axs) meter.Lall = meter.Lall + self.loss_ce_with_targets_(logits, dist_targets, meter=meter, name='Ldce') # Lw* self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), un_ys, meter, name='un_acc') self.acc_precise_(logits.argmax(dim=1), un_ys, meter, name='un_acc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: IEGParams, device: torch.device): meter = Meter() train_data, ( vxs, vys ) = batch_data # type:List[torch.Tensor],(torch.Tensor,torch.Tensor) ys, nys = train_data[-2:] xs = train_data[1] axs = torch.cat(train_data[2:2 + params.K]) logits = self.to_logits(xs) aug_logits = self.to_logits(axs) n_targets = tricks.onehot(nys, params.n_classes) guess_targets = self.unsupervised_loss(xs, axs, vxs, vys, logits, aug_logits, meter=meter) weight, eps_k = self.meta_optimizer(xs, guess_targets, n_targets, vxs, vys, meter=meter) meter.tw = eps_k[ys == nys].mean() meter.fw = eps_k[ys != nys].mean() mixed_targets = eps_k * n_targets + (1 - eps_k) * guess_targets # 表面上是 和 targets 做交叉熵,实际上 targets 都是 onehot 形式的 # loss with initial eps init_eps = torch.ones([guess_targets.shape[0]], dtype=torch.float, device=self.device) * params.grad_eps_init init_mixed_labels = tricks.elementwise_mul( init_eps, n_targets) + tricks.elementwise_mul( 1 - init_eps, guess_targets) # loss with initial weight meter.Lws = self.loss_ce_with_targets_(logits, mixed_targets) # Lw* meter.Llamda = -torch.mean( torch.sum(F.log_softmax(logits, dim=1) * init_mixed_labels, dim=1) * weight) # Lλ* meter.Lall = meter.Lall + (meter.Lws + meter.Llamda) / 2 self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc') self.acc_precise_(logits.argmax(dim=1), nys, meter, name='noisy_acc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: MentorNetParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor # Basic parameter for mentornet _epoch_step = params.epoch_step(eidx) _zero_labels = torch.zeros_like(nys) _loss_p_percentile = torch.ones( 100, dtype=torch.float) * params.loss_p_percentile _dropout_rates = self.parse_dropout_rate_list_() logits = self.to_logits(xs) _basic_losses = F.cross_entropy(logits, ys, reduction='none').detach() weight = torch.rand_like(nys).detach() # TODO replace to mentornet meter.Lall = torch.mean(_basic_losses * weight) self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc') self.acc_precise_(logits.argmax(dim=1), nys, meter, name='noisy_acc') return meter
def ssl(self, eidx, batch_data, meter: Meter): """""" sup, unsup = batch_data _ids, xs, _ys, nys = sup _uids, uxs1, uxs2, _uys, unys, unprob = unsup targets = tricks.onehot(nys, params.n_classes) logits = self.to_logits(xs).detach() un_logits = self.to_logits(torch.cat([uxs1, uxs2])) un_targets = self.label_guesses_(*un_logits.chunk(2)) un_targets = self.sharpen_(un_targets, params.T) mixed_input, mixed_target = self.mixmatch_up_(xs, [uxs1, uxs2], targets, un_targets) sup_mixed_target, unsup_mixed_target = mixed_target.detach().split_with_sizes( [xs.shape[0], mixed_input.shape[0] - xs.shape[0]]) sup_mixed_logits, unsup_mixed_logits = self.to_logits(mixed_input).split_with_sizes( [xs.shape[0], mixed_input.shape[0] - xs.shape[0]]) meter.Lall = meter.Lall + self.loss_ce_with_targets_(sup_mixed_logits, sup_mixed_target, meter=meter, name='Lx') meter.Lall = meter.Lall + self.loss_ce_with_targets_(unsup_mixed_logits, unsup_mixed_target, meter=meter, name='Lu') * params.w_sche(eidx) self.optim.zero_grad() meter.Lall.backward() self.optim.step() preds = torch.softmax(torch.cat([logits, un_logits.chunk(2)[0]]), dim=1).detach() label_pred = preds.gather(1, torch.cat([nys, unys]).unsqueeze(dim=1)).squeeze() ids = torch.cat([_ids, _uids]) if params.local_filter: weight = label_pred - self.target_mem[ids] weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes weight_mask = weight < 0 fweight = torch.ones(ids.shape[0], dtype=torch.float, device=self.device) if eidx >= params.burnin: if params.local_filter: fweight[weight_mask] -= params.gmm_w_sche(eidx) fweight -= self.noisy_cls[ids] fweight = torch.relu(fweight) self.filter_mem[ids] = fweight return meter
def unsupervised_loss(self, xs: torch.Tensor, axs: torch.Tensor, vxs: torch.Tensor, vys: torch.Tensor, logits: torch.Tensor, aug_logits: torch.Tensor, meter: Meter): '''create Lub, Lpb, Lkl''' re_ids = torch.randperm(vxs.shape[0]) re_vxs = vxs[re_ids] re_vys = vys[re_ids] p_target = self.label_guesses_(self.logit_norm_(logits), self.logit_norm_(aug_logits)) p_target = self.sharpen_(p_target, params.T) re_v_targets = tricks.onehot(re_vys, params.n_classes) mixed_input, mixed_target = self.mixmatch_up_(re_vxs, [xs, axs], re_v_targets, p_target, beta=params.mix_beta) mixed_logits = self.to_logits(mixed_input) mixed_logits_lis = mixed_logits.split_with_sizes( [re_vxs.shape[0], xs.shape[0], axs.shape[0]]) (mixed_v_logits, mixed_n_logits, mixed_an_logits) = [self.logit_norm_(l) for l in mixed_logits_lis] # type:torch.Tensor mixed_nn_logits = torch.cat([mixed_n_logits, mixed_an_logits], dim=0) mixed_v_targets, mixed_nn_targets = mixed_target.split_with_sizes( [mixed_v_logits.shape[0], mixed_nn_logits.shape[0]]) # Lpβ,验证集作为半监督中的有标签数据集 meter.Lall = meter.Lall + self.loss_ce_with_targets_( mixed_v_logits, mixed_v_targets, meter=meter, name='Lpb') # p * Luβ,训练集作为半监督中的无标签数据集 meter.Lall = meter.Lall + self.loss_ce_with_targets_( mixed_nn_logits, mixed_nn_targets, meter=meter, name='Lub') # Lkl,对多次增广的一致性损失 meter.Lall = meter.Lall + self.loss_kl_ieg_( logits, aug_logits, n_classes=params.n_classes, consistency_factor=params.consistency_factor, meter=meter) return p_target
def warmup_model(self, batch_data, model, optim, meter: Meter): (ids, xs, nys) = batch_data # type:torch.Tensor optim.zero_grad() logits = model(xs) # type:torch.Tensor meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter, name='Lce') self.acc_precise_(logits.argmax(dim=-1), nys, meter=meter, name='acc') meter.Lall.backward() optim.step()
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor logits = self.to_logits(xs) meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter) self.acc_precise_(logits.argmax(dim=1), nys, meter, name='acc') self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc') # _err = (ys != nys) # ids, xs, axs, ys, nys = ids[_err], xs[_err], axs[_err], ys[_err], nys[_err] for i in range(params.n_classes): _cls_mask = (ys == 0) _fcls_mask = (nys == 1) _tloss = F.cross_entropy(logits[_cls_mask], ys[_cls_mask], reduction='none') _floss = F.cross_entropy(logits[_fcls_mask], nys[_fcls_mask], reduction='none') tcls_grads = [ i for i in autograd.grad(_tloss, self.model.parameters(), grad_outputs=torch.ones_like(_tloss), retain_graph=True, allow_unused=True) if i is not None ] fcls_grads = [ i for i in autograd.grad(_floss, self.model.parameters(), grad_outputs=torch.ones_like(_floss), retain_graph=True, allow_unused=True) if i is not None ] res = 0 for tgrad, fgrad in zip(tcls_grads, fcls_grads): # res += (tgrad - fgrad).abs().pow(2).sum() # res = max(tricks.cos_similarity(tgrad, fgrad).mean(), res) res += tricks.cos_similarity(tgrad, fgrad).mean() meter.res = res break self.optim.zero_grad() meter.Lall.backward() self.optim.step() return meter
def unsupervised_loss(self, xs: torch.Tensor, axs: torch.Tensor, vxs: torch.Tensor, vys: torch.Tensor, logits_lis: List[torch.Tensor], meter: Meter): '''create Lub, Lpb, Lkl''' logits_lis = [self.logit_norm_(logits) for logits in logits_lis] p_target = self.label_guesses_(*logits_lis) p_target = self.sharpen_(p_target, params.T) re_v_targets = tricks.onehot(vys, params.n_classes) mixed_input, mixed_target = self.mixmatch_up_(vxs, [axs], re_v_targets, p_target, beta=params.mix_beta) mixed_logits = self.to_logits(mixed_input) mixed_logits_lis = mixed_logits.split_with_sizes( [vxs.shape[0], axs.shape[0]]) (mixed_v_logits, mixed_nn_logits) = [self.logit_norm_(l) for l in mixed_logits_lis] # type:torch.Tensor # mixed_nn_logits = torch.cat([mixed_n_logits, mixed_an_logits], dim=0) mixed_v_targets, mixed_nn_targets = mixed_target.split_with_sizes( [mixed_v_logits.shape[0], mixed_nn_logits.shape[0]]) # Lpβ,验证集作为半监督中的有标签数据集 meter.Lall = meter.Lall + self.loss_ce_with_targets_( mixed_v_logits, mixed_v_targets, meter=meter, name='Lpb') * params.semi_sche(params.eidx) # p * Luβ,训练集作为半监督中的无标签数据集 if params.lub: meter.Lall = meter.Lall + self.loss_ce_with_targets_( mixed_nn_logits, mixed_nn_targets, meter=meter, name='Lub') * params.semi_sche(params.eidx) # Lkl,对多次增广的一致性损失 return p_target
def train_batch(self, eidx, idx, global_step, batch_data, params: GmaParams, device: torch.device): meter = Meter() ids, xs, axs, nys = batch_data # type:torch.Tensor logits = self.to_logits(axs) w_logits = self.to_logits(xs) meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter) self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(w_logits.argmax(dim=1), nys, meter, name='acc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: MetaSSLParams, device: torch.device): meter = Meter() sup, unsup = batch_data xs, ys = sup ids, un_xs, un_axs, un_ys = unsup if eidx < 2: logits = self.to_logits(xs) meter.Lall = meter.Lall + self.loss_ce_( logits, ys, meter=meter, name='Lce') self.acc_precise_(logits.argmax(dim=1), ys, meter, name='acc') else: un_logits = self.to_logits(un_xs) weight = self.meta_optimizer(un_xs, xs, ys, meter=meter) # meter.Lall = meter.Lall + self.loss_ce_(self.to_logits(xs), ys, meter=meter, name='Lce') meter.Lall = meter.Lall + self.loss_ce_with_masked_( un_logits, un_logits.argmax(dim=1), weight, meter=meter, name='Ldce') # Lw* self.acc_precise_(un_logits.argmax(dim=1), un_ys, meter, name='un_acc') self.optim.zero_grad() meter.Lall.backward() self.optim.step() # self.acc_precise_(dist_targets.argmax(dim=1), un_ys, meter=meter, name='dist_acc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: ICTParams, device: torch.device): super().train_batch(eidx, idx, global_step, batch_data, params, device) meter = Meter() sup, unsup = batch_data xs, ys = sup _, un_xs, _, un_ys = unsup logits_list = self.to_logits(torch.cat([xs, un_xs])).split_with_sizes( [xs.shape[0], un_xs.shape[0]]) logits, un_logits = logits_list # type:torch.Tensor mixed_xs, ys_a, ys_b, lam_sup = self.ict_mixup_(xs, ys) logits = self.to_logits(xs) mixed_logits = self.to_logits(mixed_xs) ema_un_logits = self.predict(un_xs) mixed_un_xs, ema_mixed_un_logits, _ = self.mixup_unsup_(un_xs, ema_un_logits, mix=True) mixed_un_logits = self.to_logits(mixed_un_xs) meter.Lall = meter.Lall + self.loss_mixup_sup_ce_(mixed_logits, ys_a, ys_b, lam_sup, meter=meter) meter.Lall = meter.Lall + self.loss_mixup_unsup_mse_(mixed_un_logits, ema_un_logits, decay=params.mixup_consistency_sche(eidx), meter=meter) self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter) self.acc_precise_(un_logits.argmax(dim=1), un_ys, meter, name='unacc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor if eidx == 1: self.smooth_cls_mem[ids] = self.smooth_cls_mem[ids].scatter( 1, nys.unsqueeze(1), 0.8) + 0.1 # self.smooth_cls_mem[ids] += 0.1 logits = self.to_logits(xs) ns_targets = self.smooth_cls_mem[ids] meter.Lall = meter.Lall + self.loss_ce_with_targets_( logits, ns_targets, meter=meter) self.optim.zero_grad() meter.Lall.backward() self.optim.step() with torch.no_grad(): nlogits = self.to_logits(xs) res = (torch.softmax(nlogits, dim=-1) - torch.softmax(logits, dim=-1)) / ns_targets # res = res.scatter(1, res.argsort(dim=-1, descending=True)[:, 2:], 0) ns_targets = torch.clamp(ns_targets + res * 0.1, 0, 1) self.smooth_cls_mem[ids] = ns_targets # ns_max, _ = ns_targets.max(dim=-1, keepdim=True) # ns_min, _ = ns_targets.min(dim=-1, keepdim=True) # ns_targets = (ns_targets - ns_min) # ns_targets = ns_targets / ns_targets.sum(dim=1, keepdims=True) targets = tricks.onehot(ys, params.n_classes) meter.cm = ((targets * res) > 0).any(dim=-1).float().mean() self.smooth_cls_mem[ids] = ns_targets self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc') self.acc_precise_(ns_targets.argmax(dim=1), ys, meter, name='cacc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='nacc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: GlobalParams, device: torch.device): super().train_batch(eidx, idx, global_step, batch_data, params, device) meter = Meter() xs, ys = batch_data logits = self.to_logits(xs) meter.Lall = meter.Lall + self.loss_ce_( logits, ys, meter=meter, name='Lce') self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='acc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: SimCLRParams, device: torch.device): super().train_batch(eidx, idx, global_step, batch_data, params, device) meter = Meter() _, xs, axs, _ = batch_data # type:torch.Tensor b, c, h, w = xs.size() input_ = torch.cat([xs.unsqueeze(1), axs.unsqueeze(1)], dim=1) input_ = input_.view(-1, c, h, w) output = self.to_logits(input_).view(b, 2, -1) meter.Lall = meter.Lall + self.loss_sim_(output, params.temperature, device=device, meter=meter) self.optim.zero_grad() meter.Lall.backward() self.optim.step() return meter
def train_model(self, batch_data, model, optim, pys, probs, meter: Meter): (ids, xs, axs, ys, _) = batch_data # type:torch.Tensor nys = pys[ids] optim.zero_grad() logits = model(xs) # type:torch.Tensor meter.Lall = meter.Lall + self.loss_ce_(logits, nys, meter=meter, name='Lce') meter.Lall.backward() optim.step() with torch.no_grad(): # EMA logits = model(xs) probs[ids] = probs[ids] * 0.3 + torch.softmax(logits, dim=-1) * 0.7 n_mask = (nys != ys) self.acc_precise_(logits.argmax(dim=-1), ys, meter=meter, name='acc') self.acc_precise_(logits[n_mask].argmax(dim=-1), nys[n_mask], meter=meter, name='nacc') self.acc_precise_(probs[ids].argmax(dim=-1), ys, meter=meter, name='pacc')
def warmup_model(self, batch_data, model, head, optim, probs, meter: Meter): """最初的测试,观察模型能够从 ys 和 nys 中学到正确的部份""" (ids, xs, axs, ys, nys) = batch_data # type:torch.Tensor optim.zero_grad() features = model(xs) # type:torch.Tensor n_mask = (ys != nys) outs = head(features)[:3] meter.Lall = meter.Lall + self.loss_ce_(outs[0], nys, meter=meter, name='Lce') meter.Lall.backward() optim.step() with torch.no_grad(): # EMA probs[ids] = probs[ids] * 0.3 + torch.softmax(outs[0], dim=-1) * 0.7 # # n_mask = (nys != ys) self.acc_precise_(outs[0].argmax(dim=-1), ys, meter=meter, name='acc') self.acc_precise_(outs[0][n_mask].argmax(dim=-1), nys[n_mask], meter=meter, name='nacc')
def train_batch(self, eidx, idx, global_step, batch_data, params: SupervisedParams, device: torch.device): super().train_batch(eidx, idx, global_step, batch_data, params, device) meter = Meter() ids, xs, axs, ys = batch_data # type:torch.Tensor targets = tricks.onehot(ys, params.n_classes) mixed_xs, mixed_targets = self.mixup_(xs, targets) mixed_logits = self.to_logits(mixed_xs) meter.Lall = meter.Lall + self.loss_ce_with_targets_( mixed_logits, mixed_targets, meter=meter) with torch.no_grad(): self.acc_precise_(self.to_logits(xs).argmax(dim=1), ys, meter=meter, name='acc') self.optim.zero_grad() meter.Lall.backward() self.optim.step() return meter
def train_first(self, eidx, model, optim, batch_data, meter: Meter): ids, xs, axs, ys, nys = batch_data # type:torch.Tensor logits = model(axs) w_logits = model(xs) preds = torch.softmax(logits, dim=1).detach() label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze() # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0) if params.local_filter: weight = label_pred - self.target_mem[ids] weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes weight_mask = weight < 0 meter.tl = weight_mask[ys == nys].float().mean() meter.fl = weight_mask[ys != nys].float().mean() fweight = torch.ones(w_logits.shape[0], dtype=torch.float, device=self.device) if eidx >= params.mix_burnin: if params.local_filter: fweight[weight_mask] -= params.gmm_w_sche(eidx) fweight -= self.noisy_cls[ids] fweight = torch.relu(fweight) self.filter_mem[ids] = fweight raw_targets = torch.softmax(w_logits, dim=1) with torch.no_grad(): targets = self.plabel_mem[ ids] * params.targets_ema + raw_targets * (1 - params.targets_ema) self.plabel_mem[ids] = targets top_values, top_indices = targets.topk(2, dim=-1) p_labels = top_indices[:, 0] values = top_values[:, 0] # ratio = params.smooth_ratio_sche(eidx) mask = values > params.pred_thresh if mask.any(): self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc') n_targets = tricks.onehot(nys, params.n_classes) p_targets = tricks.onehot(p_labels, params.n_classes) mask = mask.float() meter.pm = mask.mean() meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_( logits, n_targets, fweight, meter=meter) meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_( logits, p_targets, (1 - fweight) * values * mask, meter=meter, name='Lpce') * params.plabel_sche(eidx) meter.tw = fweight[ys == nys].mean() meter.fw = fweight[ys != nys].mean() if params.local_filter: with torch.no_grad(): ids_mask = weight_mask.logical_not() alpha = params.filter_ema self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * \ (1 - alpha) # 将没有参与的逐渐回归到 # self.target_mem[ids[weight_mask]] = self.target_mem[ids[weight_mask]] * alpha + (1 / params.n_classes) * \ # (1 - alpha) if 'Lall' in meter: optim.zero_grad() meter.Lall.backward() optim.step() self.acc_precise_(w_logits.argmax(dim=1), ys, meter, name='tacc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='nacc') false_pred = targets.gather( 1, nys.unsqueeze(dim=1)).squeeze() # [ys != nys] true_pred = targets.gather( 1, ys.unsqueeze(dim=1)).squeeze() # [ys != nys] with torch.no_grad(): self.true_pred_mem[ids, eidx - 1] = true_pred self.false_pred_mem[ids, eidx - 1] = false_pred self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits, nys, reduction='none') # mem_mask = ids < self.pred_mem_size - 1 # self.pred_mem[ids[mem_mask], :, eidx - 1] = targets[ids[mem_mask]] if eidx == 1: self.cls_mem[ids, 0] = ys elif eidx == 2: self.cls_mem[ids, 1] = nys else: self.cls_mem[ids, eidx - 1] = p_labels
def train_batch(self, eidx, idx, global_step, batch_data, params: IEGParams, device: torch.device): meter = Meter() train_data, (vxs, vys) = batch_data # type:List[torch.Tensor],(torch.Tensor,torch.Tensor) ids = train_data[0] axs = train_data[1] xs = torch.cat(train_data[2:2 + params.K]) ys, nys = train_data[-2:] # type:torch.Tensor w_logits = self.to_logits(xs) aug_logits = self.to_logits(axs) logits = w_logits.chunk(params.K)[0] # logits = aug_logits # .chunk(params.K)[0] w_targets = torch.softmax(w_logits.chunk(params.K)[0], dim=1).detach() guess_targets = self.unsupervised_loss(xs, axs, vxs, vys, logits_lis=[*w_logits.detach().chunk(params.K)], meter=meter) # guess_targets = self.sharpen_(torch.softmax(logits, dim=1)) label_pred = guess_targets.gather(1, nys.unsqueeze(dim=1)).squeeze() # label_pred = .gather(1, nys.unsqueeze(dim=1)).squeeze() weight = label_pred - self.target_mem[ids] weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes fweight = torch.ones_like(weight) if eidx >= params.burnin: fweight -= self.noisy_cls[ids] fweight[weight <= 0] -= params.gmm_w_sche(eidx) fweight = torch.relu(fweight) raw_targets = w_targets # guess_targets # torch.softmax(w_logits, dim=1) with torch.no_grad(): targets = self.plabel_mem[ids] * params.targets_ema + raw_targets * (1 - params.targets_ema) self.plabel_mem[ids] = targets values, p_labels = targets.max(dim=1) mask = values > params.pred_thresh if mask.any(): self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc') mask = mask.float() meter.pm = mask.mean() meter.Lall = meter.Lall + self.loss_ce_with_masked_(logits, nys, fweight, meter=meter) meter.Lall = meter.Lall + self.loss_ce_with_masked_(logits, p_labels, (1 - fweight) * mask, meter=meter, name='Lpce') * params.plabel_sche(eidx) if params.lkl: meter.Lall = meter.Lall + self.loss_kl_ieg_(logits, aug_logits, n_classes=params.n_classes, consistency_factor=params.consistency_factor, meter=meter) meter.tw = fweight[ys == nys].mean() meter.fw = fweight[ys != nys].mean() with torch.no_grad(): ids_mask = weight.bool() alpha = params.filter_ema if eidx < params.burnin: alpha = 0.99 self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * (1 - alpha) self.weight_mem[ids] = weight < 0 if 'Lall' in meter: self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(w_targets.argmax(dim=1), ys, meter, name='true_acc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(w_targets.argmax(dim=1)[n_mask], nys[n_mask], meter, name='noisy_acc') with torch.no_grad(): false_pred = targets.gather(1, nys.unsqueeze(dim=1)).squeeze() # [ys != nys] true_pred = targets.gather(1, ys.unsqueeze(dim=1)).squeeze() # [ys != nys] self.true_pred_mem[ids] = true_pred self.false_pred_mem[ids, eidx - 1] = false_pred return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor _right_mask = (ys == nys) _error_mask = _right_mask.logical_not() logits = self.to_logits(axs) w_logits = self.to_logits(xs) preds = torch.softmax(logits, dim=1).detach() label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze() # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0) if params.local_filter: weight_mask = self.count_mem[ids] < 0 # weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes # weight_mask = weight < 0 meter.tl = weight_mask[_right_mask].float().mean() meter.fl = weight_mask[_error_mask].float().mean() fweight = torch.ones(w_logits.shape[0], dtype=torch.float, device=device) if eidx >= params.burnin: fweight -= self.noisy_cls[ids] if params.local_filter: fweight[weight_mask] -= params.gmm_w_sche(eidx) fweight = torch.relu(fweight) if params.right_n > 0: right_mask = ids < params.right_n if right_mask.any(): fweight[right_mask] = 1 raw_targets = torch.softmax(w_logits, dim=1) with torch.no_grad(): targets = self.plabel_mem[ids] * params.targets_ema + raw_targets * (1 - params.targets_ema) self.plabel_mem[ids] = targets values, p_labels = targets.max(dim=1) mask = values > params.pred_thresh if mask.any(): self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc') mask = mask.float() # uda_mask = label_pred > params.uda_sche(eidx) # meter.uda = uda_mask.float().mean() meter.pm = mask.mean() meter.Lall = meter.Lall + self.loss_ce_with_masked_(logits, nys, fweight, meter=meter) meter.Lall = meter.Lall + self.loss_ce_with_masked_(logits, p_labels, (1 - fweight) * mask, meter=meter, name='Lpce') * params.plabel_sche(eidx) meter.tw = fweight[_right_mask].mean() meter.fw = fweight[_error_mask].mean() if params.local_filter: with torch.no_grad(): ids_mask = weight_mask.logical_not() alpha = params.filter_ema # if eidx == 1: # self.target_mem[ids] = label_pred # else: if eidx < params.burnin: alpha = 0.99 self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * \ (1 - alpha) # 将没有参与的逐渐回归到 # self.target_mem[ids[weight_mask]] = (self.target_mem[ids[weight_mask]] * 0.9 + # (1 / params.n_classes) * 0.1) # self.count_mem[ids] += ((label_pred - self.target_mem[ids]) > 0).long() * 3 - 2 self.count_mem[ids] += (label_pred - self.target_mem[ids]) if 'Lall' in meter: self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='nacc') false_pred = targets.gather(1, nys.unsqueeze(dim=1)).squeeze() # [ys != nys] true_pred = targets.gather(1, ys.unsqueeze(dim=1)).squeeze() # [ys != nys] with torch.no_grad(): self.true_pred_mem[ids, eidx - 1] = true_pred self.false_pred_mem[ids, eidx - 1] = false_pred self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits, nys, reduction='none') # mem_mask = ids < self.pred_mem_size - 1 # self.pred_mem[ids[mem_mask], :, eidx - 1] = targets[ids[mem_mask]] if eidx == 1: self.cls_mem[ids, 0] = ys elif eidx == 2: self.cls_mem[ids, 1] = nys else: self.cls_mem[ids, eidx - 1] = p_labels return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor if params.mixup: id_dict = defaultdict(list) for i, (y, ny) in enumerate(zip(ys, nys)): y, ny = int(y), int(ny) if y != ny: pass id_dict[int(ny)].append(i) # 保证添加进去的都是噪音标签 else: pass # id_dict[int(ny)].append(i) # 保证添加进去的都是干净标签 for i in range(params.n_classes): id_dict[i] = cycle(id_dict[i]) if params.ideal_mixup: # 理想情况下 re_id = [] for i, (y, ny) in enumerate(zip(ys, nys)): y, ny = int(y), int(ny) try: # 如果原本标签就是干净标签,那么无所谓混合什么标签(因为干净标签占主导) # 否则,混合进去噪音标签 re_id.append(next(id_dict.get(y))) except: re_id.append(np.random.randint(0, len(nys))) elif params.worst_mixup: re_id = [] for i, (y, ny) in enumerate(zip(ys, nys)): y, ny = int(y), int(ny) try: # 混合进去非该类的标签,也即 mixup 最差的情况,如果最差情况下依然要优于非 mixup # 那就验证了模型可以从噪音中学习 # 否则就说明 mixup 的有效性源自 label smoothing 中的正确部份的标签。 neg_cls = np.random.choice(self.neg_dict[y]) re_id.append(next(id_dict.get(neg_cls))) except: re_id.append(np.random.randint(0, len(nys))) else: # re_id = torch.randperm(len(nys)) # 随机 # 安排固定的 50% 可能 re_id = [] rand = np.random.rand(len(nys)) for i, (y, ny, rand) in enumerate(zip(ys, nys, rand)): y, ny = int(y), int(ny) try: if rand < 0.1: # 混合进去非该类的标签,也即 mixup 最差的情况,如果最差情况下依然要优于非 mixup # 那就验证了模型可以从噪音中学习 # 否则就说明 mixup 的有效性源自 label smoothing 中的正确部份的标签。 neg_cls = np.random.choice(self.neg_dict[y]) re_id.append(next(id_dict.get(neg_cls))) else: re_id.append(next(id_dict.get(y))) except: re_id.append(np.random.randint(0, len(nys))) ntargets = tricks.onehot(nys, params.n_classes) mixed_input, mixed_target = self.mixup_(xs, ntargets, reids=re_id) mixed_logits = self.to_logits(mixed_input) meter.Lall = meter.Lall + self.loss_ce_with_targets_( mixed_logits, mixed_target, meter=meter) logits = self.predict(xs) else: logits = self.to_logits(xs) meter.Lall = meter.Lall = self.loss_ce_(logits, nys, meter=meter) self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='true_acc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='noisy_acc') return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: GmaParams, device: torch.device): meter = Meter() ids, xs, axs, nys = batch_data # type:torch.Tensor n_targets = tricks.onehot(nys, params.n_classes) mixed_xs, mixed_ntargets = self.mixup_(xs, n_targets) mixed_logits = self.to_logits(mixed_xs) # logits = self.to_logits(xs) w_logits = self.to_logits(xs).detach() # w_logits = logits.detach() fweight = self.filter_mem[ids] # fweight -= self.noisy_cls[ids] fweight = torch.relu(fweight) raw_targets = torch.softmax(w_logits, dim=1) if eidx == 1: targets = raw_targets else: targets = self.plabel_mem[ids] values, p_labels = targets.max(dim=-1) mask = values > params.pred_thresh if mask.any(): self.acc_precise_(p_labels[mask], nys[mask], meter, name='pacc') p_targets = tricks.onehot(p_labels, params.n_classes) mask = mask.float() meter.pm = mask.mean() meter.nm = self.noisy_cls[ids].mean() meter.fm = fweight.float().mean() p_targets[mask.logical_not()] = n_targets[mask.logical_not()] # mixed_input, mixed_target = self.mixup_(axs, n_targets, beta=0.75, target_b=p_targets) # mixed_logits = self.to_logits(mixed_input) # meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(mixed_logits, mixed_target, # fweight, # meter=meter) fmask = fweight > 0.5 # if eidx > 5: # meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_(logits, p_targets, # (fmask.logical_not().float()), # meter=meter, # name='Lpce') * params.plabel_sche(eidx) if fmask.any(): # meter.Lall = meter.Lall + self.loss_ce_with_targets_(logits[fmask], n_targets[fmask], # fweight[fweight > 0.5].float(), # meter=meter) meter.Lall = meter.Lall + self.loss_ce_with_targets_( mixed_logits[fmask], mixed_ntargets[fmask], # fweight[fweight > 0.5].float(), meter=meter) # if fmask.any(): self.acc_precise_(w_logits.argmax(dim=-1)[fmask], nys[fmask], meter, name='tacc') else: self.acc_precise_(w_logits.argmax(dim=-1)[fmask.logical_not()], nys[fmask.logical_not()], meter, name='facc') # prior = torch.ones(params.n_classes, device=self.device) / params.n_classes # pred_mean = torch.softmax(logits, dim=1).mean(0) # meter.Lpen = torch.sum(prior * torch.log(prior / pred_mean)) # penalty # meter.Lall = meter.Lall + meter.Lpen self.optim.zero_grad() if 'Lall' in meter: meter.Lall.backward() self.optim.step() with torch.no_grad(): preds = torch.softmax(w_logits, dim=1).detach() label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze() weight = label_pred - self.target_mem[ids] weight = weight + label_pred * 0.5 / params.n_classes - 0.25 / params.n_classes weight_mask = weight < 0 meter.wm = weight_mask.float().mean() ids_mask = weight_mask.logical_not() alpha = params.filter_ema_sche(eidx) if eidx == 1: alpha = 0.2 self.target_mem[ids[ids_mask]] = self.target_mem[ids[ids_mask]] * alpha + label_pred[ids_mask] * \ (1 - alpha) self.acc_precise_(w_logits.argmax(dim=1), nys, meter, name='acc') fweight = torch.ones(nys.shape[0], dtype=torch.float, device=device) # fweight[weight_mask] -= params.gmm_w_sche(eidx) self.local_noisy_cls[ids] = fweight raw_targets = torch.softmax(w_logits, dim=1) if eidx == 1: targets = raw_targets else: targets = self.plabel_mem[ ids] * params.targets_ema + raw_targets * ( 1 - params.targets_ema) self.plabel_mem[ids] = targets false_pred = raw_targets.gather( 1, nys.unsqueeze(dim=1)).squeeze() # [nys != nys] self.false_pred_mem[ids, eidx - 1] = false_pred return meter
def train_batch(self, eidx, idx, global_step, batch_data, params: NoisyParams, device: torch.device): meter = Meter() ids, xs, axs, ys, nys = batch_data # type:torch.Tensor logits = self.to_logits(axs) w_logits = self.to_logits(xs) preds = torch.softmax(logits, dim=1).detach() label_pred = preds.gather(1, nys.unsqueeze(dim=1)).squeeze() # weight = torch.softmax(label_pred - self.target_mem[ids], dim=0) fweight = (nys == ys).float() fweight[nys != ys] = params.ideal_nyw_sche(eidx) raw_targets = torch.softmax(w_logits, dim=1) with torch.no_grad(): targets = self.plabel_mem[ ids] * params.targets_ema + raw_targets * (1 - params.targets_ema) self.plabel_mem[ids] = targets top_values, top_indices = targets.topk(2, dim=-1) p_labels = top_indices[:, 0] values = top_values[:, 0] p_targets = (torch.zeros_like(targets).scatter( -1, top_indices[:, 0:1], 0.9).scatter(-1, top_indices[:, 1:2], 0.1)) mask = values > params.pred_thresh if mask.any(): self.acc_precise_(p_labels[mask], ys[mask], meter, name='pacc') mask = mask.float() meter.pm = mask.mean() n_targets = tricks.onehot(nys, params.n_classes) n_targets = n_targets * 0.9 + p_targets * 0.1 meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_( logits, n_targets, fweight, meter=meter) meter.Lall = meter.Lall + self.loss_ce_with_targets_masked_( logits, p_targets, (1 - fweight) * values * mask, meter=meter, name='Lpce') * params.plabel_sche(eidx) meter.tw = fweight[ys == nys].mean() meter.fw = fweight[ys != nys].mean() self.optim.zero_grad() meter.Lall.backward() self.optim.step() self.acc_precise_(logits.argmax(dim=1), ys, meter, name='tacc') n_mask = nys != ys if n_mask.any(): self.acc_precise_(logits.argmax(dim=1)[n_mask], nys[n_mask], meter, name='nacc') false_pred = targets.gather( 1, nys.unsqueeze(dim=1)).squeeze() # [ys != nys] true_pred = targets.gather( 1, ys.unsqueeze(dim=1)).squeeze() # [ys != nys] with torch.no_grad(): self.true_pred_mem[ids, eidx - 1] = true_pred self.false_pred_mem[ids, eidx - 1] = false_pred self.loss_mem[ids, eidx - 1] = F.cross_entropy(w_logits, nys, reduction='none') # mem_mask = ids < self.pred_mem_size - 1 # self.pred_mem[ids[mem_mask], :, eidx - 1] = targets[ids[mem_mask]] if eidx == 1: self.cls_mem[ids, 0] = ys elif eidx == 2: self.cls_mem[ids, 1] = nys else: self.cls_mem[ids, eidx - 1] = p_labels return meter