コード例 #1
0
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

print("Using model %s" % args.model)
model_cfg = getattr(models, args.model)

# only use testing data augmentation (e.g. scaling etc.)
# no random flipping
print("Loading dataset %s from %s" % (args.dataset, args.data_path))
loaders, num_classes = data.loaders(
    args.dataset,
    args.data_path,
    args.batch_size,
    args.num_workers,
    model_cfg.transform_test,
    model_cfg.transform_test,
    use_validation=False,
    split_classes=args.split_classes,
    shuffle_train=False,
)

model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       **model_cfg.kwargs,
                       c=args.num_channels,
                       max_depth=args.depth)
model.cuda()

print("Loading model %s" % args.file)
checkpoint = torch.load(args.file)
コード例 #2
0
ファイル: train.py プロジェクト: ml-lab/hessian-eff-dim
def main():
    args = parser()
    args.device = None

    if torch.cuda.is_available():
        args.device = torch.device("cuda")
        args.cuda = True
    else:
        args.device = torch.device("cpu")
        args.cuda = False

    n_trials = 1

    print("Preparing base directory %s" % args.dir)
    os.makedirs(args.dir, exist_ok=True)

    for trial in range(n_trials):
        print("Preparing directory %s" % args.dir + '/trial_' + str(trial))

        os.makedirs(args.dir + '/trial_' + str(trial), exist_ok=True)
        with open(
                os.path.join(args.dir + '/trial_' + str(trial), "command.sh"),
                "w") as f:
            f.write(" ".join(sys.argv))
            f.write("\n")

        torch.backends.cudnn.benchmark = True
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        print("Using model %s" % args.model)
        model_cfg = getattr(models, args.model)

        print("Loading dataset %s from %s" % (args.dataset, args.data_path))
        loaders, num_classes = data.loaders(
            args.dataset,
            args.data_path,
            args.batch_size,
            args.num_workers,
            model_cfg.transform_train,
            model_cfg.transform_test,
            use_validation=not args.use_test,
            split_classes=args.split_classes,
        )

        print("Preparing model")
        print(*model_cfg.args)

        # add extra args for varying names
        if args.model == 'ResNet18':
            extra_args = {'init_channels': args.num_channels}
        elif args.model == 'ConvNet':
            extra_args = {
                'init_channels': args.num_channels,
                'max_depth': args.depth
            }
        else:
            extra_args = {}

        model = model_cfg.base(*model_cfg.args,
                               num_classes=num_classes,
                               **model_cfg.kwargs,
                               **extra_args)
        model.to(args.device)

        ## train ##
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr_init,
                                    momentum=args.momentum,
                                    weight_decay=args.wd)

        for epoch in range(0, args.epochs):
            train_epoch(model,
                        loaders,
                        swag.losses.cross_entropy,
                        optimizer,
                        epoch=epoch,
                        end_epoch=args.epochs,
                        eval_freq=args.eval_freq,
                        save_freq=args.save_freq,
                        output_dir=args.dir + '/trial_' + str(trial),
                        lr_init=args.lr_init)

        print("model ", trial, " done")
コード例 #3
0
def main():
    args = parser()
    args.device = None

    if torch.cuda.is_available():
        args.device = torch.device("cuda")
        args.cuda = True
    else:
        args.device = torch.device("cpu")
        args.cuda = False

    #loss_func = torch.nn.BCEWithLogitsLoss()
    #lr = 0.01

    n_trials = 10
    #n_iters = 1000
    #losses = torch.zeros(n_trials, n_iters)
    init_eigs = []
    final_eigs = []
    #pct_keep = 0.4
    #optim = torch.optim.SGD

    print("Preparing base directory %s" % args.dir)
    os.makedirs(args.dir, exist_ok=True)

    for trial in range(n_trials):
        print("Preparing directory %s" % args.dir + '/trial_' + str(trial))

        os.makedirs(args.dir + '/trial_' + str(trial), exist_ok=True)
        with open(
                os.path.join(args.dir + '/trial_' + str(trial), "command.sh"),
                "w") as f:
            f.write(" ".join(sys.argv))
            f.write("\n")

        torch.backends.cudnn.benchmark = True
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        print("Using model %s" % args.model)
        model_cfg = getattr(models, args.model)

        print("Loading dataset %s from %s" % (args.dataset, args.data_path))
        loaders, num_classes = data.loaders(
            args.dataset,
            args.data_path,
            args.batch_size,
            args.num_workers,
            model_cfg.transform_train,
            model_cfg.transform_test,
            use_validation=not args.use_test,
            split_classes=args.split_classes,
        )

        print("Preparing model")
        print(*model_cfg.args)
        model = model_cfg.base(*model_cfg.args,
                               num_classes=num_classes,
                               **model_cfg.kwargs,
                               use_masked=True)
        model.to(args.device)
        # bad set to for now
        for m in model.modules():
            if isinstance(m, hess.nets.MaskedConv2d) or isinstance(
                    m, hess.nets.MaskedLinear):
                if m.mask is not None and m.weight is not None:
                    m.mask = m.mask.to(m.weight.device)
                if m.has_bias:
                    if m.bias_mask is not None and m.bias is not None:
                        m.bias_mask = m.bias_mask.to(m.bias.device)

        mask = hess.utils.get_mask(model)
        #mask, perm = hess.utils.mask_model(model, pct_keep, use_cuda)
        #keepers = np.array(np.where(mask.cpu() == 1))[0]

        criterion = torch.nn.functional.cross_entropy

        ## compute hessian pre-training ##
        initial_evals = utils.get_hessian_eigs(loss=criterion,
                                               model=model,
                                               mask=mask,
                                               use_cuda=args.cuda,
                                               n_eigs=100,
                                               loader=loaders['train'])
        init_eigs.append(initial_evals)

        ## train ##
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr_init,
                                    momentum=args.momentum,
                                    weight_decay=args.wd)

        for epoch in range(0, args.epochs):
            train_epoch(model,
                        loaders,
                        swag.losses.cross_entropy,
                        optimizer,
                        epoch=epoch,
                        end_epoch=args.epochs,
                        eval_freq=args.eval_freq,
                        save_freq=args.save_freq,
                        output_dir=args.dir + '/trial_' + str(trial),
                        lr_init=args.lr_init)

        ## compute final hessian ##
        final_evals = utils.get_hessian_eigs(loss=criterion,
                                             model=model,
                                             use_cuda=args.cuda,
                                             n_eigs=100,
                                             mask=mask,
                                             loader=loaders['train'])
        # sub_hess = hessian[np.ix_(keepers, keepers)]
        # e_val, _ = np.linalg.eig(sub_hess.cpu().detach())
        # final_eigs.append(e_val.real)
        final_eigs.append(final_evals)

        print("model ", trial, " done")

        # fpath = "../saved-experiments/"

        # fname = "losses.pt"
        # torch.save(losses, fpath + fname)
        fpath = args.dir + '/trial_' + str(trial)
        fname = "init_eigs.P"
        with open(fpath + fname, 'wb') as fp:
            pickle.dump(init_eigs, fp)

        fname = "final_eigs.P"
        with open(fpath + fname, 'wb') as fp:
            pickle.dump(final_eigs, fp)