示例#1
0
    def test_swag_cov(self, **kwargs):
        model = torch.nn.Linear(300, 3, bias=True)

        swag_model = SWAG(torch.nn.Linear, in_features=300, out_features=3, bias=True, subspace_type = 'covariance',
                subspace_kwargs = {'max_rank':140})

        optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)

        # construct swag model via training
        torch.manual_seed(0)
        for _ in range(101):
            model.zero_grad()

            input = torch.randn(100, 300)
            output = model(input)
            loss = ((torch.randn(100, 3) - output)**2.0).sum()
            loss.backward()
            
            optimizer.step()
            
            swag_model.collect_model(model)

        # check to ensure parameters have the correct sizes
        mean, var = swag_model._get_mean_and_variance()
        cov_mat_sqrt = swag_model.subspace.get_space()

        true_cov_mat = cov_mat_sqrt.t().matmul(cov_mat_sqrt) + torch.diag(var)

        test_cutoff = chi2(df = mean.numel()).ppf(0.95) #95% quantile of p dimensional chi-square distribution

        for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]:
            scaled_cov_mat = true_cov_mat * scale
            scaled_cov_inv = torch.inverse(scaled_cov_mat)
            # now test to ensure that sampling has the correct covariance matrix probabilistically
            all_qforms = []            
            for _ in range(3000):
                swag_model.sample(scale=scale)
                curr_pars = []
                for (module, name, _) in swag_model.base_params:
                    curr_pars.append(getattr(module, name))
                dev = flatten(curr_pars) - mean

                #(x - mu)sigma^{-1}(x - mu)
                qform = dev.matmul(scaled_cov_inv).matmul(dev)

                all_qforms.append(qform.item())
            
            samples_in_cr = (np.array(all_qforms) < test_cutoff).sum()
            print(samples_in_cr)

            #between 94 and 96% of the samples should fall within the threshold
            #this should be very loose
            self.assertTrue(0.94*3000 <= samples_in_cr <= 0.96*3000) 
示例#2
0
def RealNVPTabularSWAG(dim_in,
                       coupling_layers,
                       k,
                       nperlayer=1,
                       subspace='covariance',
                       max_num_models=10):
    swag_model = SWAG(RealNVPTabular,
                      subspace_type=subspace,
                      subspace_kwargs={'max_rank': max_num_models},
                      num_coupling_layers=coupling_layers,
                      in_dim=dim_in,
                      hidden_dim=k,
                      num_layers=1,
                      dropout=True)
    return swag_model
print('Using model %s' % args.model)
model_cfg = getattr(models, args.model)

print('Loading dataset %s from %s' % (args.dataset, args.data_path))
loaders, num_classes = data.loaders(args.dataset,
                                    args.data_path,
                                    args.batch_size,
                                    args.num_workers,
                                    model_cfg.transform_train,
                                    model_cfg.transform_test,
                                    use_validation=not args.use_test)

print('Preparing model')
swag_model = SWAG(model_cfg.base,
                  no_cov_mat=not args.cov_mat,
                  max_num_models=args.swag_rank,
                  *model_cfg.args,
                  num_classes=num_classes,
                  **model_cfg.kwargs)
swag_model.to(args.device)

ckpt = torch.load(args.checkpoint)

criterion = losses.cross_entropy

fractions = np.logspace(-np.log10(0.005 * len(loaders['train'].dataset)), 0.0,
                        args.N)
swa_accuracies = np.zeros(args.N)
swa_nlls = np.zeros(args.N)
swag_accuracies = np.zeros(args.N)
swag_nlls = np.zeros(args.N)
示例#4
0
print(*model_cfg.args)
model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       input_shape=input_shape,
                       **model_cfg.kwargs)
#model.to(args.device)

if args.cov_mat:
    args.no_cov_mat = False
else:
    args.no_cov_mat = True
if args.swa:
    print("SWAG training")
    swag_model = SWAG(model,
                      no_cov_mat=args.no_cov_mat,
                      max_num_models=args.max_num_models,
                      *model_cfg.args,
                      num_classes=num_classes,
                      **model_cfg.kwargs)
    #swag_model.to(args.device)
