def val(model, val_loader, criterion, epoch, writer, use_CUDA = True): model.eval() accuracy_logger = ScalarLogger(prefix = 'accuracy') losses_logger = ScalarLogger(prefix = 'loss') with torch.no_grad(): for (input, label, _) in val_loader: input = to_var(input, requires_grad = False) label = to_var(label, requires_grad = False).long() output = model(input) loss = criterion(output, label) prediction = torch.softmax(output, 1) top1 = accuracy(prediction, label) accuracy_logger.update(top1) losses_logger.update(loss) accuracy_logger.write(writer, 'val', epoch) losses_logger.write(writer, 'val', epoch) accuracy_ = accuracy_logger.avg() losses = losses_logger.avg() print("Validation Epoch: {}, Accuracy: {}, Losses: {}".format(epoch, accuracy_, losses)) return accuracy_, losses
def train(model, input_channel, optimizer, criterion, train_loader, val_loader, epoch, writer, args, use_CUDA = True, clamp = False, num_classes = 10): model.train() accs = [] losses_w1 = [] losses_w2 = [] iter_val_loader = iter(val_loader) meta_criterion = nn.CrossEntropyLoss(reduce = False) index = 0 noisy_labels = [] true_labels = [] w = defaultdict() w_logger = defaultdict() losses_logger = defaultdict() accuracy_logger = ScalarLogger(prefix = 'accuracy') for (input, label, real) in train_loader: noisy_labels.append(label) true_labels.append(real) input = to_var(input, requires_grad = False) label = to_var(label, requires_grad = False).long() index += 1 output = model(input) loss = meta_criterion(output, label).sum() / input.shape[0] prediction = torch.softmax(output, 1) optimizer.zero_grad() loss.backward() optimizer.step() top1 = accuracy(prediction, label) accuracy_logger.update(top1) noisy_labels = torch.cat(noisy_labels) true_labels = torch.cat(true_labels) mask = (noisy_labels != true_labels).cpu().numpy() accuracy_logger.write(writer, 'train', epoch) print("Training Epoch: {}, Accuracy: {}".format(epoch, accuracy_logger.avg())) return accuracy_logger.avg()
def train(model, input_channel, optimizers, criterion, components, train_loader, val_loader, epoch, writer, args, use_CUDA=True, clamp=False, num_classes=10): model.train() accs = [] losses_w1 = [] losses_w2 = [] iter_val_loader = iter(val_loader) meta_criterion = nn.CrossEntropyLoss(reduce=False) index = 0 w = defaultdict() w_logger = defaultdict() losses_logger = defaultdict() accuracy_logger = ScalarLogger(prefix='accuracy') for c in components: w[c] = None w_logger[c] = WLogger() losses_logger[c] = ScalarLogger(prefix='loss') w_all = [] store_input = None store_label = None store_real = None for (input, label, real) in train_loader: if store_input is None: store_input = input store_label = label meta_model = get_model(args, num_classes=num_classes, input_channel=input_channel) meta_model.load_state_dict(model.state_dict()) if use_CUDA: meta_model = meta_model.cuda() val_input, val_label, iter_val_loader = get_val_samples( iter_val_loader, val_loader) store_input = to_var(store_input, requires_grad=False) store_label = to_var(store_label, requires_grad=False).long() val_input = to_var(val_input, requires_grad=False) val_label = to_var(val_label, requires_grad=False).long() meta_output = meta_model(store_input) cost = meta_criterion(meta_output, store_label) eps = to_var(torch.zeros(cost.size())) meta_loss = (cost * eps).sum() meta_model.zero_grad() if 'all' in components: grads = torch.autograd.grad(meta_loss, (meta_model.parameters()), create_graph=True) meta_model.update_params(0.001, source_params=grads) meta_val_output = meta_model(val_input) meta_val_loss = meta_criterion(meta_val_output, val_label).sum() grad_eps = torch.autograd.grad(meta_val_loss, eps, only_inputs=True)[0] if clamp: w['all'] = torch.clamp(-grad_eps, min=0) else: w['all'] = -grad_eps norm = torch.sum(abs(w['all'])) assert (clamp and len(components) == 1) or (len(components) > 1), "Error combination" w['all'] = w['all'] / norm if ('fc' in components): w['fc'] = copy.deepcopy(w['all']) w['fc'] = torch.clamp(w['fc'], max=0) w['all'] = torch.clamp(w['all'], min=0) elif ('backbone' in components): w['backbone'] = copy.deepcopy(w['all']) w['backbone'] = torch.clamp(w['backbone'], max=0) w['all'] = torch.clamp(w['all'], min=0) else: assert ('backbone' in components) and ('fc' in components) grads_backbone = torch.autograd.grad( meta_loss, (meta_model.backbone.parameters()), create_graph=True, retain_graph=True) grads_fc = torch.autograd.grad(meta_loss, (meta_model.fc.parameters()), create_graph=True) # Backbone Grads meta_model.backbone.update_params(0.001, source_params=grads_backbone) meta_val_output = meta_model(val_input) meta_val_loss = meta_criterion(meta_val_output, val_label).sum() grad_eps = torch.autograd.grad(meta_val_loss, eps, only_inputs=True, retain_graph=True)[0] if clamp: w['backbone'] = torch.clamp(-grad_eps, min=0) else: w['backbone'] = -grad_eps norm = torch.sum(abs(w['backbone'])) w['backbone'] = w['backbone'] / norm # FC backward meta_model.load_state_dict(model.state_dict()) meta_model.fc.update_params(0.001, source_params=grads_fc) meta_val_output = meta_model(val_input) meta_val_loss = meta_criterion(meta_val_output, val_label).sum() grad_eps = torch.autograd.grad(meta_val_loss, eps, only_inputs=True, retain_graph=True)[0] if clamp: w['fc'] = torch.clamp(-grad_eps, min=0) else: w['fc'] = -grad_eps norm = torch.sum(abs(w['fc'])) w['fc'] = w['fc'] / norm w_all.append(w['all'].detach().cpu().numpy().reshape(128, 1)) w_all = np.concatenate(w_all, axis=1) assert w_all.shape[0] == 128 pickle.dump(w_all, open('w.npy', 'wb')) print(np.std(w_all, axis=1)) print(np.mean(w_all, axis=1)) print(np.std(w_all, axis=1) / np.mean(w_all, axis=1))
def train_with_meta(clean_loader, noisy_loader, test_loader, model, ema_model, criterion, consistency_criterion, optimizer, scheduler, epoch, warmup=False, self_paced_pick=0): global step, switched, single_epoch_steps pacc = AverageMeter() nacc = AverageMeter() pnacc = AverageMeter() cov0 = AverageMeter() cov1 = AverageMeter() wilcoxon = AverageMeter() model.train() ema_model.train() consistency_weight = get_current_consistency_weight(epoch - 30) resultt = np.zeros(61000) if not warmup: scheduler.step() print("Learning rate is {}".format(optimizer.param_groups[0]['lr'])) if (clean_loader): for i, (X, _, Y, T, ids, _) in enumerate(clean_loader): # measure data loading time xeps = torch.ones((X.shape[0], 2)).cuda() * 1e-10 xeps[:, 1] = 0 xeps = xeps.cuda() if args.gpu == None: X = X.cuda() Y = Y.cuda().float() T = T.cuda().long() else: X = X.cuda() Y = Y.cuda().float() T = T.cuda().long() if args.dataset == 'mnist': X = X.view(X.shape[0], -1) # compute output output = model(X) with torch.no_grad(): ema_output = ema_model(X) consistency_loss = consistency_weight * \ consistency_criterion(output, ema_output) / X.shape[0] #if epoch >= args.self_paced_start: criterion.update_p(0.5) _, loss = criterion(output, Y, eps=1) # 计算loss,使用PU标签 #print(output) # measure accuracy and record loss if check_mean_teacher(epoch): predictions = torch.sign(ema_output).long() # 使用teacher的结果作为预测 else: predictions = torch.sign(output).long() # 否则使用自己的结果 smx = torch.sigmoid(output) # 计算sigmoid概率 #print(smx) smx = torch.cat([1 - smx, smx], dim=1) # 组合成预测变量 smxY = ((Y + 1) // 2).long() # 分类结果,0-1分类 if args.soft_label: xent = -torch.sum(smx * torch.log(smx + xeps), dim=1) aux = xent.mean() else: aux = F.cross_entropy(smx + xeps, smxY) # 计算Xent loss loss = aux if check_mean_teacher(epoch): loss += consistency_loss if args.soft_label: detach = xent.detach().cpu().numpy() else: detach = 0 optimizer.zero_grad() #if not np.any(np.isnan(detach)): if True: loss.backward() optimizer.step() if np.any(np.isnan(detach)): for i in model.parameters(): print('clean_data') print(i) print('clean_grad') print(i.grad) if check_mean_teacher(epoch) and ( (i + 1) % int(single_epoch_steps / 2 - 1)) == 0: update_ema_variables(model, ema_model, args.ema_decay, step) # 更新ema参数 step += 1 pacc_, nacc_, pnacc_, psize = accuracy(predictions, T) # 使用T来计算预测准确率 if np.any(torch.isnan(output).cpu().numpy()): print(output) print("clean interrupt") raise NotImplementedError pacc.update(pacc_, psize) nacc.update(nacc_, X.size(0) - psize) pnacc.update(pnacc_, X.size(0)) print('Epoch Clean : [{0}]\t' 'PACC {pacc.val:.3f} ({pacc.avg:.3f})\t' 'NACC {nacc.val:.3f} ({nacc.avg:.3f})\t' 'PNACC {pnacc.val:.3f} ({pnacc.avg:.3f})\t'.format(epoch, pacc=pacc, nacc=nacc, pnacc=pnacc)) if epoch <= args.noisy_stop: meta_step = 0 for i, (X, Y, _, T, ids, p) in enumerate(noisy_loader): xeps = torch.ones((X.shape[0], 2)).cuda() * 1e-10 xeps[:, 1] = 0 meta_step += 1 if args.dataset == 'cifar': meta_net = create_cifar_model() else: meta_net = create_model() meta_net.load_state_dict(model.state_dict()) if torch.cuda.is_available(): meta_net.cuda() if args.gpu == None: X = X.cuda() Y = Y.cuda().float() T = T.cuda().long() p = p.cuda().float() else: X = X.cuda() Y = Y.cuda().float() T = T.cuda().long() p = p.cuda().float() if args.dataset == 'mnist': X = X.view(X.shape[0], -1) y_f_hat = meta_net(X) prob = torch.sigmoid(y_f_hat) prob = torch.cat([1 - prob, prob], dim=1) cost1 = torch.sum(prob * torch.log(prob + xeps), dim=1) eps = to_var(torch.zeros(cost1.shape[0], 2)) cost2 = criterion(y_f_hat, Y, eps=eps[:, 0]) l_f_meta = (cost1 * eps[:, 1]).mean() + cost2[1] meta_net.zero_grad() grads = torch.autograd.grad(l_f_meta, meta_net.parameters(), create_graph=True) meta_net.update_params(0.001, source_params=grads) val_data, val_Y, _, val_labels, val_ids, val_p = next( iter(test_loader)) veps = torch.ones((val_data.shape[0], 2)).cuda() * 1e-10 veps[:, 1] = 0 val_data = to_var(val_data, requires_grad=False) if args.dataset == 'mnist': val_data = val_data.view(-1, 784) val_labels = to_var(val_labels, requires_grad=False).float() y_g_hat = meta_net(val_data) val_prob = torch.sigmoid(y_g_hat) val_prob = torch.cat([1 - val_prob, val_prob], dim=1) val_xent = -torch.sum(val_prob * torch.log(val_prob + veps), dim=1) val_xent[torch.isnan(val_xent)] = 0 l_g_meta = val_xent.mean() grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0] w = torch.clamp(-grad_eps, min=1e-6) del meta_net for j in range(w.shape[0]): if Y[j] == -1: if torch.sum(w[:, 1]) >= 4: w[j, 0] = 1 w[j, 1] = 0 else: w[j, :] = w[j, :] / torch.sum(w[j, :]) else: w[j, 0] = 1 w[j, 1] = 0 #w[:, 0] = 1 #w[:, 1] = 0 w = w.cuda().detach() output = model(X) with torch.no_grad(): ema_output = ema_model(X) consistency_loss = consistency_weight * \ consistency_criterion(output, ema_output) / X.shape[0] #if epoch >= args.self_paced_start: criterion.update_p(0.5) _, loss = criterion(output, Y, eps=w[:, 0]) prob = torch.sigmoid(output) prob = torch.cat([1 - prob, prob], dim=1) xent = -torch.sum(prob * torch.log(prob + xeps), dim=1) detach = xent.detach().cpu().numpy() #xent[torch.isnan(xent)] = 0 if check_mean_teacher(epoch) and not warmup: total_loss = loss + consistency_loss + (xent * w[:, 1]).mean() predictions = torch.sign(ema_output).long() else: predictions = torch.sign(output).long() total_loss = (xent * w[:, 1]).mean() + loss cov_0 = ((prob[:, 0] * w[:, 0]).mean() - prob[:, 0].float().mean() * w[:, 0].mean()) / torch.sqrt( torch.var(prob[:, 0].float()) * torch.var(w[:, 0].float() + 1e-6)) cov_1 = ((prob[:, 0] * w[:, 1]).mean() - prob[:, 0].float().mean() * w[:, 1].mean()) / torch.sqrt( torch.var(prob[:, 0].float()) * torch.var(w[:, 1].float() + 1e-6)) #if epoch >= args.self_paced_start if check_noisy(epoch): optimizer.zero_grad() total_loss.backward() optimizer.step() if np.any(np.isnan(detach)): for i in model.parameters(): print('noisy_data') print(i) print('noisy_grad') print(i.grad) print(w) print(prob) print(xent) if check_mean_teacher(epoch) and ( (i + 1) % int(single_epoch_steps / 2 - 1)) == 0 and not warmup: update_ema_variables(model, ema_model, args.ema_decay, step) step += 1 pacc_, nacc_, pnacc_, psize = accuracy(predictions, T) if np.any(torch.isnan(output).cpu().numpy()): print('noisy_interrupt') print(output) raise NotImplementedError last_loss = loss last_total_loss = total_loss last_predictions = predictions last_output = output last_prob = prob last_xent = xent pacc.update(pacc_, psize) nacc.update(nacc_, X.size(0) - psize) pnacc.update(pnacc_, X.size(0)) cov0.update(cov_0) cov1.update(cov_1) print('Noisy Epoch: [{0}]\t' 'PACC {pacc.val:.3f} ({pacc.avg:.3f})\t' 'NACC {nacc.val:.3f} ({nacc.avg:.3f})\t' 'PNACC {pnacc.val:.3f} ({pnacc.avg:.3f})\t' 'COV0 {cov0.val:.3f} ({cov0.avg:.3f})\t' 'COV1 {cov1.val:.3f} ({cov1.avg:.3f})\t'.format(epoch, pacc=pacc, nacc=nacc, pnacc=pnacc, cov0=cov0, cov1=cov1)) #raise NotImplementedError return pacc.avg, nacc.avg, pnacc.avg
def train(model, vnet, input_channel, optimizers, optimizer_vnet, components, criterion, train_loader, val_loader, epoch, writer, args, use_CUDA=True, clamp=False, num_classes=10): model.train() accs = [] losses_w1 = [] losses_w2 = [] iter_val_loader = iter(val_loader) meta_criterion = nn.CrossEntropyLoss(reduce=False) index = 0 noisy_labels = [] true_labels = [] w = defaultdict() w_logger = defaultdict() losses_logger = defaultdict() accuracy_logger = ScalarLogger(prefix='accuracy') for c in components: w[c] = None w_logger[c] = WLogger() losses_logger[c] = ScalarLogger(prefix='loss') for (input, label, real) in train_loader: noisy_labels.append(label) true_labels.append(real) meta_model = get_model(args, num_classes=num_classes, input_channel=input_channel) meta_model.load_state_dict(model.state_dict()) if use_CUDA: meta_model = meta_model.cuda() val_input, val_label, iter_val_loader = get_val_samples( iter_val_loader, val_loader) input = to_var(input, requires_grad=False) label = to_var(label, requires_grad=False).long() val_input = to_var(val_input, requires_grad=False) val_label = to_var(val_label, requires_grad=False).long() meta_output = meta_model(input) cost = meta_criterion(meta_output, label) #eps = to_var(torch.zeros(cost.size())) cost_v = torch.reshape(cost, (len(cost), 1)) eps = vnet(cost_v.data) # shape: (N, 2) meta_loss_backbone = (cost * eps[:, 0]).sum() meta_loss_fc = (cost * eps[:, 1]).sum() meta_model.zero_grad() grads_backbone = torch.autograd.grad( meta_loss_backbone, (meta_model.backbone.parameters()), create_graph=True, retain_graph=True) grads_fc = torch.autograd.grad(meta_loss_fc, (meta_model.fc.parameters()), create_graph=True) # Backbone Grads meta_model.backbone.update_params(0.001, source_params=grads_backbone) meta_val_feature = torch.flatten(meta_model.backbone(val_input), 1) meta_val_output = meta_model.fc(meta_val_feature) meta_val_loss = meta_criterion(meta_val_output, val_label).sum() ''' TODO: temorarily remove if args.with_kl and args.reg_start <= epoch: train_feature = torch.flatten(meta_model.backbone(input), 1) meta_val_loss -= sample_wise_kl(train_feature, meta_val_feature) grad_eps = torch.autograd.grad(meta_val_loss, eps, only_inputs = True, retain_graph = True)[0] if clamp: w['backbone'] = torch.clamp(-grad_eps, min = 0) else: w['backbone'] = -grad_eps norm = torch.sum(abs(w['backbone'])) w['backbone'] = w['backbone'] / norm ''' optimizer_vnet.zero_grad() meta_val_loss.backward(retain_graph=True) optimizer_vnet.step() # FC backward meta_model.load_state_dict(model.state_dict()) meta_model.fc.update_params(0.001, source_params=grads_fc) meta_val_output = meta_model(val_input) meta_val_loss = meta_criterion(meta_val_output, val_label).sum() ''' grad_eps = torch.autograd.grad(meta_val_loss, eps, only_inputs = True, retain_graph = True)[0] if clamp: w['fc'] = torch.clamp(-grad_eps, min = 0) else: w['fc'] = -grad_eps norm = torch.sum(abs(w['fc'])) w['fc'] = w['fc'] / norm ''' optimizer_vnet.zero_grad() meta_val_loss.backward(retain_graph=True) optimizer_vnet.step() index += 1 output = model(input) losses = defaultdict() loss = meta_criterion(output, label) loss_v = torch.reshape(loss, (len(loss), 1)) with torch.no_grad(): w_ = vnet(loss_v) if clamp: w_ = torch.clamp(w_, min=0) for i in range(w_.shape[1]): w_[:, i] = torch.sum(torch.abs(w_[:, i])) w['backbone'] = w_[:, 0] w['fc'] = w_[:, 1] prediction = torch.softmax(output, 1) for c in components: w_logger[c].update(w[c]) losses[c] = (loss * w[c]).sum() optimizers[c].zero_grad() losses[c].backward(retain_graph=True) optimizers[c].step() losses_logger[c].update(losses[c]) top1 = accuracy(prediction, label) accuracy_logger.update(top1) noisy_labels = torch.cat(noisy_labels) true_labels = torch.cat(true_labels) mask = (noisy_labels != true_labels).cpu().numpy() for c in components: w_logger[c].write(writer, c, epoch) w_logger[c].mask_write(writer, c, epoch, mask) losses_logger[c].write(writer, c, epoch) accuracy_logger.write(writer, 'train', epoch) print("Training Epoch: {}, Accuracy: {}".format(epoch, accuracy_logger.avg())) return accuracy_logger.avg()
def train_with_meta(clean1_loader, noisy1_loader, clean2_loader, noisy2_loader, test_loader, model1, model2, ema_model1, ema_model2, criterion, consistency_criterion, optimizer1, scheduler1, optimizer2, scheduler2, epoch, warmup = False, self_paced_pick = 0): global step, switched batch_time = AverageMeter() data_time = AverageMeter() #losses = AverageMeter() pacc1 = AverageMeter() nacc1 = AverageMeter() pnacc1 = AverageMeter() pacc2 = AverageMeter() nacc2 = AverageMeter() pnacc2 = AverageMeter() pacc3 = AverageMeter() nacc3 = AverageMeter() pnacc3 = AverageMeter() pacc4 = AverageMeter() nacc4 = AverageMeter() pnacc4 = AverageMeter() model1.train() model2.train() ema_model1.train() ema_model2.train() end = time.time() consistency_weight = get_current_consistency_weight(epoch - 30) if not warmup: scheduler1.step() scheduler2.step() resultt = np.zeros(61000) if clean1_loader: for i, (X, left, right, _, Y, T, ids) in enumerate(clean1_loader): # measure data loading time data_time.update(time.time() - end) X = X.cuda(args.gpu) left = left.cuda(args.gpu) right = right.cuda(args.gpu) Y = Y.cuda(args.gpu).float() T = T.cuda(args.gpu).long() # compute output output1 = model1(X, left, right) output2 = model2(X, left, right) with torch.no_grad(): ema_output1 = ema_model1(X, left, right) consistency_loss = consistency_weight * \ consistency_criterion(output1, ema_output1) / X.shape[0] predictiont1 = torch.sign(ema_output1).long() predictions1 = torch.sign(output1).long() # 否则使用自己的结果 smx1 = torch.sigmoid(output1) # 计算sigmoid概率 smx1 = torch.cat([1 - smx1, smx1], dim=1) # 组合成预测变量 smxY = ((Y + 1) // 2).long() # 分类结果,0-1分类 smx2 = torch.sigmoid(output2) # 计算sigmoid概率 smx2 = torch.cat([1 - smx2, smx2], dim=1) # 组合成预测变量 smx1[smx1 > 1-1e-7] = 1-1e-7 smx1[smx1 < 1e-7] = 1e-7 smx2[smx2 > 1-1e-7] = 1-1e-7 smx2[smx2 < 1e-7] = 1e-7 if args.soft_label: aux1 = - torch.sum(smx1 * torch.log(smx1)) / smx1.shape[0] aux2 = - torch.sum(smx2 * torch.log(smx2)) / smx2.shape[0] else: smxY = smxY.float() smxY = smxY.view(-1, 1) smxY = torch.cat([1 - smxY, smxY], dim = 1) aux1 = - torch.sum(smxY * torch.log(smx1)) / smxY.shape[0] # 计算Xent loss aux2 = - torch.sum(smxY * torch.log(smx2)) / smxY.shape[0] # 计算Xent loss loss = aux1 if check_mean_teacher(epoch): loss += consistency_loss optimizer1.zero_grad() loss.backward() optimizer1.step() pacc_1, nacc_1, pnacc_1, psize = accuracy(predictions1, T) # 使用T来计算预测准确率 pacc_3, nacc_3, pnacc_3, psize = accuracy(predictiont1, T) pacc1.update(pacc_1, psize) nacc1.update(nacc_1, X.size(0) - psize) pnacc1.update(pnacc_1, X.size(0)) pacc3.update(pacc_3, psize) nacc3.update(nacc_3, X.size(0) - psize) pnacc3.update(pnacc_3, X.size(0)) if clean2_loader: for i, (X, left, right, _, Y, T, ids) in enumerate(clean2_loader): # measure data loading time data_time.update(time.time() - end) X = X.cuda(args.gpu) left = left.cuda(args.gpu) right = right.cuda(args.gpu) Y = Y.cuda(args.gpu).float() T = T.cuda(args.gpu).long() # compute output output1 = model1(X, left, right) output2 = model2(X, left, right) with torch.no_grad(): ema_output2 = ema_model2(X, left, right) consistency_loss = consistency_weight * \ consistency_criterion(output2, ema_output2) / X.shape[0] predictiont2 = torch.sign(ema_output2).long() predictions2 = torch.sign(output2).long() smx1 = torch.sigmoid(output1) # 计算sigmoid概率 smx1 = torch.cat([1 - smx1, smx1], dim=1) # 组合成预测变量 smxY = ((Y + 1) // 2).long() # 分类结果,0-1分类 smx2 = torch.sigmoid(output2) # 计算sigmoid概率 smx2 = torch.cat([1 - smx2, smx2], dim=1) # 组合成预测变量 smx1[smx1 > 1-1e-7] = 1-1e-7 smx1[smx1 < 1e-7] = 1e-7 smx2[smx2 > 1-1e-7] = 1-1e-7 smx2[smx2 < 1e-7] = 1e-7 if args.soft_label: aux1 = - torch.sum(smx1 * torch.log(smx1)) / smx1.shape[0] aux2 = - torch.sum(smx2 * torch.log(smx2)) / smx2.shape[0] else: smxY = smxY.float() smxY = smxY.view(-1, 1) smxY = torch.cat([1 - smxY, smxY], dim = 1) aux1 = - torch.sum(smxY * torch.log(smx1)) / smxY.shape[0] # 计算Xent loss aux2 = - torch.sum(smxY * torch.log(smx2)) / smxY.shape[0] # 计算Xent loss loss = aux2 if check_mean_teacher(epoch): loss += consistency_loss optimizer2.zero_grad() loss.backward() optimizer2.step() pacc_2, nacc_2, pnacc_2, psize = accuracy(predictions2, T) pacc_4, nacc_4, pnacc_4, psize = accuracy(predictiont2, T) pacc2.update(pacc_2, psize) nacc2.update(nacc_2, X.size(0) - psize) pnacc2.update(pnacc_2, X.size(0)) pacc4.update(pacc_4, psize) nacc4.update(nacc_4, X.size(0) - psize) pnacc4.update(pnacc_4, X.size(0)) if check_mean_teacher(epoch): update_ema_variables(model1, ema_model1, args.ema_decay, step) # 更新ema参数 update_ema_variables(model2, ema_model2, args.ema_decay, step) step += 1 print('Epoch Clean : [{0}]\t' 'PACC1 {pacc1.val:.3f} ({pacc1.avg:.3f})\t' 'NACC1 {nacc1.val:.3f} ({nacc1.avg:.3f})\t' 'PNACC1 {pnacc1.val:.3f} ({pnacc1.avg:.3f})\t' 'PACC2 {pacc2.val:.3f} ({pacc2.avg:.3f})\t' 'NACC2 {nacc2.val:.3f} ({nacc2.avg:.3f})\t' 'PNACC2 {pnacc2.val:.3f} ({pnacc2.avg:.3f})\t' 'PACC3 {pacc3.val:.3f} ({pacc3.avg:.3f})\t' 'NACC3 {nacc3.val:.3f} ({nacc3.avg:.3f})\t' 'PNACC3 {pnacc3.val:.3f} ({pnacc3.avg:.3f})\t' 'PACC4 {pacc4.val:.3f} ({pacc4.avg:.3f})\t' 'NACC4 {nacc4.val:.3f} ({nacc4.avg:.3f})\t' 'PNACC4 {pnacc4.val:.3f} ({pnacc4.avg:.3f})\t'.format( epoch, pacc1=pacc1, nacc1=nacc1, pnacc1=pnacc1, pacc2=pacc2, nacc2=nacc2, pnacc2=pnacc2, pacc3=pacc3, nacc3=nacc3, pnacc3=pnacc3, pacc4=pacc4, nacc4=nacc4, pnacc4=pnacc4)) for i, (X, left, right, Y, _, T, ids) in enumerate(noisy1_loader): X = X.cuda(args.gpu) left = left.cuda(args.gpu) right = right.cuda(args.gpu) Y = Y.cuda(args.gpu).float() T = T.cuda(args.gpu).long() meta_net = create_lenet_model() meta_net.load_state_dict(model1.state_dict()) if torch.cuda.is_available(): meta_net.cuda() y_f_hat = meta_net(X, left, right) prob = torch.sigmoid(y_f_hat) prob = torch.cat([1-prob, prob], dim=1) cost1 = torch.sum(prob * torch.log(prob + 1e-10), dim = 1) eps = to_var(torch.zeros(cost1.shape[0], 2)) cost2 = criterion(y_f_hat, Y, eps = eps[:, 0]) l_f_meta = (cost1 * eps[:, 1]).mean() + cost2[1] meta_net.zero_grad() grads = torch.autograd.grad(l_f_meta, meta_net.parameters(), create_graph = True, allow_unused=True) meta_net.update_params(0.001, source_params = grads) val_data, val_left, val_right, val_Y, _, val_labels, val_ids = next(iter(test_loader)) val_data = to_var(val_data, requires_grad = False) val_left = to_var(val_left, requires_grad = False) val_right = to_var(val_right, requires_grad = False) val_labels = to_var(val_labels, requires_grad=False).float() y_g_hat = meta_net(val_data, val_left, val_right) val_prob = torch.sigmoid(y_g_hat) val_prob = torch.cat([1 - val_prob, val_prob], dim=1) l_g_meta = -torch.mean(torch.sum(val_prob * torch.log(val_prob + 1e-10), dim = 1)) * 2 grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0] #print(grad_eps) w = torch.clamp(-grad_eps, min = 1e-10) acount = 0 bcount = 0 ccount = 0 dcount = 0 for j in range(w.shape[0]): if Y[j] == -1: if torch.sum(w[:, 1]) >= 4: w[j, 0] = 1 w[j, 1] = 0 else: w[j, :] = w[j, :] / torch.sum(w[j, :]) else: w[j, 0] = 1 w[j, 1] = 0 #w[:, 0] = 1 #w[:, 1] = 0 w = w.cuda().detach() # compute output # compute output output1 = model1(X, left, right) output2 = model2(X, left, right) with torch.no_grad(): ema_output1 = ema_model1(X, left, right) #if epoch >= args.self_paced_start: criterion.update_p(0.5) _, loss = criterion(output1, Y, eps = w[:, 0]) consistency_loss = consistency_weight * \ consistency_criterion(output1, ema_output1) / X.shape[0] #print(loss1) predictions1 = torch.sign(output1).long() predictiont1 = torch.sign(ema_output1).long() smx1 = torch.sigmoid(output1) # 计算sigmoid概率 smx1 = torch.cat([1 - smx1, smx1], dim=1) # 组合成预测变量 smxY = ((Y + 1) // 2).long() # 分类结果,0-1分类 smx2 = torch.sigmoid(output2) # 计算sigmoid概率 smx2 = torch.cat([1 - smx2, smx2], dim=1) # 组合成预测变量 xent = -torch.sum(smx1 * torch.log(smx1 + 1e-10), dim = 1) if args.type_noisy == 'mu' and check_mean_teacher(epoch): aux = F.mse_loss(smx1[:, 0], smx2[:, 0].detach()) if aux < loss * args.alpha: loss += aux if check_mean_teacher(epoch): loss += consistency_loss loss += (xent * w[:, 1]).mean() optimizer1.zero_grad() loss.backward() optimizer1.step() pacc_3, nacc_3, pnacc_3, psize = accuracy(predictiont1, T) pacc_1, nacc_1, pnacc_1, psize = accuracy(predictions1, T) # 使用T来计算预测准确率 pacc1.update(pacc_1, psize) nacc1.update(nacc_1, X.size(0) - psize) pnacc1.update(pnacc_1, X.size(0)) pacc3.update(pacc_3, psize) nacc3.update(nacc_3, X.size(0) - psize) pnacc3.update(pnacc_3, X.size(0)) for i, (X, left, right, Y, _, T, ids) in enumerate(noisy2_loader): X = X.cuda(args.gpu) left = left.cuda(args.gpu) right = right.cuda(args.gpu) Y = Y.cuda(args.gpu).float() T = T.cuda(args.gpu).long() meta_net = create_lenet_model() meta_net.load_state_dict(model2.state_dict()) if torch.cuda.is_available(): meta_net.cuda() y_f_hat = meta_net(X, left, right) prob = torch.sigmoid(y_f_hat) prob = torch.cat([1-prob, prob], dim=1) cost1 = torch.sum(prob * torch.log(prob + 1e-10), dim = 1) eps = to_var(torch.zeros(cost1.shape[0], 2)) cost2 = criterion(y_f_hat, Y, eps = eps[:, 0]) l_f_meta = (cost1 * eps[:, 1]).mean() + cost2[1] meta_net.zero_grad() grads = torch.autograd.grad(l_f_meta, meta_net.parameters(), create_graph = True) meta_net.update_params(0.001, source_params = grads) val_data, val_left, val_right, val_Y, _, val_labels, val_ids = next(iter(test_loader)) val_data = to_var(val_data, requires_grad = False) val_left = to_var(val_left, requires_grad = False) val_right = to_var(val_right, requires_grad = False) val_labels = to_var(val_labels, requires_grad=False).float() y_g_hat = meta_net(val_data, val_left, val_right) val_prob = torch.sigmoid(y_g_hat) val_prob = torch.cat([1 - val_prob, val_prob], dim=1) l_g_meta = -torch.mean(torch.sum(val_prob * torch.log(val_prob + 1e-10), dim = 1)) * 2 grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0] #print(grad_eps) w = torch.clamp(-grad_eps, min = 1e-10) acount = 0 bcount = 0 ccount = 0 dcount = 0 for j in range(w.shape[0]): if Y[j] == -1: if torch.sum(w[:, 1]) >= 4: w[j, 0] = 1 w[j, 1] = 0 else: w[j, :] = w[j, :] / torch.sum(w[j, :]) else: w[j, 0] = 1 w[j, 1] = 0 #w[:, 0] = 1 #w[:, 1] = 0 w = w.cuda().detach() # compute output output1 = model1(X, left, right) output2 = model2(X, left, right) with torch.no_grad(): ema_output2 = ema_model2(X, left, right) _, loss = criterion(output2, Y, eps = w[:, 0]) consistency_loss = consistency_weight * \ consistency_criterion(output2, ema_output2) / X.shape[0] #print(loss2) predictions2 = torch.sign(output2).long() predictiont2 = torch.sign(ema_output2).long() smx1 = torch.sigmoid(output1) # 计算sigmoid概率 smx1 = torch.cat([1 - smx1, smx1], dim=1) # 组合成预测变量 smxY = ((Y + 1) // 2).long() # 分类结果,0-1分类 smx2 = torch.sigmoid(output2) # 计算sigmoid概率 smx2 = torch.cat([1 - smx2, smx2], dim=1) # 组合成预测变量 xent = -torch.sum(smx2 * torch.log(smx2 + 1e-10), dim = 1) if args.type_noisy == 'mu' and check_mean_teacher(epoch): aux = F.mse_loss(smx2[:, 0], smx1[:, 0].detach()) if aux < loss * args.alpha: loss += aux if check_mean_teacher(epoch): loss += consistency_loss loss += (xent * w[:, 1]).mean() optimizer2.zero_grad() loss.backward() optimizer2.step() pacc_2, nacc_2, pnacc_2, psize = accuracy(predictions2, T) pacc_4, nacc_4, pnacc_4, psize = accuracy(predictiont2, T) pacc2.update(pacc_2, psize) nacc2.update(nacc_2, X.size(0) - psize) pnacc2.update(pnacc_2, X.size(0)) pacc4.update(pacc_4, psize) nacc4.update(nacc_4, X.size(0) - psize) pnacc4.update(pnacc_4, X.size(0)) if check_mean_teacher(epoch): update_ema_variables(model1, ema_model1, args.ema_decay, step) # 更新ema参数 update_ema_variables(model2, ema_model2, args.ema_decay, step) step += 1 print('Epoch Noisy : [{0}]\t' 'PACC1 {pacc1.val:.3f} ({pacc1.avg:.3f})\t' 'NACC1 {nacc1.val:.3f} ({nacc1.avg:.3f})\t' 'PNACC1 {pnacc1.val:.3f} ({pnacc1.avg:.3f})\t' 'PACC2 {pacc2.val:.3f} ({pacc2.avg:.3f})\t' 'NACC2 {nacc2.val:.3f} ({nacc2.avg:.3f})\t' 'PNACC2 {pnacc2.val:.3f} ({pnacc2.avg:.3f})\t' 'PACC3 {pacc3.val:.3f} ({pacc3.avg:.3f})\t' 'NACC3 {nacc3.val:.3f} ({nacc3.avg:.3f})\t' 'PNACC3 {pnacc3.val:.3f} ({pnacc3.avg:.3f})\t' 'PACC4 {pacc4.val:.3f} ({pacc4.avg:.3f})\t' 'NACC4 {nacc4.val:.3f} ({nacc4.avg:.3f})\t' 'PNACC4 {pnacc4.val:.3f} ({pnacc4.avg:.3f})\t'.format( epoch, pacc1=pacc1, nacc1=nacc1, pnacc1=pnacc1, pacc2=pacc2, nacc2=nacc2, pnacc2=pnacc2, pacc3=pacc3, nacc3=nacc3, pnacc3=pnacc3, pacc4=pacc4, nacc4=nacc4, pnacc4=pnacc4)) return pacc1.avg, nacc1.avg, pnacc1.avg
def train(model, input_channel, optimizers, criterion, components, train_loader, val_loader, epoch, writer, args, use_CUDA=True, clamp=False, num_classes=10): model.train() accs = [] losses_w1 = [] losses_w2 = [] iter_val_loader = iter(val_loader) meta_criterion = nn.CrossEntropyLoss(reduce=False) index = 0 noisy_labels = [] true_labels = [] w = defaultdict() w_logger = defaultdict() losses_logger = defaultdict() accuracy_logger = ScalarLogger(prefix='accuracy') for c in components: w[c] = None w_logger[c] = WLogger() losses_logger[c] = ScalarLogger(prefix='loss') for (input, label, real) in train_loader: noisy_labels.append(label) true_labels.append(real) meta_model = get_model(args, num_classes=num_classes, input_channel=input_channel) meta_model.load_state_dict(model.state_dict()) if use_CUDA: meta_model = meta_model.cuda() val_input, val_label, iter_val_loader = get_val_samples( iter_val_loader, val_loader) input = to_var(input, requires_grad=False) label = to_var(label, requires_grad=False).long() val_input = to_var(val_input, requires_grad=False) val_label = to_var(val_label, requires_grad=False).long() meta_output = meta_model(input) cost = meta_criterion(meta_output, label) eps = to_var(torch.zeros(cost.size())) meta_loss = (cost * eps).sum() meta_model.zero_grad() if 'all' in components: grads = torch.autograd.grad(meta_loss, (meta_model.parameters()), create_graph=True) meta_model.update_params(0.001, source_params=grads) meta_val_output = meta_model(val_input) meta_val_loss = meta_criterion(meta_val_output, val_label).sum() grad_eps = torch.autograd.grad(meta_val_loss, eps, only_inputs=True)[0] if clamp: w['all'] = torch.clamp(-grad_eps, min=0) else: w['all'] = -grad_eps norm = torch.sum(abs(w['all'])) assert (clamp and len(components) == 1) or (len(components) > 1), "Error combination" w['all'] = w['all'] / norm if ('fc' in components): w['fc'] = copy.deepcopy(w['all']) w['fc'] = torch.clamp(w['fc'], max=0) w['all'] = torch.clamp(w['all'], min=0) elif ('backbone' in components): w['backbone'] = copy.deepcopy(w['all']) w['backbone'] = torch.clamp(w['backbone'], max=0) w['all'] = torch.clamp(w['all'], min=0) else: assert ('backbone' in components) and ('fc' in components) grads_backbone = torch.autograd.grad( meta_loss, (meta_model.backbone.parameters()), create_graph=True, retain_graph=True) grads_fc = torch.autograd.grad(meta_loss, (meta_model.fc.parameters()), create_graph=True) # Backbone Grads meta_model.backbone.update_params(0.001, source_params=grads_backbone) meta_val_feature = torch.flatten(meta_model.backbone(val_input), 1) meta_val_output = meta_model.fc(val_input) meta_val_loss = meta_criterion(meta_val_output, val_label).sum() if args.with_kl and args.reg_start <= epoch: train_feature = torch.flatten(meta_model.backbone(input), 1) meta_val_loss -= sample_wise_kl(train_feature, meta_val_feature) grad_eps = torch.autograd.grad(meta_val_loss, eps, only_inputs=True, retain_graph=True)[0] if clamp: w['backbone'] = torch.clamp(-grad_eps, min=0) else: w['backbone'] = -grad_eps norm = torch.sum(abs(w['backbone'])) w['backbone'] = w['backbone'] / norm # FC backward meta_model.load_state_dict(model.state_dict()) meta_model.fc.update_params(0.001, source_params=grads_fc) meta_val_output = meta_model(val_input) meta_val_loss = meta_criterion(meta_val_output, val_label).sum() grad_eps = torch.autograd.grad(meta_val_loss, eps, only_inputs=True, retain_graph=True)[0] if clamp: w['fc'] = torch.clamp(-grad_eps, min=0) else: w['fc'] = -grad_eps norm = torch.sum(abs(w['fc'])) w['fc'] = w['fc'] / norm index += 1 output = model(input) loss = defaultdict() prediction = torch.softmax(output, 1) for c in components: w_logger[c].update(w[c]) loss[c] = (meta_criterion(output, label) * w[c]).sum() optimizers[c].zero_grad() loss[c].backward(retain_graph=True) optimizers[c].step() losses_logger[c].update(loss[c]) top1 = accuracy(prediction, label) accuracy_logger.update(top1) noisy_labels = torch.cat(noisy_labels) true_labels = torch.cat(true_labels) mask = (noisy_labels != true_labels).cpu().numpy() for c in components: w_logger[c].write(writer, c, epoch) w_logger[c].mask_write(writer, c, epoch, mask) losses_logger[c].write(writer, c, epoch) accuracy_logger.write(writer, 'train', epoch) print("Training Epoch: {}, Accuracy: {}".format(epoch, accuracy_logger.avg())) return accuracy_logger.avg()