def guide(): loc = pyro.param('loc', torch.tensor(0.)) scale = pyro.param('scale', torch.tensor(0.5)) pyro.sample('latent', Normal(loc, scale))
def guide(self, data): encoder = pyro.module("encoder", self.vae_encoder) with pyro.plate("data", data.size(0)): z_mean, z_var = encoder.forward(data) pyro.sample("latent", Normal(z_mean, z_var.sqrt()).to_event(1))
def model(data): # print("in model") # decay = pyro.sample("decay", dist.Normal(Variable(torch.FloatTensor([0.0])), Variable(torch.FloatTensor([10.0])))) decay = pyro.sample("decay", dist.Normal(0.0, 10.0)) # distances = torch.FloatTensor(xrange(1, 36)) # print(decay) # print(decay+1) # print(decay*2) # quit() #print(torch.log(distances)) distances = xrange(1, 36) contributions = [torch.exp(decay * log(d)) for d in distances] # contributions = torch.exp(torch.mul(torch.log(distances), decay)) byParticipant_Intercept_logVariance = pyro.sample( "byParticipant_Intercept_logVariance", byParticipantIntercept_Variance_Prior) byItem_Intercept_logVariance = pyro.sample("byItem_Intercept_logVariance", byItemIntercept_Variance_Prior) byParticipant_SurprisalSlope_logVariance = pyro.sample( "byParticipant_SurprisalSlope_logVariance", byParticipantSlope_Variance_Prior) byParticipant_LogWordFreqSlope_logVariance = pyro.sample( "byParticipant_LogWordFreqSlope_logVariance", byParticipantSlope_Variance_Prior) byItem_Slope_logVariance = pyro.sample("byItem_Slope_logVariance", byItemSlope_Variance_Prior) alpha = pyro.sample("alpha", alpha_Prior) # print("alpha in model") # print(alpha) beta1 = pyro.sample("beta1", beta1_Prior) kappa = pyro.sample("kappa", kappa_Prior) log_sigma = pyro.sample("log_sigma", log_sigma_Prior) sigma = logExp(log_sigma) byParticipant_Intercept = pyro.sample( "byParticipant_Intercept", Normal( torch.autograd.Variable(torch.zeros(L)), logExp(byParticipant_Intercept_logVariance) * torch.autograd.Variable(torch.ones(L)))) byItem_Intercept = pyro.sample( "byItem_Intercept", Normal( torch.autograd.Variable(torch.zeros(M)), logExp(byItem_Intercept_logVariance) * torch.autograd.Variable(torch.ones(M)))) byParticipant_SurprisalSlope = pyro.sample( "byParticipant_SurprisalSlope", Normal( torch.autograd.Variable(torch.zeros(L)), logExp(byParticipant_SurprisalSlope_logVariance) * torch.autograd.Variable(torch.ones(L)))) byParticipant_LogWordFreqSlope = pyro.sample( "byParticipant_LogWordFreqSlope", Normal( torch.autograd.Variable(torch.zeros(L)), logExp(byParticipant_LogWordFreqSlope_logVariance) * torch.autograd.Variable(torch.ones(L)))) #print("Start of model") for q in pyro.irange("data_loop", len(data), subsample_size=50): point = data[q] participant = int(point[header["WorkerId.Renumbered"]]) - 1 assert participant >= 0 assert participant < L item = int(point[header["tokenID.Renumbered"]]) - 1 surprisals = torch.FloatTensor( [point[header["Increment" + str(x)]] for x in range(0, 35)]) effectiveSurprisal = sum( [x * y for x, y in zip(surprisals, contributions)]) mean = alpha mean = mean + byParticipant_Intercept[participant] mean = mean + byItem_Intercept[item] mean = mean + (beta1 + byParticipant_LogWordFreqSlope[participant] ) * point[header["Surprisal0"]] mean = mean + (kappa + byParticipant_SurprisalSlope[participant] ) * effectiveSurprisal pyro.sample("time_{}".format(q), dist.Normal(mean, sigma), obs=torch.FloatTensor([log(point[header["RT"]])])) if random.random() > 0.99: print(surprisals) print([mean, log(point[header["RT"]])])
def model(): prior_dist = Normal(self.loc0, torch.pow(self.lam0, -0.5)) loc_latent = pyro.sample("loc_latent", prior_dist) x_dist = Normal(loc_latent, torch.pow(self.lam, -0.5)) pyro.sample("obs", x_dist, obs=self.data) return loc_latent
def guide(): loc = pyro.param('loc', torch.tensor(0.)) scale = pyro.param('scale', torch.tensor(0.5), constraint=constraints.positive) pyro.sample('latent', Normal(loc, scale))
# take a gradient step optim.step() if (j + 1) % 50 == 0: print("[iteration %04d] loss: %.4f" % (j + 1, loss.item())) # Inspect learned parameters print("Learned parameters:") for name, param in net.named_parameters(): print(name, param.data.numpy()) main() loc = torch.zeros(1, 1) scale = torch.ones(1, 1) # define a unit normal prior prior = Normal(loc, scale) # overload the parameters in the regression module with samples from the prior lifted_module = pyro.random_module("regression_module", net, prior) # sample a nn from the prior sampled_reg_model = lifted_module() def model(x_data, y_data): # weight and bias priors fc1w_prior = Normal(torch.zeros(1, 2), torch.ones(1, 2)).to_event(1) fc1b_prior = Normal(torch.tensor([[8.]]), torch.tensor([[1000.]])).to_event(1) outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight)) outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias))
def vnormal(name, *shape): loc = pyro.param(name+"m", torch.randn(*shape, requires_grad=True, device=device)) scale = pyro.param(name+"s", torch.randn(*shape, requires_grad=True, device=device)) return Normal(loc, softplus(scale))
def model(): z = pyro.sample( "z", Normal(10.0 * torch.ones(1), 0.0001 * torch.ones(1))) latent_prob = torch.exp(z) / (torch.exp(z) + torch.ones(1)) flip = pyro.sample("flip", Bernoulli(latent_prob)) return flip
def pgm_model(self): sex_dist = Bernoulli(logits=self.sex_logits).to_event(1) # pseudo call to register with pyro _ = self.sex_logits sex = pyro.sample( 'sex', sex_dist, infer=dict(baseline={'use_decaying_avg_baseline': True})) slice_number_dist = Uniform(self.slice_number_min, self.slice_number_max).to_event(1) slice_number = pyro.sample('slice_number', slice_number_dist) age_base_dist = Normal(self.age_base_loc, self.age_base_scale).to_event(1) age_dist = TransformedDistribution(age_base_dist, self.age_flow_transforms) age = pyro.sample('age', age_dist) _ = self.age_flow_components age_ = self.age_flow_constraint_transforms.inv(age) duration_context = torch.cat([sex, age_], 1) duration_base_dist = Normal(self.duration_base_loc, self.duration_base_scale).to_event(1) duration_dist = ConditionalTransformedDistribution( duration_base_dist, self.duration_flow_transforms).condition( duration_context) # noqa: E501 duration = pyro.sample('duration', duration_dist) _ = self.duration_flow_components duration_ = self.duration_flow_constraint_transforms.inv(duration) edss_context = torch.cat([sex, duration_], 1) edss_base_dist = Normal(self.edss_base_loc, self.edss_base_scale).to_event(1) edss_dist = ConditionalTransformedDistribution( edss_base_dist, self.edss_flow_transforms).condition(edss_context) # noqa: E501 edss = pyro.sample('edss', edss_dist) _ = self.edss_flow_components edss_ = self.edss_flow_constraint_transforms.inv(edss) brain_context = torch.cat([sex, age_], 1) brain_volume_base_dist = Normal( self.brain_volume_base_loc, self.brain_volume_base_scale).to_event(1) brain_volume_dist = ConditionalTransformedDistribution( brain_volume_base_dist, self.brain_volume_flow_transforms).condition(brain_context) brain_volume = pyro.sample('brain_volume', brain_volume_dist) _ = self.brain_volume_flow_components brain_volume_ = self.brain_volume_flow_constraint_transforms.inv( brain_volume) ventricle_context = torch.cat([age_, brain_volume_, duration_], 1) ventricle_volume_base_dist = Normal( self.ventricle_volume_base_loc, self.ventricle_volume_base_scale).to_event(1) ventricle_volume_dist = ConditionalTransformedDistribution( ventricle_volume_base_dist, self.ventricle_volume_flow_transforms).condition( ventricle_context) # noqa: E501 ventricle_volume = pyro.sample('ventricle_volume', ventricle_volume_dist) _ = self.ventricle_volume_flow_components ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv( ventricle_volume) lesion_context = torch.cat( [brain_volume_, ventricle_volume_, duration_, edss_], 1) lesion_volume_base_dist = Normal( self.lesion_volume_base_loc, self.lesion_volume_base_scale).to_event(1) lesion_volume_dist = ConditionalTransformedDistribution( lesion_volume_base_dist, self.lesion_volume_flow_transforms).condition(lesion_context) lesion_volume = pyro.sample('lesion_volume', lesion_volume_dist) _ = self.lesion_volume_flow_components return dict(age=age, sex=sex, ventricle_volume=ventricle_volume, brain_volume=brain_volume, lesion_volume=lesion_volume, duration=duration, edss=edss, slice_number=slice_number)
def scale1_prior(tensor, *args, **kwargs): flat_tensor = tensor.view(-1) m = torch.zeros(flat_tensor.size(0)) s = torch.ones(flat_tensor.size(0)) return Normal(m, s).sample().view(tensor.size()).exp()
def stoch_fn(tensor, *args, **kwargs): loc = torch.zeros(tensor.size()) scale = torch.ones(tensor.size()) return pyro.sample("sample", Normal(loc, scale))
def normal_like(X): # Looks like each scalar in X will be sampled with # this Normal distribution. return Normal(loc=0, scale=1).expand(X.shape)
def p_Z1(self, B, X0, A0): inp_cat = torch.cat([B, X0, A0], -1) mu = self.prior_W(inp_cat) sigma = torch.nn.functional.softplus(self.prior_sigma(inp_cat)) p_z_bxa = Independent(Normal(mu, sigma), 1) return p_z_bxa
from torch import tensor from torch.distributions.constraints import positive import pyro from pyro.distributions import Normal #pyro.set_rng_seed(101) prior_weight = pyro.param("prior_weight", tensor(60.0)) weight_variance = pyro.param("weight_variance", tensor(1.0), constraint=positive) weight = pyro.sample("weight", Normal(prior_weight, weight_variance)) prior_bmr = pyro.param("prior_bmr", tensor(1600.0)) bmr_variance = pyro.param("bmr_variance", tensor(50.0), constraint=positive) logging_variance = pyro.param("logging_variance", tensor(250.0), constraint=positive) bmr = pyro.sample("bmr", Normal(prior_bmr, bmr_variance)) cal_weight_fac = pyro.param("cal_weight_fac", tensor(1/2000.0)) consumed_calories = pyro.sample("consumed_calories", Normal(prior_bmr, logging_variance)) time_step = 1.0 # Time is in days posterior_weight = weight + time_step * (consumed_calories - bmr) * cal_weight_fac print(posterior_weight)
def guide(self, x, y): pyro.module("encoder", self.encoder) with pyro.plate("data", x.shape[0]): loc_z, scale_z = self.encoder(x, y) pyro.sample("latent", Normal(loc_z, scale_z).to_event(1))
def model(): loc = pyro.sample("loc", Normal(torch.zeros(1), torch.ones(1))) xd = Normal(loc, torch.ones(1)) pyro.sample("xs", xd, obs=self.data) return loc
def reconstruct_image(self, x, y): loc_z, scale_z = self.encoder(x, y) z = Normal(loc_z, scale_z).sample() loc_img = self.decoder(z, y) return loc_img
def guide(): return pyro.sample("loc", Normal(torch.zeros(1), torch.ones(1)))
def normal(*shape): loc = torch.zeros(*shape).to(device) scale = torch.ones(*shape).to(device) return Normal(loc, scale)
def guide(mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): pyro.module("c_lin_z_to_hidden", c_lin_z_to_hidden) pyro.module("c_lin_hidden_to_loc", c_lin_hidden_to_loc) pyro.module("c_lin_hidden_to_scale ", c_lin_hidden_to_scale) pyro.module("rnn", rnn) if debug: print("===== guide:S =====") print("mini_batch:\t type={}, shape={}".format(type(mini_batch), mini_batch.size())) print("mini_batch_reversed:\t type={}, shape={}".format( type(mini_batch_reversed), mini_batch_reversed.size())) print("mini_batch_mask:\t type={}, shape={}".format( type(mini_batch_mask), mini_batch_mask.size())) print("mini_batch_seq_lengths:\t type={}, shape={}".format( type(mini_batch_seq_lengths), mini_batch_seq_lengths.size())) print("===== guide:E =====") #===== init tensor shape mini_batch = torch.reshape(mini_batch, [20, 160, 88]) mini_batch_reversed = torch.reshape(mini_batch_reversed, [20, 160, 88]) mini_batch_mask = torch.reshape(mini_batch_mask, [20, 160]) mini_batch_seq_lengths = torch.reshape(mini_batch_seq_lengths, [20]) #===== init tensor shape # this is the number of time steps we need to process in the mini-batch # # T_max = mini_batch.size(1) # T_max = 160 # register all PyTorch (sub)modules with pyro pyro.module("rnn", rnn) # if on gpu we need the fully broadcast view of the rnn initial state # to be in contiguous gpu memory # h_0_contig = h_0.expand(1, mini_batch.size(0), rnn.hidden_size).contiguous() h_0_contig = torch.Tensor.expand(h_0, [1, 20, 600]) # push the observed x's through the rnn; # rnn_output contains the hidden state at each time step rnn_output, _ = rnn(mini_batch_reversed, h_0_contig) # reverse the time-ordering in the hidden state and un-pack it # rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) #===== pad_and_reverse #=== wy: disallow using PackedSequence in the whole program # rnn_output, _ = nn.utils.rnn.pad_packed_sequence(rnn_output, batch_first=True) #=== wy # # reversed_output = reverse_sequences(rnn_output, seq_lengths) # rnn_output = reverse_sequences(rnn_output, mini_batch_seq_lengths) #======= reverse_sequences # shape = [20, 160, 600] _mini_batch = rnn_output _seq_lengths = mini_batch_seq_lengths # reversed_mini_batch = _mini_batch.new_zeros(_mini_batch.size()) reversed_mini_batch = torch.zeros(20, 160, 600) # for b in range(_mini_batch.size(0)): for b in range(20): T = _seq_lengths[b] # time_slice = torch.arange(T - 1, -1, -1, device=_mini_batch.device) time_slice = torch.arange(T - 1, -1, -1) reversed_sequence = torch.index_select(_mini_batch[b, :, :], 0, time_slice) reversed_mini_batch[b, 0:T, :] = reversed_sequence # return reversed_mini_batch rnn_output = reversed_mini_batch #======= reverse_sequences #===== pad_and_reverse # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) # z_prev = z_q_0.expand(mini_batch.size(0), z_q_0.size(0)) z_prev = torch.Tensor.expand(z_q_0, [20, 100]) # we enclose all the sample statements in the guide in a plate. # this marks that each datapoint is conditionally independent of the others. # with pyro.plate("z_minibatch", len(mini_batch)): with pyro.plate("z_minibatch", 20): # sample the latents z one time step at a time # # for t in range(1, T_max + 1): # for t in range(T_max): for t in range(160): # h_rnn = rnn_output[:, t - 1, :] h_rnn = rnn_output[:, t, :] h_combined = 0.5 * (c_tanh(c_lin_z_to_hidden(z_prev)) + h_rnn) loc = c_lin_hidden_to_loc(h_combined) scale = c_softplus(c_lin_hidden_to_scale(h_combined)) z_loc = loc z_scale = scale z_dist = Normal(z_loc, z_scale) # assert z_dist.event_shape == () # assert z_dist.batch_shape == (len(mini_batch), z_q_0.size(0)) # sample z_t from the distribution z_dist with pyro.poutine.scale(scale=annealing_factor): # when no normalizing flow used, ".to_event(1)" indicates latent dimensions are independent # z_t = pyro.sample("z_%d" % t, z_t = pyro.sample( "z_{}".format(t), Normal(z_loc, z_scale) # .mask(mini_batch_mask[:, t - 1:t]) .mask(mini_batch_mask[:, t:t + 1]).to_event(1)) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t
def observe_T(T_obs, obs_name): T_simulated = simulate(mu) T_obs_dist = Normal(T_simulated, torch.tensor(time_measurement_sigma)) pyro.sample(obs_name, T_obs_dist, obs=T_obs)
def guide(data): if CUDA_: w_mu1 = Variable(torch.randn(hidden_nodes, p).cuda(), requires_grad=True) w_log_sig1 = Variable((-3.0 * torch.ones(hidden_nodes, p) + 0.05 * torch.randn(hidden_nodes, p)).cuda(), requires_grad=True) b_mu1 = Variable(torch.randn(1, hidden_nodes).cuda(), requires_grad=True) b_log_sig1 = Variable((-3.0 * torch.ones(1, hidden_nodes) + 0.05 * torch.randn(1, hidden_nodes)).cuda(), requires_grad=True) w_mu2 = Variable(torch.randn(output_nodes, hidden_nodes).cuda(), requires_grad=True) w_log_sig2 = Variable( (-3.0 * torch.ones(output_nodes, hidden_nodes) + 0.05 * torch.randn(output_nodes, hidden_nodes)).cuda(), requires_grad=True) b_mu2 = Variable(torch.randn(1, output_nodes).cuda(), requires_grad=True) b_log_sig2 = Variable((-3.0 * torch.ones(1, output_nodes) + 0.05 * torch.randn(1, output_nodes)).cuda(), requires_grad=True) else: w_mu1 = Variable(torch.randn(hidden_nodes, p), requires_grad=True) w_log_sig1 = Variable((-3.0 * torch.ones(hidden_nodes, p) + 0.05 * torch.randn(hidden_nodes, p)), requires_grad=True) b_mu1 = Variable(torch.randn(1, hidden_nodes), requires_grad=True) b_log_sig1 = Variable((-3.0 * torch.ones(1, hidden_nodes) + 0.05 * torch.randn(1, hidden_nodes)), requires_grad=True) w_mu2 = Variable(torch.randn(output_nodes, hidden_nodes), requires_grad=True) w_log_sig2 = Variable((-3.0 * torch.ones(output_nodes, hidden_nodes) + 0.05 * torch.randn(output_nodes, hidden_nodes)), requires_grad=True) b_mu2 = Variable(torch.randn(1, output_nodes), requires_grad=True) b_log_sig2 = Variable((-3.0 * torch.ones(1, output_nodes) + 0.05 * torch.randn(1, output_nodes)), requires_grad=True) # register learnable params in the param store mw_param1 = pyro.param("guide_mean_weight1", w_mu1) sw_param1 = softplus(pyro.param("guide_log_sigma_weight1", w_log_sig1)) mb_param1 = pyro.param("guide_mean_bias1", b_mu1) sb_param1 = softplus(pyro.param("guide_log_sigma_bias1", b_log_sig1)) # gaussian guide distributions for w and b w_dist1 = Normal(mw_param1, sw_param1) b_dist1 = Normal(mb_param1, sb_param1) # register learnable params in the param store mw_param2 = pyro.param("guide_mean_weight2", w_mu2) sw_param2 = softplus(pyro.param("guide_log_sigma_weight2", w_log_sig2)) mb_param2 = pyro.param("guide_mean_bias2", b_mu2) sb_param2 = softplus(pyro.param("guide_log_sigma_bias2", b_log_sig2)) # gaussian guide distributions for w and b w_dist2 = Normal(mw_param2, sw_param2) b_dist2 = Normal(mb_param2, sb_param2) dists = { 'fc1.weight': w_dist1, 'fc1.bias': b_dist1, 'fc2.weight': w_dist2, 'fc2.bias': b_dist2 } # overloading the parameters in the module with random samples from the guide distributions lifted_module = pyro.random_module("module", bnn_model, dists) # sample a regressor return lifted_module()
def model(): sample = pyro.sample('latent', Normal(torch.tensor(0.), torch.tensor(0.3))) return pyro.sample('obs', Normal(sample, torch.tensor(0.2)), obs=torch.tensor(0.1))
def guide(batch, tag, hidden, label): softplus = torch.nn.Softplus() # embedding weight distribution priors embedding_mu = torch.randn_like(net.embedding.weight) embedding_sigma = torch.randn_like(net.embedding.weight) embedding_mu_param = pyro.param("embedding_mu", embedding_mu) embedding_sigma_param = softplus(pyro.param("embedding_sigma", embedding_sigma)) embedding_prior = Normal(loc=embedding_mu_param, scale=embedding_sigma_param) # lstm layer 1 input-hidden weight distribution priors lstmih0w_mu = torch.randn_like(net.lstm.weight_ih_l0) lstmih0w_sigma = torch.randn_like(net.lstm.weight_ih_l0) lstmih0w_mu_param = pyro.param("lstmih0w_mu", lstmih0w_mu) lstmih0w_mu_param = softplus(pyro.param("lstmih0w_sigma", lstmih0w_sigma)) lstmih0w_prior = Normal(loc=lstmih0w_mu_param , scale=lstmih0w_mu_param ) # lstm layer 1 input-hidden bias distribution priors lstmih0b_mu = torch.randn_like(net.lstm.bias_ih_l0) lstmih0b_sigma = torch.randn_like(net.lstm.bias_ih_l0) lstmih0b_mu_param = pyro.param("lstmih0b_mu", lstmih0b_mu) lstmih0b_sigma_param = softplus(pyro.param("lstmih0b_sigma", lstmih0b_sigma)) lstmih0b_prior = Normal(loc=lstmih0b_mu_param, scale=lstmih0b_sigma_param) # lstm layer 1 hidden-hidden weight distribution priors lstmhh0w_mu = torch.randn_like(net.lstm.weight_hh_l0) lstmhh0w_sigma = torch.randn_like(net.lstm.weight_hh_l0) lstmhh0w_mu_param = pyro.param("lstmhh0w_mu", lstmhh0w_mu) lstmhh0w_sigma_param = softplus(pyro.param("lstmhh0w_sigma", lstmhh0w_sigma)) lstmhh0w_prior = Normal(loc=lstmhh0w_mu_param, scale=lstmhh0w_sigma_param) # lstm layer 1 hidden-hidden bias distribution priors lstmhh0b_mu = torch.randn_like(net.lstm.bias_hh_l0) lstmhh0b_sigma = torch.randn_like(net.lstm.bias_hh_l0) lstmhh0b_mu_param = pyro.param("lstmhh0b_mu", lstmhh0b_mu) lstmhh0b_sigma_param = softplus(pyro.param("lstmhh0b_sigma", lstmhh0b_sigma)) lstmhh0b_prior = Normal(loc=lstmhh0b_mu_param, scale=lstmhh0b_sigma_param) # lstm layer 2 input-hidden weight distribution priors lstmih1w_mu = torch.randn_like(net.lstm.weight_ih_l1) lstmih1w_sigma = torch.randn_like(net.lstm.weight_ih_l1) lstmih1w_mu_param = pyro.param("lstmih1w_mu", lstmih1w_mu) lstmih1w_mu_param = softplus(pyro.param("lstmih1w_sigma", lstmih1w_sigma)) lstmih1w_prior = Normal(loc=lstmih1w_mu_param , scale=lstmih1w_mu_param ) # lstm layer 2 input-hidden bias distribution priors lstmih1b_mu = torch.randn_like(net.lstm.bias_ih_l1) lstmih1b_sigma = torch.randn_like(net.lstm.bias_ih_l1) lstmih1b_mu_param = pyro.param("lstmih1b_mu", lstmih1b_mu) lstmih1b_sigma_param = softplus(pyro.param("lstmih1b_sigma", lstmih1b_sigma)) lstmih1b_prior = Normal(loc=lstmih1b_mu_param, scale=lstmih1b_sigma_param) # lstm layer 2 hidden-hidden weight distribution priors lstmhh1w_mu = torch.randn_like(net.lstm.weight_hh_l1) lstmhh1w_sigma = torch.randn_like(net.lstm.weight_hh_l1) lstmhh1w_mu_param = pyro.param("lstmhh1w_mu", lstmhh1w_mu) lstmhh1w_sigma_param = softplus(pyro.param("lstmhh1w_sigma", lstmhh1w_sigma)) lstmhh1w_prior = Normal(loc=lstmhh1w_mu_param, scale=lstmhh1w_sigma_param) # lstm layer 2 hidden-hidden bias distribution priors lstmhh1b_mu = torch.randn_like(net.lstm.bias_hh_l1) lstmhh1b_sigma = torch.randn_like(net.lstm.bias_hh_l1) lstmhh1b_mu_param = pyro.param("lstmhh1b_mu", lstmhh1b_mu) lstmhh1b_sigma_param = softplus(pyro.param("lstmhh1b_sigma", lstmhh1b_sigma)) lstmhh1b_prior = Normal(loc=lstmhh1b_mu_param, scale=lstmhh1b_sigma_param) # first fully connected layer weight distribution priors fc1w_mu = torch.randn_like(net.fc1.weight) fc1w_sigma = torch.randn_like(net.fc1.weight) fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu) fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma)) fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param) # first fully connected layer bias distribution priors fc1b_mu = torch.randn_like(net.fc1.bias) fc1b_sigma = torch.randn_like(net.fc1.bias) fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu) fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma)) fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param) # second fully connected layer weight distribution priors fc2w_mu = torch.randn_like(net.fc2.weight) fc2w_sigma = torch.randn_like(net.fc2.weight) fc2w_mu_param = pyro.param("fc2w_mu", fc2w_mu) fc2w_sigma_param = softplus(pyro.param("fc2w_sigma", fc2w_sigma)) fc2w_prior = Normal(loc=fc2w_mu_param, scale=fc2w_sigma_param) # Output layer bias distribution priors fc2b_mu = torch.randn_like(net.fc2.bias) fc2b_sigma = torch.randn_like(net.fc2.bias) fc2b_mu_param = pyro.param("fc2b_mu", fc2b_mu) fc2b_sigma_param = softplus(pyro.param("fc2b_sigma", fc2b_sigma)) fc2b_prior = Normal(loc=fc2b_mu_param, scale=fc2b_sigma_param) priors = {'embedding.weight': embedding_prior, 'lstm.weight_ih_l0': lstmih0w_prior, 'lstm.bias_ih_l0': lstmih0b_prior, 'lstm.weight_hh_l0': lstmhh0w_prior, 'lstm.bias_hh_l0': lstmhh0b_prior, 'lstm.weight_ih_l0': lstmih0w_prior, 'lstm.bias_ih_l0': lstmih0b_prior,'lstm.weight_hh_l0': lstmhh0w_prior, 'lstm.bias_hh_l0': lstmhh0b_prior, 'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior} lifted_module = pyro.random_module("module", net, priors) return lifted_module()
def main(mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): pyro.module("e_lin_z_to_hidden", e_lin_z_to_hidden) pyro.module("e_lin_hidden_to_hidden", e_lin_hidden_to_hidden) pyro.module("e_lin_hidden_to_input", e_lin_hidden_to_input) pyro.module("t_lin_gate_z_to_hidden", t_lin_gate_z_to_hidden) pyro.module("t_lin_gate_hidden_to_z", t_lin_gate_hidden_to_z) pyro.module("t_lin_proposed_mean_z_to_hidden", t_lin_proposed_mean_z_to_hidden) pyro.module("t_lin_proposed_mean_hidden_to_z", t_lin_proposed_mean_hidden_to_z) pyro.module("t_lin_sig", t_lin_sig) pyro.module("t_lin_z_to_loc", t_lin_z_to_loc) #===== init tensor shape mini_batch = torch.reshape(mini_batch, [20, 160, 88]) mini_batch_reversed = torch.reshape(mini_batch_reversed, [20, 160, 88]) mini_batch_mask = torch.reshape(mini_batch_mask, [20, 160]) mini_batch_seq_lengths = torch.reshape(mini_batch_seq_lengths, [20]) #===== init tensor shape # this is the number of time steps we need to process in the mini-batch # # T_max = mini_batch.size(1) # T_max = 160 # set z_prev = z_0 to setup the recursive conditioning in p(z_t | z_{t-1}) # z_prev = z_0.expand(mini_batch.size(0), z_0.size(0)) z_prev = torch.Tensor.expand(z_0, [20, 100]) # we enclose all the sample statements in the model in a plate. # this marks that each datapoint is conditionally independent of the others # with pyro.plate("z_minibatch", len(mini_batch)): #len(mini_batch)= 20 with pyro.plate("z_minibatch", 20): # sample the latents z and observed x's one time step at a time # # for t in range(1, T_max + 1): # for t in range(T_max): for t in range(160): # the next chunk of code samples z_t ~ p(z_t | z_{t-1}) # note that (both here and elsewhere) we use poutine.scale to take care # of KL annealing. we use the mask() method to deal with raggedness # in the observed data (i.e. different sequences in the mini-batch # have different lengths) _gate = t_relu(t_lin_gate_z_to_hidden(z_prev)) gate = torch.sigmoid(t_lin_gate_hidden_to_z(_gate)) _proposed_mean = t_relu(t_lin_proposed_mean_z_to_hidden(z_prev)) proposed_mean = t_lin_proposed_mean_hidden_to_z(_proposed_mean) loc = (1 - gate) * t_lin_z_to_loc(z_prev) + gate * proposed_mean scale = t_softplus(t_lin_sig(t_relu(proposed_mean))) # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1}) z_loc = loc z_scale = scale # then sample z_t according to dist.Normal(z_loc, z_scale) # note that we use the reshape method so that the univariate Normal distribution # is treated as a multivariate Normal distribution with a diagonal covariance. with pyro.poutine.scale(scale=annealing_factor): z_t = pyro.sample( "z_{}".format(t), Normal(z_loc, z_scale) # .mask(mini_batch_mask[:, t - 1:t]) .mask(mini_batch_mask[:, t:t + 1]).to_event(1)) # compute the probabilities that parameterize the bernoulli likelihood h1 = e_relu(e_lin_z_to_hidden(z_t)) h2 = e_relu(e_lin_hidden_to_hidden(h1)) ps = torch.sigmoid(e_lin_hidden_to_input(h2)) emission_probs_t = ps # the next statement instructs pyro to observe x_t according to the # bernoulli distribution p(x_t|z_t) pyro.sample( "obs_x_{}".format(t), Bernoulli(emission_probs_t) # .mask(mini_batch_mask[:, t - 1:t]) .mask(mini_batch_mask[:, t:t + 1]).to_event(1), # obs=mini_batch[:, t - 1, :]) obs=mini_batch[:, t, :]) # the latent sampled at this time step will be conditioned upon # in the next time step so keep track of it z_prev = z_t
def model(batch, tag, hidden, label): log_softmax = nn.LogSoftmax(dim=1) embedding_prior = Normal(loc=pretrained_dict['embedding.weight'], scale=torch.ones_like(net.embedding.weight)) lstmih0w_prior = Normal(loc=pretrained_dict['lstm.weight_ih_l0'], scale=torch.ones_like(net.lstm.weight_ih_l0)) lstmih0b_prior = Normal(loc=pretrained_dict['lstm.bias_ih_l0'], scale=torch.ones_like(net.lstm.bias_ih_l0)) lstmhh0w_prior = Normal(loc=pretrained_dict['lstm.weight_hh_l0'], scale=torch.ones_like(net.lstm.weight_hh_l0)) lstmhh0b_prior = Normal(loc=pretrained_dict['lstm.bias_hh_l0'], scale=torch.ones_like(net.lstm.bias_hh_l0)) lstmih1w_prior = Normal(loc=pretrained_dict['lstm.weight_ih_l1'], scale=torch.ones_like(net.lstm.weight_ih_l1)) lstmih1b_prior = Normal(loc=pretrained_dict['lstm.bias_ih_l1'], scale=torch.ones_like(net.lstm.bias_ih_l1)) lstmhh1w_prior = Normal(loc=pretrained_dict['lstm.weight_hh_l1'], scale=torch.ones_like(net.lstm.weight_hh_l1)) lstmhh1b_prior = Normal(loc=pretrained_dict['lstm.bias_hh_l1'], scale=torch.ones_like(net.lstm.bias_hh_l1)) fc1w_prior = Normal(loc=pretrained_dict['fc1.weight'], scale=torch.ones_like(net.fc1.weight)) fc1b_prior = Normal(loc=pretrained_dict['fc1.bias'], scale=torch.ones_like(net.fc1.bias)) fc2w_prior = Normal(loc=pretrained_dict['fc2.weight'], scale=torch.ones_like(net.fc2.weight)) fc2b_prior = Normal(loc=pretrained_dict['fc2.bias'], scale=torch.ones_like(net.fc2.bias)) priors = {'embedding.weight': embedding_prior, 'lstm.weight_ih_l0': lstmih0w_prior, 'lstm.bias_ih_l0': lstmih0b_prior, 'lstm.weight_hh_l0': lstmhh0w_prior, 'lstm.bias_hh_l0': lstmhh0b_prior, 'lstm.weight_ih_l0': lstmih0w_prior, 'lstm.bias_ih_l0': lstmih0b_prior,'lstm.weight_hh_l0': lstmhh0w_prior, 'lstm.bias_hh_l0': lstmhh0b_prior, 'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior} # lift module parameters to random variables sampled from the priors lifted_module = pyro.random_module("module", net, priors) # sample module lifted_model = lifted_module() output, hidden = lifted_model(batch, tag, hidden) lhat = log_softmax(output) pyro.sample("obs", Categorical(logits=lhat), obs=label)
d[i] = d[i][1:-1] print("Processed") from math import log, exp from random import shuffle N = len(data) M = int(max([line[header["tokenID.Renumbered"]] for line in data])) L = int(max([line[header["WorkerId.Renumbered"]] for line in data])) #print(M) #print(L) #quit() alpha_Prior = Normal(Variable(torch.FloatTensor([0.0])), Variable(torch.FloatTensor([10.0]))) beta1_Prior = Normal(Variable(torch.FloatTensor([0.0])), Variable(torch.FloatTensor([10.0]))) decay_Prior = Normal(Variable(torch.FloatTensor([0.0])), Variable(torch.FloatTensor([10.0]))) kappa_Prior = Normal(Variable(torch.FloatTensor([0.0])), Variable(torch.FloatTensor([10.0]))) log_sigma_Prior = Normal(Variable(torch.FloatTensor([0.0])), Variable(torch.FloatTensor([10.0]))) byParticipantIntercept_Variance_Prior = Normal( Variable(torch.FloatTensor([0.0])), Variable(torch.FloatTensor([10.0]))) byItemIntercept_Variance_Prior = Normal(Variable(torch.FloatTensor([0.0])), Variable(torch.FloatTensor([10.0])))
def build_0d_dist(x, a, s): return Normal(loc=a * x.values, scale=s)
def guide(self, data): encoder = pyro.module('encoder', self.vae_encoder) z_mean, z_var = encoder.forward(data) pyro.sample('latent', Normal(z_mean, z_var.sqrt()))
def model(data): data = torch.reshape(data, [60000, 50, 50]) pyro.module("decode_l1", decode_l1) pyro.module("decode_l2", decode_l2) with pyro.plate('data', 60000, 64) as ix: # size = [64, 50, 50] batch = data[ix] #================= prior state_x = torch.zeros([64, 50, 50]) state_z_pres = torch.ones([64, 1]) state_z_where = None z_pres = [] z_where = [] for t in range(3): #==================== prior_step # size = [64, 50, 50] prev_x = state_x # size = [64, 1] prev_z_pres = state_z_pres # size = None or [64, 3] prev_z_where = state_z_where # size = [64, 1] cur_z_pres =\ pyro.sample('z_pres_{}'.format(t), Bernoulli(trial_probs[t] * prev_z_pres) .to_event(1)) sample_mask = cur_z_pres # size = [64, 3] cur_z_where =\ pyro.sample('z_where_{}'.format(t), Normal(torch.Tensor.expand(z_where_loc_prior, [64, 3]), torch.Tensor.expand(z_where_scale_prior, [64, 3])) .mask(sample_mask) .to_event(1)) # size = [64, 50] cur_z_what =\ pyro.sample('z_what_{}'.format(t), Normal(torch.zeros([64, 50]), torch.ones([64, 50])) .mask(sample_mask) .to_event(1)) #===== decode # size = [64, 784] y_att = torch.sigmoid( decode_l2(F.relu(decode_l1(cur_z_what))) - 2.0) #===== decode #===== window_to_image windows = y_att #===== expand_z_where # size = [64, 4] out = torch.cat((torch.zeros(64, 1), cur_z_where), 1) # size = [64, 6] out = torch.index_select(out, 1, expansion_indices) # size = [64, 2, 3] out = torch.Tensor.view(out, [64, 2, 3]) theta = out #===== expand_z_where # size = [64, 50, 50, 2] grid = F.affine_grid(theta, [64, 1, 50, 50]) # size = [64, 1, 50, 50] out = F.grid_sample(torch.Tensor.view(windows, [64, 1, 28, 28]), grid) y = torch.Tensor.view(out, [64, 50, 50]) #===== window_to_image # size = [64, 50, 50] cur_x = prev_x + (y * torch.Tensor.view(cur_z_pres, [64, 1, 1])) state_x = cur_x state_z_pres = cur_z_pres state_z_where = cur_z_where #==================== prior_step z_where.append(state_z_where) z_pres.append(state_z_pres) # size = [64, 50, 50] x = state_x #================== prior pyro.sample('obs', Normal(torch.Tensor.view(x, [64, 2500]), (0.3 * torch.ones(64, 2500))).to_event(1), obs=torch.Tensor.view(batch, [64, 2500]))