def ResidualTabularWPrior(num_classes, dim_in, coupling_layers, k, means_r=1., cov_std=1., nperlayer=1, acc=0.9): #print(f'Instantiating means with dimension {dim_in}.') device = torch.device('cuda') inv_cov_std = torch.ones((num_classes, ), device=device) / cov_std model = TabularResidualFlow( in_dim=dim_in, hidden_dim=k, num_per_block=coupling_layers) #*np.sqrt(1000/dim_in)/3 dist_scaling = np.sqrt(-8 * np.log(1 - acc)) means = utils.get_means('random', r=means_r * dist_scaling, num_means=num_classes, trainloader=None, shape=(dim_in), device=device) means[0] /= means[0].norm() means[0] *= dist_scaling / 2 means[1] = -means[0] model.prior = SSLGaussMixture(means, inv_cov_std, device=device) means_np = means.cpu().numpy() #print("Pairwise dists:", cdist(means_np, means_np)) return model
def RealNVPTabularWPrior(num_classes, dim_in, coupling_layers, k, means_r=.8, cov_std=1., nperlayer=1, acc=0.9): #print(f'Instantiating means with dimension {dim_in}.') device = torch.device('cuda') inv_cov_std = torch.ones((num_classes, ), device=device) / cov_std model = RealNVPTabular(num_coupling_layers=coupling_layers, in_dim=dim_in, hidden_dim=k, num_layers=1, dropout=True) #*np.sqrt(1000/dim_in)/3 #dist_scaling = np.sqrt(-8*np.log(1-acc))#np.sqrt(4*np.log(20)/dim_in)#np.sqrt(1000/dim_in) if num_classes == 2: means = utils.get_means('random', r=means_r, num_means=num_classes, trainloader=None, shape=(dim_in), device=device) #means = torch.zeros(2,dim_in,device=device) #means[0,1] = 3.75 dist = 2 * (means[0]**2).sum().sqrt() means[0] *= 7.5 / dist means[1] = -means[0] # means[0] /= means[0].norm() # means[0] *= dist_scaling/2 # means[1] = - means[0] model.prior = SSLGaussMixture(means, inv_cov_std, device=device) means_np = means.cpu().numpy() else: means = utils.get_means('random', r=means_r * .7, num_means=num_classes, trainloader=None, shape=(dim_in), device=device) model.prior = SSLGaussMixture(means, inv_cov_std, device=device) means_np = means.cpu().numpy() print("Means :", means_np) print("Pairwise dists:", cdist(means_np, means_np)) return model
if args.resume is not None: print("Using the means for ckpt") means = checkpoint['means'] print("Means:", means) print("Cov std:", cov_std) means_np = means.cpu().numpy() print("Pairwise dists:", cdist(means_np, means_np)) if args.means_trainable: print("Using learnable means") means = torch.tensor(means_np, requires_grad=True) mean_imgs = torchvision.utils.make_grid(means.reshape((10, *img_shape)), nrow=5) writer.add_image("means", mean_imgs) prior = SSLGaussMixture(means, device=device) loss_fn = FlowLoss(prior) #PAVEL: check why do we need this param_groups = utils.get_param_groups(net, args.weight_decay, norm_suffix='weight_g') if args.means_trainable: param_groups.append({'name': 'means', 'params': means}) if args.optimizer == "SGD": optimizer = optim.SGD(param_groups, lr=args.lr) elif args.optimizer == "Adam": optimizer = optim.Adam(param_groups, lr=args.lr) for epoch in range(start_epoch, args.num_epochs):