Exemple #1
0
def test_grad_corr(dataloader, net, ssh, ext):
    criterion = nn.CrossEntropyLoss().cuda()
    net.eval()
    ssh.eval()
    corr = []
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        net.zero_grad()
        ssh.zero_grad()
        inputs_cls, labels_cls = inputs.cuda(), labels.cuda()
        outputs_cls = net(inputs_cls)
        loss_cls = criterion(outputs_cls, labels_cls)
        grad_cls = torch.autograd.grad(loss_cls, ext.parameters())
        grad_cls = flat_grad(grad_cls)

        ext.zero_grad()
        inputs, labels = rotate_batch(inputs, 'expand')
        inputs_ssh, labels_ssh = inputs.cuda(), labels.cuda()
        outputs_ssh = ssh(inputs_ssh)
        loss_ssh = criterion(outputs_ssh, labels_ssh)
        grad_ssh = torch.autograd.grad(loss_ssh, ext.parameters())
        grad_ssh = flat_grad(grad_ssh)

        corr.append(torch.dot(grad_cls, grad_ssh).item())
    net.train()
    ssh.train()
    return corr
Exemple #2
0
def test_ttt(dataloader, model, sslabel=None):
    criterion = nn.CrossEntropyLoss(reduction='none').cuda()
    model.eval()
    correct = []
    losses = []
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        if sslabel is not None:
            inputs, labels = rotate_batch(inputs, sslabel)
        inputs, labels = inputs.cuda(), labels.cuda()
        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            losses.append(loss.cpu())
            _, predicted = outputs.max(1)
            correct.append(predicted.eq(labels).cpu())
    correct = torch.cat(correct).numpy()
    losses = torch.cat(losses).numpy()
    model.train()
    return 1 - correct.mean(), correct, losses
Exemple #3
0
def ttt_test(train_loader, model_kq, model, val_loader, config_lsvm, args, ssh,
             teset, head):
    if ',' in args.aug:
        tr_transform = transforms.Compose(
            aug(args.aug.split(',')[0], int(args.aug.split(',')[1])))
    else:
        tr_transform = transforms.Compose(aug(args.aug))
    # stliu: load ckpt first
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        if args.gpu is None:
            checkpoint = torch.load(args.resume)
        else:
            # Map model to be loaded to specified single gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = torch.load(args.resume, map_location=loc)
    err_ssh = 0 if args.shared is None else test_ttt(
        val_loader, ssh, sslabel='expand')[0]
    print('SSH ERROR:', err_ssh)

    # stliu: get SVM classifier
    feats_bank = []
    with torch.no_grad():
        # generate feature bank
        for (images, _) in tqdm(train_loader, desc='Feature extracting'):
            feats = model(images.cuda(args.gpu, non_blocking=True), 'r')
            feats_bank.append(feats)
    feats_bank = torch.cat(feats_bank, dim=0)
    label_bank = torch.tensor(train_loader.dataset.targets)
    model_lsvm = liblinearutil.train(label_bank.cpu().numpy(),
                                     feats_bank.cpu().numpy(), config_lsvm)

    # stliu: test time training
    if args.frozen:
        model_kq = FrozenBatchNorm2d.convert_frozen_batchnorm(model_kq)
        model = FrozenBatchNorm2d.convert_frozen_batchnorm(model)
    top1 = AverageMeter('Acc@1', ':4.2f')
    criterion_ssh = nn.CrossEntropyLoss().cuda()
    if args.bn_only:
        optimizer_ssh = torch.optim.SGD(ssh.parameters(), lr=0)
    else:
        optimizer_ssh = torch.optim.SGD(ssh.parameters(), lr=args.lr)
    ttt_bar = tqdm(range(1, len(teset) + 1))
    test_transform = transforms.Compose([transforms.ToTensor(), normalize])
    for i in ttt_bar:
        pretrained_dict = checkpoint['state_dict']
        model_dict = model_kq.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model_kq.load_state_dict(pretrained_dict)
        head.load_state_dict(checkpoint['head'])
        _, label = teset[i - 1]  # stliu: get the label for the image
        image = Image.fromarray(teset.data[i - 1])
        ssh.train()
        inputs = [tr_transform(image) for _ in range(args.batch_size)]
        inputs = torch.stack(inputs)
        inputs_ssh, labels_ssh = rotate_batch(inputs, 'rand')
        inputs_ssh, labels_ssh = inputs_ssh.cuda(), labels_ssh.cuda()
        optimizer_ssh.zero_grad()
        outputs_ssh = ssh(inputs_ssh)
        loss_ssh = criterion_ssh(outputs_ssh, labels_ssh)
        loss_ssh.backward()
        optimizer_ssh.step()

        # test again
        state_dict = model_kq.state_dict()
        for k in list(state_dict.keys()):
            if k.startswith('module.encoder_q'
                            ) and not k.startswith('module.encoder_q.fc'):
                state_dict[k[len("module.encoder_q."):]] = state_dict[k]
            del state_dict[k]
        model.load_state_dict(state_dict, strict=False)
        model.eval()
        ssh.eval()

        inputs = [test_transform(image) for _ in range(args.batch_size)]
        inputs = torch.stack(inputs)
        inputs = inputs.cuda(args.gpu, non_blocking=True)
        feats = model(inputs, 'r')
        targets = np.array([label for _ in range(args.batch_size)])
        _, top1_acc, _ = liblinearutil.predict(targets,
                                               feats.cpu().detach().numpy(),
                                               model_lsvm, '-q')

        # measure accuracy and record
        top1.update(top1_acc[0])
        ttt_bar.set_description('New Acc@SVM:{:.2f}%'.format(top1.avg))
    return top1.avg
Exemple #4
0
def train(train_loader, model, criterion, optimizer, epoch, args, ssh):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    # progress = ProgressMeter(
    # 	len(train_loader),
    # 	[batch_time, data_time, losses, top1, top5],
    # 	prefix="Epoch: [{}]".format(epoch))

    # stliu: design new pregress
    epoch_time = AverageMeter('Epoch Time', ':6.3f')
    progress = ProgressMeter(len(train_loader),
                             [epoch_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    if args.norm != 'bnf':
        model.train()
        ssh.train()

    end = time.time()

    for i, (images, _) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        if args.gpu is not None:
            images[0] = images[0].cuda(args.gpu, non_blocking=True)
            images[1] = images[1].cuda(args.gpu, non_blocking=True)

        # compute output
        output, target = model(im_q=images[0], im_k=images[1])
        loss = criterion(output, target)
        if args.shared is not None:
            inputs_ssh, labels_ssh = rotate_batch(images[0],
                                                  args.rotation_type)
            inputs_ssh, labels_ssh = inputs_ssh.cuda(
                args.gpu,
                non_blocking=True), labels_ssh.cuda(args.gpu,
                                                    non_blocking=True)
            outputs_ssh = ssh(inputs_ssh)
            loss_ssh = criterion(outputs_ssh, labels_ssh)
            loss += loss_ssh
        # acc1/acc5 are (K+1)-way contrast classifier accuracy
        # measure accuracy and record lossa
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images[0].size(0))
        top1.update(acc1[0], images[0].size(0))
        top5.update(acc5[0], images[0].size(0))
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        epoch_time.update(batch_time.avg * len(train_loader))
        end = time.time()

        if (i + 1) % args.print_freq == 0:  # stliu: change i to i+1
            progress.display(i)

    return losses.avg