예제 #1
0
    def test_swag_cov(self, **kwargs):
        model = torch.nn.Linear(300, 3, bias=True)

        swag_model = SWAG(torch.nn.Linear, in_features=300, out_features=3, bias=True, subspace_type = 'covariance',
                subspace_kwargs = {'max_rank':140})

        optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)

        # construct swag model via training
        torch.manual_seed(0)
        for _ in range(101):
            model.zero_grad()

            input = torch.randn(100, 300)
            output = model(input)
            loss = ((torch.randn(100, 3) - output)**2.0).sum()
            loss.backward()
            
            optimizer.step()
            
            swag_model.collect_model(model)

        # check to ensure parameters have the correct sizes
        mean, var = swag_model._get_mean_and_variance()
        cov_mat_sqrt = swag_model.subspace.get_space()

        true_cov_mat = cov_mat_sqrt.t().matmul(cov_mat_sqrt) + torch.diag(var)

        test_cutoff = chi2(df = mean.numel()).ppf(0.95) #95% quantile of p dimensional chi-square distribution

        for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]:
            scaled_cov_mat = true_cov_mat * scale
            scaled_cov_inv = torch.inverse(scaled_cov_mat)
            # now test to ensure that sampling has the correct covariance matrix probabilistically
            all_qforms = []            
            for _ in range(3000):
                swag_model.sample(scale=scale)
                curr_pars = []
                for (module, name, _) in swag_model.base_params:
                    curr_pars.append(getattr(module, name))
                dev = flatten(curr_pars) - mean

                #(x - mu)sigma^{-1}(x - mu)
                qform = dev.matmul(scaled_cov_inv).matmul(dev)

                all_qforms.append(qform.item())
            
            samples_in_cr = (np.array(all_qforms) < test_cutoff).sum()
            print(samples_in_cr)

            #between 94 and 96% of the samples should fall within the threshold
            #this should be very loose
            self.assertTrue(0.94*3000 <= samples_in_cr <= 0.96*3000) 
예제 #2
0
def RealNVPTabularSWAG(dim_in,
                       coupling_layers,
                       k,
                       nperlayer=1,
                       subspace='covariance',
                       max_num_models=10):
    swag_model = SWAG(RealNVPTabular,
                      subspace_type=subspace,
                      subspace_kwargs={'max_rank': max_num_models},
                      num_coupling_layers=coupling_layers,
                      in_dim=dim_in,
                      hidden_dim=k,
                      num_layers=1,
                      dropout=True)
    return swag_model
예제 #3
0
    use_validation=not args.use_test,
    split_classes=args.split_classes
)

print('Preparing model')
print(*model_cfg.args)
model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model.to(args.device)

if args.cov_mat:
    args.no_cov_mat = False
else:
    args.no_cov_mat = True

print('SWAG training')
swag_model = SWAG(model_cfg.base, no_cov_mat=args.no_cov_mat, max_num_models=args.max_num_models,
                  *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
swag_model.to(args.device)


def schedule(epoch):
    t = (epoch) / (args.swa_start)
    lr_ratio = args.swa_lr / args.lr_init
    if t <= 0.5:
        factor = 1.0
    elif t <= 0.9:
        factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
    else:
        factor = lr_ratio
    return args.lr_init * factor

예제 #4
0
                                    use_validation=not args.use_test,
                                    split_classes=args.split_classes)

print('Preparing model')
print(*model_cfg.args)

model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       **model_cfg.kwargs)
model.cuda()

swag_model = SWAG(model_cfg.base,
                  num_classes=num_classes,
                  subspace_type='pca',
                  subspace_kwargs={
                      'max_rank': 140,
                      'pca_rank': args.rank,
                  },
                  *model_cfg.args,
                  **model_cfg.kwargs)
swag_model.to(args.device)


def checkpoint_num(filename):
    num = filename.split("-")[1]
    num = num.split(".")[0]
    num = int(num)
    return num


for file in os.listdir(args.dir):
print('Using model %s' % args.model)
model_cfg = getattr(models, args.model)

print('Loading dataset %s from %s' % (args.dataset, args.data_path))
loaders, num_classes = data.loaders(args.dataset,
                                    args.data_path,
                                    args.batch_size,
                                    args.num_workers,
                                    model_cfg.transform_train,
                                    model_cfg.transform_test,
                                    use_validation=not args.use_test)

print('Preparing model')
swag_model = SWAG(model_cfg.base,
                  no_cov_mat=not args.cov_mat,
                  max_num_models=args.swag_rank,
                  *model_cfg.args,
                  num_classes=num_classes,
                  **model_cfg.kwargs)
swag_model.to(args.device)

