Ejemplo n.º 1
0
def run(args: DictConfig) -> None:
    cuda_available = torch.cuda.is_available()
    torch.manual_seed(args.seed)

    device = "cuda" if cuda_available and args.device == 'cuda' else "cpu"

    n_classes = args.get(args.dataset).n_classes
    rep_size = args.get(args.dataset).rep_size
    margin = args.get(args.dataset).margin

    classifier = get_model(name=args.classifier_name,
                           n_classes=n_classes).to(args.device)

    sdim = SDIM(disc_classifier=classifier,
                n_classes=n_classes,
                rep_size=rep_size,
                mi_units=args.mi_units,
                margin=margin).to(args.device)

    base_dir = hydra.utils.to_absolute_path('logs/sdim/{}'.format(
        args.dataset))
    save_name = 'SDIM_{}.pth'.format(args.classifier_name)
    sdim.load_state_dict(
        torch.load(os.path.join(base_dir, save_name),
                   map_location=lambda storage, loc: storage))

    ood_detection(sdim, args)
Ejemplo n.º 2
0
def run(args: DictConfig) -> None:
    assert torch.cuda.is_available()
    torch.manual_seed(args.seed)

    n_classes = args.get(args.dataset).n_classes
    rep_size = args.get(args.dataset).rep_size
    margin = args.get(args.dataset).margin

    classifier = resnet18(n_classes=n_classes).to(args.device)

    sdim = SDIM(disc_classifier=classifier,
                n_classes=n_classes,
                rep_size=rep_size,
                mi_units=args.mi_units,
                margin=margin).to(args.device)

    base_dir = hydra.utils.to_absolute_path('logs/sdim/{}'.format(
        args.dataset))
    save_name = 'SDIM_resnet18{}.pth'.format(suffix_dict[args.base_type])
    sdim.load_state_dict(
        torch.load(os.path.join(base_dir, save_name),
                   map_location=lambda storage, loc: storage))

    if args.sample_likelihood:
        sample_cases(sdim, args)
    else:
        pgd_attack(sdim, args)
Ejemplo n.º 3
0
def run(args: DictConfig) -> None:
    cuda_available = torch.cuda.is_available()
    torch.manual_seed(args.seed)

    device = "cuda" if cuda_available and args.device == 'cuda' else "cpu"

    n_classes = args.get(args.dataset).n_classes
    rep_size = args.get(args.dataset).rep_size
    margin = args.get(args.dataset).margin

    classifier = load_pretrained_model(args)
    if args.dataset == 'tiny_imagenet':
        args.data_dir = 'tiny_imagenet'

    sdim = SDIM(disc_classifier=classifier,
                n_classes=n_classes,
                rep_size=rep_size,
                mi_units=args.mi_units,
                margin=margin,
                alpha=args.alpha,
                beta=args.beta,
                gamma=args.gamma).to(args.device)

    optimizer = Adam(sdim.parameters(), lr=args.learning_rate)

    if args.inference:
        save_name = 'SDIM_{}.pth'.format(args.classifier_name)
        sdim.load_state_dict(
            torch.load(save_name, map_location=lambda storage, loc: storage))

        thresholds1, thresholds2 = extract_thresholds(sdim, args)
        clean_eval(sdim, args, thresholds1, thresholds2)
    else:
        train(sdim, optimizer, args)
Ejemplo n.º 4
0
def run(args: DictConfig) -> None:
    assert torch.cuda.is_available()
    torch.manual_seed(args.seed)

    n_classes = args.get(args.dataset).n_classes
    rep_size = args.get(args.dataset).rep_size
    margin = args.get(args.dataset).margin

    classifier = load_pretrained_model(args)

    sdim = SDIM(disc_classifier=classifier,
                n_classes=n_classes,
                rep_size=rep_size,
                mi_units=args.mi_units,
                margin=margin,
                alpha=args.alpha,
                beta=args.beta,
                gamma=args.gamma).to(args.device)

    optimizer = Adam(sdim.parameters(), lr=args.learning_rate)

    if args.inference:
        save_name = 'SDIM_resnet18{}.pth'.format(suffix_dict[args.base_type])
        sdim.load_state_dict(
            torch.load(save_name, map_location=lambda storage, loc: storage))

        thresholds1, thresholds2 = extract_thresholds(sdim, args)
        clean_eval(sdim, args, thresholds1, thresholds2)
    else:
        train(sdim, optimizer, args)
