Пример #1
0
criterion = losses.cross_entropy

W = []
num_checkpoints = len(args.checkpoint)
for path in args.checkpoint:
    print('Loading %s' % path)
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['state_dict'])

    swag_model.collect_model(model)
    W.append(
        np.concatenate(
            [p.detach().cpu().numpy().ravel() for p in model.parameters()]))
W = np.array(W)

mean, _, cov_mat_list = swag_model.export_numpy_params(export_cov_mat=True)
cov_mat = np.hstack([mat.reshape(args.swag_rank, -1) for mat in cov_mat_list])

tsvd = sklearn.decomposition.TruncatedSVD(n_components=args.swag_rank,
                                          n_iter=7)
tsvd.fit(cov_mat)

component_variances = np.dot(np.dot(
    tsvd.components_, cov_mat.T), np.dot(
        cov_mat, tsvd.components_.T)) / (cov_mat.shape[0] - 1)

pc_idx = [
    0, 1, 2, 3, 4, args.swag_rank // 2 - 1, args.swag_rank // 2,
    args.swag_rank // 2 + 1, args.swag_rank - 2, args.swag_rank - 1
]
pc_idx = np.sort(np.unique(np.minimum(pc_idx, args.swag_rank - 1)))
Пример #2
0
        print("updating sgd_ens")
        if sgd_ens_preds is None:
            sgd_ens_preds = sgd_preds.copy()
        else:
            # TODO: rewrite in a numerically stable way
            sgd_ens_preds = sgd_ens_preds * n_ensembled / (n_ensembled + 1) + sgd_preds / (n_ensembled + 1)
        n_ensembled += 1
        swag_model.collect_model(model)
        if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
            swag_model.sample(0.0)
            utils.bn_update(loaders['train'], swag_model)
            swag_res = utils.eval(loaders['test'], swag_model, criterion)
        else:
            swag_res = {'loss': None, 'accuracy': None}

    _, swag_var = swag_model.export_numpy_params(export_cov_mat=False)
    variances.append(np.sum(swag_var))
    np.savez(os.path.join(args.dir, 'variances.npz'), variances=np.array(variances))

    if (epoch + 1) % args.save_freq == 0:
        utils.save_checkpoint(
            args.dir,
            epoch + 1,
            state_dict=model.state_dict(),
            optimizer=optimizer.state_dict()
        )

        utils.save_checkpoint(
            args.dir,
            epoch + 1,
            name='swag',