ckpt = torch.load(args.checkpoint)

criterion = losses.cross_entropy

fractions = np.logspace(-np.log10(0.005 * len(loaders['train'].dataset)), 0.0,
                        args.N)
swa_accuracies = np.zeros(args.N)
swa_nlls = np.zeros(args.N)
swag_accuracies = np.zeros(args.N)
swag_nlls = np.zeros(args.N)
예제 #6
0
print(*model_cfg.args)
model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       **model_cfg.kwargs)
wandb.watch(model)
model.to(args.device)

if args.cov_mat:
    args.no_cov_mat = False
else:
    args.no_cov_mat = True
if args.swa:
    print("SWAG training")
    swag_model = SWAG(model_cfg.base,
                      no_cov_mat=args.no_cov_mat,
                      max_num_models=args.max_num_models,
                      *model_cfg.args,
                      num_classes=num_classes,
                      **model_cfg.kwargs)
    swag_model.to(args.device)
else:
    print("SGD training")


def schedule(epoch):
    t = (epoch) / (args.swa_start if args.swa else args.epochs)
    lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01
    if t <= 0.5:
        factor = 1.0
    elif t <= 0.9:
        factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
    else:
    label_arr = np.load(args.label_arr)
    print("Corruption:",
          (loaders['train'].dataset.targets != label_arr).mean())
    loaders['train'].dataset.targets = label_arr

print('Preparing model')
model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       **model_cfg.kwargs)
model.to(args.device)
print("Model has {} parameters".format(
    sum([p.numel() for p in model.parameters()])))

swag_model = SWAG(model_cfg.base,
                  args.subspace, {'max_rank': args.max_num_models},
                  *model_cfg.args,
                  num_classes=num_classes,
                  **model_cfg.kwargs)
swag_model.to(args.device)

columns = ['swag', 'sample', 'te_loss', 'te_acc', 'ens_loss', 'ens_acc']

if args.warm_start:
    print('estimating initial PI using K-Means')
    kmeans_input = []
    for ckpt_i, ckpt in enumerate(args.swag_ckpts):
        #print("Checkpoint {}".format(ckpt))
        checkpoint = torch.load(ckpt)
        swag_model.subspace.rank = torch.tensor(0)
        swag_model.load_state_dict(checkpoint['state_dict'])
        mean, variance = swag_model._get_mean_and_variance()
예제 #8
0
                                    model_cfg.transform_test,
                                    use_validation=not args.use_test,
                                    split_classes=args.split_classes)

print('Preparing model')
print(*model_cfg.args)
model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       **model_cfg.kwargs)
model.to(args.device)

if args.swag:
    print('SGLD+SWAG training')
    swag_model = SWAG(model_cfg.base,
                      subspace_type=args.subspace,
                      subspace_kwargs={'max_rank': args.max_num_models},
                      *model_cfg.args,
                      num_classes=num_classes,
                      **model_cfg.kwargs)
    swag_model.to(args.device)
else:
    print('SGLD training')

criterion = losses.cross_entropy

sgld_optimizer = SGLD(model.parameters(),
                      lr=args.lr_init,
                      weight_decay=args.wd,
                      noise_factor=args.noise_factor)

num_batches = len(loaders['train'])
num_iters = num_batches * (args.epochs - args.ens_start + 1)
예제 #9
0
                                    args.data_path,
                                    args.batch_size,
                                    args.num_workers,
                                    model_cfg.transform_train,
                                    model_cfg.transform_test,
                                    use_validation=not args.use_test)

print('Preparing model')
model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       **model_cfg.kwargs)
model.to(args.device)

swag_model = SWAG(model_cfg.base,
                  no_cov_mat=False,
                  max_num_models=args.swag_rank,
                  *model_cfg.args,
                  num_classes=num_classes,
                  **model_cfg.kwargs)
swag_model.to(args.device)

criterion = losses.cross_entropy

W = []
num_checkpoints = len(args.checkpoint)
for path in args.checkpoint:
    print('Loading %s' % path)
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['state_dict'])

    swag_model.collect_model(model)
    W.append(
예제 #10
0
                                    model_cfg.transform_train,
                                    model_cfg.transform_test,
                                    use_validation=not args.use_test,
                                    split_classes=args.split_classes,
                                    shuffle_train=False)
"""if args.split_classes is not None:
    num_classes /= 2
    num_classes = int(num_classes)"""

print('Preparing model')
if args.method in ['SWAG', 'HomoNoise', 'SWAGDrop']:
    model = SWAG(model_cfg.base,
                 num_classes=num_classes,
                 subspace_type='pca',
                 subspace_kwargs={
                     'max_rank': 140,
                     'pca_rank': args.rank,
                 },
                 *model_cfg.args,
                 **model_cfg.kwargs)
