Esempio n. 1
0
def run(args: DictConfig) -> None:
    cuda_available = torch.cuda.is_available()
    torch.manual_seed(args.seed)

    device = "cuda" if cuda_available and args.device == 'cuda' else "cpu"

    n_classes = args.get(args.dataset).n_classes
    rep_size = args.get(args.dataset).rep_size
    margin = args.get(args.dataset).margin

    classifier = get_model(name=args.classifier_name,
                           n_classes=n_classes).to(args.device)

    sdim = SDIM(disc_classifier=classifier,
                n_classes=n_classes,
                rep_size=rep_size,
                mi_units=args.mi_units,
                margin=margin).to(args.device)

    base_dir = hydra.utils.to_absolute_path('logs/sdim/{}'.format(
        args.dataset))
    save_name = 'SDIM_{}.pth'.format(args.classifier_name)
    sdim.load_state_dict(
        torch.load(os.path.join(base_dir, save_name),
                   map_location=lambda storage, loc: storage))

    ood_detection(sdim, args)
def run(args: DictConfig) -> None:
    assert torch.cuda.is_available()
    torch.manual_seed(args.seed)

    n_classes = args.get(args.dataset).n_classes
    rep_size = args.get(args.dataset).rep_size
    margin = args.get(args.dataset).margin

    classifier = resnet18(n_classes=n_classes).to(args.device)

    sdim = SDIM(disc_classifier=classifier,
                n_classes=n_classes,
                rep_size=rep_size,
                mi_units=args.mi_units,
                margin=margin).to(args.device)

    base_dir = hydra.utils.to_absolute_path('logs/sdim/{}'.format(
        args.dataset))
    save_name = 'SDIM_resnet18{}.pth'.format(suffix_dict[args.base_type])
    sdim.load_state_dict(
        torch.load(os.path.join(base_dir, save_name),
                   map_location=lambda storage, loc: storage))

    if args.sample_likelihood:
        sample_cases(sdim, args)
    else:
        pgd_attack(sdim, args)
Esempio n. 3
0
def run(args: DictConfig) -> None:
    cuda_available = torch.cuda.is_available()
    torch.manual_seed(args.seed)

    device = "cuda" if cuda_available and args.device == 'cuda' else "cpu"

    n_classes = args.get(args.dataset).n_classes
    rep_size = args.get(args.dataset).rep_size
    margin = args.get(args.dataset).margin

    classifier = load_pretrained_model(args)
    if args.dataset == 'tiny_imagenet':
        args.data_dir = 'tiny_imagenet'

    sdim = SDIM(disc_classifier=classifier,
                n_classes=n_classes,
                rep_size=rep_size,
                mi_units=args.mi_units,
                margin=margin,
                alpha=args.alpha,
                beta=args.beta,
                gamma=args.gamma).to(args.device)

    optimizer = Adam(sdim.parameters(), lr=args.learning_rate)

    if args.inference:
        save_name = 'SDIM_{}.pth'.format(args.classifier_name)
        sdim.load_state_dict(
            torch.load(save_name, map_location=lambda storage, loc: storage))

        thresholds1, thresholds2 = extract_thresholds(sdim, args)
        clean_eval(sdim, args, thresholds1, thresholds2)
    else:
        train(sdim, optimizer, args)
Esempio n. 4
0
def run(args: DictConfig) -> None:
    assert torch.cuda.is_available()
    torch.manual_seed(args.seed)

    n_classes = args.get(args.dataset).n_classes
    rep_size = args.get(args.dataset).rep_size
    margin = args.get(args.dataset).margin

    classifier = load_pretrained_model(args)

    sdim = SDIM(disc_classifier=classifier,
                n_classes=n_classes,
                rep_size=rep_size,
                mi_units=args.mi_units,
                margin=margin,
                alpha=args.alpha,
                beta=args.beta,
                gamma=args.gamma).to(args.device)

    optimizer = Adam(sdim.parameters(), lr=args.learning_rate)

    if args.inference:
        save_name = 'SDIM_resnet18{}.pth'.format(suffix_dict[args.base_type])
        sdim.load_state_dict(
            torch.load(save_name, map_location=lambda storage, loc: storage))

        thresholds1, thresholds2 = extract_thresholds(sdim, args)
        clean_eval(sdim, args, thresholds1, thresholds2)
    else:
        train(sdim, optimizer, args)
