def test_grad_corr(dataloader, net, ssh, ext): criterion = nn.CrossEntropyLoss().cuda() net.eval() ssh.eval() corr = [] for batch_idx, (inputs, labels) in enumerate(dataloader): net.zero_grad() ssh.zero_grad() inputs_cls, labels_cls = inputs.cuda(), labels.cuda() outputs_cls = net(inputs_cls) loss_cls = criterion(outputs_cls, labels_cls) grad_cls = torch.autograd.grad(loss_cls, ext.parameters()) grad_cls = flat_grad(grad_cls) ext.zero_grad() inputs, labels = rotate_batch(inputs, 'expand') inputs_ssh, labels_ssh = inputs.cuda(), labels.cuda() outputs_ssh = ssh(inputs_ssh) loss_ssh = criterion(outputs_ssh, labels_ssh) grad_ssh = torch.autograd.grad(loss_ssh, ext.parameters()) grad_ssh = flat_grad(grad_ssh) corr.append(torch.dot(grad_cls, grad_ssh).item()) net.train() ssh.train() return corr
def test_ttt(dataloader, model, sslabel=None): criterion = nn.CrossEntropyLoss(reduction='none').cuda() model.eval() correct = [] losses = [] for batch_idx, (inputs, labels) in enumerate(dataloader): if sslabel is not None: inputs, labels = rotate_batch(inputs, sslabel) inputs, labels = inputs.cuda(), labels.cuda() with torch.no_grad(): outputs = model(inputs) loss = criterion(outputs, labels) losses.append(loss.cpu()) _, predicted = outputs.max(1) correct.append(predicted.eq(labels).cpu()) correct = torch.cat(correct).numpy() losses = torch.cat(losses).numpy() model.train() return 1 - correct.mean(), correct, losses
def ttt_test(train_loader, model_kq, model, val_loader, config_lsvm, args, ssh, teset, head): if ',' in args.aug: tr_transform = transforms.Compose( aug(args.aug.split(',')[0], int(args.aug.split(',')[1]))) else: tr_transform = transforms.Compose(aug(args.aug)) # stliu: load ckpt first if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) err_ssh = 0 if args.shared is None else test_ttt( val_loader, ssh, sslabel='expand')[0] print('SSH ERROR:', err_ssh) # stliu: get SVM classifier feats_bank = [] with torch.no_grad(): # generate feature bank for (images, _) in tqdm(train_loader, desc='Feature extracting'): feats = model(images.cuda(args.gpu, non_blocking=True), 'r') feats_bank.append(feats) feats_bank = torch.cat(feats_bank, dim=0) label_bank = torch.tensor(train_loader.dataset.targets) model_lsvm = liblinearutil.train(label_bank.cpu().numpy(), feats_bank.cpu().numpy(), config_lsvm) # stliu: test time training if args.frozen: model_kq = FrozenBatchNorm2d.convert_frozen_batchnorm(model_kq) model = FrozenBatchNorm2d.convert_frozen_batchnorm(model) top1 = AverageMeter('Acc@1', ':4.2f') criterion_ssh = nn.CrossEntropyLoss().cuda() if args.bn_only: optimizer_ssh = torch.optim.SGD(ssh.parameters(), lr=0) else: optimizer_ssh = torch.optim.SGD(ssh.parameters(), lr=args.lr) ttt_bar = tqdm(range(1, len(teset) + 1)) test_transform = transforms.Compose([transforms.ToTensor(), normalize]) for i in ttt_bar: pretrained_dict = checkpoint['state_dict'] model_dict = model_kq.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model_kq.load_state_dict(pretrained_dict) head.load_state_dict(checkpoint['head']) _, label = teset[i - 1] # stliu: get the label for the image image = Image.fromarray(teset.data[i - 1]) ssh.train() inputs = [tr_transform(image) for _ in range(args.batch_size)] inputs = torch.stack(inputs) inputs_ssh, labels_ssh = rotate_batch(inputs, 'rand') inputs_ssh, labels_ssh = inputs_ssh.cuda(), labels_ssh.cuda() optimizer_ssh.zero_grad() outputs_ssh = ssh(inputs_ssh) loss_ssh = criterion_ssh(outputs_ssh, labels_ssh) loss_ssh.backward() optimizer_ssh.step() # test again state_dict = model_kq.state_dict() for k in list(state_dict.keys()): if k.startswith('module.encoder_q' ) and not k.startswith('module.encoder_q.fc'): state_dict[k[len("module.encoder_q."):]] = state_dict[k] del state_dict[k] model.load_state_dict(state_dict, strict=False) model.eval() ssh.eval() inputs = [test_transform(image) for _ in range(args.batch_size)] inputs = torch.stack(inputs) inputs = inputs.cuda(args.gpu, non_blocking=True) feats = model(inputs, 'r') targets = np.array([label for _ in range(args.batch_size)]) _, top1_acc, _ = liblinearutil.predict(targets, feats.cpu().detach().numpy(), model_lsvm, '-q') # measure accuracy and record top1.update(top1_acc[0]) ttt_bar.set_description('New Acc@SVM:{:.2f}%'.format(top1.avg)) return top1.avg
def train(train_loader, model, criterion, optimizer, epoch, args, ssh): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') # progress = ProgressMeter( # len(train_loader), # [batch_time, data_time, losses, top1, top5], # prefix="Epoch: [{}]".format(epoch)) # stliu: design new pregress epoch_time = AverageMeter('Epoch Time', ':6.3f') progress = ProgressMeter(len(train_loader), [epoch_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) # switch to train mode if args.norm != 'bnf': model.train() ssh.train() end = time.time() for i, (images, _) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) if args.gpu is not None: images[0] = images[0].cuda(args.gpu, non_blocking=True) images[1] = images[1].cuda(args.gpu, non_blocking=True) # compute output output, target = model(im_q=images[0], im_k=images[1]) loss = criterion(output, target) if args.shared is not None: inputs_ssh, labels_ssh = rotate_batch(images[0], args.rotation_type) inputs_ssh, labels_ssh = inputs_ssh.cuda( args.gpu, non_blocking=True), labels_ssh.cuda(args.gpu, non_blocking=True) outputs_ssh = ssh(inputs_ssh) loss_ssh = criterion(outputs_ssh, labels_ssh) loss += loss_ssh # acc1/acc5 are (K+1)-way contrast classifier accuracy # measure accuracy and record lossa acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images[0].size(0)) top1.update(acc1[0], images[0].size(0)) top5.update(acc5[0], images[0].size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) epoch_time.update(batch_time.avg * len(train_loader)) end = time.time() if (i + 1) % args.print_freq == 0: # stliu: change i to i+1 progress.display(i) return losses.avg