elif args.method in ['SGD', 'Dropout', 'KFACLaplace']:
    model = model_cfg.base(*model_cfg.args,
                           num_classes=num_classes,
                           **model_cfg.kwargs)
else:
    assert False
model.cuda()


def train_dropout(m):
    if type(m) == torch.nn.modules.dropout.Dropout:
    args.batch_size,
    args.num_workers,
)

print('Preparing model')
model = model_class(pretrained=args.pretrained, num_classes=num_classes)
model.to(args.device)

if args.cov_mat:
    args.no_cov_mat = False
else:
    args.no_cov_mat = True
if args.swa:
    print('SWAG training')
    args.swa_device = 'cpu' if args.swa_cpu else args.device
    swag_model = SWAG(model_class, no_cov_mat=args.no_cov_mat, max_num_models=20,
                      num_classes=num_classes)
    swag_model.to(args.swa_device)
    if args.pretrained:
        model.to(args.swa_device)
        swag_model.collect_model(model)
        model.to(args.device)
else:
    print('SGD training')


def schedule(epoch):
    if args.swa and epoch >= args.swa_start:
        return args.swa_lr
    else:
        return args.lr_init * (0.1 ** (epoch // 30))
예제 #12
0
args = parser.parse_args()

args.device = None
if torch.cuda.is_available():
    args.device = torch.device('cuda')
else:
    args.device = torch.device('cpu')

# torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

print('Using model %s' % args.model)
model_class = getattr(torchvision.models, args.model)

num_classes = 1000

print('Preparing model')
model = model_class(pretrained=args.pretrained, num_classes=num_classes)
model.to(args.device)

print('SWAG training')
swag_model = SWAG(model_class,
                  no_cov_mat=False,
                  max_num_models=20,
                  num_classes=num_classes)
swag_model.to('cpu')

for k in range(100):
    swag_model.collect_model(model)
    print(k + 1)
예제 #13
0
loaders, num_classes = data.loaders(args.data_path, args.batch_size, args.num_workers)

print("Preparing model")
model = model_class(pretrained=args.pretrained, num_classes=num_classes)
model.to(args.device)

if args.cov_mat:
    args.no_cov_mat = False
else:
    args.no_cov_mat = True
if args.swa:
    print("SWAG training")
    args.swa_device = "cpu" if args.swa_cpu else args.device
    swag_model = SWAG(
        model_class,
        no_cov_mat=args.no_cov_mat,
        max_num_models=20,
        num_classes=num_classes,
    )
    swag_model.to(args.swa_device)
    if args.pretrained:
        model.to(args.swa_device)
        swag_model.collect_model(model)
        model.to(args.device)
else:
    print("SGD training")


def schedule(epoch):
    if args.swa and epoch >= args.swa_start:
        return args.swa_lr
    else:
예제 #14
0
                                    args.num_workers,
                                    model_cfg.transform_train,
                                    model_cfg.transform_test,
                                    use_validation=not args.use_test,
                                    split_classes=args.split_classes)

model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       **model_cfg.kwargs)
model.to(args.device)

swag_model = SWAG(model_cfg.base,
                  num_classes=num_classes,
                  subspace_type='pca',
                  subspace_kwargs={
                      'max_rank': args.rank,
                      'pca_rank': args.rank,
                  },
                  *model_cfg.args,
                  **model_cfg.kwargs)
swag_model.to(args.device)

print('Loading checkpoint %s' % args.ckpt)
checkpoint = torch.load(args.ckpt)
swag_model.load_state_dict(checkpoint['state_dict'])

num_parameters = sum([p.numel() for p in model.parameters()])

offset = 0
for param in model.parameters():
    size = param.numel()
