def make_group_generator(): # Note that this Variable is NOT going to show up in `net.parameters()` and # therefore it is implicitly free from the ridge penalty/p(theta) prior. log_sigma = Variable(torch.log(1e-2 * torch.ones(image_size)), requires_grad=True) return NormalNet(mu_net=torch.nn.Linear(group_input_dim, image_size), sigma_net=Lambda(lambda x, log_sigma: torch.exp( log_sigma.expand(x.size(0), -1)) + 1e-3, extra_args=(log_sigma, )))
A = torch.zeros(image_size, image_size) A[i, :] = 0.5 A = torch.Tensor(A) data.append(A.view(-1)) temp = torch.stack([data[i] for i in range(len(data))]) X = torch.stack([temp for _ in range(num_samples)]) X += 0.05 * torch.randn(X.size()) X = X.transpose(0, 1) stddev_multiple = 0.1 inference_net = NormalNet(mu_net=nn.Sequential(nn.Linear(dim_h + dim_h, dim_z)), sigma_net=torch.nn.Sequential( nn.Linear(dim_h + dim_h, dim_z), Lambda(torch.exp), Lambda(lambda x: x * stddev_multiple + 1e-3))) def make_group_generator(): # Note that this Variable is NOT going to show up in `net.parameters()` and # therefore it is implicitly free from the ridge penalty/p(theta) prior. log_sigma = Variable(torch.log(1e-2 * torch.ones(image_size)), requires_grad=True) return NormalNet(mu_net=torch.nn.Sequential( torch.nn.Tanh(), torch.nn.Linear(group_input_dim, image_size)), sigma_net=Lambda(lambda x, log_sigma: torch.exp( log_sigma.expand(x.size(0), -1)) + 1e-3, extra_args=(log_sigma, )))
# `inference_net` below. Zero means that the model has no actual connection to # the output and therefore the standard deviation defaults to the minimum. One # means that we're learning the real model. This value is flipped to 1 after # some number of iterations. stddev_multiple = 0.1 inference_net = NormalNet( mu_net=torch.nn.Sequential( # inference_net_base, torch.nn.Linear(dim_x, dim_z)), # Fixed standard deviation # sigma_net=Lambda(lambda x: 1e-3 * Variable(torch.ones(x.size(0), dim_z))) # Learned constant standard deviation # sigma_net=Lambda( # lambda x: torch.exp(inference_net_log_stddev.expand(x.size(0), -1)) + 1e-3 # ) # Learned standard deviation as a function of the input sigma_net=torch.nn.Sequential( # inference_net_base, torch.nn.Linear(dim_x, dim_z), Lambda(torch.exp), Lambda(lambda x: x * stddev_multiple + 1e-3))) def make_group_generator(group_output_dim): # Note that this Variable is NOT going to show up in `net.parameters()` and # therefore it is implicitly free from the ridge penalty/p(theta) prior. log_sigma = Variable(torch.log(1e-2 * torch.ones(group_output_dim).type(
shuffle=True ) dim_x = metrain.size(2) num_groups = len(groups) #group_dims = 196 stddev_multiple = 1 group_dims = [196 for grp in groups] inference_net = NormalNet( mu_net=torch.nn.Sequential( # inference_net_base, torch.nn.Linear(dim_h + dim_h, dim_z) ), sigma_net=torch.nn.Sequential( # inference_net_base, torch.nn.Linear(dim_h + dim_h, dim_z), Lambda(torch.exp), Lambda(lambda x: x * stddev_multiple + 1e-3) ) ) def make_group_generator(output_dim): # Note that this Variable is NOT going to show up in `net.parameters()` and # therefore it is implicitly free from the ridge penalty/p(theta) prior. log_sigma = Variable( torch.log(1e-2 * torch.ones(output_dim)), requires_grad=True ) return NormalNet(