示例#1
0
 def guide():
     loc = pyro.param('loc', torch.tensor(0.))
     scale = pyro.param('scale', torch.tensor(0.5))
     pyro.sample('latent', Normal(loc, scale))
示例#2
0
 def guide(self, data):
     encoder = pyro.module("encoder", self.vae_encoder)
     with pyro.plate("data", data.size(0)):
         z_mean, z_var = encoder.forward(data)
         pyro.sample("latent", Normal(z_mean, z_var.sqrt()).to_event(1))
def model(data):
    #  print("in model")
    #  decay = pyro.sample("decay", dist.Normal(Variable(torch.FloatTensor([0.0])), Variable(torch.FloatTensor([10.0]))))
    decay = pyro.sample("decay", dist.Normal(0.0, 10.0))

    #  distances = torch.FloatTensor(xrange(1, 36))
    #  print(decay)
    #  print(decay+1)
    #  print(decay*2)
    #  quit()

    #print(torch.log(distances))
    distances = xrange(1, 36)
    contributions = [torch.exp(decay * log(d)) for d in distances]
    # contributions = torch.exp(torch.mul(torch.log(distances), decay))

    byParticipant_Intercept_logVariance = pyro.sample(
        "byParticipant_Intercept_logVariance",
        byParticipantIntercept_Variance_Prior)
    byItem_Intercept_logVariance = pyro.sample("byItem_Intercept_logVariance",
                                               byItemIntercept_Variance_Prior)
    byParticipant_SurprisalSlope_logVariance = pyro.sample(
        "byParticipant_SurprisalSlope_logVariance",
        byParticipantSlope_Variance_Prior)
    byParticipant_LogWordFreqSlope_logVariance = pyro.sample(
        "byParticipant_LogWordFreqSlope_logVariance",
        byParticipantSlope_Variance_Prior)

    byItem_Slope_logVariance = pyro.sample("byItem_Slope_logVariance",
                                           byItemSlope_Variance_Prior)

    alpha = pyro.sample("alpha", alpha_Prior)
    #  print("alpha in model")
    #  print(alpha)
    beta1 = pyro.sample("beta1", beta1_Prior)
    kappa = pyro.sample("kappa", kappa_Prior)
    log_sigma = pyro.sample("log_sigma", log_sigma_Prior)
    sigma = logExp(log_sigma)

    byParticipant_Intercept = pyro.sample(
        "byParticipant_Intercept",
        Normal(
            torch.autograd.Variable(torch.zeros(L)),
            logExp(byParticipant_Intercept_logVariance) *
            torch.autograd.Variable(torch.ones(L))))
    byItem_Intercept = pyro.sample(
        "byItem_Intercept",
        Normal(
            torch.autograd.Variable(torch.zeros(M)),
            logExp(byItem_Intercept_logVariance) *
            torch.autograd.Variable(torch.ones(M))))

    byParticipant_SurprisalSlope = pyro.sample(
        "byParticipant_SurprisalSlope",
        Normal(
            torch.autograd.Variable(torch.zeros(L)),
            logExp(byParticipant_SurprisalSlope_logVariance) *
            torch.autograd.Variable(torch.ones(L))))

    byParticipant_LogWordFreqSlope = pyro.sample(
        "byParticipant_LogWordFreqSlope",
        Normal(
            torch.autograd.Variable(torch.zeros(L)),
            logExp(byParticipant_LogWordFreqSlope_logVariance) *
            torch.autograd.Variable(torch.ones(L))))

    #print("Start of model")

    for q in pyro.irange("data_loop", len(data), subsample_size=50):
        point = data[q]

        participant = int(point[header["WorkerId.Renumbered"]]) - 1
        assert participant >= 0
        assert participant < L
        item = int(point[header["tokenID.Renumbered"]]) - 1
        surprisals = torch.FloatTensor(
            [point[header["Increment" + str(x)]] for x in range(0, 35)])
        effectiveSurprisal = sum(
            [x * y for x, y in zip(surprisals, contributions)])
        mean = alpha
        mean = mean + byParticipant_Intercept[participant]
        mean = mean + byItem_Intercept[item]
        mean = mean + (beta1 + byParticipant_LogWordFreqSlope[participant]
                       ) * point[header["Surprisal0"]]
        mean = mean + (kappa + byParticipant_SurprisalSlope[participant]
                       ) * effectiveSurprisal

        pyro.sample("time_{}".format(q),
                    dist.Normal(mean, sigma),
                    obs=torch.FloatTensor([log(point[header["RT"]])]))
        if random.random() > 0.99:
            print(surprisals)
            print([mean, log(point[header["RT"]])])
示例#4
0
 def model():
     prior_dist = Normal(self.loc0, torch.pow(self.lam0, -0.5))
     loc_latent = pyro.sample("loc_latent", prior_dist)
     x_dist = Normal(loc_latent, torch.pow(self.lam, -0.5))
     pyro.sample("obs", x_dist, obs=self.data)
     return loc_latent
示例#5
0
 def guide():
     loc = pyro.param('loc', torch.tensor(0.))
     scale = pyro.param('scale', torch.tensor(0.5), constraint=constraints.positive)
     pyro.sample('latent', Normal(loc, scale))
示例#6
0
        # take a gradient step
        optim.step()
        if (j + 1) % 50 == 0:
            print("[iteration %04d] loss: %.4f" % (j + 1, loss.item()))
    # Inspect learned parameters
    print("Learned parameters:")
    for name, param in net.named_parameters():
        print(name, param.data.numpy())


