def ResidualTabularWPrior(num_classes,
                          dim_in,
                          coupling_layers,
                          k,
                          means_r=1.,
                          cov_std=1.,
                          nperlayer=1,
                          acc=0.9):
    #print(f'Instantiating means with dimension {dim_in}.')
    device = torch.device('cuda')
    inv_cov_std = torch.ones((num_classes, ), device=device) / cov_std
    model = TabularResidualFlow(
        in_dim=dim_in, hidden_dim=k,
        num_per_block=coupling_layers)  #*np.sqrt(1000/dim_in)/3
    dist_scaling = np.sqrt(-8 * np.log(1 - acc))
    means = utils.get_means('random',
                            r=means_r * dist_scaling,
                            num_means=num_classes,
                            trainloader=None,
                            shape=(dim_in),
                            device=device)
    means[0] /= means[0].norm()
    means[0] *= dist_scaling / 2
    means[1] = -means[0]
    model.prior = SSLGaussMixture(means, inv_cov_std, device=device)
    means_np = means.cpu().numpy()
    #print("Pairwise dists:", cdist(means_np, means_np))
    return model
def RealNVPTabularWPrior(num_classes,
                         dim_in,
                         coupling_layers,
                         k,
                         means_r=.8,
                         cov_std=1.,
                         nperlayer=1,
                         acc=0.9):
    #print(f'Instantiating means with dimension {dim_in}.')
    device = torch.device('cuda')
    inv_cov_std = torch.ones((num_classes, ), device=device) / cov_std
    model = RealNVPTabular(num_coupling_layers=coupling_layers,
                           in_dim=dim_in,
                           hidden_dim=k,
                           num_layers=1,
                           dropout=True)  #*np.sqrt(1000/dim_in)/3
    #dist_scaling = np.sqrt(-8*np.log(1-acc))#np.sqrt(4*np.log(20)/dim_in)#np.sqrt(1000/dim_in)
    if num_classes == 2:
        means = utils.get_means('random',
                                r=means_r,
                                num_means=num_classes,
                                trainloader=None,
                                shape=(dim_in),
                                device=device)
        #means = torch.zeros(2,dim_in,device=device)
        #means[0,1] = 3.75
        dist = 2 * (means[0]**2).sum().sqrt()
        means[0] *= 7.5 / dist
        means[1] = -means[0]
        # means[0] /= means[0].norm()
        # means[0] *= dist_scaling/2
        # means[1] = - means[0]
        model.prior = SSLGaussMixture(means, inv_cov_std, device=device)
        means_np = means.cpu().numpy()
    else:
        means = utils.get_means('random',
                                r=means_r * .7,
                                num_means=num_classes,
                                trainloader=None,
                                shape=(dim_in),
                                device=device)
        model.prior = SSLGaussMixture(means, inv_cov_std, device=device)
        means_np = means.cpu().numpy()
    print("Means :", means_np)
    print("Pairwise dists:", cdist(means_np, means_np))
    return model
Esempio n. 3
0
if args.resume is not None:
    print("Using the means for ckpt")
    means = checkpoint['means']

print("Means:", means)
print("Cov std:", cov_std)
means_np = means.cpu().numpy()
print("Pairwise dists:", cdist(means_np, means_np))

if args.means_trainable:
    print("Using learnable means")
    means = torch.tensor(means_np, requires_grad=True)
mean_imgs = torchvision.utils.make_grid(means.reshape((10, *img_shape)),
                                        nrow=5)
writer.add_image("means", mean_imgs)
prior = SSLGaussMixture(means, device=device)
loss_fn = FlowLoss(prior)

#PAVEL: check why do we need this
param_groups = utils.get_param_groups(net,
                                      args.weight_decay,
                                      norm_suffix='weight_g')
if args.means_trainable:
    param_groups.append({'name': 'means', 'params': means})

if args.optimizer == "SGD":
    optimizer = optim.SGD(param_groups, lr=args.lr)
elif args.optimizer == "Adam":
    optimizer = optim.Adam(param_groups, lr=args.lr)

for epoch in range(start_epoch, args.num_epochs):