def run(args: DictConfig) -> None:
    cuda_available = torch.cuda.is_available()
    torch.manual_seed(args.seed)

    device = "cuda" if cuda_available and args.device == 'cuda' else "cpu"

    n_classes = args.get(args.dataset).n_classes
    rep_size = args.get(args.dataset).rep_size
    margin = args.get(args.dataset).margin

    if args.dataset == 'tiny_imagenet':
        classifier = get_model_for_tiny_imagenet(name=args.classifier_name,
                                                 n_classes=n_classes).to(
                                                     args.device)
    else:
        classifier = get_model(name=args.classifier_name,
                               n_classes=n_classes).to(args.device)

    sdim = SDIM(disc_classifier=classifier,
                n_classes=n_classes,
                rep_size=rep_size,
                mi_units=args.mi_units,
                margin=margin).to(args.device)

    base_dir = hydra.utils.to_absolute_path('logs/sdim/{}'.format(
        args.dataset))
    save_name = 'SDIM_{}.pth'.format(args.classifier_name)
    sdim.load_state_dict(
        torch.load(os.path.join(base_dir, save_name),
                   map_location=lambda storage, loc: storage))

    if args.sample_likelihood:
        sample_cases(sdim, args)
    else:
        thresholds1, thresholds2 = extract_thresholds(sdim, args)
        corruption_eval(sdim, args, thresholds1, thresholds2)
    hps.device = torch.device("cuda" if use_cuda else "cpu")

    if hps.problem == 'cifar10':
        hps.image_channel = 3
    elif hps.problem == 'svhn':
        hps.image_channel = 3
    elif hps.problem == 'mnist':
        hps.image_channel = 1

    prefix = ''
    if hps.encoder_name.startswith('sdim_'):
        prefix = 'sdim_'
        hps.encoder_name = hps.encoder_name.strip('sdim_')
        model = SDIM(rep_size=hps.rep_size,
                     mi_units=hps.mi_units,
                     encoder_name=hps.encoder_name,
                     image_channel=hps.image_channel).to(hps.device)

        checkpoint_path = os.path.join(
            hps.log_dir,
            'sdim_{}_{}_d{}.pth'.format(hps.encoder_name, hps.problem,
                                        hps.rep_size))
        model.load_state_dict(
            torch.load(checkpoint_path,
                       map_location=lambda storage, loc: storage))
    else:
        n_encoder_layers = int(hps.encoder_name.strip('resnet'))
        model = build_resnet_32x32(n=n_encoder_layers,
                                   fc_size=hps.n_classes,
                                   image_channel=hps.image_channel).to(
                                       hps.device)
Ejemplo n.º 7
0
    hps.device = torch.device("cuda" if use_cuda else "cpu")

    if hps.problem == 'cifar10':
        hps.image_channel = 3
    elif hps.problem == 'svhn':
        hps.image_channel = 3
    elif hps.problem == 'mnist':
        hps.image_channel = 1
    elif hps.problem == 'fashion':
        hps.image_channel = 1

    model = SDIM(rep_size=hps.rep_size,
                 mi_units=hps.mi_units,
                 encoder_name=hps.encoder_name,
                 image_channel=hps.image_channel,
                 margin=hps.margin,
                 alpha=hps.alpha,
                 beta=hps.beta,
                 gamma=hps.gamma
                 ).to(hps.device)
    optimizer = Adam(model.parameters(), lr=hps.lr)

    print('==>  # Model parameters: {}.'.format(cal_parameters(model)))

    if hps.noise_attack:
        noise_attack(model, hps)
    elif hps.inference:
        inference(model, hps)
    elif hps.ood_inference:
        ood_inference(model, hps)
    elif hps.rejection_inference:
Ejemplo n.º 8
0
    train_set = datasets.ImageNet(root='~/data', split='train', download=True, transform=train_transform)
    train_loader = DataLoader(dataset=train_set, batch_size=hps.n_batch_train, shuffle=True,
                              pin_memory=True, num_workers=16)

    test_set = datasets.ImageNet(root='~/data', split='val', download=True, transform=test_transform)
    test_loader = DataLoader(dataset=test_set, batch_size=hps.n_batch_test, shuffle=False,
                             pin_memory=True, num_workers=16)

    # Models
    print('Classifier name: {}'.format(hps.classifier_name))
    classifier = get_model(model_name=hps.classifier_name).to(hps.device)

    sdim = SDIM(disc_classifier=classifier,
                rep_size=hps.rep_size,
                mi_units=hps.mi_units,
                n_classes=hps.n_classes,
                margin=3).to(hps.device)

    if use_cuda and hps.n_gpu > 1:
        sdim = torch.nn.DataParallel(sdim, device_ids=list(range(hps.n_gpu)))

    optimizer = Adam(filter(lambda param: param.requires_grad is True, sdim.parameters()), lr=hps.lr)

    torch.manual_seed(hps.seed)
    np.random.seed(hps.seed)

    # Create log dir
    logdir = os.path.abspath(hps.log_dir) + "/"
    if not os.path.exists(logdir):
        os.mkdir(logdir)
    if not os.path.isdir(args.save):
        os.makedirs(args.save)

    # Init model, criterion, and optimizer
    classifier = CifarResNeXt(args.cardinality, args.depth, args.n_classes, args.base_width, args.widen_factor).to(args.device)
    print('# Classifier parameters: ', cal_parameters(classifier))

    save_name = 'ResNeXt{}_{}x{}d.pth'.format(args.depth, args.cardinality, args.base_width)
    check_point = torch.load(os.path.join(args.save, save_name))
    classifier.load_state_dict(check_point['model_state'])

    train_acc = check_point['train_acc']
    test_acc = check_point['test_acc']
    print('Original Discriminative Classifier, train acc: {:.4f}, test acc: {:.4f}'.format(train_acc, test_acc))

    sdim = SDIM(disc_classifier=classifier, rep_size=args.rep_size, mi_units=args.mi_units, n_classes=args.n_classes).to(args.device)

    optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad is True, sdim.parameters()),
                                 lr=args.learning_rate)

    if use_cuda and args.n_gpu > 1:
        sdim = torch.nn.DataParallel(sdim, device_ids=list(range(args.n_gpu)))

    best_train_loss = np.inf

    # train function (forward, backward, update)
    def train():
        sdim.train()
        loss_list = []
        mi_list = []
        nll_list = []
