Ejemplo n.º 1
0
print(*model_cfg.args)

swag_model = SWAG(model_cfg.base,
                  num_classes=num_classes,
                  subspace_type='pca',
                  subspace_kwargs={
                      'max_rank': 20,
                      'pca_rank': args.rank,
                  },
                  *model_cfg.args,
                  **model_cfg.kwargs)
swag_model.to(args.device)

print('Loading: %s' % args.checkpoint)
ckpt = torch.load(args.checkpoint)
swag_model.load_state_dict(ckpt['state_dict'], strict=False)

swag_model.set_swa()
print("SWA:",
      utils.eval(loaders["train"], swag_model, criterion=losses.cross_entropy))

mean, var, cov_factor = swag_model.get_space()
subspace = Subspace(mean, cov_factor)

print(torch.norm(cov_factor, dim=1))

nvp_flow = construct_flow(cov_factor.shape[0],
                          device=torch.cuda.current_device())

vi_model = VINFModel(base=model_cfg.base,
                     subspace=subspace,
ckpt = torch.load(args.checkpoint)

criterion = losses.cross_entropy

fractions = np.logspace(-np.log10(0.005 * len(loaders['train'].dataset)), 0.0,
                        args.N)
swa_accuracies = np.zeros(args.N)
swa_nlls = np.zeros(args.N)
swag_accuracies = np.zeros(args.N)
swag_nlls = np.zeros(args.N)

columns = ['fraction', 'swa_acc', 'swa_loss', 'swag_acc', 'swag_loss', 'time']

for i, fraction in enumerate(fractions):
    start_time = time.time()
    swag_model.load_state_dict(ckpt['state_dict'])

    swag_model.sample(0.0)
    utils.bn_update(loaders['train'], swag_model, subset=fraction)
    swa_res = utils.eval(loaders['test'], swag_model, criterion)
    swa_accuracies[i] = swa_res['accuracy']
    swa_nlls[i] = swa_res['loss']

    predictions = np.zeros((len(loaders['test'].dataset), num_classes))

    for j in range(args.S):
        swag_model.load_state_dict(ckpt['state_dict'])
        swag_model.sample(scale=0.5, cov=args.cov_mat)
        utils.bn_update(loaders['train'], swag_model, subset=fraction)
        sample_res = utils.predict(loaders['test'], swag_model)
        predictions += sample_res['predictions']
Ejemplo n.º 3
0
    checkpoint = torch.load(args.resume)
    start_epoch = checkpoint["epoch"]
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

if args.swa and args.swa_resume is not None:
    checkpoint = torch.load(args.swa_resume)
    swag_model = SWAG(model_cfg.base,
                      no_cov_mat=args.no_cov_mat,
                      max_num_models=args.max_num_models,
                      loading=True,
                      *model_cfg.args,
                      num_classes=num_classes,
                      **model_cfg.kwargs)
    swag_model.to(args.device)
    swag_model.load_state_dict(checkpoint["state_dict"])

columns = [
    "ep", "lr", "tr_loss", "tr_acc", "te_loss", "te_acc", "time", "mem_usage"
]
if args.swa:
    columns = columns[:-2] + ["swa_te_loss", "swa_te_acc"] + columns[-2:]
    swag_res = {"loss": None, "accuracy": None}

utils.save_checkpoint(
    args.dir,
    start_epoch,
    state_dict=model.state_dict(),
    optimizer=optimizer.state_dict(),
)
Ejemplo n.º 4
0
)

start_epoch = 0
if args.resume is not None:
    print('Resume training from %s' % args.resume)
    checkpoint = torch.load(args.resume)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

if args.swa and args.swa_resume is not None:
    checkpoint = torch.load(args.swa_resume)
    swag_model = SWAG(model_cfg.base, no_cov_mat=args.no_cov_mat, max_num_models=args.max_num_models, 
                    loading=True, *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
    swag_model.to(args.device)
    swag_model.load_state_dict(checkpoint['state_dict'])

columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time', 'mem_usage']
if args.swa:
    columns = columns[:-2] + ['swa_te_loss', 'swa_te_acc'] + columns[-2:]
    swag_res = {'loss': None, 'accuracy': None}

utils.save_checkpoint(
    args.dir,
    start_epoch,
    state_dict=model.state_dict(),
    optimizer=optimizer.state_dict()
)

sgd_ens_preds = None
sgd_targets = None
Ejemplo n.º 5
0
                                momentum=0.9,
                                weight_decay=1e-4)
    loader = generate_dataloaders(N=10)

    state_dict = None

    for epoch in range(num_epochs):
        model.train()

        for x, y in loader:
            model.zero_grad()
            pred = model(x)
            loss = ((pred - y)**2.0).sum()
            loss.backward()
            optimizer.step()
        small_swag_model.collect_model(model)

        if epoch == 4:
            state_dict = small_swag_model.state_dict()

    small_swag_model.fit()
    with torch.no_grad():
        x = torch.arange(-6., 6., 1.0).unsqueeze(1)
        for i in range(10):
            small_swag_model.sample(0.5)
            small_swag_model(x)

    _, _ = small_swag_model.get_space(export_cov_factor=False)
    _, _, _ = small_swag_model.get_space(export_cov_factor=True)
    small_swag_model.load_state_dict(state_dict)
Ejemplo n.º 6
0
ckpt = torch.load(args.checkpoint)

criterion = losses.cross_entropy

fractions = np.logspace(-np.log10(0.005 * len(loaders["train"].dataset)), 0.0,
                        args.N)
swa_accuracies = np.zeros(args.N)
swa_nlls = np.zeros(args.N)
swag_accuracies = np.zeros(args.N)
swag_nlls = np.zeros(args.N)

columns = ["fraction", "swa_acc", "swa_loss", "swag_acc", "swag_loss", "time"]

for i, fraction in enumerate(fractions):
    start_time = time.time()
    swag_model.load_state_dict(ckpt["state_dict"])

    swag_model.sample(0.0)
    utils.bn_update(loaders["train"], swag_model, subset=fraction)
    swa_res = utils.eval(loaders["test"], swag_model, criterion)
    swa_accuracies[i] = swa_res["accuracy"]
    swa_nlls[i] = swa_res["loss"]

    predictions = np.zeros((len(loaders["test"].dataset), num_classes))

    for j in range(args.S):
        swag_model.load_state_dict(ckpt["state_dict"])
        swag_model.sample(scale=0.5, cov=args.cov_mat)
        utils.bn_update(loaders["train"], swag_model, subset=fraction)
        sample_res = utils.predict(loaders["test"], swag_model)
        predictions += sample_res["predictions"]