Ejemplo n.º 1
0
def main():
    parser = get_parser_ens()
    args = parser.parse_args()
    args.method = os.path.basename(__file__).split('-')[1][:-3]
    if args.aug_test:
        args.method = args.method + '_augment'
    torch.backends.cudnn.benchmark = True

    compute = {
        'CIFAR10':
        ['VGG16BN', 'PreResNet110', 'PreResNet164', 'WideResNet28x10'],
        'CIFAR100':
        ['VGG16BN', 'PreResNet110', 'PreResNet164', 'WideResNet28x10'],
        'ImageNet': ['ResNet50']
    }

    for model in compute[args.dataset]:
        args.model = model
        logger = Logger(base='./logs/')
        print('-' * 5, 'Computing results of', model, 'on', args.dataset + '.',
              '-' * 5)

        loaders, num_classes = get_data(args)
        targets = get_targets(loaders['test'], args)
        args.num_classes = num_classes
        model = get_model(args)

        for run in range(1, 6):
            log_probs = []

            fnames = read_models(args,
                                 base=os.path.expanduser(args.models_dir),
                                 run=run if args.dataset != 'ImageNet' else -1)
            fnames = sorted(fnames,
                            key=lambda a: int(a.split('-')[-1].split('.')[0]))

            for ns in range(100)[:min(
                    len(fnames), 100 if args.dataset != 'ImageNet' else 50)]:
                start = time.time()
                model.load_state_dict(get_sd(fnames[ns], args))
                ones_log_prob = one_sample_pred(loaders['test'], model)
                log_probs.append(ones_log_prob)
                logger.add_metrics_ts(ns,
                                      log_probs,
                                      targets,
                                      args,
                                      time_=start)
                logger.save(args)

            os.makedirs('.megacache', exist_ok=True)
            logits_pth = '.megacache/logits_%s-%s-%s-%s-%s'
            logits_pth = logits_pth % (args.dataset, args.model, args.method,
                                       ns + 1, run)
            log_prob = logsumexp(np.dstack(log_probs), axis=2) - np.log(ns + 1)
            print('Save final logprobs to %s' % logits_pth, end='\n\n')
            np.save(logits_pth, log_prob)
Ejemplo n.º 2
0
def main():
    parser = get_parser_ens()
    args = parser.parse_args()
    args.method = os.path.basename(__file__).split('-')[1][:-3]
    torch.backends.cudnn.benchmark = True

    if args.aug_test:
        args.method = args.method + '_augment'

    print('Computing for all datasets!')

    compute = {
        'CIFAR10': ['VGG16BN', 'WideResNet28x10do'],
        'CIFAR100': ['VGG16BN', 'WideResNet28x10do']
    }

    for model in compute[args.dataset]:
        args.model = model
        logger = Logger()
        print('-' * 5, 'Computing results of', model, 'on', args.dataset + '.',
              '-' * 5)

        loaders, num_classes = get_data(args)
        targets = get_targets(loaders['test'], args)

        fnames = read_models(args, base=os.path.expanduser(args.models_dir))
        args.num_classes = num_classes
        model = get_model(args)

        for try_ in range(1, 6):
            fnames = np.random.permutation(fnames)
            model.load_state_dict(get_sd(fnames[0], args))

            log_probs = []
            for ns in range(100):
                start = time.time()
                ones_log_prob = one_sample_pred(loaders['test'], model)
                log_probs.append(ones_log_prob)
                logger.add_metrics_ts(ns,
                                      log_probs,
                                      targets,
                                      args,
                                      time_=start)
                logger.save(args)

            os.makedirs('./.megacache', exist_ok=True)
            logits_pth = '.megacache/logits_%s-%s-%s-%s-%s'
            logits_pth = logits_pth % (args.dataset, args.model, args.method,
                                       ns + 1, try_)
            log_prob = logsumexp(np.dstack(log_probs), axis=2) - np.log(ns + 1)
            print('Save final logprobs to %s' % logits_pth)
            np.save(logits_pth, log_prob)
            print('Used weights from %s' % fnames[0], end='\n\n')
