Exemple #1
0
if torch.cuda.is_available():
    args.device = torch.device('cuda')
else:
    args.device = torch.device('cpu')

print('Using model %s' % args.model)
model_cfg = getattr(models, args.model)

print('Preparing model')
print(*model_cfg.args)
model = model_cfg.base(*model_cfg.args,
                       num_classes=args.num_classes,
                       **model_cfg.kwargs)
model.to(args.device)

swag_model = SWAG(model_cfg.base,
                  subspace_type=args.subspace,
                  subspace_kwargs={'max_rank': args.max_num_models},
                  *model_cfg.args,
                  num_classes=args.num_classes,
                  **model_cfg.kwargs)
swag_model.to(args.device)

for path in args.checkpoint:
    print(path)
    ckpt = torch.load(path)
    model.load_state_dict(ckpt['state_dict'])
    swag_model.collect_model(model)

torch.save({'state_dict': swag_model.state_dict()}, args.path)
Exemple #2
0
                                momentum=0.9,
                                weight_decay=1e-4)
    loader = generate_dataloaders(N=10)

    state_dict = None

    for epoch in range(num_epochs):
        model.train()

        for x, y in loader:
            model.zero_grad()
            pred = model(x)
            loss = ((pred - y)**2.0).sum()
            loss.backward()
            optimizer.step()
        small_swag_model.collect_model(model)

        if epoch == 4:
            state_dict = small_swag_model.state_dict()

    small_swag_model.fit()
    with torch.no_grad():
        x = torch.arange(-6., 6., 1.0).unsqueeze(1)
        for i in range(10):
            small_swag_model.sample(0.5)
            small_swag_model(x)

    _, _ = small_swag_model.get_space(export_cov_factor=False)
    _, _, _ = small_swag_model.get_space(export_cov_factor=True)
    small_swag_model.load_state_dict(state_dict)
Exemple #3
0
            swag_res = utils.eval(loaders["test"], swag_model, criterion)
        else:
            swag_res = {"loss": None, "accuracy": None}

    if (epoch + 1) % args.save_freq == 0:
        utils.save_checkpoint(
            args.dir,
            epoch + 1,
            state_dict=model.state_dict(),
            optimizer=optimizer.state_dict(),
        )
        if args.swa:
            utils.save_checkpoint(args.dir,
                                  epoch + 1,
                                  name="swag",
                                  state_dict=swag_model.state_dict())

    time_ep = time.time() - time_ep
    memory_usage = torch.cuda.memory_allocated() / (1024.0**3)
    values = [
        epoch + 1,
        lr,
        train_res["loss"],
        train_res["accuracy"],
        test_res["loss"],
        test_res["accuracy"],
        time_ep,
        memory_usage,
    ]
    if args.swa:
        values = values[:-2] + [swag_res["loss"], swag_res["accuracy"]
        ):
            swag_res = {"loss": None, "accuracy": None}
            swag_model.to(args.device)
            swag_model.sample(0.0)
            print("EPOCH %d. SWAG BN" % (epoch + 1))
            utils.bn_update(loaders["train"], swag_model, verbose=True, subset=0.1)
            print("EPOCH %d. SWAG EVAL" % (epoch + 1))
            swag_res = utils.eval(loaders["test"], swag_model, criterion, verbose=True)
            swag_model.to(args.swa_device)
        else:
            swag_res = {"loss": None, "accuracy": None}

    if (epoch + 1) % args.save_freq == 0:
        if args.swa:
            utils.save_checkpoint(
                args.dir, epoch + 1, name="swag", state_dict=swag_model.state_dict()
            )
        else:
            utils.save_checkpoint(
                args.dir,
                epoch + 1,
                state_dict=model.state_dict(),
                optimizer=optimizer.state_dict(),
            )

    time_ep = time.time() - time_ep
    memory_usage = torch.cuda.memory_allocated() / (1024.0 ** 3)
    values = [
        epoch + 1,
        lr,
        train_res["loss"],