main()

loc = torch.zeros(1, 1)
scale = torch.ones(1, 1)
# define a unit normal prior
prior = Normal(loc, scale)
# overload the parameters in the regression module with samples from the prior
lifted_module = pyro.random_module("regression_module", net, prior)
# sample a nn from the prior
sampled_reg_model = lifted_module()


def model(x_data, y_data):
    # weight and bias priors
    fc1w_prior = Normal(torch.zeros(1, 2), torch.ones(1, 2)).to_event(1)
    fc1b_prior = Normal(torch.tensor([[8.]]),
                        torch.tensor([[1000.]])).to_event(1)
    outw_prior = Normal(loc=torch.zeros_like(net.out.weight),
                        scale=torch.ones_like(net.out.weight))
    outb_prior = Normal(loc=torch.zeros_like(net.out.bias),
                        scale=torch.ones_like(net.out.bias))
def vnormal(name, *shape):
    loc = pyro.param(name+"m", torch.randn(*shape, requires_grad=True, device=device))
    scale = pyro.param(name+"s", torch.randn(*shape, requires_grad=True, device=device))
    return Normal(loc, softplus(scale))
示例#8
0
 def model():
     z = pyro.sample(
         "z", Normal(10.0 * torch.ones(1), 0.0001 * torch.ones(1)))
     latent_prob = torch.exp(z) / (torch.exp(z) + torch.ones(1))
     flip = pyro.sample("flip", Bernoulli(latent_prob))
     return flip
    def pgm_model(self):
        sex_dist = Bernoulli(logits=self.sex_logits).to_event(1)
        # pseudo call to register with pyro
        _ = self.sex_logits
        sex = pyro.sample(
            'sex',
            sex_dist,
            infer=dict(baseline={'use_decaying_avg_baseline': True}))

        slice_number_dist = Uniform(self.slice_number_min,
                                    self.slice_number_max).to_event(1)
        slice_number = pyro.sample('slice_number', slice_number_dist)

        age_base_dist = Normal(self.age_base_loc,
                               self.age_base_scale).to_event(1)
        age_dist = TransformedDistribution(age_base_dist,
                                           self.age_flow_transforms)
        age = pyro.sample('age', age_dist)
        _ = self.age_flow_components
        age_ = self.age_flow_constraint_transforms.inv(age)

        duration_context = torch.cat([sex, age_], 1)
        duration_base_dist = Normal(self.duration_base_loc,
                                    self.duration_base_scale).to_event(1)
        duration_dist = ConditionalTransformedDistribution(
            duration_base_dist, self.duration_flow_transforms).condition(
                duration_context)  # noqa: E501
        duration = pyro.sample('duration', duration_dist)
        _ = self.duration_flow_components
        duration_ = self.duration_flow_constraint_transforms.inv(duration)

        edss_context = torch.cat([sex, duration_], 1)
        edss_base_dist = Normal(self.edss_base_loc,
                                self.edss_base_scale).to_event(1)
        edss_dist = ConditionalTransformedDistribution(
            edss_base_dist,
            self.edss_flow_transforms).condition(edss_context)  # noqa: E501
        edss = pyro.sample('edss', edss_dist)
        _ = self.edss_flow_components
        edss_ = self.edss_flow_constraint_transforms.inv(edss)

        brain_context = torch.cat([sex, age_], 1)
        brain_volume_base_dist = Normal(
            self.brain_volume_base_loc,
            self.brain_volume_base_scale).to_event(1)
        brain_volume_dist = ConditionalTransformedDistribution(
            brain_volume_base_dist,
            self.brain_volume_flow_transforms).condition(brain_context)
        brain_volume = pyro.sample('brain_volume', brain_volume_dist)
        _ = self.brain_volume_flow_components
        brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
            brain_volume)

        ventricle_context = torch.cat([age_, brain_volume_, duration_], 1)
        ventricle_volume_base_dist = Normal(
            self.ventricle_volume_base_loc,
            self.ventricle_volume_base_scale).to_event(1)
        ventricle_volume_dist = ConditionalTransformedDistribution(
            ventricle_volume_base_dist,
            self.ventricle_volume_flow_transforms).condition(
                ventricle_context)  # noqa: E501
        ventricle_volume = pyro.sample('ventricle_volume',
                                       ventricle_volume_dist)
        _ = self.ventricle_volume_flow_components
        ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(
            ventricle_volume)

        lesion_context = torch.cat(
            [brain_volume_, ventricle_volume_, duration_, edss_], 1)
        lesion_volume_base_dist = Normal(
            self.lesion_volume_base_loc,
            self.lesion_volume_base_scale).to_event(1)
        lesion_volume_dist = ConditionalTransformedDistribution(
            lesion_volume_base_dist,
            self.lesion_volume_flow_transforms).condition(lesion_context)
        lesion_volume = pyro.sample('lesion_volume', lesion_volume_dist)
        _ = self.lesion_volume_flow_components

        return dict(age=age,
                    sex=sex,
                    ventricle_volume=ventricle_volume,
                    brain_volume=brain_volume,
                    lesion_volume=lesion_volume,
                    duration=duration,
                    edss=edss,
                    slice_number=slice_number)
