Exemple #1
0
def main():
    args = setup(seed=None)  # Disable seed to get random loss samples

    print("Preparing directories")
    filename = f"{args.prefix}{args.model}_{args.data}{args.suffix}"
    os.makedirs(os.path.join(args.root_dir, "directions"), exist_ok=True)
    directions_path = os.path.join(args.root_dir, "directions", filename)
    os.makedirs(os.path.join(args.results_dir, "loss1d" if args.loss1d else "loss2d"), exist_ok=True)
    results_path = os.path.join(args.results_dir, "loss1d" if args.loss1d else "loss2d", filename)

    print("Loading model")
    if args.model == 'lenet5':
        model = lenet5.lenet5(pretrained=args.data, device=args.device)
    elif args.model == 'resnet18' and args.data != 'imagenet':
        model = resnet.resnet18(pretrained=os.path.join(args.root_dir, 'weights', f"{args.model}_{args.data}.pt"),
                                num_classes=43 if args.data == 'gtsrb' else 10, device=args.device)
    else:
        model_class = getattr(torchvision.models, args.model)
        if args.model in ['googlenet', 'inception_v3']:
            model = model_class(pretrained=True, aux_logits=False)
        else:
            model = model_class(pretrained=True)
    model.to(args.device).eval()
    if args.parallel:
        model = torch.nn.parallel.DataParallel(model)

    print(f"Loading data")
    data_dir = os.path.join(args.torch_dir, "datasets")
    if args.data == 'cifar10':
        train_data, val_data = datasets.cifar10(data_dir, args.batch_size, args.workers, augment=False)
    elif args.data == 'mnist':
        train_data, val_data = datasets.mnist(data_dir, args.batch_size, args.workers, augment=False)
    elif args.data == 'gtsrb':
        data_dir = os.path.join(args.root_dir, "datasets", "gtsrb")
        train_data, val_data = datasets.gtsrb(data_dir, batch_size=args.batch_size, workers=args.workers)
    elif args.data == 'tiny':
        img_size = 64
        data_dir = os.path.join(args.root_dir, "datasets", "imagenet")
        train_data, val_data = datasets.imagenet(data_dir, img_size, args.batch_size, augment=False,
                                                 workers=args.workers, tiny=True)
    elif args.data == 'imagenet':
        img_size = 224
        if args.model in ['googlenet', 'inception_v3']:
            img_size = 299
        data_dir = os.path.join(args.root_dir, "datasets", "imagenet")
        train_data, val_data = imagenet(data_dir, img_size, args.batch_size, augment=False, shuffle=False)
    else:
        raise ValueError
    cudnn.benchmark = True

    if args.loss1d:
        loss1d(args, model, train_data, val_data, directions_path, results_path)
    elif args.loss2d:
        loss2d(args, model, train_data, directions_path, results_path)
    else:
        print(f"You need to specify either --loss1d or --loss2d.")
Exemple #2
0
def main():
    args = setup()

    print("Loading model")
    """
    model_class = getattr(torchvision.models, args.model)
    if args.model in ['googlenet', 'inception_v3']:
        model = model_class(pretrained=True, aux_logits=False)
    else:
        model = model_class(pretrained=True)
    model.fc = torch.nn.Linear(model.fc.in_features, 43)
    """
    model = resnet18(num_classes=43)
    model.to(args.device).train()
    if args.parallel:
        model = torch.nn.parallel.DataParallel(model)
    train_loader, val_loader = gtsrb(args.data_dir,
                                     batch_size=args.batch_size,
                                     workers=args.workers)

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.l2)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.l2)
    criterion = torch.nn.CrossEntropyLoss()

    train(model, train_loader, val_loader, optimizer, criterion, args.epochs,
          args.lr, args.device)

    path = os.path.join(args.root_dir, 'weights',
                        f"{args.model}_{args.data}.pth")
    torch.save(model.state_dict(), path)