예제 #15
0
    def test_swag_diag(self, **kwargs):
        model = torch.nn.Linear(300, 3, bias=True)

        swag_model = SWAG(
            torch.nn.Linear,
            in_features=300,
            out_features=3,
            bias=True,
            no_cov_mat=True,
            max_num_models=100,
            loading=False,
        )

        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

        # construct swag model via training
        torch.manual_seed(0)
        for _ in range(101):
            model.zero_grad()

            input = torch.randn(100, 300)
            output = model(input)
            loss = ((torch.randn(100, 3) - output) ** 2.0).sum()
            loss.backward()

            optimizer.step()

            swag_model.collect_model(model)

        # check to ensure parameters have the correct sizes
        mean_list = []
        sq_mean_list = []
        for (module, name), param in zip(swag_model.params, model.parameters()):
            mean = module.__getattr__("%s_mean" % name)
            sq_mean = module.__getattr__("%s_sq_mean" % name)

            self.assertEqual(param.size(), mean.size())
            self.assertEqual(param.size(), sq_mean.size())

            mean_list.append(mean)
            sq_mean_list.append(sq_mean)

        mean = flatten(mean_list).cuda()
        sq_mean = flatten(sq_mean_list).cuda()

        for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]:
            var = scale * (sq_mean - mean ** 2)

            std = torch.sqrt(var)
            dist = torch.distributions.Normal(mean, std)

            # now test to ensure that sampling has the correct covariance matrix probabilistically
            all_qforms = 0
            for _ in range(20):
                swag_model.sample(scale=scale, cov=False)
                curr_pars = []
                for (module, name) in swag_model.params:
                    curr_pars.append(getattr(module, name))

                curr_probs = dist.cdf(flatten(curr_pars))

                # check if within 95% CI
                num_in_cr = ((curr_probs > 0.025) & (curr_probs < 0.975)).float().sum()
                # all_qforms.append( num_in_cr )
                all_qforms += num_in_cr

            # print(all_qforms/(20 * mean.numel()))
            # now compute average
            avg_prob_in_cr = all_qforms / (20 * mean.numel())

            # CLT should hold a bit tighter here
            self.assertTrue(0.945 <= avg_prob_in_cr <= 0.955)
    label_arr = np.load(args.label_arr)
    print("Corruption:",
          (loaders['train'].dataset.targets != label_arr).mean())
    loaders['train'].dataset.targets = label_arr

print('Preparing model')
model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       **model_cfg.kwargs)
model.to(args.device)
print("Model has {} parameters".format(
    sum([p.numel() for p in model.parameters()])))

swag_model = SWAG(model_cfg.base,
                  args.subspace, {'max_rank': args.max_num_models},
                  *model_cfg.args,
                  num_classes=num_classes,
                  **model_cfg.kwargs)
swag_model.to(args.device)

columns = ['swag', 'sample', 'te_loss', 'te_acc', 'ens_loss', 'ens_acc']

n_ensembled = 0.
multiswag_probs = None

for ckpt_i, ckpt in enumerate(args.swag_ckpts):
    print("Checkpoint {}".format(ckpt))
    checkpoint = torch.load(ckpt)
    swag_model.subspace.rank = torch.tensor(0)
    swag_model.load_state_dict(checkpoint['state_dict'])
예제 #17
0
loaders, num_classes = data.loaders(args.dataset,
                                    args.data_path,
                                    args.batch_size,
                                    args.num_workers,
                                    model_cfg.transform_train,
                                    model_cfg.transform_test,
                                    use_validation=not args.use_test,
                                    split_classes=args.split_classes,
                                    shuffle_train=False)

print('Preparing model')
if args.method in ['SWAG', 'HomoNoise', 'SWAGDrop']:
    model = SWAG(model_cfg.base,
                 no_cov_mat=not args.cov_mat,
                 max_num_models=args.max_num_models,
                 loading=True,
                 *model_cfg.args,
                 num_classes=num_classes,
                 **model_cfg.kwargs)
elif args.method in ['SGD', 'Dropout', 'KFACLaplace']:
    model = model_cfg.base(*model_cfg.args,
                           num_classes=num_classes,
                           **model_cfg.kwargs)
else:
    assert False
model.cuda()


def train_dropout(m):
    if type(m) == torch.nn.modules.dropout.Dropout:
        m.train()
예제 #18
0
                                    args.num_workers,
                                    model_cfg.transform_train,
                                    model_cfg.transform_test,
                                    use_validation=not args.use_test)

print('Preparing model')
model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       **model_cfg.kwargs)
model.to(args.device)

swag_model = SWAG(model_cfg.base,
                  num_classes=num_classes,
                  subspace_type='pca',
                  subspace_kwargs={
                      'max_rank': 20,
                      'pca_rank': args.rank,
                  },
                  *model_cfg.args,
                  **model_cfg.kwargs)
swag_model.to(args.device)

criterion = losses.cross_entropy

print('Loading: %s' % args.checkpoint)
ckpt = torch.load(args.checkpoint)
swag_model.load_state_dict(ckpt['state_dict'], strict=False)

mean, _, cov_factor = swag_model.export_numpy_parameters(True)

norms = np.linalg.norm(cov_factor, axis=1)
예제 #19
0
    joint_transform=model_cfg.joint_transform,
    ft_joint_transform=model_cfg.ft_joint_transform,
    target_transform=model_cfg.target_transform,
)

if args.loss == "cross_entropy":
    criterion = losses.seg_cross_entropy
else:
    criterion = losses.seg_ale_cross_entropy

