Ejemplo n.º 1
0
def test(train_loader, model_kq, model, val_loader, config_lsvm, args, ssh):
	err_ssh = 0 if args.shared is None else test_ttt(val_loader, ssh, sslabel='expand')[0]
	print('SSH ERROR:', err_ssh)

	model.eval()
	top1, feats_bank = AverageMeter('Acc@1', ':4.2f'), []

	with torch.no_grad():
		# generate feature bank
		for (images, _) in tqdm(train_loader, desc='Feature extracting'):
			feats= model(images.cuda(args.gpu, non_blocking=True), 'r')
			feats_bank.append(feats)
	feats_bank = torch.cat(feats_bank, dim=0)
	label_bank = torch.tensor(train_loader.dataset.targets)
	model_lsvm = liblinearutil.train(label_bank.cpu().numpy(), feats_bank.cpu().numpy(), config_lsvm)

	with torch.no_grad():
		val_bar = tqdm(val_loader)
		for (images, target) in val_bar:
			images = images.cuda(args.gpu, non_blocking=True)

			# compute output
			feats = model(images, 'r')
			_, top1_acc, _ = liblinearutil.predict(target.cpu().numpy(), feats.cpu().numpy(), model_lsvm, '-q')

			# measure accuracy and record
			top1.update(top1_acc[0], images.size(0))
			val_bar.set_description('Acc@SVM:{:.2f}%'.format(top1.avg))
	return top1.avg
Ejemplo n.º 2
0
def ttt_test(train_loader, model_kq, model, val_loader, config_lsvm, args, ssh,
             teset, head):
    if ',' in args.aug:
        tr_transform = transforms.Compose(
            aug(args.aug.split(',')[0], int(args.aug.split(',')[1])))
    else:
        tr_transform = transforms.Compose(aug(args.aug))
    # stliu: load ckpt first
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        if args.gpu is None:
            checkpoint = torch.load(args.resume)
        else:
            # Map model to be loaded to specified single gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = torch.load(args.resume, map_location=loc)
    err_ssh = 0 if args.shared is None else test_ttt(
        val_loader, ssh, sslabel='expand')[0]
    print('SSH ERROR:', err_ssh)

    # stliu: get SVM classifier
    feats_bank = []
    with torch.no_grad():
        # generate feature bank
        for (images, _) in tqdm(train_loader, desc='Feature extracting'):
            feats = model(images.cuda(args.gpu, non_blocking=True), 'r')
            feats_bank.append(feats)
    feats_bank = torch.cat(feats_bank, dim=0)
    label_bank = torch.tensor(train_loader.dataset.targets)
    model_lsvm = liblinearutil.train(label_bank.cpu().numpy(),
                                     feats_bank.cpu().numpy(), config_lsvm)

    # stliu: test time training
    if args.frozen:
        model_kq = FrozenBatchNorm2d.convert_frozen_batchnorm(model_kq)
        model = FrozenBatchNorm2d.convert_frozen_batchnorm(model)
    top1 = AverageMeter('Acc@1', ':4.2f')
    criterion_ssh = nn.CrossEntropyLoss().cuda()
    if args.bn_only:
        optimizer_ssh = torch.optim.SGD(ssh.parameters(), lr=0)
    else:
        optimizer_ssh = torch.optim.SGD(ssh.parameters(), lr=args.lr)
    ttt_bar = tqdm(range(1, len(teset) + 1))
    test_transform = transforms.Compose([transforms.ToTensor(), normalize])
    for i in ttt_bar:
        pretrained_dict = checkpoint['state_dict']
        model_dict = model_kq.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model_kq.load_state_dict(pretrained_dict)
        head.load_state_dict(checkpoint['head'])
        _, label = teset[i - 1]  # stliu: get the label for the image
        image = Image.fromarray(teset.data[i - 1])
        ssh.train()
        inputs = [tr_transform(image) for _ in range(args.batch_size)]
        inputs = torch.stack(inputs)
        inputs_ssh, labels_ssh = rotate_batch(inputs, 'rand')
        inputs_ssh, labels_ssh = inputs_ssh.cuda(), labels_ssh.cuda()
        optimizer_ssh.zero_grad()
        outputs_ssh = ssh(inputs_ssh)
        loss_ssh = criterion_ssh(outputs_ssh, labels_ssh)
        loss_ssh.backward()
        optimizer_ssh.step()

        # test again
        state_dict = model_kq.state_dict()
        for k in list(state_dict.keys()):
            if k.startswith('module.encoder_q'
                            ) and not k.startswith('module.encoder_q.fc'):
                state_dict[k[len("module.encoder_q."):]] = state_dict[k]
            del state_dict[k]
        model.load_state_dict(state_dict, strict=False)
        model.eval()
        ssh.eval()

        inputs = [test_transform(image) for _ in range(args.batch_size)]
        inputs = torch.stack(inputs)
        inputs = inputs.cuda(args.gpu, non_blocking=True)
        feats = model(inputs, 'r')
        targets = np.array([label for _ in range(args.batch_size)])
        _, top1_acc, _ = liblinearutil.predict(targets,
                                               feats.cpu().detach().numpy(),
                                               model_lsvm, '-q')

        # measure accuracy and record
        top1.update(top1_acc[0])
        ttt_bar.set_description('New Acc@SVM:{:.2f}%'.format(top1.avg))
    return top1.avg