Exemple #3
0
def summary(args):
    if args.model == 'lenet5':
        model = lenet5(pretrained=args.data, device=args.device)
    elif args.model == 'resnet18' and args.data != 'imagenet':
        model = resnet18(pretrained=os.path.join(
            args.root_dir, 'weights', f"{args.model}_{args.data}.pth"),
                         device=args.device)
    else:
        model_class = getattr(torchvision.models, args.model)
        if args.model in ['googlenet', 'inception_v3']:
            model = model_class(pretrained=True, aux_logits=False)
        else:
            model = model_class(pretrained=True)

    module_classes = list()
    for module in model.modules():
        module_class = module.__class__.__name__
        if module_class in ['Linear', 'Conv2d']:
            module_classes.append(module_class)

    diag_list = list()
    kfac_list = list()
    efb_list = list()
    inf_list = list()
    for est in ['diag', 'kfac', 'efb', 'inf']:
        if est == 'diag':
            factors_list = diag_list
        elif est == 'kfac':
            factors_list = kfac_list
        elif est == 'efb':
            factors_list = efb_list
        else:
            factors_list = inf_list

        factors_path = os.path.join(args.root_dir, "factors",
                                    f"{args.model}_{args.data}_{est}")
        if est in ["diag", "kfac"]:
            factors = torch.load(factors_path + '.pth', map_location='cpu')
        elif est == 'efb':
            kfac_factors = torch.load(
                factors_path.replace("efb", "kfac") + '.pth')
            lambdas = torch.load(factors_path + '.pth', map_location='cpu')

            factors = list()
            eigvecs = get_eigenvectors(kfac_factors)

            for eigvec, lambda_ in zip(eigvecs, lambdas):
                factors.append((eigvec[0], eigvec[1], lambda_))
        elif est == 'inf':
            try:
                factors = torch.load(f"{factors_path}{args.rank}.pth",
                                     map_location='cpu')
            except FileNotFoundError:
                factors = np.load(
                    factors_path + f"{args.rank}.npz",
                    allow_pickle=True)['sif_list']  # Todo: Remove

        numel_sum = 0
        for index, (cls, factor) in enumerate(zip(module_classes, factors)):
            numel = np.sum([f.numel() for f in factor]).astype(int)
            if est == 'diag':
                factors_list.append(
                    [f"{cls} {index}", numel, (numel * 32) / (8 * 1024**2)])
            else:
                factors_list.append([numel, (numel * 32) / (8 * 1024**2)])
            numel_sum += numel
        if est == 'diag':
            factors_list.append(
                ["Total", numel_sum, (numel_sum * 32) / (8 * 1024**2)])
        else:
            factors_list.append([numel_sum, (numel_sum * 32) / (8 * 1024**2)])

    factors_list = np.concatenate([diag_list, kfac_list, efb_list, inf_list],
                                  axis=1)

    print(
        tabulate.tabulate(factors_list,
                          headers=[
                              'Layer #', '#Parameters', 'Size (MB)',
                              '#Parameters', 'Size (MB)', '#Parameters',
                              'Size (MB)', '#Parameters', 'Size (MB)'
                          ],
                          floatfmt=".1f",
                          numalign='right',
                          tablefmt='latex'))