print("Preparing model")
if args.method in ["SWAG", "HomoNoise", "SWAGDrop"]:
    model = SWAG(
        model_cfg.base,
        no_cov_mat=False,
        max_num_models=20,
        num_classes=num_classes,
        use_aleatoric=args.loss == "aleatoric",
    )

elif args.method in ["SGD", "Dropout"]:
    # construct and load model
    model = model_cfg.base(num_classes=num_classes,
                           use_aleatoric=args.loss == "aleatoric")
else:
    assert False
model.cuda()


def train_dropout(m):
    if m.__module__ == torch.nn.modules.dropout.__name__:
torch.cuda.manual_seed(args.seed)

print('Using model %s' % args.model)
model_class = getattr(torchvision.models, args.model)

print('Loading ImageNet from %s' % (args.data_path))
loaders, num_classes = data.loaders(
    args.data_path,
    args.batch_size,
    args.num_workers,
)

print('Preparing model')
swag_model = SWAG(model_class,
                  no_cov_mat=not args.cov_mat,
                  loading=True,
                  max_num_models=20,
                  num_classes=num_classes)
swag_model.to(args.device)

criterion = losses.cross_entropy

print('Loading checkpoint %s' % args.ckpt)
checkpoint = torch.load(args.ckpt)
swag_model.load_state_dict(checkpoint['state_dict'])

print('SWA')
swag_model.sample(0.0)
print('SWA BN update')
utils.bn_update(loaders['train'], swag_model, verbose=True, subset=0.1)
print('SWA EVAL')
예제 #21
0
파일: models.py 프로젝트: yyht/drbayes
class RegressionRunner(RegressionModel):
    def __init__(self,
                 base,
                 epochs,
                 criterion,
                 batch_size=50,
                 lr_init=1e-2,
                 momentum=0.9,
                 wd=1e-4,
                 swag_lr=1e-3,
                 swag_freq=1,
                 swag_start=50,
                 subspace_type='pca',
                 subspace_kwargs={'max_rank': 20},
                 use_cuda=False,
                 use_swag=False,
                 double_bias_lr=False,
                 model_variance=True,
                 num_samples=30,
                 scale=0.5,
                 const_lr=False,
                 *args,
                 **kwargs):

        self.base = base
        self.model = base(*args, **kwargs)
        num_pars = 0
        for p in self.model.parameters():
            num_pars += p.numel()
        print('number of parameters: ', num_pars)

        if use_cuda:
            self.model.cuda()

        if use_swag:
            self.swag_model = SWAG(base,
                                   subspace_type=subspace_type,
                                   subspace_kwargs=subspace_kwargs,
                                   *args,
                                   **kwargs)
            if use_cuda:
                self.swag_model.cuda()
        else:
            self.swag_model = None

        self.use_cuda = use_cuda

        if not double_bias_lr:
            pars = self.model.parameters()
        else:
            pars = []
            for name, module in self.model.named_parameters():
                if 'bias' in str(name):
                    print('Doubling lr of ', name)
                    pars.append({'params': module, 'lr': 2.0 * lr_init})
                else:
                    pars.append({'params': module, 'lr': lr_init})

        self.optimizer = torch.optim.SGD(pars,
                                         lr=lr_init,
                                         momentum=momentum,
                                         weight_decay=wd)

        self.const_lr = const_lr
        self.batch_size = batch_size

        # TODO: set up criterions better for classification
        if model_variance:
            self.criterion = criterion(noise_var=None)
        else:
            self.criterion = criterion(noise_var=1.0)

        if self.criterion.noise_var is not None:
            self.var = self.criterion.noise_var

        self.epochs = epochs

        self.lr_init = lr_init

        self.use_swag = use_swag
        self.swag_start = swag_start
        self.swag_lr = swag_lr
        self.swag_freq = swag_freq

        self.num_samples = num_samples
        self.scale = scale

    def train(self,
              model,
              loader,
              optimizer,
              criterion,
              lr_init=1e-2,
              epochs=3000,
              swag_model=None,
              swag=False,
              swag_start=2000,
              swag_freq=50,
              swag_lr=1e-3,
              print_freq=100,
              use_cuda=False,
              const_lr=False):
        # copied from pavels regression notebook
        if const_lr:
            lr = lr_init

        train_res_list = []
        for epoch in range(epochs):
            if not const_lr:
                t = (epoch + 1) / swag_start if swag else (epoch + 1) / epochs
                lr_ratio = swag_lr / lr_init if swag else 0.05

                if t <= 0.5:
                    factor = 1.0
                elif t <= 0.9:
                    factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
                else:
                    factor = lr_ratio

                lr = factor * lr_init
                adjust_learning_rate(optimizer, factor)

            train_res = utils.train_epoch(loader,
                                          model,
                                          criterion,
                                          optimizer,
                                          cuda=use_cuda,
                                          regression=True)
            train_res_list.append(train_res)
            if swag and epoch > swag_start:
                swag_model.collect_model(model)

            if (epoch % print_freq == 0 or epoch == epochs - 1):
                print('Epoch %d. LR: %g. Loss: %.4f' %
                      (epoch, lr, train_res['loss']))

        return train_res_list

    def fit(self, features, labels):
        self.features, self.labels = torch.FloatTensor(
            features), torch.FloatTensor(labels)

        # construct data loader
        self.data_loader = DataLoader(TensorDataset(self.features,
                                                    self.labels),
                                      batch_size=self.batch_size)

        # now train with pre-specified options
        result = self.train(model=self.model,
                            loader=self.data_loader,
                            optimizer=self.optimizer,
                            criterion=self.criterion,
                            lr_init=self.lr_init,
                            swag_model=self.swag_model,
                            swag=self.use_swag,
                            swag_start=self.swag_start,
                            swag_freq=self.swag_freq,
                            swag_lr=self.swag_lr,
                            use_cuda=self.use_cuda,
                            epochs=self.epochs,
                            const_lr=self.const_lr)

        if self.criterion.noise_var is not None:
            # another forwards pass through network to estimate noise variance
            preds, targets = utils.predictions(model=self.model,
                                               test_loader=self.data_loader,
                                               regression=True,
                                               cuda=self.use_cuda)
            self.var = np.power(np.linalg.norm(preds - targets),
                                2.0) / targets.shape[0]
            print(self.var)

        return result

    def predict(self, features, swag_model=None):
        """
        default prediction method is to use built in Low rank Gaussian
        SWA: scale = 0.0, num_samples = 1
        """
        swag_model = swag_model if swag_model is not None else self.swag_model

        if self.use_cuda:
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')

        with torch.no_grad():

            if swag_model is None:
                self.model.eval()
                preds = self.model(
                    torch.FloatTensor(features).to(device)).data.cpu()

                if preds.size(1) == 1:
                    var = torch.ones_like(preds[:, 0]).unsqueeze(1) * self.var
                else:
                    var = preds[:, 1].view(-1, 1)
                    preds = preds[:, 0].view(-1, 1)

                print(var.mean())

            else:
                prediction = 0
                sq_prediction = 0
                for _ in range(self.num_samples):
                    swag_model.sample(scale=self.scale)
                    current_prediction = swag_model(
                        torch.FloatTensor(features).to(device)).data.cpu()
                    prediction += current_prediction
                    if current_prediction.size(1) == 2:
                        #convert to standard deviation
                        current_prediction[:, 1] = current_prediction[:,
                                                                      1]**0.5

                    sq_prediction += current_prediction**2.0
                # preds = bma/(self.num_samples)

                # compute mean of prediction
                # \mu^*
                preds = (prediction[:, 0] / self.num_samples).view(-1, 1)

                # 1/M \sum(\sigma^2(x) + \mu^2(x)) - \mu*^2
                var = torch.sum(sq_prediction, 1, keepdim=True
                                ) / self.num_samples - preds.pow(2.0)

                # add variance if not heteroscedastic
                if prediction.size(1) == 1:
                    var = var + self.var

            return preds.numpy(), var.numpy()
