import time import torch.optim import torch.optim.lr_scheduler as lr_scheduler import models.transform_layers as TL from utils.utils import AverageMeter, normalize device = torch.device("cuda" if torch.cuda.is_available() else "cpu") hflip = TL.HorizontalFlipLayer().to(device) def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None, simclr_aug=None, linear=None, linear_optim=None): if P.multi_gpu: rotation_linear = model.module.shift_cls_layer joint_linear = model.module.joint_distribution_layer else: rotation_linear = model.shift_cls_layer joint_linear = model.joint_distribution_layer
def main(P): P.no_strict = False P.shift_trans_type = 'rotation' P.mode = 'ood_pre' P.n_classes = 2 P.model = 'resnet18_imagenet' P.image_size = (224, 224, 3) P.resize_factor = 0.54 P.resize_fix = True P.layers = ['simclr', 'shift'] device = torch.device(f"cuda" if P.use_cuda else "cpu") P.shift_trans, P.K_shift = C.get_shift_module(P, eval=True) P.axis = pickle.load(open(P.axis_path, "rb")) if P.w_sim_path: P.weight_sim = pickle.load(open(P.w_sim_path, "rb")) P.weight_shi = pickle.load(open(P.w_shi_path, "rb")) else: P.weight_sim = [ 0.007519226599080519, 0.007939391391667395, 0.008598049328054363, 0.015014530319964874 ] P.weight_shi = [ 0.04909334419285857, 0.052858438675397496, 0.05840793893796496, 0.11790745570891596 ] hflip = TL.HorizontalFlipLayer().to(device) test_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), ]) simclr_aug = get_simclr_augmentation(P, P.image_size).to(device) model = C.get_classifier(P.model, n_classes=P.n_classes).to(device) model = C.get_shift_classifer(model, P.K_shift).to(device) if P.use_cuda: checkpoint = torch.load(P.load_path) else: checkpoint = torch.load(P.load_path, map_location='cpu') model.load_state_dict(checkpoint, strict=not P.no_strict) kwargs = { 'simclr_aug': simclr_aug, 'layers': P.layers, 'hflip': hflip, 'device': device } image_files = glob.glob(os.path.join(P.image_dir, "*")) total_scores = [] for i, image_file in enumerate(image_files): print(i, len(image_files), image_file) start = time.time() try: img = Image.open(image_file).convert("RGB") except: continue img = test_transform(img) features = get_features(P, model, img, **kwargs) scores = get_scores(P, features, device).numpy() print(time.time() - start) print(scores) total_scores += list(scores) print(total_scores) total_scores = np.array(total_scores) for i in range(20, 100): if P.is_positive: print('true accuracy', i, (total_scores >= i / 100).sum() / len(total_scores)) else: print('false accuracy', i, (total_scores < i / 100).sum() / len(total_scores)) if P.is_positive: print('true accuracy thres', P.score_thres, (total_scores >= P.score_thres).sum() / len(total_scores)) else: print('false accuracy thres', P.score_thres, (total_scores < P.score_thres).sum() / len(total_scores))