Exemple #4
0
def main():
    args = setup()

    print("Preparing directories")
    os.makedirs(os.path.join(args.root_dir, "factors"), exist_ok=True)
    filename = f"{args.prefix}{args.model}_{args.data}_{args.estimator}{args.suffix}"
    factors_path = os.path.join(args.root_dir, "factors", filename)

    print("Loading model")
    if args.model == 'lenet5':
        model = lenet5.lenet5(pretrained=args.data, device=args.device)
    elif args.model == 'resnet18' and args.data != 'imagenet':
        model = resnet.resnet18(pretrained=os.path.join(
            args.root_dir, 'weights', f"{args.model}_{args.data}.pth"),
                                num_classes=43 if args.data == 'gtsrb' else 10,
                                device=args.device)
    else:
        model_class = getattr(torchvision.models, args.model)
        if args.model in ['googlenet', 'inception_v3']:
            model = model_class(pretrained=True, aux_logits=False)
        else:
            model = model_class(pretrained=True)
    model.to(args.device).train()
    if args.parallel:
        model = torch.nn.parallel.DataParallel(model)

    if args.estimator != 'inf':
        print(f"Loading data")
        if args.data == 'cifar10':
            data = datasets.cifar10(args.torch_data,
                                    args.batch_size,
                                    args.workers,
                                    args.augment,
                                    splits='train')
        elif args.data == 'mnist':
            data = datasets.mnist(args.torch_data,
                                  args.batch_size,
                                  args.workers,
                                  args.augment,
                                  splits='train')
        elif args.data == 'gtsrb':
            data_dir = os.path.join(args.root_dir, "datasets", "gtsrb")
            data = datasets.gtsrb(data_dir,
                                  batch_size=args.batch_size,
                                  workers=args.workers,
                                  splits='train')
        elif args.data == 'tiny':
            img_size = 64
            data_dir = os.path.join(args.root_dir, "datasets", "imagenet")
            data = datasets.imagenet(data_dir,
                                     img_size,
                                     args.batch_size,
                                     splits='train',
                                     tiny=True)
        elif args.data == 'imagenet':
            img_size = 224
            data_dir = os.path.join(args.root_dir, "datasets", "imagenet")
            if args.model in ['googlenet', 'inception_v3']:
                img_size = 299
            data = datasets.imagenet(data_dir,
                                     img_size,
                                     args.batch_size,
                                     workers=args.workers,
                                     splits='train')
    torch.backends.cudnn.benchmark = True

    print("Computing factors")
    if args.estimator == 'inf':
        est = compute_inf(args)
    elif args.estimator == 'efb':
        factors = torch.load(factors_path.replace("efb", "kfac") + '.pth')
        est = compute_factors(args, model, data, factors)
    else:
        est = compute_factors(args, model, data)

    print("Saving factors")
    if args.estimator == "inf":
        torch.save(est.state, f"{factors_path}{args.rank}.pth")
    elif args.estimator == "efb":
        torch.save(list(est.state.values()), factors_path + '.pth')
        torch.save(list(est.diags.values()),
                   factors_path.replace("efb", "diag") + '.pth')
    else:
        torch.save(list(est.state.values()), factors_path + '.pth')
