def forward(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): T_max = mini_batch.size(1) h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous() rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig) rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0)) z_container = [] z_loc_container = [] z_scale_container = [] for t in range(1, T_max + 1): z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :]) if args.clip != None: z_scale = torch.clamp(z_scale, min=args.clip) if self.use_cuda: eps = torch.randn(z_loc.size()).cuda() else: eps = torch.randn(z_loc.size()) z_t = z_loc + z_scale * eps z_prev = z_t z_container.append(z_t) z_loc_container.append(z_loc) z_scale_container.append(z_scale) z_container = torch.stack(z_container) z_loc_container = torch.stack(z_loc_container) z_scale_container = torch.stack(z_scale_container) return z_container.transpose(0, 1), z_loc_container.transpose( 0, 1), z_scale_container.transpose(0, 1)
def forward(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) # if on gpu we need the fully broadcast view of the rnn initial state # to be in contiguous gpu memory h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous() # if any(torch.isnan(h_0_contig.reshape(-1))): # for param in self.rnn.parameters(): # print(param) # assert False # push the observed x's through the rnn; # rnn_output contains the hidden state at each time step rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig) # if True: # if any(torch.isnan(rnn_output.data.reshape(-1))): # assert False # reverse the time-ordering in the hidden state and un-pack it rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0)) # z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0)) # z_prev = self.z_q_0 # if any(torch.isnan(z_prev.reshape(-1))): # print("z_prev") z_container = [] z_loc_container = [] z_scale_container = [] for t in range(1,T_max+1): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :]) if args.clip != None: z_scale = torch.clamp(z_scale, min = args.clip) # Reparameterization Trick if self.use_cuda: eps = torch.randn(z_loc.size()).cuda() else: eps = torch.randn(z_loc.size()) z_t = z_loc + z_scale * eps # the latent sampled at this time step will be conditioned upon # in the next time step so keep track of it z_prev = z_t z_container.append(z_t) z_loc_container.append(z_loc) z_scale_container.append(z_scale) z_container = torch.stack(z_container) z_loc_container = torch.stack(z_loc_container) z_scale_container = torch.stack(z_scale_container) return z_container.transpose(0,1), z_loc_container.transpose(0,1), z_scale_container.transpose(0,1)
def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) # register all PyTorch (sub)modules with pyro pyro.module("dmm", self) # if on gpu we need the fully broadcast view of the rnn initial state # to be in contiguous gpu memory h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous() # push the observed x's through the rnn; # rnn_output contains the hidden state at each time step rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig) # reverse the time-ordering in the hidden state and un-pack it rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0)) # we enclose all the sample statements in the guide in a plate. # this marks that each datapoint is conditionally independent of the others. with pyro.plate("z_minibatch", len(mini_batch)): # sample the latents z one time step at a time # we wrap this loop in pyro.markov so that TraceEnum_ELBO can use multiple samples from the guide at each z for t in pyro.markov(range(1, T_max + 1)): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :]) # if we are using normalizing flows, we apply the sequence of transformations # parameterized by self.iafs to the base distribution defined in the previous line # to yield a transformed distribution that we use for q(z_t|...) if len(self.iafs) > 0: z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs) assert z_dist.event_shape == (self.z_q_0.size(0),) assert z_dist.batch_shape[-1:] == (len(mini_batch),) else: z_dist = dist.Normal(z_loc, z_scale) assert z_dist.event_shape == () assert z_dist.batch_shape[-2:] == (len(mini_batch), self.z_q_0.size(0)) # sample z_t from the distribution z_dist with pyro.poutine.scale(scale=annealing_factor): if len(self.iafs) > 0: # in output of normalizing flow, all dimensions are correlated (event shape is not empty) z_t = pyro.sample("z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1])) else: # when no normalizing flow used, ".to_event(1)" indicates latent dimensions are independent z_t = pyro.sample("z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1:t]) .to_event(1)) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t
def forward(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) # if on gpu we need the fully broadcast view of the rnn initial state # to be in contiguous gpu memory h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous() # if any(torch.isnan(h_0_contig.reshape(-1))): # print("h_0_contig") # push the observed x's through the rnn; # rnn_output contains the hidden state at each time step rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig) if self.rnn_check: if any(torch.isnan(rnn_output.data.reshape(-1))): # print("rnn_output First") # print(self.rnn.state_dict().items()) # print(rnn_output) # torch.save(rnn_output, "out") # torch.save(self.rnn.state_dict().items, "dic") # torch.save(mini_batch_reversed, "mini_batch_reversed") # torch.save(h_0_contig, "h_0_contig") assert False # reverse the time-ordering in the hidden state and un-pack it rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) # print(rnn_output.size()) # assert False # if any(torch.isnan(rnn_output.reshape(-1))): # print("rnn_output") # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = self.z_0[:mini_batch.size(0)] # if any(torch.isnan(z_prev.reshape(-1))): # print("z_prev") x_container = [] for t in range(1, T_max + 1): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :]) # Reparameterization Trick if self.use_cuda: eps = torch.randn(z_loc.size()).cuda() else: eps = torch.randn(z_loc.size()) z_t = z_loc + z_scale * eps # compute the probabilities that parameterize the bernoulli likelihood emission_probs_t = self.emitter(z_t) # the next statement instructs pyro to observe x_t according to the # bernoulli distribution p(x_t|z_t) #Reparameterization Trick # eps = torch.rand(88) # # assert len(emission_probs_t) == 88 # appxm = torch.log(eps) - torch.log(1-eps) + torch.log(probs) - torch.log(1-probs) # x = torch.sigmoid(args) # No Reparameterization Trick x = emission_probs_t #Reparameterization Trick if self.rpt: if self.use_cuda: eps = torch.rand(88).cuda() else: eps = torch.rand(88) # assert len(emission_probs_t) == 88 appxm = torch.log(eps + 1e-20) - torch.log( 1 - eps + 1e-20) + torch.log(x + 1e-20) - torch.log(1 - x + 1e-20) # appxm = torch.log(eps) - torch.log(1-eps) + torch.log(x) - torch.log(1-x) x = torch.sigmoid(appxm) # the latent sampled at this time step will be conditioned upon # in the next time step so keep track of it z_prev = z_t x_container.append(x) x_container = torch.stack(x_container) return x_container.transpose(0, 1)