示例#10
0
 def scale1_prior(tensor, *args, **kwargs):
     flat_tensor = tensor.view(-1)
     m = torch.zeros(flat_tensor.size(0))
     s = torch.ones(flat_tensor.size(0))
     return Normal(m, s).sample().view(tensor.size()).exp()
示例#11
0
 def stoch_fn(tensor, *args, **kwargs):
     loc = torch.zeros(tensor.size())
     scale = torch.ones(tensor.size())
     return pyro.sample("sample", Normal(loc, scale))
示例#12
0
def normal_like(X):
    # Looks like each scalar in X will be sampled with
    # this Normal distribution.
    return Normal(loc=0, scale=1).expand(X.shape)
示例#13
0
 def p_Z1(self, B, X0, A0):
     inp_cat = torch.cat([B, X0, A0], -1)
     mu = self.prior_W(inp_cat)
     sigma = torch.nn.functional.softplus(self.prior_sigma(inp_cat))
     p_z_bxa = Independent(Normal(mu, sigma), 1)
     return p_z_bxa
示例#14
0
from torch import tensor
from torch.distributions.constraints import positive
import pyro
from pyro.distributions import Normal

#pyro.set_rng_seed(101)

prior_weight = pyro.param("prior_weight", tensor(60.0))
weight_variance = pyro.param("weight_variance", tensor(1.0), constraint=positive)
weight = pyro.sample("weight", Normal(prior_weight, weight_variance))
prior_bmr = pyro.param("prior_bmr", tensor(1600.0))
bmr_variance = pyro.param("bmr_variance", tensor(50.0), constraint=positive)
logging_variance = pyro.param("logging_variance", tensor(250.0), constraint=positive)
bmr = pyro.sample("bmr", Normal(prior_bmr, bmr_variance))
cal_weight_fac = pyro.param("cal_weight_fac", tensor(1/2000.0))
consumed_calories = pyro.sample("consumed_calories", Normal(prior_bmr, logging_variance))
time_step = 1.0  # Time is in days
posterior_weight = weight + time_step * (consumed_calories - bmr) * cal_weight_fac
print(posterior_weight)
示例#15
0
 def guide(self, x, y):
     pyro.module("encoder", self.encoder)
     with pyro.plate("data", x.shape[0]):
         loc_z, scale_z = self.encoder(x, y)
         pyro.sample("latent", Normal(loc_z, scale_z).to_event(1))
示例#16
0
 def model():
     loc = pyro.sample("loc", Normal(torch.zeros(1), torch.ones(1)))
     xd = Normal(loc, torch.ones(1))
     pyro.sample("xs", xd, obs=self.data)
     return loc
示例#17
0
 def reconstruct_image(self, x, y):
     loc_z, scale_z = self.encoder(x, y)
     z = Normal(loc_z, scale_z).sample()
     loc_img = self.decoder(z, y)
     return loc_img
示例#18
0
 def guide():
     return pyro.sample("loc", Normal(torch.zeros(1), torch.ones(1)))
def normal(*shape):
    loc = torch.zeros(*shape).to(device)
    scale = torch.ones(*shape).to(device)
    return Normal(loc, scale)
