Exemplo n.º 1
0
args = parser.parse_args()

args.device = None
if torch.cuda.is_available():
    args.device = torch.device('cuda')
else:
    args.device = torch.device('cpu')

# torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

print('Using model %s' % args.model)
model_class = getattr(torchvision.models, args.model)

num_classes = 1000

print('Preparing model')
model = model_class(pretrained=args.pretrained, num_classes=num_classes)
model.to(args.device)

print('SWAG training')
swag_model = SWAG(model_class,
                  no_cov_mat=False,
                  max_num_models=20,
                  num_classes=num_classes)
swag_model.to('cpu')

for k in range(100):
    swag_model.collect_model(model)
    print(k + 1)
Exemplo n.º 2
0
if torch.cuda.is_available():
    args.device = torch.device('cuda')
else:
    args.device = torch.device('cpu')

print('Using model %s' % args.model)
model_cfg = getattr(models, args.model)

print('Preparing model')
print(*model_cfg.args)
model = model_cfg.base(*model_cfg.args,
                       num_classes=args.num_classes,
                       **model_cfg.kwargs)
model.to(args.device)

swag_model = SWAG(model_cfg.base,
                  subspace_type=args.subspace,
                  subspace_kwargs={'max_rank': args.max_num_models},
                  *model_cfg.args,
                  num_classes=args.num_classes,
                  **model_cfg.kwargs)
swag_model.to(args.device)

for path in args.checkpoint:
    print(path)
    ckpt = torch.load(path)
    model.load_state_dict(ckpt['state_dict'])
    swag_model.collect_model(model)

torch.save({'state_dict': swag_model.state_dict()}, args.path)
Exemplo n.º 3
0
args = parser.parse_args()

args.device = None
if torch.cuda.is_available():
    args.device = torch.device("cuda")
else:
    args.device = torch.device("cpu")

# torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

print("Using model %s" % args.model)
model_class = getattr(torchvision.models, args.model)

num_classes = 1000

print("Preparing model")
model = model_class(pretrained=args.pretrained, num_classes=num_classes)
model.to(args.device)

print("SWAG training")
swag_model = SWAG(model_class,
                  no_cov_mat=False,
                  max_num_models=20,
                  num_classes=num_classes)
swag_model.to("cpu")

for k in range(100):
    swag_model.collect_model(model)
    print(k + 1)