def test(train_loader, model_kq, model, val_loader, config_lsvm, args, ssh): err_ssh = 0 if args.shared is None else test_ttt(val_loader, ssh, sslabel='expand')[0] print('SSH ERROR:', err_ssh) model.eval() top1, feats_bank = AverageMeter('Acc@1', ':4.2f'), [] with torch.no_grad(): # generate feature bank for (images, _) in tqdm(train_loader, desc='Feature extracting'): feats= model(images.cuda(args.gpu, non_blocking=True), 'r') feats_bank.append(feats) feats_bank = torch.cat(feats_bank, dim=0) label_bank = torch.tensor(train_loader.dataset.targets) model_lsvm = liblinearutil.train(label_bank.cpu().numpy(), feats_bank.cpu().numpy(), config_lsvm) with torch.no_grad(): val_bar = tqdm(val_loader) for (images, target) in val_bar: images = images.cuda(args.gpu, non_blocking=True) # compute output feats = model(images, 'r') _, top1_acc, _ = liblinearutil.predict(target.cpu().numpy(), feats.cpu().numpy(), model_lsvm, '-q') # measure accuracy and record top1.update(top1_acc[0], images.size(0)) val_bar.set_description('Acc@SVM:{:.2f}%'.format(top1.avg)) return top1.avg
def ttt_test(train_loader, model_kq, model, val_loader, config_lsvm, args, ssh, teset, head): if ',' in args.aug: tr_transform = transforms.Compose( aug(args.aug.split(',')[0], int(args.aug.split(',')[1]))) else: tr_transform = transforms.Compose(aug(args.aug)) # stliu: load ckpt first if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) err_ssh = 0 if args.shared is None else test_ttt( val_loader, ssh, sslabel='expand')[0] print('SSH ERROR:', err_ssh) # stliu: get SVM classifier feats_bank = [] with torch.no_grad(): # generate feature bank for (images, _) in tqdm(train_loader, desc='Feature extracting'): feats = model(images.cuda(args.gpu, non_blocking=True), 'r') feats_bank.append(feats) feats_bank = torch.cat(feats_bank, dim=0) label_bank = torch.tensor(train_loader.dataset.targets) model_lsvm = liblinearutil.train(label_bank.cpu().numpy(), feats_bank.cpu().numpy(), config_lsvm) # stliu: test time training if args.frozen: model_kq = FrozenBatchNorm2d.convert_frozen_batchnorm(model_kq) model = FrozenBatchNorm2d.convert_frozen_batchnorm(model) top1 = AverageMeter('Acc@1', ':4.2f') criterion_ssh = nn.CrossEntropyLoss().cuda() if args.bn_only: optimizer_ssh = torch.optim.SGD(ssh.parameters(), lr=0) else: optimizer_ssh = torch.optim.SGD(ssh.parameters(), lr=args.lr) ttt_bar = tqdm(range(1, len(teset) + 1)) test_transform = transforms.Compose([transforms.ToTensor(), normalize]) for i in ttt_bar: pretrained_dict = checkpoint['state_dict'] model_dict = model_kq.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model_kq.load_state_dict(pretrained_dict) head.load_state_dict(checkpoint['head']) _, label = teset[i - 1] # stliu: get the label for the image image = Image.fromarray(teset.data[i - 1]) ssh.train() inputs = [tr_transform(image) for _ in range(args.batch_size)] inputs = torch.stack(inputs) inputs_ssh, labels_ssh = rotate_batch(inputs, 'rand') inputs_ssh, labels_ssh = inputs_ssh.cuda(), labels_ssh.cuda() optimizer_ssh.zero_grad() outputs_ssh = ssh(inputs_ssh) loss_ssh = criterion_ssh(outputs_ssh, labels_ssh) loss_ssh.backward() optimizer_ssh.step() # test again state_dict = model_kq.state_dict() for k in list(state_dict.keys()): if k.startswith('module.encoder_q' ) and not k.startswith('module.encoder_q.fc'): state_dict[k[len("module.encoder_q."):]] = state_dict[k] del state_dict[k] model.load_state_dict(state_dict, strict=False) model.eval() ssh.eval() inputs = [test_transform(image) for _ in range(args.batch_size)] inputs = torch.stack(inputs) inputs = inputs.cuda(args.gpu, non_blocking=True) feats = model(inputs, 'r') targets = np.array([label for _ in range(args.batch_size)]) _, top1_acc, _ = liblinearutil.predict(targets, feats.cpu().detach().numpy(), model_lsvm, '-q') # measure accuracy and record top1.update(top1_acc[0]) ttt_bar.set_description('New Acc@SVM:{:.2f}%'.format(top1.avg)) return top1.avg