示例#20
0
def guide(mini_batch,
          mini_batch_reversed,
          mini_batch_mask,
          mini_batch_seq_lengths,
          annealing_factor=1.0):
    pyro.module("c_lin_z_to_hidden", c_lin_z_to_hidden)
    pyro.module("c_lin_hidden_to_loc", c_lin_hidden_to_loc)
    pyro.module("c_lin_hidden_to_scale ", c_lin_hidden_to_scale)
    pyro.module("rnn", rnn)

    if debug:
        print("===== guide:S =====")
        print("mini_batch:\t type={}, shape={}".format(type(mini_batch),
                                                       mini_batch.size()))
        print("mini_batch_reversed:\t type={}, shape={}".format(
            type(mini_batch_reversed), mini_batch_reversed.size()))
        print("mini_batch_mask:\t type={}, shape={}".format(
            type(mini_batch_mask), mini_batch_mask.size()))
        print("mini_batch_seq_lengths:\t type={}, shape={}".format(
            type(mini_batch_seq_lengths), mini_batch_seq_lengths.size()))
        print("===== guide:E =====")

    #===== init tensor shape
    mini_batch = torch.reshape(mini_batch, [20, 160, 88])
    mini_batch_reversed = torch.reshape(mini_batch_reversed, [20, 160, 88])
    mini_batch_mask = torch.reshape(mini_batch_mask, [20, 160])
    mini_batch_seq_lengths = torch.reshape(mini_batch_seq_lengths, [20])
    #===== init tensor shape

    # this is the number of time steps we need to process in the mini-batch
    # # T_max = mini_batch.size(1)
    # T_max = 160
    # register all PyTorch (sub)modules with pyro
    pyro.module("rnn", rnn)
    # if on gpu we need the fully broadcast view of the rnn initial state
    # to be in contiguous gpu memory
    # h_0_contig = h_0.expand(1, mini_batch.size(0), rnn.hidden_size).contiguous()
    h_0_contig = torch.Tensor.expand(h_0, [1, 20, 600])
    # push the observed x's through the rnn;
    # rnn_output contains the hidden state at each time step
    rnn_output, _ = rnn(mini_batch_reversed, h_0_contig)

    # reverse the time-ordering in the hidden state and un-pack it
    # rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
    #===== pad_and_reverse
    #=== wy: disallow using PackedSequence in the whole program
    # rnn_output, _ = nn.utils.rnn.pad_packed_sequence(rnn_output, batch_first=True)
    #=== wy
    # # reversed_output = reverse_sequences(rnn_output, seq_lengths)
    # rnn_output = reverse_sequences(rnn_output, mini_batch_seq_lengths)
    #======= reverse_sequences
    # shape = [20, 160, 600]
    _mini_batch = rnn_output
    _seq_lengths = mini_batch_seq_lengths

    # reversed_mini_batch = _mini_batch.new_zeros(_mini_batch.size())
    reversed_mini_batch = torch.zeros(20, 160, 600)

    # for b in range(_mini_batch.size(0)):
    for b in range(20):
        T = _seq_lengths[b]
        # time_slice = torch.arange(T - 1, -1, -1, device=_mini_batch.device)
        time_slice = torch.arange(T - 1, -1, -1)
        reversed_sequence = torch.index_select(_mini_batch[b, :, :], 0,
                                               time_slice)
        reversed_mini_batch[b, 0:T, :] = reversed_sequence

    # return reversed_mini_batch
    rnn_output = reversed_mini_batch
    #======= reverse_sequences
    #===== pad_and_reverse

    # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
    # z_prev = z_q_0.expand(mini_batch.size(0), z_q_0.size(0))
    z_prev = torch.Tensor.expand(z_q_0, [20, 100])

    # we enclose all the sample statements in the guide in a plate.
    # this marks that each datapoint is conditionally independent of the others.
    # with pyro.plate("z_minibatch", len(mini_batch)):
    with pyro.plate("z_minibatch", 20):
        # sample the latents z one time step at a time
        # # for t in range(1, T_max + 1):
        # for t in range(T_max):
        for t in range(160):
            # h_rnn = rnn_output[:, t - 1, :]
            h_rnn = rnn_output[:, t, :]
            h_combined = 0.5 * (c_tanh(c_lin_z_to_hidden(z_prev)) + h_rnn)
            loc = c_lin_hidden_to_loc(h_combined)
            scale = c_softplus(c_lin_hidden_to_scale(h_combined))
            z_loc = loc
            z_scale = scale
            z_dist = Normal(z_loc, z_scale)
            # assert z_dist.event_shape == ()
            # assert z_dist.batch_shape == (len(mini_batch), z_q_0.size(0))

            # sample z_t from the distribution z_dist
            with pyro.poutine.scale(scale=annealing_factor):
                # when no normalizing flow used, ".to_event(1)" indicates latent dimensions are independent
                # z_t = pyro.sample("z_%d" % t,
                z_t = pyro.sample(
                    "z_{}".format(t),
                    Normal(z_loc, z_scale)
                    # .mask(mini_batch_mask[:, t - 1:t])
                    .mask(mini_batch_mask[:, t:t + 1]).to_event(1))

            # the latent sampled at this time step will be conditioned upon in the next time step
            # so keep track of it
            z_prev = z_t
示例#21
0
 def observe_T(T_obs, obs_name):
     T_simulated = simulate(mu)
     T_obs_dist = Normal(T_simulated, torch.tensor(time_measurement_sigma))
     pyro.sample(obs_name, T_obs_dist, obs=T_obs)
def guide(data):
    if CUDA_:
        w_mu1 = Variable(torch.randn(hidden_nodes, p).cuda(),
                         requires_grad=True)
        w_log_sig1 = Variable((-3.0 * torch.ones(hidden_nodes, p) +
                               0.05 * torch.randn(hidden_nodes, p)).cuda(),
                              requires_grad=True)
        b_mu1 = Variable(torch.randn(1, hidden_nodes).cuda(),
                         requires_grad=True)
        b_log_sig1 = Variable((-3.0 * torch.ones(1, hidden_nodes) +
                               0.05 * torch.randn(1, hidden_nodes)).cuda(),
                              requires_grad=True)
        w_mu2 = Variable(torch.randn(output_nodes, hidden_nodes).cuda(),
                         requires_grad=True)
        w_log_sig2 = Variable(
            (-3.0 * torch.ones(output_nodes, hidden_nodes) +
             0.05 * torch.randn(output_nodes, hidden_nodes)).cuda(),
            requires_grad=True)
        b_mu2 = Variable(torch.randn(1, output_nodes).cuda(),
                         requires_grad=True)
        b_log_sig2 = Variable((-3.0 * torch.ones(1, output_nodes) +
                               0.05 * torch.randn(1, output_nodes)).cuda(),
                              requires_grad=True)
    else:
        w_mu1 = Variable(torch.randn(hidden_nodes, p), requires_grad=True)
        w_log_sig1 = Variable((-3.0 * torch.ones(hidden_nodes, p) +
                               0.05 * torch.randn(hidden_nodes, p)),
                              requires_grad=True)
        b_mu1 = Variable(torch.randn(1, hidden_nodes), requires_grad=True)
        b_log_sig1 = Variable((-3.0 * torch.ones(1, hidden_nodes) +
                               0.05 * torch.randn(1, hidden_nodes)),
                              requires_grad=True)
        w_mu2 = Variable(torch.randn(output_nodes, hidden_nodes),
                         requires_grad=True)
        w_log_sig2 = Variable((-3.0 * torch.ones(output_nodes, hidden_nodes) +
                               0.05 * torch.randn(output_nodes, hidden_nodes)),
                              requires_grad=True)
        b_mu2 = Variable(torch.randn(1, output_nodes), requires_grad=True)
        b_log_sig2 = Variable((-3.0 * torch.ones(1, output_nodes) +
                               0.05 * torch.randn(1, output_nodes)),
                              requires_grad=True)

    # register learnable params in the param store
    mw_param1 = pyro.param("guide_mean_weight1", w_mu1)
    sw_param1 = softplus(pyro.param("guide_log_sigma_weight1", w_log_sig1))
    mb_param1 = pyro.param("guide_mean_bias1", b_mu1)
    sb_param1 = softplus(pyro.param("guide_log_sigma_bias1", b_log_sig1))
    # gaussian guide distributions for w and b
    w_dist1 = Normal(mw_param1, sw_param1)
    b_dist1 = Normal(mb_param1, sb_param1)
    # register learnable params in the param store
    mw_param2 = pyro.param("guide_mean_weight2", w_mu2)
    sw_param2 = softplus(pyro.param("guide_log_sigma_weight2", w_log_sig2))
    mb_param2 = pyro.param("guide_mean_bias2", b_mu2)
    sb_param2 = softplus(pyro.param("guide_log_sigma_bias2", b_log_sig2))
    # gaussian guide distributions for w and b
    w_dist2 = Normal(mw_param2, sw_param2)
    b_dist2 = Normal(mb_param2, sb_param2)

    dists = {
        'fc1.weight': w_dist1,
        'fc1.bias': b_dist1,
        'fc2.weight': w_dist2,
        'fc2.bias': b_dist2
    }
    # overloading the parameters in the module with random samples from the guide distributions
    lifted_module = pyro.random_module("module", bnn_model, dists)
    # sample a regressor
    return lifted_module()
