Exemplo n.º 1
0
    if P.one_class_idx is not None:
        ood_test_set = get_subclass_dataset(full_test_set, classes=cls_list[ood])
        ood = f'one_class_{ood}'  # change save name
    else:
        ood_test_set = get_dataset(P, dataset=ood, test_only=True, image_size=P.image_size)

    if P.multi_gpu:
        ood_sampler = DistributedSampler(ood_test_set, num_replicas=P.n_gpus, rank=P.local_rank)
        ood_test_loader[ood] = DataLoader(ood_test_set, sampler=ood_sampler, batch_size=P.test_batch_size, **kwargs)
    else:
        ood_test_loader[ood] = DataLoader(ood_test_set, shuffle=False, batch_size=P.test_batch_size, **kwargs)

### Initialize model ###

simclr_aug = C.get_simclr_augmentation(P, image_size=P.image_size).to(device)
P.shift_trans, P.K_shift = C.get_shift_module(P, eval=True)
P.shift_trans = P.shift_trans.to(device)

model = C.get_classifier(P.model, n_classes=P.n_classes).to(device)
model = C.get_shift_classifer(model, P.K_shift).to(device)

criterion = nn.CrossEntropyLoss().to(device)

if P.optimizer == 'sgd':
    optimizer = optim.SGD(model.parameters(), lr=P.lr_init, momentum=0.9, weight_decay=P.weight_decay)
    lr_decay_gamma = 0.1
elif P.optimizer == 'lars':
    from torchlars import LARS
    base_optimizer = optim.SGD(model.parameters(), lr=P.lr_init, momentum=0.9, weight_decay=P.weight_decay)
    optimizer = LARS(base_optimizer, eps=1e-8, trust_coef=0.001)
    lr_decay_gamma = 0.1
Exemplo n.º 2
0
def main(P):
    P.no_strict = False

    P.shift_trans_type = 'rotation'
    P.mode = 'ood_pre'
    P.n_classes = 2
    P.model = 'resnet18_imagenet'
    P.image_size = (224, 224, 3)

    P.resize_factor = 0.54
    P.resize_fix = True
    P.layers = ['simclr', 'shift']

    device = torch.device(f"cuda" if P.use_cuda else "cpu")

    P.shift_trans, P.K_shift = C.get_shift_module(P, eval=True)

    P.axis = pickle.load(open(P.axis_path, "rb"))
    if P.w_sim_path:
        P.weight_sim = pickle.load(open(P.w_sim_path, "rb"))
        P.weight_shi = pickle.load(open(P.w_shi_path, "rb"))
    else:
        P.weight_sim = [
            0.007519226599080519, 0.007939391391667395, 0.008598049328054363,
            0.015014530319964874
        ]
        P.weight_shi = [
            0.04909334419285857, 0.052858438675397496, 0.05840793893796496,
            0.11790745570891596
        ]

    hflip = TL.HorizontalFlipLayer().to(device)

    test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])

    simclr_aug = get_simclr_augmentation(P, P.image_size).to(device)

    model = C.get_classifier(P.model, n_classes=P.n_classes).to(device)
    model = C.get_shift_classifer(model, P.K_shift).to(device)

    if P.use_cuda:
        checkpoint = torch.load(P.load_path)
    else:
        checkpoint = torch.load(P.load_path, map_location='cpu')
    model.load_state_dict(checkpoint, strict=not P.no_strict)

    kwargs = {
        'simclr_aug': simclr_aug,
        'layers': P.layers,
        'hflip': hflip,
        'device': device
    }

    image_files = glob.glob(os.path.join(P.image_dir, "*"))
    total_scores = []
    for i, image_file in enumerate(image_files):
        print(i, len(image_files), image_file)
        start = time.time()
        try:
            img = Image.open(image_file).convert("RGB")
        except:
            continue
        img = test_transform(img)
        features = get_features(P, model, img, **kwargs)
        scores = get_scores(P, features, device).numpy()
        print(time.time() - start)
        print(scores)
        total_scores += list(scores)
    print(total_scores)
    total_scores = np.array(total_scores)
    for i in range(20, 100):
        if P.is_positive:
            print('true accuracy', i,
                  (total_scores >= i / 100).sum() / len(total_scores))
        else:
            print('false accuracy', i,
                  (total_scores < i / 100).sum() / len(total_scores))

    if P.is_positive:
        print('true accuracy thres', P.score_thres,
              (total_scores >= P.score_thres).sum() / len(total_scores))
    else:
        print('false accuracy thres', P.score_thres,
              (total_scores < P.score_thres).sum() / len(total_scores))