Ejemplo n.º 3
0
def main():
    parser = get_parser_ens()
    args = parser.parse_args()
    args.method = os.path.basename(__file__).split('-')[1][:-3]
    torch.backends.cudnn.benchmark = True

    if args.aug_test:
        args.method = args.method + '_augment'

    os.makedirs('./logs', exist_ok=True)

    compute = {
        'CIFAR10':  ['BayesVGG16BN', 'BayesPreResNet110', 'BayesPreResNet164', 'BayesWideResNet28x10'],
        'CIFAR100': ['BayesVGG16BN', 'BayesPreResNet110', 'BayesPreResNet164', 'BayesWideResNet28x10'],
        'ImageNet': ['BayesResNet50']
    }

    for model in compute[args.dataset]:
        args.model = model
        logger = Logger()
        print('-'*5, 'Computing results of', model, 'on', args.dataset + '.', '-'*5)

        loaders, num_classes = get_data(args)
        targets = get_targets(loaders['test'], args)

        fnames = read_models(args, base=os.path.expanduser(args.models_dir))
        args.num_classes = num_classes
        model = get_model(args)

        for run in range(1, 6):
            print('Repeat num. %s' % run)
            log_probs = []

            checkpoint = get_sd(fnames[0], args)
            model.load_state_dict(checkpoint)

            for ns in range(100 if args.dataset != 'ImageNet' else 50):
                start = time.time()
                ones_log_prob = one_sample_pred(loaders['test'], model)
                log_probs.append(ones_log_prob)
                logger.add_metrics_ts(ns, log_probs, targets, args, time_=start)
                logger.save(args)

            os.makedirs('.megacache', exist_ok=True)
            logits_pth = '.megacache/logits_%s-%s-%s-%s-%s'
            logits_pth = logits_pth % (args.dataset, args.model, args.method, ns+1, run)
            log_prob = logsumexp(np.dstack(log_probs), axis=2) - np.log(ns+1)
            np.save(logits_pth, log_prob)