else:
    print("SGD training")


def schedule(epoch):
    t = (epoch) / (args.swa_start if args.swa else args.epochs)
    lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01
    if t <= 0.5:
        factor = 1.0
    elif t <= 0.9:
        factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
    else:
示例#5
0
print(*model_cfg.args)
model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       **model_cfg.kwargs)
wandb.watch(model)
model.to(args.device)

if args.cov_mat:
    args.no_cov_mat = False
else:
    args.no_cov_mat = True
if args.swa:
    print("SWAG training")
    swag_model = SWAG(model_cfg.base,
                      no_cov_mat=args.no_cov_mat,
                      max_num_models=args.max_num_models,
                      *model_cfg.args,
                      num_classes=num_classes,
                      **model_cfg.kwargs)
    swag_model.to(args.device)
else:
    print("SGD training")


def schedule(epoch):
    t = (epoch) / (args.swa_start if args.swa else args.epochs)
    lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01
    if t <= 0.5:
        factor = 1.0
    elif t <= 0.9:
        factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
    else:
示例#6
0
        'max_rank': 20,
        'pca_rank': 'mle',
    }
}, {
    'name': 'freq_dir',
    'kwargs': {
        'max_rank': 20,
    }
}]

for item in subspaces:
    name, kwargs = item['name'], item['kwargs']
    print('Now running %s %r' % (name, kwargs))
    model = model_generator()
    small_swag_model = SWAG(base=model_generator,
                            subspace_type=name,
                            subspace_kwargs=kwargs)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=1e-3,
                                momentum=0.9,
                                weight_decay=1e-4)
    loader = generate_dataloaders(N=10)

    state_dict = None

    for epoch in range(num_epochs):
        model.train()

        for x, y in loader:
            model.zero_grad()
    label_arr = np.load(args.label_arr)
    print("Corruption:",
          (loaders['train'].dataset.targets != label_arr).mean())
    loaders['train'].dataset.targets = label_arr

print('Preparing model')
model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       **model_cfg.kwargs)
model.to(args.device)
print("Model has {} parameters".format(
    sum([p.numel() for p in model.parameters()])))

swag_model = SWAG(model_cfg.base,
                  args.subspace, {'max_rank': args.max_num_models},
                  *model_cfg.args,
                  num_classes=num_classes,
                  **model_cfg.kwargs)
swag_model.to(args.device)

columns = ['swag', 'sample', 'te_loss', 'te_acc', 'ens_loss', 'ens_acc']

if args.warm_start:
    print('estimating initial PI using K-Means')
    kmeans_input = []
    for ckpt_i, ckpt in enumerate(args.swag_ckpts):
        #print("Checkpoint {}".format(ckpt))
        checkpoint = torch.load(ckpt)
        swag_model.subspace.rank = torch.tensor(0)
        swag_model.load_state_dict(checkpoint['state_dict'])
        mean, variance = swag_model._get_mean_and_variance()
torch.cuda.manual_seed(args.seed)

print('Using model %s' % args.model)
model_class = getattr(torchvision.models, args.model)

print('Loading ImageNet from %s' % (args.data_path))
loaders, num_classes = data.loaders(
    args.data_path,
    args.batch_size,
    args.num_workers,
)

print('Preparing model')
swag_model = SWAG(model_class,
                  no_cov_mat=not args.cov_mat,
                  loading=True,
                  max_num_models=20,
                  num_classes=num_classes)
swag_model.to(args.device)

criterion = losses.cross_entropy

print('Loading checkpoint %s' % args.ckpt)
checkpoint = torch.load(args.ckpt)
swag_model.load_state_dict(checkpoint['state_dict'])

