if torch.cuda.is_available(): args.device = torch.device('cuda') else: args.device = torch.device('cpu') print('Using model %s' % args.model) model_cfg = getattr(models, args.model) print('Preparing model') print(*model_cfg.args) model = model_cfg.base(*model_cfg.args, num_classes=args.num_classes, **model_cfg.kwargs) model.to(args.device) swag_model = SWAG(model_cfg.base, subspace_type=args.subspace, subspace_kwargs={'max_rank': args.max_num_models}, *model_cfg.args, num_classes=args.num_classes, **model_cfg.kwargs) swag_model.to(args.device) for path in args.checkpoint: print(path) ckpt = torch.load(path) model.load_state_dict(ckpt['state_dict']) swag_model.collect_model(model) torch.save({'state_dict': swag_model.state_dict()}, args.path)
momentum=0.9, weight_decay=1e-4) loader = generate_dataloaders(N=10) state_dict = None for epoch in range(num_epochs): model.train() for x, y in loader: model.zero_grad() pred = model(x) loss = ((pred - y)**2.0).sum() loss.backward() optimizer.step() small_swag_model.collect_model(model) if epoch == 4: state_dict = small_swag_model.state_dict() small_swag_model.fit() with torch.no_grad(): x = torch.arange(-6., 6., 1.0).unsqueeze(1) for i in range(10): small_swag_model.sample(0.5) small_swag_model(x) _, _ = small_swag_model.get_space(export_cov_factor=False) _, _, _ = small_swag_model.get_space(export_cov_factor=True) small_swag_model.load_state_dict(state_dict)
swag_res = utils.eval(loaders["test"], swag_model, criterion) else: swag_res = {"loss": None, "accuracy": None} if (epoch + 1) % args.save_freq == 0: utils.save_checkpoint( args.dir, epoch + 1, state_dict=model.state_dict(), optimizer=optimizer.state_dict(), ) if args.swa: utils.save_checkpoint(args.dir, epoch + 1, name="swag", state_dict=swag_model.state_dict()) time_ep = time.time() - time_ep memory_usage = torch.cuda.memory_allocated() / (1024.0**3) values = [ epoch + 1, lr, train_res["loss"], train_res["accuracy"], test_res["loss"], test_res["accuracy"], time_ep, memory_usage, ] if args.swa: values = values[:-2] + [swag_res["loss"], swag_res["accuracy"]
): swag_res = {"loss": None, "accuracy": None} swag_model.to(args.device) swag_model.sample(0.0) print("EPOCH %d. SWAG BN" % (epoch + 1)) utils.bn_update(loaders["train"], swag_model, verbose=True, subset=0.1) print("EPOCH %d. SWAG EVAL" % (epoch + 1)) swag_res = utils.eval(loaders["test"], swag_model, criterion, verbose=True) swag_model.to(args.swa_device) else: swag_res = {"loss": None, "accuracy": None} if (epoch + 1) % args.save_freq == 0: if args.swa: utils.save_checkpoint( args.dir, epoch + 1, name="swag", state_dict=swag_model.state_dict() ) else: utils.save_checkpoint( args.dir, epoch + 1, state_dict=model.state_dict(), optimizer=optimizer.state_dict(), ) time_ep = time.time() - time_ep memory_usage = torch.cuda.memory_allocated() / (1024.0 ** 3) values = [ epoch + 1, lr, train_res["loss"],