Exemple #5
0
def main():
    args = setup()

    print("Preparing directories")
    filename = f"{args.prefix}{args.model}_{args.data}{args.suffix}"
    factors_path = os.path.join(
        args.root_dir, "factors",
        f"{args.prefix}{args.model}_{args.data}_{args.estimator}{args.suffix}")
    weights_path = os.path.join(args.root_dir, "weights",
                                f"{args.model}_{args.data}.pth")
    if args.exp_id == -1:
        if not args.no_results:
            os.makedirs(os.path.join(args.results_dir, args.model, "data",
                                     args.estimator, args.optimizer),
                        exist_ok=True)
        if args.plot:
            os.makedirs(os.path.join(args.results_dir, args.model, "figures",
                                     args.estimator, args.optimizer),
                        exist_ok=True)
        results_path = os.path.join(args.results_dir, args.model, "data",
                                    args.estimator, args.optimizer, filename)
    else:
        if not args.no_results:
            os.makedirs(os.path.join(args.results_dir, args.model, "data",
                                     args.estimator, args.optimizer,
                                     args.exp_id),
                        exist_ok=True)
        if args.plot:
            os.makedirs(os.path.join(args.results_dir, args.model, "figures",
                                     args.estimator, args.optimizer,
                                     args.exp_id),
                        exist_ok=True)
        results_path = os.path.join(args.results_dir, args.model, "data",
                                    args.estimator, args.optimizer,
                                    args.exp_id, filename)

    print("Loading model")
    if args.model == 'lenet5':
        model = lenet5(pretrained=args.data, device=args.device)
    elif args.model == 'resnet18' and args.data != 'imagenet':
        model = resnet18(pretrained=weights_path,
                         num_classes=43 if args.data == 'gtsrb' else 10,
                         device=args.device)
    else:
        model_class = getattr(torchvision.models, args.model)
        if args.model in ['googlenet', 'inception_v3']:
            model = model_class(pretrained=True, aux_logits=False)
        else:
            model = model_class(pretrained=True)
    model.to(args.device).eval()
    if args.parallel:
        model = torch.nn.parallel.DataParallel(model)

    print("Loading data")
    if args.data == 'mnist':
        val_loader = datasets.mnist(args.torch_data, splits='val')
    elif args.data == 'cifar10':
        val_loader = datasets.cifar10(args.torch_data, splits='val')
    elif args.data == 'gtsrb':
        val_loader = datasets.gtsrb(args.data_dir,
                                    batch_size=args.batch_size,
                                    splits='val')
    elif args.data == 'imagenet':
        img_size = 224
        if args.model in ['googlenet', 'inception_v3']:
            img_size = 299
        data_dir = os.path.join(args.root_dir, "datasets", "imagenet")
        val_loader = datasets.imagenet(data_dir,
                                       img_size,
                                       args.batch_size,
                                       splits="val")
    else:
        raise ValueError

    print("Loading factors")
    factors_path = os.path.join(args.root_dir, "factors",
                                f"{args.model}_{args.data}_{args.estimator}")
    if args.estimator in ['diag', 'kfac']:
        if args.estimator == 'diag':
            estimator = Diagonal(model)
        elif args.estimator == 'kfac':
            estimator = KFAC(model)
        estimator.state = torch.load(factors_path + '.pth')
    elif args.estimator in ['efb', 'inf']:
        if args.estimator == 'efb':
            kfac_factors = torch.load(
                factors_path.replace("efb", "kfac") + '.pth')
            estimator = EFB(model, kfac_factors)
            estimator.state = torch.load(factors_path + '.pth')
        if args.estimator == 'inf':
            diags = torch.load(factors_path.replace("inf", "diag") + '.pth')
            kfac_factors = torch.load(
                factors_path.replace("inf", "kfac") + '.pth')
            lambdas = torch.load(factors_path.replace("inf", "efb") + '.pth')
            try:
                factors = torch.load(f"{factors_path}{args.rank}.pth")
            except FileNotFoundError:
                factors = np.load(
                    factors_path + f"{args.rank}.npz",
                    allow_pickle=True)['sif_list']  # Todo: Remove
            estimator = INF(model, diags, kfac_factors, lambdas)
            estimator.state = factors
    torch.backends.cudnn.benchmark = True

    norm_min = -10
    norm_max = 10
    scale_min = -10
    scale_max = 10
    if args.boundaries:
        x0 = [[norm_min, scale_min], [norm_max, scale_max],
              [norm_min, scale_max], [norm_max, scale_min],
              [norm_min / 2., scale_min], [norm_max / 2., scale_max],
              [norm_min, scale_max / 2.], [norm_max, scale_min / 2.],
              [norm_min / 2., scale_min / 2.], [norm_max / 2., scale_max / 2.],
              [norm_min / 2., scale_max / 2.], [norm_max / 2., scale_min / 2.]]
    else:
        x0 = None

    space = list()
    space.append(
        skopt.space.Real(norm_min, norm_max, name=f"norm", prior='uniform'))
    space.append(
        skopt.space.Real(scale_min, scale_max, name=f"scale", prior='uniform'))

    try:
        stats = np.load(
            results_path +
            f"_hyperopt_stats{'_layer.npy' if args.layer else '.npy'}",
            allow_pickle=True).item()
        print(f"Found {len(stats['cost'])} Previous evaluations.")
    except FileNotFoundError:
        stats = {
            "norms": [],
            "scales": [],
            "acc": [],
            "ece": [],
            "nll": [],
            "ent": [],
            "cost": []
        }

    @skopt.utils.use_named_args(dimensions=space)
    def objective(**params):
        norms = [10**params["norm"]] * len(factors)
        scales = [10**params["scale"]] * len(factors)
        print("Norm:", norms[0], "Scale:", scales[0])
        try:
            estimator.invert(norms, args.pre_scale * scales)
        except (RuntimeError, np.linalg.LinAlgError):
            print(f"Error: Singular matrix")
            return 200

        predictions, labels, _ = eval_bnn(model,
                                          val_loader,
                                          estimator,
                                          args.samples,
                                          stats=False,
                                          device=args.device,
                                          verbose=args.verbose)

        err = 100 - accuracy(predictions, labels)
        ece = 100 * expected_calibration_error(predictions, labels)[0]
        nll = negative_log_likelihood(predictions, labels)
        ent = predictive_entropy(predictions, mean=True)
        stats["norms"].append(norms)
        stats["scales"].append(scales)
        stats["acc"].append(100 - err)
        stats["ece"].append(ece)
        stats["nll"].append(nll)
        stats["ent"].append(ent)
        stats["cost"].append(err + ece)
        print(
            f"Err.: {err:.2f}% | ECE: {ece:.2f}% | NLL: {nll:.3f} | Ent.: {ent:.3f}"
        )
        np.save(
            results_path +
            f"_hyperopt_stats{'_layer.npy' if args.layer else '.npy'}", stats)

        return err + ece

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=FutureWarning)

        if args.optimizer == "gbrt":
            res = skopt.gbrt_minimize(func=objective,
                                      dimensions=space,
                                      n_calls=args.calls,
                                      x0=x0,
                                      verbose=True,
                                      n_jobs=args.workers,
                                      n_random_starts=0 if x0 else 10,
                                      acq_func='EI')

        # EI (neg. expected improvement)
        # LCB (lower confidence bound)
        # PI (neg. prob. of improvement): Usually favours exploitation over exploration
        # gp_hedge (choose probabilistically between all)
        if args.optimizer == "gp":
            res = skopt.gp_minimize(func=objective,
                                    dimensions=space,
                                    n_calls=args.calls,
                                    x0=x0,
                                    verbose=True,
                                    n_jobs=args.workers,
                                    n_random_starts=0 if x0 else 1,
                                    acq_func='gp_hedge')

        # acq_func: EI (neg. expected improvement), LCB (lower confidence bound), PI (neg. prob. of improvement)
        # xi: how much improvement one wants over the previous best values.
        # kappa: Importance of variance of predicted values. High: exploration > exploitation
        # base_estimator: RF (random forest), ET (extra trees)
        elif args.optimizer == "forest":
            res = skopt.forest_minimize(func=objective,
                                        dimensions=space,
                                        n_calls=args.calls,
                                        x0=x0,
                                        verbose=True,
                                        n_jobs=args.workers,
                                        n_random_starts=0 if x0 else 1,
                                        acq_func='EI')

        elif args.optimizer == "random":
            res = skopt.dummy_minimize(func=objective,
                                       dimensions=space,
                                       n_calls=args.calls,
                                       x0=x0,
                                       verbose=True)

        elif args.optimizer == "grid":
            space = [
                np.arange(norm_min, norm_max + 1, 10),
                np.arange(scale_min, scale_max + 1, 10)
            ]
            res = grid(func=objective, dimensions=space)
        else:
            raise ValueError

        print(f"Minimal cost of {min(stats['cost'])} found at:")
        print("Norm:", stats['norms'][np.argmin(stats['cost'])][0], "Scale:",
              stats['scales'][np.argmin(stats['cost'])][0])

    if not args.no_results:
        print("Saving results")
        del res.specs['args']['func']
        np.save(f"{results_path}_hyperopt_stats.npy", stats)
        skopt.dump(res, f"{results_path}_hyperopt_dump.pkl")

        all_stats = {
            "norms": [],
            "scales": [],
            "acc": [],
            "ece": [],
            "nll": [],
            "ent": [],
            "cost": []
        }
        path = os.path.join(args.results_dir, args.model, "data",
                            args.estimator)
        paths = [subdir[0] for subdir in os.walk(path)]
        for p in paths:
            try:
                tmp_stats = np.load(p, allow_pickle=True).item()
                for key, value in tmp_stats.items():
                    all_stats[key].extend(value)
            except FileNotFoundError:
                pass
        np.save(os.path.join(path, f"{filename}_best_params.npy"), [
            all_stats['norms'][np.argmin(all_stats['cost'])],
            all_stats['scales'][np.argmin(all_stats['cost'])]
        ])

    if args.plot:
        print("Plotting results")
        hyperparameters(args)