示例#23
0
 def model():
     sample = pyro.sample('latent', Normal(torch.tensor(0.), torch.tensor(0.3)))
     return pyro.sample('obs', Normal(sample, torch.tensor(0.2)), obs=torch.tensor(0.1))
def guide(batch, tag, hidden, label):
	softplus = torch.nn.Softplus()
    
	# embedding weight distribution priors
	embedding_mu = torch.randn_like(net.embedding.weight)
	embedding_sigma = torch.randn_like(net.embedding.weight)
	embedding_mu_param = pyro.param("embedding_mu", embedding_mu)
	embedding_sigma_param = softplus(pyro.param("embedding_sigma", embedding_sigma))
	embedding_prior = Normal(loc=embedding_mu_param, scale=embedding_sigma_param)
	# lstm layer 1 input-hidden weight distribution priors
	lstmih0w_mu = torch.randn_like(net.lstm.weight_ih_l0)
	lstmih0w_sigma = torch.randn_like(net.lstm.weight_ih_l0)
	lstmih0w_mu_param = pyro.param("lstmih0w_mu", lstmih0w_mu)
	lstmih0w_mu_param  = softplus(pyro.param("lstmih0w_sigma", lstmih0w_sigma))
	lstmih0w_prior = Normal(loc=lstmih0w_mu_param , scale=lstmih0w_mu_param )
	# lstm layer 1 input-hidden bias distribution priors
	lstmih0b_mu = torch.randn_like(net.lstm.bias_ih_l0)
	lstmih0b_sigma = torch.randn_like(net.lstm.bias_ih_l0)
	lstmih0b_mu_param = pyro.param("lstmih0b_mu", lstmih0b_mu)
	lstmih0b_sigma_param = softplus(pyro.param("lstmih0b_sigma", lstmih0b_sigma))
	lstmih0b_prior = Normal(loc=lstmih0b_mu_param, scale=lstmih0b_sigma_param)
	# lstm layer 1 hidden-hidden weight distribution priors
	lstmhh0w_mu = torch.randn_like(net.lstm.weight_hh_l0)
	lstmhh0w_sigma = torch.randn_like(net.lstm.weight_hh_l0)
	lstmhh0w_mu_param = pyro.param("lstmhh0w_mu", lstmhh0w_mu)
	lstmhh0w_sigma_param = softplus(pyro.param("lstmhh0w_sigma", lstmhh0w_sigma))
	lstmhh0w_prior = Normal(loc=lstmhh0w_mu_param, scale=lstmhh0w_sigma_param)
	# lstm layer 1 hidden-hidden bias distribution priors
	lstmhh0b_mu = torch.randn_like(net.lstm.bias_hh_l0)
	lstmhh0b_sigma = torch.randn_like(net.lstm.bias_hh_l0)
	lstmhh0b_mu_param = pyro.param("lstmhh0b_mu", lstmhh0b_mu)
	lstmhh0b_sigma_param = softplus(pyro.param("lstmhh0b_sigma", lstmhh0b_sigma))
	lstmhh0b_prior = Normal(loc=lstmhh0b_mu_param, scale=lstmhh0b_sigma_param)
	# lstm layer 2 input-hidden weight distribution priors
	lstmih1w_mu = torch.randn_like(net.lstm.weight_ih_l1)
	lstmih1w_sigma = torch.randn_like(net.lstm.weight_ih_l1)
	lstmih1w_mu_param = pyro.param("lstmih1w_mu", lstmih1w_mu)
	lstmih1w_mu_param  = softplus(pyro.param("lstmih1w_sigma", lstmih1w_sigma))
	lstmih1w_prior = Normal(loc=lstmih1w_mu_param , scale=lstmih1w_mu_param )
	# lstm layer 2 input-hidden bias distribution priors
	lstmih1b_mu = torch.randn_like(net.lstm.bias_ih_l1)
	lstmih1b_sigma = torch.randn_like(net.lstm.bias_ih_l1)
	lstmih1b_mu_param = pyro.param("lstmih1b_mu", lstmih1b_mu)
	lstmih1b_sigma_param = softplus(pyro.param("lstmih1b_sigma", lstmih1b_sigma))
	lstmih1b_prior = Normal(loc=lstmih1b_mu_param, scale=lstmih1b_sigma_param)
	# lstm layer 2 hidden-hidden weight distribution priors
	lstmhh1w_mu = torch.randn_like(net.lstm.weight_hh_l1)
	lstmhh1w_sigma = torch.randn_like(net.lstm.weight_hh_l1)
	lstmhh1w_mu_param = pyro.param("lstmhh1w_mu", lstmhh1w_mu)
	lstmhh1w_sigma_param = softplus(pyro.param("lstmhh1w_sigma", lstmhh1w_sigma))
	lstmhh1w_prior = Normal(loc=lstmhh1w_mu_param, scale=lstmhh1w_sigma_param)
	# lstm layer 2 hidden-hidden bias distribution priors
	lstmhh1b_mu = torch.randn_like(net.lstm.bias_hh_l1)
	lstmhh1b_sigma = torch.randn_like(net.lstm.bias_hh_l1)
	lstmhh1b_mu_param = pyro.param("lstmhh1b_mu", lstmhh1b_mu)
	lstmhh1b_sigma_param = softplus(pyro.param("lstmhh1b_sigma", lstmhh1b_sigma))
	lstmhh1b_prior = Normal(loc=lstmhh1b_mu_param, scale=lstmhh1b_sigma_param)
	# first fully connected layer weight distribution priors
	fc1w_mu = torch.randn_like(net.fc1.weight)
	fc1w_sigma = torch.randn_like(net.fc1.weight)
	fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
	fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
	fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
	# first fully connected layer bias distribution priors
	fc1b_mu = torch.randn_like(net.fc1.bias)
	fc1b_sigma = torch.randn_like(net.fc1.bias)
	fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
	fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
	fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)
	# second fully connected layer weight distribution priors
	fc2w_mu = torch.randn_like(net.fc2.weight)
	fc2w_sigma = torch.randn_like(net.fc2.weight)
	fc2w_mu_param = pyro.param("fc2w_mu", fc2w_mu)
	fc2w_sigma_param = softplus(pyro.param("fc2w_sigma", fc2w_sigma))
	fc2w_prior = Normal(loc=fc2w_mu_param, scale=fc2w_sigma_param)
	# Output layer bias distribution priors
	fc2b_mu = torch.randn_like(net.fc2.bias)
	fc2b_sigma = torch.randn_like(net.fc2.bias)
	fc2b_mu_param = pyro.param("fc2b_mu", fc2b_mu)
	fc2b_sigma_param = softplus(pyro.param("fc2b_sigma", fc2b_sigma))
	fc2b_prior = Normal(loc=fc2b_mu_param, scale=fc2b_sigma_param)

	priors = {'embedding.weight': embedding_prior, 'lstm.weight_ih_l0': lstmih0w_prior, 'lstm.bias_ih_l0': lstmih0b_prior,
	'lstm.weight_hh_l0': lstmhh0w_prior, 'lstm.bias_hh_l0': lstmhh0b_prior, 'lstm.weight_ih_l0': lstmih0w_prior,
	'lstm.bias_ih_l0': lstmih0b_prior,'lstm.weight_hh_l0': lstmhh0w_prior, 'lstm.bias_hh_l0': lstmhh0b_prior,
	'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,  'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior}
	    
	lifted_module = pyro.random_module("module", net, priors)
	    
	return lifted_module()