print('SWA')
swag_model.sample(0.0)
print('SWA BN update')
utils.bn_update(loaders['train'], swag_model, verbose=True, subset=0.1)
print('SWA EVAL')
示例#9
0
    def test_swag_cov(self, **kwargs):
        model = torch.nn.Linear(300, 3, bias=True)

        swag_model = SWAG(
            torch.nn.Linear,
            in_features=300,
            out_features=3,
            bias=True,
            no_cov_mat=False,
            max_num_models=100,
            loading=False,
        )

        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

        # construct swag model via training
        torch.manual_seed(0)
        for _ in range(101):
            model.zero_grad()

            input = torch.randn(100, 300)
            output = model(input)
            loss = ((torch.randn(100, 3) - output) ** 2.0).sum()
            loss.backward()

            optimizer.step()

            swag_model.collect_model(model)

        # check to ensure parameters have the correct sizes
        mean_list = []
        sq_mean_list = []
        cov_mat_sqrt_list = []
        for (module, name), param in zip(swag_model.params, model.parameters()):
            mean = module.__getattr__("%s_mean" % name)
            sq_mean = module.__getattr__("%s_sq_mean" % name)
            cov_mat_sqrt = module.__getattr__("%s_cov_mat_sqrt" % name)

            self.assertEqual(param.size(), mean.size())
            self.assertEqual(param.size(), sq_mean.size())
            self.assertEqual(
                [swag_model.max_num_models, param.numel()], list(cov_mat_sqrt.size())
            )

            mean_list.append(mean)
            sq_mean_list.append(sq_mean)
            cov_mat_sqrt_list.append(cov_mat_sqrt)

        mean = flatten(mean_list).cuda()
        sq_mean = flatten(sq_mean_list).cuda()
        cov_mat_sqrt = torch.cat(cov_mat_sqrt_list, dim=1).cuda()

        true_cov_mat = (
            1.0 / (swag_model.max_num_models - 1)
        ) * cov_mat_sqrt.t().matmul(cov_mat_sqrt) + torch.diag(sq_mean - mean ** 2)

        test_cutoff = chi2(df=mean.numel()).ppf(
            0.95
        )  # 95% quantile of p dimensional chi-square distribution

        for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]:
            scaled_cov_mat = true_cov_mat * scale
            scaled_cov_inv = torch.inverse(scaled_cov_mat)
            # now test to ensure that sampling has the correct covariance matrix probabilistically
            all_qforms = []
            for _ in range(2000):
                swag_model.sample(scale=scale, cov=True)
                curr_pars = []
                for (module, name) in swag_model.params:
                    curr_pars.append(getattr(module, name))
                dev = flatten(curr_pars) - mean

                # (x - mu)sigma^{-1}(x - mu)
                qform = dev.matmul(scaled_cov_inv).matmul(dev)

                all_qforms.append(qform.item())

            samples_in_cr = (np.array(all_qforms) < test_cutoff).sum()
            print(samples_in_cr)

            # between 94 and 96% of the samples should fall within the threshold
            # this should be very loose
            self.assertTrue(1880 <= samples_in_cr <= 1920)
示例#10
0
    criterion = partial(criterion, weight=class_weights)

if args.resume is not None:
    print('Resume training from %s' % args.resume)
    checkpoint = torch.load(args.resume)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    del checkpoint

if args.swa:
    print('SWAG training')
    swag_model = SWAG(model_cfg.base,
                      no_cov_mat=False,
                      max_num_models=20,
                      *model_cfg.args,
                      num_classes=num_classes,
                      use_aleatoric=args.loss == 'aleatoric',
                      **model_cfg.kwargs)
    swag_model.to(args.device)
else:
    print('SGD training')

if args.swa and args.swa_resume is not None:
    checkpoint = torch.load(args.swa_resume)
    swag_model = SWAG(model_cfg.base,
                      no_cov_mat=False,
                      max_num_models=20,
                      *model_cfg.args,
                      num_classes=num_classes,
                      use_aleatoric=args.loss == 'aleatoric',
示例#11
0
    target_transform=model_cfg.target_transform,
)

# criterion = nn.NLLLoss(weight=camvid.class_weight[:-1].cuda(), reduction='none').cuda()
if args.loss == "cross_entropy":
    criterion = losses.seg_cross_entropy
else:
    criterion = losses.seg_ale_cross_entropy

