Ejemplo n.º 1
0
def main():
    parser = get_parser_ens()
    args = parser.parse_args()
    args.method = os.path.basename(__file__).split('-')[1][:-3]
    if args.aug_test:
        args.method = args.method + '_augment'
    torch.backends.cudnn.benchmark = True

    compute = {
        'CIFAR10':
        ['VGG16BN', 'PreResNet110', 'PreResNet164', 'WideResNet28x10'],
        'CIFAR100':
        ['VGG16BN', 'PreResNet110', 'PreResNet164', 'WideResNet28x10'],
        'ImageNet': ['ResNet50']
    }

    for model in compute[args.dataset]:
        args.model = model
        logger = Logger(base='./logs/')
        print('-' * 5, 'Computing results of', model, 'on', args.dataset + '.',
              '-' * 5)

        loaders, num_classes = get_data(args)
        targets = get_targets(loaders['test'], args)
        args.num_classes = num_classes
        model = get_model(args)

        for run in range(1, 6):
            log_probs = []

            fnames = read_models(args,
                                 base=os.path.expanduser(args.models_dir),
                                 run=run if args.dataset != 'ImageNet' else -1)
            fnames = sorted(fnames,
                            key=lambda a: int(a.split('-')[-1].split('.')[0]))

            for ns in range(100)[:min(
                    len(fnames), 100 if args.dataset != 'ImageNet' else 50)]:
                start = time.time()
                model.load_state_dict(get_sd(fnames[ns], args))
                ones_log_prob = one_sample_pred(loaders['test'], model)
                log_probs.append(ones_log_prob)
                logger.add_metrics_ts(ns,
                                      log_probs,
                                      targets,
                                      args,
                                      time_=start)
                logger.save(args)

            os.makedirs('.megacache', exist_ok=True)
            logits_pth = '.megacache/logits_%s-%s-%s-%s-%s'
            logits_pth = logits_pth % (args.dataset, args.model, args.method,
                                       ns + 1, run)
            log_prob = logsumexp(np.dstack(log_probs), axis=2) - np.log(ns + 1)
            print('Save final logprobs to %s' % logits_pth, end='\n\n')
            np.save(logits_pth, log_prob)
Ejemplo n.º 2
0
def main():
    parser = get_parser_ens()
    args = parser.parse_args()
    args.method = os.path.basename(__file__).split('-')[1][:-3]
    torch.backends.cudnn.benchmark = True

    if args.aug_test:
        args.method = args.method + '_augment'

    print('Computing for all datasets!')

    compute = {
        'CIFAR10': ['VGG16BN', 'WideResNet28x10do'],
        'CIFAR100': ['VGG16BN', 'WideResNet28x10do']
    }

    for model in compute[args.dataset]:
        args.model = model
        logger = Logger()
        print('-' * 5, 'Computing results of', model, 'on', args.dataset + '.',
              '-' * 5)

        loaders, num_classes = get_data(args)
        targets = get_targets(loaders['test'], args)

        fnames = read_models(args, base=os.path.expanduser(args.models_dir))
        args.num_classes = num_classes
        model = get_model(args)

        for try_ in range(1, 6):
            fnames = np.random.permutation(fnames)
            model.load_state_dict(get_sd(fnames[0], args))

            log_probs = []
            for ns in range(100):
                start = time.time()
                ones_log_prob = one_sample_pred(loaders['test'], model)
                log_probs.append(ones_log_prob)
                logger.add_metrics_ts(ns,
                                      log_probs,
                                      targets,
                                      args,
                                      time_=start)
                logger.save(args)

            os.makedirs('./.megacache', exist_ok=True)
            logits_pth = '.megacache/logits_%s-%s-%s-%s-%s'
            logits_pth = logits_pth % (args.dataset, args.model, args.method,
                                       ns + 1, try_)
            log_prob = logsumexp(np.dstack(log_probs), axis=2) - np.log(ns + 1)
            print('Save final logprobs to %s' % logits_pth)
            np.save(logits_pth, log_prob)
            print('Used weights from %s' % fnames[0], end='\n\n')
Ejemplo n.º 3
0
def main():
    parser = get_parser_ens()
    args = parser.parse_args()
    args.method = os.path.basename(__file__).split('-')[1][:-3]
    torch.backends.cudnn.benchmark = True

    if args.aug_test:
        args.method = args.method + '_augment'

    os.makedirs('./logs', exist_ok=True)

    compute = {
        'CIFAR10':  ['BayesVGG16BN', 'BayesPreResNet110', 'BayesPreResNet164', 'BayesWideResNet28x10'],
        'CIFAR100': ['BayesVGG16BN', 'BayesPreResNet110', 'BayesPreResNet164', 'BayesWideResNet28x10'],
        'ImageNet': ['BayesResNet50']
    }

    for model in compute[args.dataset]:
        args.model = model
        logger = Logger()
        print('-'*5, 'Computing results of', model, 'on', args.dataset + '.', '-'*5)

        loaders, num_classes = get_data(args)
        targets = get_targets(loaders['test'], args)

        fnames = read_models(args, base=os.path.expanduser(args.models_dir))
        args.num_classes = num_classes
        model = get_model(args)

        for run in range(1, 6):
            print('Repeat num. %s' % run)
            log_probs = []

            checkpoint = get_sd(fnames[0], args)
            model.load_state_dict(checkpoint)

            for ns in range(100 if args.dataset != 'ImageNet' else 50):
                start = time.time()
                ones_log_prob = one_sample_pred(loaders['test'], model)
                log_probs.append(ones_log_prob)
                logger.add_metrics_ts(ns, log_probs, targets, args, time_=start)
                logger.save(args)

            os.makedirs('.megacache', exist_ok=True)
            logits_pth = '.megacache/logits_%s-%s-%s-%s-%s'
            logits_pth = logits_pth % (args.dataset, args.model, args.method, ns+1, run)
            log_prob = logsumexp(np.dstack(log_probs), axis=2) - np.log(ns+1)
            np.save(logits_pth, log_prob)
Ejemplo n.º 4
0
                             lr=args.lr_init,
                             momentum=args.momentum,
                             weight_decay=args.wd)
print(optimizer1)

params = [p for n, p in model.named_parameters() if 'wlog_sigma' in n]
optimizer2 = torch.optim.Adam(
    params,
    lr=args.lr_vars,
    weight_decay=0,
)
print(optimizer2)

start_epoch = 0
args.model = args.model.replace('Bayes', '')
fnames = utils.read_models(args,
                           base=os.path.abspath('/home/aashukha/megares'))
args.model = 'Bayes' + args.model

print('Resume training from %s' % fnames[0])
checkpoint = torch.load(fnames[0])
model.load_state_dict(checkpoint, strict=False)

columns = [
    'ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time', 'nll', 'kl',
    'te_ac', 'te_nll'
]

for epoch in range(start_epoch, args.epochs):
    time_ep = time.time()

    lr = schedule(epoch)