def main(mini_batch,
         mini_batch_reversed,
         mini_batch_mask,
         mini_batch_seq_lengths,
         annealing_factor=1.0):
    pyro.module("e_lin_z_to_hidden", e_lin_z_to_hidden)
    pyro.module("e_lin_hidden_to_hidden", e_lin_hidden_to_hidden)
    pyro.module("e_lin_hidden_to_input", e_lin_hidden_to_input)
    pyro.module("t_lin_gate_z_to_hidden", t_lin_gate_z_to_hidden)
    pyro.module("t_lin_gate_hidden_to_z", t_lin_gate_hidden_to_z)
    pyro.module("t_lin_proposed_mean_z_to_hidden",
                t_lin_proposed_mean_z_to_hidden)
    pyro.module("t_lin_proposed_mean_hidden_to_z",
                t_lin_proposed_mean_hidden_to_z)
    pyro.module("t_lin_sig", t_lin_sig)
    pyro.module("t_lin_z_to_loc", t_lin_z_to_loc)

    #===== init tensor shape
    mini_batch = torch.reshape(mini_batch, [20, 160, 88])
    mini_batch_reversed = torch.reshape(mini_batch_reversed, [20, 160, 88])
    mini_batch_mask = torch.reshape(mini_batch_mask, [20, 160])
    mini_batch_seq_lengths = torch.reshape(mini_batch_seq_lengths, [20])
    #===== init tensor shape

    # this is the number of time steps we need to process in the mini-batch
    # # T_max = mini_batch.size(1)
    # T_max = 160
    # set z_prev = z_0 to setup the recursive conditioning in p(z_t | z_{t-1})
    # z_prev = z_0.expand(mini_batch.size(0), z_0.size(0))
    z_prev = torch.Tensor.expand(z_0, [20, 100])

    # we enclose all the sample statements in the model in a plate.
    # this marks that each datapoint is conditionally independent of the others
    # with pyro.plate("z_minibatch", len(mini_batch)): #len(mini_batch)= 20
    with pyro.plate("z_minibatch", 20):
        # sample the latents z and observed x's one time step at a time
        # # for t in range(1, T_max + 1):
        # for t in range(T_max):
        for t in range(160):
            # the next chunk of code samples z_t ~ p(z_t | z_{t-1})
            # note that (both here and elsewhere) we use poutine.scale to take care
            # of KL annealing. we use the mask() method to deal with raggedness
            # in the observed data (i.e. different sequences in the mini-batch
            # have different lengths)
            _gate = t_relu(t_lin_gate_z_to_hidden(z_prev))
            gate = torch.sigmoid(t_lin_gate_hidden_to_z(_gate))
            _proposed_mean = t_relu(t_lin_proposed_mean_z_to_hidden(z_prev))
            proposed_mean = t_lin_proposed_mean_hidden_to_z(_proposed_mean)
            loc = (1 - gate) * t_lin_z_to_loc(z_prev) + gate * proposed_mean
            scale = t_softplus(t_lin_sig(t_relu(proposed_mean)))

            # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1})
            z_loc = loc
            z_scale = scale

            # then sample z_t according to dist.Normal(z_loc, z_scale)
            # note that we use the reshape method so that the univariate Normal distribution
            # is treated as a multivariate Normal distribution with a diagonal covariance.
            with pyro.poutine.scale(scale=annealing_factor):
                z_t = pyro.sample(
                    "z_{}".format(t),
                    Normal(z_loc, z_scale)
                    # .mask(mini_batch_mask[:, t - 1:t])
                    .mask(mini_batch_mask[:, t:t + 1]).to_event(1))

            # compute the probabilities that parameterize the bernoulli likelihood
            h1 = e_relu(e_lin_z_to_hidden(z_t))
            h2 = e_relu(e_lin_hidden_to_hidden(h1))
            ps = torch.sigmoid(e_lin_hidden_to_input(h2))
            emission_probs_t = ps
            # the next statement instructs pyro to observe x_t according to the
            # bernoulli distribution p(x_t|z_t)
            pyro.sample(
                "obs_x_{}".format(t),
                Bernoulli(emission_probs_t)
                # .mask(mini_batch_mask[:, t - 1:t])
                .mask(mini_batch_mask[:, t:t + 1]).to_event(1),
                # obs=mini_batch[:, t - 1, :])
                obs=mini_batch[:, t, :])
            # the latent sampled at this time step will be conditioned upon
            # in the next time step so keep track of it
            z_prev = z_t