예제 #22
0
        'max_rank': 20,
        'pca_rank': 'mle',
    }
}, {
    'name': 'freq_dir',
    'kwargs': {
        'max_rank': 20,
    }
}]

for item in subspaces:
    name, kwargs = item['name'], item['kwargs']
    print('Now running %s %r' % (name, kwargs))
    model = model_generator()
    small_swag_model = SWAG(base=model_generator,
                            subspace_type=name,
                            subspace_kwargs=kwargs)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=1e-3,
                                momentum=0.9,
                                weight_decay=1e-4)
    loader = generate_dataloaders(N=10)

    state_dict = None

    for epoch in range(num_epochs):
        model.train()

        for x, y in loader:
            model.zero_grad()
예제 #23
0
파일: models.py 프로젝트: yyht/drbayes
    def __init__(self,
                 base,
                 epochs,
                 criterion,
                 batch_size=50,
                 lr_init=1e-2,
                 momentum=0.9,
                 wd=1e-4,
                 swag_lr=1e-3,
                 swag_freq=1,
                 swag_start=50,
                 subspace_type='pca',
                 subspace_kwargs={'max_rank': 20},
                 use_cuda=False,
                 use_swag=False,
                 double_bias_lr=False,
                 model_variance=True,
                 num_samples=30,
                 scale=0.5,
                 const_lr=False,
                 *args,
                 **kwargs):

        self.base = base
        self.model = base(*args, **kwargs)
        num_pars = 0
        for p in self.model.parameters():
            num_pars += p.numel()
        print('number of parameters: ', num_pars)

        if use_cuda:
            self.model.cuda()

        if use_swag:
            self.swag_model = SWAG(base,
                                   subspace_type=subspace_type,
                                   subspace_kwargs=subspace_kwargs,
                                   *args,
                                   **kwargs)
            if use_cuda:
                self.swag_model.cuda()
        else:
            self.swag_model = None

        self.use_cuda = use_cuda

        if not double_bias_lr:
            pars = self.model.parameters()
        else:
            pars = []
            for name, module in self.model.named_parameters():
                if 'bias' in str(name):
                    print('Doubling lr of ', name)
                    pars.append({'params': module, 'lr': 2.0 * lr_init})
                else:
                    pars.append({'params': module, 'lr': lr_init})

        self.optimizer = torch.optim.SGD(pars,
                                         lr=lr_init,
                                         momentum=momentum,
                                         weight_decay=wd)

        self.const_lr = const_lr
        self.batch_size = batch_size

        # TODO: set up criterions better for classification
        if model_variance:
            self.criterion = criterion(noise_var=None)
        else:
            self.criterion = criterion(noise_var=1.0)

        if self.criterion.noise_var is not None:
            self.var = self.criterion.noise_var

        self.epochs = epochs

        self.lr_init = lr_init

        self.use_swag = use_swag
        self.swag_start = swag_start
        self.swag_lr = swag_lr
        self.swag_freq = swag_freq

        self.num_samples = num_samples
        self.scale = scale