Ejemplo n.º 4
0
def main():
    torch.backends.cudnn.benchmark = True
    args = get_parser_ens().parse_args()
    args.method = 'randaugment'

    print(args.models)
    print('Using the following snapshots:')
    print('\n'.join(args.models))

    args.dataset = args.models[0].split('/')[-1].split('-')[0]
    args.model = args.models[0].split('/')[-1].split('-')[1]
    print(args.model, args.dataset)

    num_tta = args.num_tta
    samples_per_policy = 1
    if args.policy is not None:
        policy = np.load(args.policy, allow_pickle=True)['arr_0']
        if args.num_tta > len(policy):
            num_tta = len(policy)
            samples_per_policy = args.num_tta // num_tta

    path = os.path.join(args.data_path, args.dataset.lower())

    ds = getattr(torchvision.datasets, args.dataset)

    if args.dataset == 'CIFAR10':
        args.num_classes = 10
    elif args.dataset == 'CIFAR100':
        args.num_classes = 100
    else:
        raise NotImplementedError

    model_cfg = getattr(models, args.model)
    print('WARNING: using random M')

    if args.no_tta:
        print('\033[93m' + 'TTA IS DISABLED!' + '\033[0m')

    logger = Logger(base=args.log_dir)
    model = get_model(args)

    if args.valid:
        train_set = ds(path,
                       train=True,
                       download=True,
                       transform=model_cfg.transform_train)
        sss = StratifiedShuffleSplit(n_splits=1,
                                     test_size=5000,
                                     random_state=0)
        train_idx = np.array(list(range(len(train_set.data))))
        sss = sss.split(train_idx, train_set.targets)
        train_idx, valid_idx = next(sss)
    full_ens_preds = []
    for try_ in range(num_tta):
        start = time.time()

        current_policy = None
        if args.policy is not None:
            current_policy = policy[try_]
            if current_policy is None:
                current_policy = []

        if args.no_tta:
            transform_train = model_cfg.transform_test
            current_transform = 'None'
            print('\033[93m' + 'Using the following transform:' + '\033[0m')
            print('\033[93m' + current_transform + '\033[0m')
        else:
            transform_train = transforms.Compose([
                BetterRandAugment(args.N,
                                  args.M,
                                  True,
                                  False,
                                  transform=current_policy,
                                  verbose=args.verbose,
                                  true_m0=args.true_m0,
                                  randomize_sign=not args.fix_sign,
                                  used_transforms=args.transforms),
                model_cfg.transform_train
            ])
            current_transform = transform_train.transforms[
                0].get_transform_str()
            print('\033[93m' + 'Using the following transform:' + '\033[0m')
            print('\033[93m' + current_transform + '\033[0m')

        if args.valid:
            print('\033[93m' + 'Using the following objects for validation:' +
                  '\033[0m')
            print(train_idx, valid_idx)

            test_set = ds(path,
                          train=True,
                          download=True,
                          transform=transform_train)
            test_set.data = test_set.data[valid_idx]
            test_set.targets = list(np.array(test_set.targets)[valid_idx])
            test_set.train = False
        else:
            test_set = ds(path,
                          train=False,
                          download=True,
                          transform=transform_train)

        loaders = {
            'test':
            torch.utils.data.DataLoader(test_set,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=args.num_workers,
                                        pin_memory=True)
        }

        # Load the model and update BN statistics (if run with a single model)
        if len(args.models) == 1:
            try:
                model.load_state_dict(get_sd(args.models[0], args))
            except RuntimeError:
                model = torch.nn.DataParallel(model).cuda()
                model.load_state_dict(get_sd(args.models[0], args))
            if args.bn_update:
                bn_update(loaders['train'], model)
                print('BatchNorm statistics updated!')

        log_probs = []
        ns = 0
        for fname in args.models:
            # Load the model and update BN if several models are supplied
            if len(args.models) > 1:
                try:
                    model.load_state_dict(get_sd(fname, args))
                except RuntimeError:
                    if hasattr(model, 'module'):
                        model.module.load_state_dict(get_sd(fname, args))
                    else:
                        model = torch.nn.DataParallel(model).cuda()
                        model.load_state_dict(get_sd(fname, args))
                if args.bn_update:
                    bn_update(loaders['train'], model)
                    print('BatchNorm statistics updated!')
            for _ in range(samples_per_policy):
                ones_log_prob = one_sample_pred(loaders['test'], model)
                log_probs.append(ones_log_prob)
                ns += 1

        log_prob = logsumexp(np.dstack(log_probs), axis=2) - np.log(ns)
        full_ens_preds.append(log_prob)

        fname = '%s-%s-%s-%s.npz' % (
            args.dataset, args.model, args.method,
            '-'.join([os.path.basename(f) for f in args.models]) + args.fname +
            ('' if args.transforms is None else ''.join(
                str(args.transforms).split())) + '#' + current_transform +
            '#' + 'N%d-M%d' % (args.N, args.M))
        if len(fname) > 255:
            fname = '%s-%s-%s-%s.npz' % (
                args.dataset, args.model, args.method,
                os.path.basename(args.models[0]) + '-' +
                '-'.join([os.path.basename(f)[-5:]
                          for f in args.models[1:]]) + args.fname +
                ('' if args.transforms is None else ''.join(
                    str(args.transforms).split())) + '#' + current_transform +
                '#' + 'N%d-M%d' % (args.N, args.M))
        fname = os.path.join(args.log_dir, fname)
        if not args.silent:
            np.savez(fname, log_prob)
            print('\033[93m' + 'Saved to ' + fname + '\033[0m')
        print('Full ens metrics: ', end='')
        logger.add_metrics_ts(try_,
                              full_ens_preds,
                              np.array(test_set.targets),
                              args,
                              time_=start)

        print('---%s--- ends' % try_, flush=True)

    logger.save(args)