# construct and load model
if args.swa_resume is not None:
    checkpoint = torch.load(args.swa_resume)
    model = SWAG(
        model_cfg.base,
        no_cov_mat=False,
        max_num_models=20,
        num_classes=num_classes,
        use_aleatoric=args.loss == "aleatoric",
    )
    model.cuda()
    model.load_state_dict(checkpoint["state_dict"])

    model.sample(0.0)
    bn_update(loaders["fine_tune"], model)
else:
    model = model_cfg.base(num_classes=num_classes,
                           use_aleatoric=args.loss == "aleatoric").cuda()
    checkpoint = torch.load(args.resume)
    start_epoch = checkpoint["epoch"]
    print(start_epoch)
    model.load_state_dict(checkpoint["state_dict"])
示例#12
0
文件: models.py 项目: yyht/drbayes
    def __init__(self,
                 base,
                 epochs,
                 criterion,
                 batch_size=50,
                 lr_init=1e-2,
                 momentum=0.9,
                 wd=1e-4,
                 swag_lr=1e-3,
                 swag_freq=1,
                 swag_start=50,
                 subspace_type='pca',
                 subspace_kwargs={'max_rank': 20},
                 use_cuda=False,
                 use_swag=False,
                 double_bias_lr=False,
                 model_variance=True,
                 num_samples=30,
                 scale=0.5,
                 const_lr=False,
                 *args,
                 **kwargs):

        self.base = base
        self.model = base(*args, **kwargs)
        num_pars = 0
        for p in self.model.parameters():
            num_pars += p.numel()
        print('number of parameters: ', num_pars)

        if use_cuda:
            self.model.cuda()

        if use_swag:
            self.swag_model = SWAG(base,
                                   subspace_type=subspace_type,
                                   subspace_kwargs=subspace_kwargs,
                                   *args,
                                   **kwargs)
            if use_cuda:
                self.swag_model.cuda()
        else:
            self.swag_model = None

        self.use_cuda = use_cuda

        if not double_bias_lr:
            pars = self.model.parameters()
        else:
            pars = []
            for name, module in self.model.named_parameters():
                if 'bias' in str(name):
                    print('Doubling lr of ', name)
                    pars.append({'params': module, 'lr': 2.0 * lr_init})
                else:
                    pars.append({'params': module, 'lr': lr_init})

        self.optimizer = torch.optim.SGD(pars,
                                         lr=lr_init,
                                         momentum=momentum,
                                         weight_decay=wd)

        self.const_lr = const_lr
        self.batch_size = batch_size

        # TODO: set up criterions better for classification
        if model_variance:
            self.criterion = criterion(noise_var=None)
        else:
            self.criterion = criterion(noise_var=1.0)

        if self.criterion.noise_var is not None:
            self.var = self.criterion.noise_var

        self.epochs = epochs

        self.lr_init = lr_init

        self.use_swag = use_swag
        self.swag_start = swag_start
        self.swag_lr = swag_lr
        self.swag_freq = swag_freq

        self.num_samples = num_samples
        self.scale = scale
    args.batch_size,
    args.num_workers,
)

print('Preparing model')
model = model_class(pretrained=args.pretrained, num_classes=num_classes)
model.to(args.device)

if args.cov_mat:
    args.no_cov_mat = False
else:
    args.no_cov_mat = True
if args.swa:
    print('SWAG training')
    args.swa_device = 'cpu' if args.swa_cpu else args.device
    swag_model = SWAG(model_class, no_cov_mat=args.no_cov_mat, max_num_models=20,
                      num_classes=num_classes)
    swag_model.to(args.swa_device)
    if args.pretrained:
        model.to(args.swa_device)
        swag_model.collect_model(model)
        model.to(args.device)
else:
    print('SGD training')


def schedule(epoch):
    if args.swa and epoch >= args.swa_start:
        return args.swa_lr
    else:
        return args.lr_init * (0.1 ** (epoch // 30))
args = parser.parse_args()

args.device = None
if torch.cuda.is_available():
    args.device = torch.device('cuda')
else:
    args.device = torch.device('cpu')

# torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

print('Using model %s' % args.model)
model_class = getattr(torchvision.models, args.model)

num_classes = 1000

print('Preparing model')
model = model_class(pretrained=args.pretrained, num_classes=num_classes)
model.to(args.device)

print('SWAG training')
swag_model = SWAG(model_class,
                  no_cov_mat=False,
                  max_num_models=20,
                  num_classes=num_classes)
swag_model.to('cpu')

for k in range(100):
    swag_model.collect_model(model)
    print(k + 1)
示例#15
0
                                    use_validation=not args.use_test,
                                    split_classes=args.split_classes)

print('Preparing model')
print(*model_cfg.args)

model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       **model_cfg.kwargs)
model.cuda()