def model(batch, tag, hidden, label):
	log_softmax = nn.LogSoftmax(dim=1)
	    
	embedding_prior = Normal(loc=pretrained_dict['embedding.weight'], scale=torch.ones_like(net.embedding.weight))

	lstmih0w_prior = Normal(loc=pretrained_dict['lstm.weight_ih_l0'], scale=torch.ones_like(net.lstm.weight_ih_l0))
	lstmih0b_prior = Normal(loc=pretrained_dict['lstm.bias_ih_l0'], scale=torch.ones_like(net.lstm.bias_ih_l0))
	lstmhh0w_prior = Normal(loc=pretrained_dict['lstm.weight_hh_l0'], scale=torch.ones_like(net.lstm.weight_hh_l0))
	lstmhh0b_prior = Normal(loc=pretrained_dict['lstm.bias_hh_l0'], scale=torch.ones_like(net.lstm.bias_hh_l0))

	lstmih1w_prior = Normal(loc=pretrained_dict['lstm.weight_ih_l1'], scale=torch.ones_like(net.lstm.weight_ih_l1))
	lstmih1b_prior = Normal(loc=pretrained_dict['lstm.bias_ih_l1'], scale=torch.ones_like(net.lstm.bias_ih_l1))
	lstmhh1w_prior = Normal(loc=pretrained_dict['lstm.weight_hh_l1'], scale=torch.ones_like(net.lstm.weight_hh_l1))
	lstmhh1b_prior = Normal(loc=pretrained_dict['lstm.bias_hh_l1'], scale=torch.ones_like(net.lstm.bias_hh_l1))

	fc1w_prior = Normal(loc=pretrained_dict['fc1.weight'], scale=torch.ones_like(net.fc1.weight))
	fc1b_prior = Normal(loc=pretrained_dict['fc1.bias'], scale=torch.ones_like(net.fc1.bias))
	    
	fc2w_prior = Normal(loc=pretrained_dict['fc2.weight'], scale=torch.ones_like(net.fc2.weight))
	fc2b_prior = Normal(loc=pretrained_dict['fc2.bias'], scale=torch.ones_like(net.fc2.bias))
	    
	priors = {'embedding.weight': embedding_prior, 'lstm.weight_ih_l0': lstmih0w_prior, 'lstm.bias_ih_l0': lstmih0b_prior,
	'lstm.weight_hh_l0': lstmhh0w_prior, 'lstm.bias_hh_l0': lstmhh0b_prior, 'lstm.weight_ih_l0': lstmih0w_prior,
	'lstm.bias_ih_l0': lstmih0b_prior,'lstm.weight_hh_l0': lstmhh0w_prior, 'lstm.bias_hh_l0': lstmhh0b_prior,
	'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior,  'fc2.weight': fc2w_prior, 'fc2.bias': fc2b_prior}
	    
	# lift module parameters to random variables sampled from the priors
	lifted_module = pyro.random_module("module", net, priors)
	# sample module
	lifted_model = lifted_module()

	output, hidden = lifted_model(batch, tag, hidden)
	    
	lhat = log_softmax(output)
	    
	pyro.sample("obs", Categorical(logits=lhat), obs=label)
            d[i] = d[i][1:-1]

print("Processed")

from math import log, exp
from random import shuffle

N = len(data)
M = int(max([line[header["tokenID.Renumbered"]] for line in data]))
L = int(max([line[header["WorkerId.Renumbered"]] for line in data]))

#print(M)
#print(L)
#quit()

alpha_Prior = Normal(Variable(torch.FloatTensor([0.0])),
                     Variable(torch.FloatTensor([10.0])))
beta1_Prior = Normal(Variable(torch.FloatTensor([0.0])),
                     Variable(torch.FloatTensor([10.0])))
decay_Prior = Normal(Variable(torch.FloatTensor([0.0])),
                     Variable(torch.FloatTensor([10.0])))
kappa_Prior = Normal(Variable(torch.FloatTensor([0.0])),
                     Variable(torch.FloatTensor([10.0])))

log_sigma_Prior = Normal(Variable(torch.FloatTensor([0.0])),
                         Variable(torch.FloatTensor([10.0])))

byParticipantIntercept_Variance_Prior = Normal(
    Variable(torch.FloatTensor([0.0])), Variable(torch.FloatTensor([10.0])))
byItemIntercept_Variance_Prior = Normal(Variable(torch.FloatTensor([0.0])),
                                        Variable(torch.FloatTensor([10.0])))
示例#28
0
def build_0d_dist(x, a, s):
    return Normal(loc=a * x.values, scale=s)
示例#29
0
 def guide(self, data):
     encoder = pyro.module('encoder', self.vae_encoder)
     z_mean, z_var = encoder.forward(data)
     pyro.sample('latent', Normal(z_mean, z_var.sqrt()))
示例#30
0
def model(data):
    data = torch.reshape(data, [60000, 50, 50])

    pyro.module("decode_l1", decode_l1)
    pyro.module("decode_l2", decode_l2)

    with pyro.plate('data', 60000, 64) as ix:
        # size = [64, 50, 50]
        batch = data[ix]

        #================= prior
        state_x = torch.zeros([64, 50, 50])
        state_z_pres = torch.ones([64, 1])
        state_z_where = None

        z_pres = []
        z_where = []

        for t in range(3):
            #==================== prior_step
            # size = [64, 50, 50]
            prev_x = state_x
            # size = [64, 1]
            prev_z_pres = state_z_pres
            # size = None or [64, 3]
            prev_z_where = state_z_where

            # size = [64, 1]
            cur_z_pres =\
                pyro.sample('z_pres_{}'.format(t),
                            Bernoulli(trial_probs[t] * prev_z_pres)
                            .to_event(1))

            sample_mask = cur_z_pres
            # size = [64, 3]
            cur_z_where =\
                pyro.sample('z_where_{}'.format(t),
                            Normal(torch.Tensor.expand(z_where_loc_prior, [64, 3]),
                                   torch.Tensor.expand(z_where_scale_prior, [64, 3]))
                            .mask(sample_mask)
                            .to_event(1))

            # size = [64, 50]
            cur_z_what =\
                pyro.sample('z_what_{}'.format(t),
                            Normal(torch.zeros([64, 50]),
                                   torch.ones([64, 50]))
                            .mask(sample_mask)
                            .to_event(1))

            #===== decode
            # size = [64, 784]
            y_att = torch.sigmoid(
                decode_l2(F.relu(decode_l1(cur_z_what))) - 2.0)
            #===== decode

            #===== window_to_image
            windows = y_att

            #===== expand_z_where
            # size = [64, 4]
            out = torch.cat((torch.zeros(64, 1), cur_z_where), 1)
            # size = [64, 6]
            out = torch.index_select(out, 1, expansion_indices)
            # size = [64, 2, 3]
            out = torch.Tensor.view(out, [64, 2, 3])
            theta = out
            #===== expand_z_where
            # size = [64, 50, 50, 2]
            grid = F.affine_grid(theta, [64, 1, 50, 50])
            # size = [64, 1, 50, 50]
            out = F.grid_sample(torch.Tensor.view(windows, [64, 1, 28, 28]),
                                grid)

            y = torch.Tensor.view(out, [64, 50, 50])
            #===== window_to_image

            # size = [64, 50, 50]
            cur_x = prev_x + (y * torch.Tensor.view(cur_z_pres, [64, 1, 1]))

            state_x = cur_x
            state_z_pres = cur_z_pres
            state_z_where = cur_z_where
            #==================== prior_step

            z_where.append(state_z_where)
            z_pres.append(state_z_pres)

        # size = [64, 50, 50]
        x = state_x
        #================== prior

        pyro.sample('obs',
                    Normal(torch.Tensor.view(x, [64, 2500]),
                           (0.3 * torch.ones(64, 2500))).to_event(1),
                    obs=torch.Tensor.view(batch, [64, 2500]))