Ejemplo n.º 10
0
def train(hps: DictConfig) -> None:
    # This enables a ctr-C without triggering errors
    import signal

    signal.signal(signal.SIGINT, lambda x, y: sys.exit(0))

    logger = logging.getLogger(__name__)

    cuda_available = torch.cuda.is_available()

    torch.manual_seed(hps.seed)

    device = "cuda" if cuda_available and hps.device == 'cuda' else "cpu"

    # Models
    local_channel = hps.get(hps.base_classifier).last_conv_channel
    classifier = get_model(model_name=hps.base_classifier,
                           in_size=local_channel,
                           out_size=hps.rep_size).to(hps.device)
    logger.info('Base classifier name: {}, # parameters: {}'.format(
        hps.base_classifier, cal_parameters(classifier)))

    sdim = SDIM(disc_classifier=classifier,
                mi_units=hps.mi_units,
                n_classes=hps.n_classes,
                margin=hps.margin,
                rep_size=hps.rep_size,
                local_channel=local_channel).to(hps.device)

    # logging the SDIM desc.
    for desc in sdim.desc():
        logger.info(desc)

    train_loader = Loader('train', batch_size=hps.n_batch_train, device=device)

    if cuda_available and hps.n_gpu > 1:
        sdim = torch.nn.DataParallel(sdim, device_ids=list(range(hps.n_gpu)))

    optimizer = Adam(filter(lambda param: param.requires_grad is True,
                            sdim.parameters()),
                     lr=hps.lr)

    torch.manual_seed(hps.seed)
    np.random.seed(hps.seed)

    # Create log dir
    logdir = os.path.abspath(hps.log_dir) + "/"
    if not os.path.exists(logdir):
        os.mkdir(logdir)

    loss_optimal = 1e5
    n_iters = 0

    losses = AverageMeter('Loss')
    MIs = AverageMeter('MI')
    nlls = AverageMeter('NLL')
    margins = AverageMeter('Margin')
    top1 = AverageMeter('Acc@1')
    top5 = AverageMeter('Acc@5')

    for x, y in train_loader:
        n_iters += 1
        if n_iters == hps.training_iters:
            break

        # backward
        optimizer.zero_grad()
        loss, mi_loss, nll_loss, ll_margin, log_lik = sdim(x, y)
        loss.mean().backward()
        optimizer.step()

        acc1, acc5 = accuracy(log_lik, y, topk=(1, 5))
        losses.update(loss.item(), x.size(0))
        top1.update(acc1, x.size(0))
        top5.update(acc5, x.size(0))

        MIs.update(mi_loss.item(), x.size(0))
        nlls.update(nll_loss.item(), x.size(0))
        margins.update(ll_margin.item(), x.size(0))

        if n_iters % hps.log_interval == hps.log_interval - 1:
            logger.info(
                'Train loss: {:.4f}, mi: {:.4f}, nll: {:.4f}, ll_margin: {:.4f}'
                .format(losses.avg, MIs.avg, nlls.avg, margins.avg))
            logger.info('Train Acc@1: {:.3f}, Acc@5: {:.3f}'.format(
                top1.avg, top5.avg))

            if losses.avg < loss_optimal:
                loss_optimal = losses.avg
                model_path = 'SDIM_{}.pth'.format(hps.base_classifier)

                if cuda_available and hps.n_gpu > 1:
                    state = sdim.module.state_dict()
                else:
                    state = sdim.state_dict()

                check_point = {
                    'model_state': state,
                    'train_acc_top1': top1.avg,
                    'train_acc_top5': top5.avg
                }

                torch.save(check_point, os.path.join(hps.log_dir, model_path))

            losses.reset()
            MIs.reset()
            nlls.reset()
            margins.reset()
            top1.reset()
            top5.reset()
Ejemplo n.º 11
0
def inference(hps: DictConfig) -> None:
    # This enables a ctr-C without triggering errors
    import signal

    signal.signal(signal.SIGINT, lambda x, y: sys.exit(0))

    logger = logging.getLogger(__name__)

    cuda_available = torch.cuda.is_available()

    torch.manual_seed(hps.seed)

    device = "cuda" if cuda_available and hps.device == 'cuda' else "cpu"

    # Models
    local_channel = hps.get(hps.base_classifier).last_conv_channel
    classifier = get_model(model_name=hps.base_classifier,
                           in_size=local_channel,
                           out_size=hps.rep_size).to(hps.device)
    logger.info('Base classifier name: {}, # parameters: {}'.format(
        hps.base_classifier, cal_parameters(classifier)))

    sdim = SDIM(disc_classifier=classifier,
                mi_units=hps.mi_units,
                n_classes=hps.n_classes,
                margin=hps.margin,
                rep_size=hps.rep_size,
                local_channel=local_channel).to(hps.device)

    model_path = 'SDIM_{}.pth'.format(hps.base_classifier)
    base_dir = '/userhome/cs/u3003679/generative-classification-with-rejection'
    path = os.path.join(base_dir, model_path)
    sdim.load_state_dict(torch.load(path)['model_state'])

    # logging the SDIM desc.
    for desc in sdim.desc():
        logger.info(desc)

    eval_loader = Loader('eval', batch_size=hps.n_batch_test, device=device)

    if cuda_available and hps.n_gpu > 1:
        sdim = torch.nn.DataParallel(sdim, device_ids=list(range(hps.n_gpu)))

    torch.manual_seed(hps.seed)
    np.random.seed(hps.seed)

    n_iters = 0

    top1 = AverageMeter('Acc@1')
    top5 = AverageMeter('Acc@5')

    sdim.eval()
    for x, y in eval_loader:
        n_iters += 1
        if n_iters == len(eval_loader):
            break

        with torch.no_grad():
            log_lik = sdim.infer(x)

        acc1, acc5 = accuracy(log_lik, y, topk=(1, 5))

        top1.update(acc1, x.size(0))
        top5.update(acc5, x.size(0))

    logger.info('Test Acc@1: {:.3f}, Acc@5: {:.3f}'.format(top1.avg, top5.avg))