swag_model = SWAG(model_cfg.base,
                  num_classes=num_classes,
                  subspace_type='pca',
                  subspace_kwargs={
                      'max_rank': 140,
                      'pca_rank': args.rank,
                  },
                  *model_cfg.args,
                  **model_cfg.kwargs)
swag_model.to(args.device)


def checkpoint_num(filename):
    num = filename.split("-")[1]
    num = num.split(".")[0]
    num = int(num)
    return num


for file in os.listdir(args.dir):
示例#16
0
if torch.cuda.is_available():
    args.device = torch.device('cuda')
else:
    args.device = torch.device('cpu')

print('Using model %s' % args.model)
model_cfg = getattr(models, args.model)

print('Preparing model')
print(*model_cfg.args)
model = model_cfg.base(*model_cfg.args,
                       num_classes=args.num_classes,
                       **model_cfg.kwargs)
model.to(args.device)

swag_model = SWAG(model_cfg.base,
                  subspace_type=args.subspace,
                  subspace_kwargs={'max_rank': args.max_num_models},
                  *model_cfg.args,
                  num_classes=args.num_classes,
                  **model_cfg.kwargs)
swag_model.to(args.device)

for path in args.checkpoint:
    print(path)
    ckpt = torch.load(path)
    model.load_state_dict(ckpt['state_dict'])
    swag_model.collect_model(model)

torch.save({'state_dict': swag_model.state_dict()}, args.path)
示例#17
0
    def test_swag_diag(self, **kwargs):
        model = torch.nn.Linear(300, 3, bias=True)

        swag_model = SWAG(
            torch.nn.Linear,
            in_features=300,
            out_features=3,
            bias=True,
            no_cov_mat=True,
            max_num_models=100,
            loading=False,
        )

        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

        # construct swag model via training
        torch.manual_seed(0)
        for _ in range(101):
            model.zero_grad()

            input = torch.randn(100, 300)
            output = model(input)
            loss = ((torch.randn(100, 3) - output) ** 2.0).sum()
            loss.backward()

            optimizer.step()

            swag_model.collect_model(model)

        # check to ensure parameters have the correct sizes
        mean_list = []
        sq_mean_list = []
        for (module, name), param in zip(swag_model.params, model.parameters()):
            mean = module.__getattr__("%s_mean" % name)
            sq_mean = module.__getattr__("%s_sq_mean" % name)

            self.assertEqual(param.size(), mean.size())
            self.assertEqual(param.size(), sq_mean.size())

            mean_list.append(mean)
            sq_mean_list.append(sq_mean)

        mean = flatten(mean_list).cuda()
        sq_mean = flatten(sq_mean_list).cuda()

        for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]:
            var = scale * (sq_mean - mean ** 2)

            std = torch.sqrt(var)
            dist = torch.distributions.Normal(mean, std)

            # now test to ensure that sampling has the correct covariance matrix probabilistically
            all_qforms = 0
            for _ in range(20):
                swag_model.sample(scale=scale, cov=False)
                curr_pars = []
                for (module, name) in swag_model.params:
                    curr_pars.append(getattr(module, name))

                curr_probs = dist.cdf(flatten(curr_pars))

                # check if within 95% CI
                num_in_cr = ((curr_probs > 0.025) & (curr_probs < 0.975)).float().sum()
                # all_qforms.append( num_in_cr )
                all_qforms += num_in_cr

            # print(all_qforms/(20 * mean.numel()))
            # now compute average
            avg_prob_in_cr = all_qforms / (20 * mean.numel())

            # CLT should hold a bit tighter here
            self.assertTrue(0.945 <= avg_prob_in_cr <= 0.955)