예제 #24
0
print(*model_cfg.args)
model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       input_shape=input_shape,
                       **model_cfg.kwargs)
#model.to(args.device)

if args.cov_mat:
    args.no_cov_mat = False
else:
    args.no_cov_mat = True
if args.swa:
    print("SWAG training")
    swag_model = SWAG(model,
                      no_cov_mat=args.no_cov_mat,
                      max_num_models=args.max_num_models,
                      *model_cfg.args,
                      num_classes=num_classes,
                      **model_cfg.kwargs)
    #swag_model.to(args.device)
else:
    print("SGD training")


def schedule(epoch):
    t = (epoch) / (args.swa_start if args.swa else args.epochs)
    lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01
    if t <= 0.5:
        factor = 1.0
    elif t <= 0.9:
        factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
    else:
예제 #25
0
    target_transform=model_cfg.target_transform,
)

# criterion = nn.NLLLoss(weight=camvid.class_weight[:-1].cuda(), reduction='none').cuda()
if args.loss == "cross_entropy":
    criterion = losses.seg_cross_entropy
else:
    criterion = losses.seg_ale_cross_entropy

# construct and load model
if args.swa_resume is not None:
    checkpoint = torch.load(args.swa_resume)
    model = SWAG(
        model_cfg.base,
        no_cov_mat=False,
        max_num_models=20,
        num_classes=num_classes,
        use_aleatoric=args.loss == "aleatoric",
    )
    model.cuda()
    model.load_state_dict(checkpoint["state_dict"])

    model.sample(0.0)
    bn_update(loaders["fine_tune"], model)
else:
    model = model_cfg.base(num_classes=num_classes,
                           use_aleatoric=args.loss == "aleatoric").cuda()
    checkpoint = torch.load(args.resume)
    start_epoch = checkpoint["epoch"]
    print(start_epoch)
    model.load_state_dict(checkpoint["state_dict"])
예제 #26
0
                                    args.data_path,
                                    args.batch_size,
                                    args.num_workers,
                                    model_cfg.transform_train,
                                    model_cfg.transform_test,
                                    use_validation=not args.use_test,
                                    split_classes=args.split_classes)

print('Preparing model')
print(*model_cfg.args)

swag_model = SWAG(model_cfg.base,
                  num_classes=num_classes,
                  subspace_type='pca',
                  subspace_kwargs={
                      'max_rank': 20,
                      'pca_rank': args.rank,
                  },
                  *model_cfg.args,
                  **model_cfg.kwargs)
swag_model.to(args.device)

print('Loading: %s' % args.checkpoint)
ckpt = torch.load(args.checkpoint)
swag_model.load_state_dict(ckpt['state_dict'], strict=False)

swag_model.set_swa()
print("SWA:",
      utils.eval(loaders["train"], swag_model, criterion=losses.cross_entropy))

mean, var, cov_factor = swag_model.get_space()
예제 #27
0
    criterion = partial(criterion, weight=class_weights)