Exemple #6
0
def main():
    args = setup()

    print("Preparing directories")
    os.makedirs(os.path.join(args.results_dir, args.model, "data", args.estimator), exist_ok=True)
    os.makedirs(os.path.join(args.results_dir, args.model, "figures", args.estimator), exist_ok=True)
    filename = f"{args.prefix}{args.model}_{args.data}{args.suffix}"
    results_path = os.path.join(args.results_dir, args.model, "data", args.estimator, filename)
    fig_path = os.path.join(args.results_dir, args.model, "figures", args.estimator, filename)

    print("Loading model")
    if args.model == 'lenet5':
        model = lenet5(pretrained=args.data, device=args.device)
    elif args.model == 'resnet18' and args.data != 'imagenet':
        model = resnet18(pretrained=os.path.join(args.root_dir, 'weights', f"{args.model}_{args.data}.pth"),
                         num_classes=43 if args.data == 'gtsrb' else 10, device=args.device)
    else:
        model_class = getattr(torchvision.models, args.model)
        if args.model in ['googlenet', 'inception_v3']:
            model = model_class(pretrained=True, aux_logits=False)
        else:
            model = model_class(pretrained=True)
    model.to(args.device).eval()
    if args.parallel:
        model = torch.nn.parallel.DataParallel(model)

    if args.ood or args.fgsm:
        print("Loading factors")
        factors_path = os.path.join(args.root_dir, "factors", f"{args.model}_{args.data}_{args.estimator}")
        if args.estimator in ['diag', 'kfac']:
            if args.estimator == 'diag':
                estimator = Diagonal(model)
            elif args.estimator == 'kfac':
                estimator = KFAC(model)
            estimator.state = torch.load(factors_path + '.pth')
        elif args.estimator in ['efb', 'inf']:
            if args.estimator == 'efb':
                kfac_factors = torch.load(factors_path.replace("efb", "kfac") + '.pth')
                estimator = EFB(model, kfac_factors)
                estimator.state = torch.load(factors_path + '.pth')
            if args.estimator == 'inf':
                diags = torch.load(factors_path.replace("inf", "diag") + '.pth')
                kfac_factors = torch.load(factors_path.replace("inf", "kfac") + '.pth')
                lambdas = torch.load(factors_path.replace("inf", "efb") + '.pth')
                try:
                    factors = torch.load(f"{factors_path}{args.rank}.pth")
                except FileNotFoundError:
                    factors = np.load(factors_path + f"{args.rank}.npz", allow_pickle=True)['sif_list']  # Todo: Remove
                estimator = INF(model, diags, kfac_factors, lambdas)
                estimator.state = factors

        print("Inverting factors")
        if args.norm == -1 or args.scale == -1:
            norm, scale = np.load(results_path + "_best_params.npy")
        else:
            norm, scale = args.norm, args.scale
        scale = args.pre_scale * scale
        estimator.invert(norm, scale)

    if args.fgsm:
        adversarial_attack(args, model, estimator, results_path, fig_path)
    elif args.ood:
        out_of_domain(args, model, estimator, results_path, fig_path)
    else:
        fig_path = os.path.join(args.results_dir, args.model, "figures", filename)
        test(args, model, fig_path)