def run(args: DictConfig) -> None:
    cuda_available = torch.cuda.is_available()
    torch.manual_seed(args.seed)

    device = "cuda" if cuda_available and args.device == 'cuda' else "cpu"

    n_classes = args.get(args.dataset).n_classes
    rep_size = args.get(args.dataset).rep_size
    margin = args.get(args.dataset).margin

    if args.dataset == 'tiny_imagenet':
        classifier = get_model_for_tiny_imagenet(name=args.classifier_name,
                                                 n_classes=n_classes).to(
                                                     args.device)
    else:
        classifier = get_model(name=args.classifier_name,
                               n_classes=n_classes).to(args.device)

    sdim = SDIM(disc_classifier=classifier,
                n_classes=n_classes,
                rep_size=rep_size,
                mi_units=args.mi_units,
                margin=margin).to(args.device)

    base_dir = hydra.utils.to_absolute_path('logs/sdim/{}'.format(
        args.dataset))
    save_name = 'SDIM_{}.pth'.format(args.classifier_name)
    sdim.load_state_dict(
        torch.load(os.path.join(base_dir, save_name),
                   map_location=lambda storage, loc: storage))

    if args.sample_likelihood:
        sample_cases(sdim, args)
    else:
        thresholds1, thresholds2 = extract_thresholds(sdim, args)
        corruption_eval(sdim, args, thresholds1, thresholds2)
    prefix = ''
    if hps.encoder_name.startswith('sdim_'):
        prefix = 'sdim_'
        hps.encoder_name = hps.encoder_name.strip('sdim_')
        model = SDIM(rep_size=hps.rep_size,
                     mi_units=hps.mi_units,
                     encoder_name=hps.encoder_name,
                     image_channel=hps.image_channel).to(hps.device)

        checkpoint_path = os.path.join(
            hps.log_dir,
            'sdim_{}_{}_d{}.pth'.format(hps.encoder_name, hps.problem,
                                        hps.rep_size))
        model.load_state_dict(
            torch.load(checkpoint_path,
                       map_location=lambda storage, loc: storage))
    else:
        n_encoder_layers = int(hps.encoder_name.strip('resnet'))
        model = build_resnet_32x32(n=n_encoder_layers,
                                   fc_size=hps.n_classes,
                                   image_channel=hps.image_channel).to(
                                       hps.device)

        checkpoint_path = os.path.join(
            hps.log_dir, '{}_{}.pth'.format(hps.encoder_name, hps.problem))
        model.load_state_dict(
            torch.load(checkpoint_path,
                       map_location=lambda storage, loc: storage))

    print('Model name: {}'.format(hps.encoder_name))
Esempio n. 7
0
def inference(hps: DictConfig) -> None:
    # This enables a ctr-C without triggering errors
    import signal

    signal.signal(signal.SIGINT, lambda x, y: sys.exit(0))

    logger = logging.getLogger(__name__)

    cuda_available = torch.cuda.is_available()

    torch.manual_seed(hps.seed)

    device = "cuda" if cuda_available and hps.device == 'cuda' else "cpu"

    # Models
    local_channel = hps.get(hps.base_classifier).last_conv_channel
    classifier = get_model(model_name=hps.base_classifier,
                           in_size=local_channel,
                           out_size=hps.rep_size).to(hps.device)
    logger.info('Base classifier name: {}, # parameters: {}'.format(
        hps.base_classifier, cal_parameters(classifier)))

    sdim = SDIM(disc_classifier=classifier,
                mi_units=hps.mi_units,
                n_classes=hps.n_classes,
                margin=hps.margin,
                rep_size=hps.rep_size,
                local_channel=local_channel).to(hps.device)

    model_path = 'SDIM_{}.pth'.format(hps.base_classifier)
    base_dir = '/userhome/cs/u3003679/generative-classification-with-rejection'
    path = os.path.join(base_dir, model_path)
    sdim.load_state_dict(torch.load(path)['model_state'])

    # logging the SDIM desc.
    for desc in sdim.desc():
        logger.info(desc)

    eval_loader = Loader('eval', batch_size=hps.n_batch_test, device=device)

    if cuda_available and hps.n_gpu > 1:
        sdim = torch.nn.DataParallel(sdim, device_ids=list(range(hps.n_gpu)))

    torch.manual_seed(hps.seed)
    np.random.seed(hps.seed)

    n_iters = 0

    top1 = AverageMeter('Acc@1')
    top5 = AverageMeter('Acc@5')

    sdim.eval()
    for x, y in eval_loader:
        n_iters += 1
        if n_iters == len(eval_loader):
            break

        with torch.no_grad():
            log_lik = sdim.infer(x)

        acc1, acc5 = accuracy(log_lik, y, topk=(1, 5))

        top1.update(acc1, x.size(0))
        top5.update(acc5, x.size(0))

    logger.info('Test Acc@1: {:.3f}, Acc@5: {:.3f}'.format(top1.avg, top5.avg))