if args.resume is not None:
    print('Resume training from %s' % args.resume)
    checkpoint = torch.load(args.resume)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    del checkpoint

if args.swa:
    print('SWAG training')
    swag_model = SWAG(model_cfg.base,
                      no_cov_mat=False,
                      max_num_models=20,
                      *model_cfg.args,
                      num_classes=num_classes,
                      use_aleatoric=args.loss == 'aleatoric',
                      **model_cfg.kwargs)
    swag_model.to(args.device)
else:
    print('SGD training')

if args.swa and args.swa_resume is not None:
    checkpoint = torch.load(args.swa_resume)
    swag_model = SWAG(model_cfg.base,
                      no_cov_mat=False,
                      max_num_models=20,
                      *model_cfg.args,
                      num_classes=num_classes,
                      use_aleatoric=args.loss == 'aleatoric',
예제 #28
0
if torch.cuda.is_available():
    args.device = torch.device('cuda')
else:
    args.device = torch.device('cpu')

print('Using model %s' % args.model)
model_cfg = getattr(models, args.model)

print('Preparing model')
print(*model_cfg.args)
model = model_cfg.base(*model_cfg.args,
                       num_classes=args.num_classes,
                       **model_cfg.kwargs)
model.to(args.device)

swag_model = SWAG(model_cfg.base,
                  subspace_type=args.subspace,
                  subspace_kwargs={'max_rank': args.max_num_models},
                  *model_cfg.args,
                  num_classes=args.num_classes,
                  **model_cfg.kwargs)
swag_model.to(args.device)

for path in args.checkpoint:
    print(path)
    ckpt = torch.load(path)
    model.load_state_dict(ckpt['state_dict'])
    swag_model.collect_model(model)

torch.save({'state_dict': swag_model.state_dict()}, args.path)
예제 #29
0
    def test_swag_cov(self, **kwargs):
        model = torch.nn.Linear(300, 3, bias=True)

        swag_model = SWAG(
            torch.nn.Linear,
            in_features=300,
            out_features=3,
            bias=True,
            no_cov_mat=False,
            max_num_models=100,
            loading=False,
        )

        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

        # construct swag model via training
        torch.manual_seed(0)
        for _ in range(101):
            model.zero_grad()

            input = torch.randn(100, 300)
            output = model(input)
            loss = ((torch.randn(100, 3) - output) ** 2.0).sum()
            loss.backward()

            optimizer.step()

            swag_model.collect_model(model)

        # check to ensure parameters have the correct sizes
        mean_list = []
        sq_mean_list = []
        cov_mat_sqrt_list = []
        for (module, name), param in zip(swag_model.params, model.parameters()):
            mean = module.__getattr__("%s_mean" % name)
            sq_mean = module.__getattr__("%s_sq_mean" % name)
            cov_mat_sqrt = module.__getattr__("%s_cov_mat_sqrt" % name)

            self.assertEqual(param.size(), mean.size())
            self.assertEqual(param.size(), sq_mean.size())
            self.assertEqual(
                [swag_model.max_num_models, param.numel()], list(cov_mat_sqrt.size())
            )

            mean_list.append(mean)
            sq_mean_list.append(sq_mean)
            cov_mat_sqrt_list.append(cov_mat_sqrt)

        mean = flatten(mean_list).cuda()
        sq_mean = flatten(sq_mean_list).cuda()
        cov_mat_sqrt = torch.cat(cov_mat_sqrt_list, dim=1).cuda()

        true_cov_mat = (
            1.0 / (swag_model.max_num_models - 1)
        ) * cov_mat_sqrt.t().matmul(cov_mat_sqrt) + torch.diag(sq_mean - mean ** 2)

        test_cutoff = chi2(df=mean.numel()).ppf(
            0.95
        )  # 95% quantile of p dimensional chi-square distribution

        for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]:
            scaled_cov_mat = true_cov_mat * scale
            scaled_cov_inv = torch.inverse(scaled_cov_mat)
            # now test to ensure that sampling has the correct covariance matrix probabilistically
            all_qforms = []
            for _ in range(2000):
                swag_model.sample(scale=scale, cov=True)
                curr_pars = []
                for (module, name) in swag_model.params:
                    curr_pars.append(getattr(module, name))
                dev = flatten(curr_pars) - mean

                # (x - mu)sigma^{-1}(x - mu)
                qform = dev.matmul(scaled_cov_inv).matmul(dev)

                all_qforms.append(qform.item())

            samples_in_cr = (np.array(all_qforms) < test_cutoff).sum()
            print(samples_in_cr)

            # between 94 and 96% of the samples should fall within the threshold
            # this should be very loose
            self.assertTrue(1880 <= samples_in_cr <= 1920)