def test_grad_corr(dataloader, net, ssh, ext): criterion = nn.CrossEntropyLoss().cuda() net.eval() ssh.eval() corr = [] for batch_idx, (inputs, labels) in enumerate(dataloader): net.zero_grad() ssh.zero_grad() inputs_cls, labels_cls = inputs.cuda(), labels.cuda() outputs_cls = net(inputs_cls) loss_cls = criterion(outputs_cls, labels_cls) grad_cls = torch.autograd.grad(loss_cls, ext.parameters()) grad_cls = flat_grad(grad_cls) ext.zero_grad() inputs, labels = rotate_batch(inputs, 'expand') inputs_ssh, labels_ssh = inputs.cuda(), labels.cuda() outputs_ssh = ssh(inputs_ssh) loss_ssh = criterion(outputs_ssh, labels_ssh) grad_ssh = torch.autograd.grad(loss_ssh, ext.parameters()) grad_ssh = flat_grad(grad_ssh) corr.append(torch.dot(grad_cls, grad_ssh).item()) net.train() ssh.train() return corr
def adapt_single(model, image, optimizer, criterion, niter, batch_size): model.train() for iteration in range(niter): inputs = [rotation_tr_transforms(image) for _ in range(batch_size)] inputs, labels = rotate_batch(inputs) inputs, labels = inputs.cuda(), labels.cuda() optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step()
def adapt_single_tensor(model, tensor, optimizer, criterion, niter, batch_size): model.train() for iteration in range(niter): inputs = [tensor for _ in range(batch_size)] inputs, labels = rotate_batch(inputs) inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() _, ssh = model(inputs) loss = criterion(ssh, labels) loss.backward() optimizer.step()
def test(dataloader, model, sslabel=None): criterion = nn.CrossEntropyLoss(reduction='none').cuda() model.eval() # 让 model变成测试模式 correct = [] losses = [] for batch_idx, (inputs, labels) in enumerate(dataloader): if sslabel is not None: inputs, labels = rotate_batch(inputs, sslabel) # 在这里rotate inputs, labels = inputs.cuda(), labels.cuda() with torch.no_grad(): outputs = model(inputs) # 就看这个能不能送进去 loss = criterion(outputs, labels) # 就是识别任务,判断能否得出正确的label losses.append(loss.cpu()) _, predicted = outputs.max(1) correct.append(predicted.eq(labels).cpu()) # 是否预测正确 correct = torch.cat(correct).numpy() # 这个回头要测试一下 losses = torch.cat(losses).numpy() model.train() # 变回训练模式 return 1 - correct.mean(), correct, losses
def test(dataloader, model, sslabel=None): criterion = nn.CrossEntropyLoss(reduction='none').cuda() model.eval() correct = [] losses = [] for batch_idx, (inputs, labels) in enumerate(dataloader): if sslabel is not None: inputs, labels = rotate_batch(inputs, sslabel) inputs, labels = inputs.cuda(), labels.cuda() with torch.no_grad(): outputs = model(inputs) loss = criterion(outputs, labels) losses.append(loss.cpu()) _, predicted = outputs.max(1) correct.append(predicted.eq(labels).cpu()) correct = torch.cat(correct).numpy() losses = torch.cat(losses).numpy() model.train() return 1 - correct.mean(), correct, losses
def train(trloader, epoch): net.train() batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter(len(trloader), batch_time, data_time, losses, top1, prefix="Epoch: [{}]".format(epoch)) end = time.time() for i, dl in enumerate(trloader): data_time.update(time.time() - end) optimizer.zero_grad() inputs_cls, labels_cls = dl[0].to(device), dl[1].to(device) outputs_cls, _ = net(inputs_cls) loss = criterion(outputs_cls, labels_cls) losses.update(loss.item(), len(labels_cls)) _, predicted = outputs_cls.max(1) acc1 = predicted.eq(labels_cls).sum().item() / len(labels_cls) top1.update(acc1, len(labels_cls)) rot_inputs, rot_labels = rotate_batch(dl[0]) inputs_ssh, labels_ssh = rot_inputs.to(device), rot_labels.to(device) _, outputs_ssh = net(inputs_ssh) loss_ssh = criterion(outputs_ssh, labels_ssh) loss += loss_ssh loss.backward() optimizer.step() batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.print(i)
all_err_cls = [] all_err_ssh = [] print('Running...') print('Error (%)\t\ttest\t\tself-supervised') for epoch in range(1, args.nepoch + 1): net.train() ssh.train() for batch_idx, (inputs, labels) in enumerate(trloader): optimizer.zero_grad() inputs_cls, labels_cls = inputs.cuda(), labels.cuda() outputs_cls = net(inputs_cls) loss = criterion(outputs_cls, labels_cls) if args.shared is not None: inputs_ssh, labels_ssh = rotate_batch( inputs, args.rotation_type) # train randomly inputs_ssh, labels_ssh = inputs_ssh.cuda(), labels_ssh.cuda() outputs_ssh = ssh(inputs_ssh) loss_ssh = criterion(outputs_ssh, labels_ssh) loss += loss_ssh loss.backward() optimizer.step() err_cls = test(teloader, net)[0] err_ssh = 0 if args.shared is None else test( teloader, ssh, sslabel='expand')[0] all_err_cls.append(err_cls) all_err_ssh.append(err_ssh) scheduler.step()
if epoch < args.epochs_pre: net.train() ssh.train() c_err_cls = [] avg_err_cls = 0 for i, (inputs, labels) in enumerate(train_loader): # print("inside train") optimizer.zero_grad() X, label = inputs.to(device), labels.to(device) y = net(X) loss = criterion(y, label) if args.shared is not None: inputs_ssh, labels_ssh = rotate_batch(inputs, args.rotation_type) inputs_ssh, labels_ssh = inputs_ssh.to(device), labels_ssh.to(device) outputs_ssh = ssh(inputs_ssh) loss_ssh = criterion(outputs_ssh, labels_ssh) loss += loss_ssh loss.backward() optimizer.step() scheduler.step() else: if epoch % args.aug_freq == 0 or epoch == args.epochs_pre: AUG_STEP += 1 net.eval()
trset, trloader = prepare_train_data(args) teset, teloader = prepare_test_data(args) corrs = [] print("Gradient Correlation") for i in range(args.epochs): idx = random.randint(0, len(trset) - 1) img, lbl = trset[idx] random_rot = random.randint(1, 3) rot_img = rotate_single_with_label(img, random_rot) # get gradient loss for auxiliary head d_aux_loss = [] inputs = [rot_img for _ in range(args.batch_size)] inputs, labels = rotate_batch(inputs) inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() _, ssh = net(inputs) loss = criterion(ssh, labels) loss.backward(retain_graph=True) for p in net.parameters(): if p.grad is None: continue # split point if list(p.grad.size())[0] == 512: break d_aux_loss.append(p.grad.data.clone()) # get gradient loss for main head