def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) # register all PyTorch (sub)modules with pyro pyro.module("dmm", self) # if on gpu we need the fully broadcast view of the rnn initial state # to be in contiguous gpu memory h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous() # push the observed x's through the rnn; # rnn_output contains the hidden state at each time step rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig) # reverse the time-ordering in the hidden state and un-pack it rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0)) # we enclose all the sample statements in the guide in a plate. # this marks that each datapoint is conditionally independent of the others. with pyro.plate("z_minibatch", len(mini_batch)): # sample the latents z one time step at a time # we wrap this loop in pyro.markov so that TraceEnum_ELBO can use multiple samples from the guide at each z for t in pyro.markov(range(1, T_max + 1)): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :]) # if we are using normalizing flows, we apply the sequence of transformations # parameterized by self.iafs to the base distribution defined in the previous line # to yield a transformed distribution that we use for q(z_t|...) if len(self.iafs) > 0: z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs) assert z_dist.event_shape == (self.z_q_0.size(0),) assert z_dist.batch_shape[-1:] == (len(mini_batch),) else: z_dist = dist.Normal(z_loc, z_scale) assert z_dist.event_shape == () assert z_dist.batch_shape[-2:] == (len(mini_batch), self.z_q_0.size(0)) # sample z_t from the distribution z_dist with pyro.poutine.scale(scale=annealing_factor): if len(self.iafs) > 0: # in output of normalizing flow, all dimensions are correlated (event shape is not empty) z_t = pyro.sample("z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1])) else: # when no normalizing flow used, ".to_event(1)" indicates latent dimensions are independent z_t = pyro.sample("z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1:t]) .to_event(1)) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t
def calc_loss(self, x_mb, x_mb_reversed, mini_batch_seq_lengths): """ :param x_mb: :param x_mb_reversed: :param mini_batch_seq_lengths: :return: rec_loss, kl_loss """ # this is the number of time steps we need to process in the mini-batch T_max = mini_batch_seq_lengths.max() # if on gpu we need the fully broadcast view of the rnn initial state # to be in contiguous gpu memory h_0_contig = self.h_0.expand(1, x_mb.size(0), self.rnn.hidden_size).contiguous() # push the observed x's through the rnn; # rnn_output contains the hidden state at each time step rnn_out, _ = self.rnn(x_mb_reversed, h_0_contig) # reverse the time-ordering in the hidden state and un-pack it rnn_out = poly.pad_and_reverse(rnn_out, mini_batch_seq_lengths) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = self.z_q_0.expand(x_mb.size(0), self.z_q_0.size(0)) rec_losses = torch.zeros((x_mb.size(0), T_max), device=x_mb.device) kl_divs = torch.zeros((x_mb.size(0), T_max), device=x_mb.device) # sample the latents z one time step at a time for t in range(1, T_max + 1): # get prior parameters z_prior_loc, z_prior_logvar = self.trans(z_prev) # sample from posterior q(z_t | z_{t-1}, x_{t:T}) z_loc, z_logvar = self.combiner(z_prev, rnn_out[:, t - 1, :]) z_t = sample_via_reparam(z_loc, z_logvar) # get predicted logits p(x_t|z_t) p_x_t = self.emitter(z_t).contiguous() # calculate loss kl_divs[:, t - 1] = self.kl_div(z_prior_loc, z_prior_logvar, z_loc, z_logvar) rec_loss = nn.BCEWithLogitsLoss(reduction='none')( p_x_t.view(-1), x_mb[:, t - 1, :].contiguous().view(-1)) rec_losses[:, t - 1] = rec_loss.view(x_mb.size(0), -1).mean(dim=1) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t x_mask = sequence_mask(mini_batch_seq_lengths) x_mask = x_mask.gt(0).view(-1) rec_loss = rec_losses.view(-1).masked_select(x_mask).mean() kl_loss = kl_divs.view(-1).masked_select(x_mask).mean() return rec_loss, kl_loss
def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) # register all PyTorch (sub)modules with pyro pyro.module("dmm", self) # if on gpu we need the fully broadcast view of the rnn initial state # to be in contiguous gpu memory h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous() # push the observed x's through the rnn; # rnn_output contains the hidden state at each time step rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig) # reverse the time-ordering in the hidden state and un-pack it rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0)) # we enclose all the sample statements in the guide in a iarange. # this marks that each datapoint is conditionally independent of the others. with pyro.iarange("z_minibatch", len(mini_batch)): # sample the latents z one time step at a time for t in range(1, T_max + 1): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :]) # if we are using normalizing flows, we apply the sequence of transformations # parameterized by self.iafs to the base distribution defined in the previous line # to yield a transformed distribution that we use for q(z_t|...) if len(self.iafs) > 0: z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs) else: z_dist = dist.Normal(z_loc, z_scale) assert z_dist.event_shape == () assert z_dist.batch_shape == (len(mini_batch), self.z_q_0.size(0)) # sample z_t from the distribution z_dist with pyro.poutine.scale(scale=annealing_factor): z_t = pyro.sample("z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1:t]) .independent(1)) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t
def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) # register all PyTorch (sub)modules with pyro pyro.module("dmm", self) # if on gpu we need the fully broadcast view of the rnn initial state # to be in contiguous gpu memory h_0_contig = self.h_0 if not self.use_cuda \ else self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous() # push the observed x's through the rnn; # rnn_output contains the hidden state at each time step rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig) # reverse the time-ordering in the hidden state and un-pack it rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = self.z_q_0 # sample the latents z one time step at a time for t in range(1, T_max + 1): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_mu, z_sigma = self.combiner(z_prev, rnn_output[:, t - 1, :]) z_dist = dist.normal # if we are using normalizing flows, we apply the sequence of transformations # parameterized by self.iafs to the base distribution defined in the previous line # to yield a transformed distribution that we use for q(z_t|...) if self.iafs.__len__() > 0: z_dist = TransformedDistribution(z_dist, self.iafs) # sample z_t from the distribution z_dist z_t = pyro.sample("z_%d" % t, z_dist, z_mu, z_sigma, log_pdf_mask=annealing_factor * mini_batch_mask[:, t - 1:t]) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t
def test_minibatch(dmm, mini_batch, args, sample_z=True): # Generate data that we can feed into the below fn. test_data_sequences = mini_batch.type(torch.FloatTensor) mini_batch_indices = torch.arange(0, test_data_sequences.size(0)) test_seq_lengths = torch.full( (test_data_sequences.size(0), ), test_data_sequences.size(1)).type(torch.IntTensor) # grab a fully prepped mini-batch using the helper function in the data loader mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \ = poly.get_mini_batch(mini_batch_indices, test_data_sequences, test_seq_lengths, cuda=args.cuda) # Get the initial RNN state. h_0 = dmm.h_0 h_0_contig = h_0.expand(1, mini_batch.size(0), dmm.rnn.hidden_size).contiguous() # Feed the test sequence into the RNN. rnn_output, rnn_hidden_state = dmm.rnn(mini_batch_reversed, h_0_contig) # Reverse the time ordering of the hidden state and unpack it. rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) # print(rnn_output) # print(rnn_output.shape) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = dmm.z_q_0.expand(mini_batch.size(0), dmm.z_q_0.size(0)) # sample the latents z one time step at a time T_max = mini_batch.size(1) sequence_z = [] sequence_output = [] for t in range(1, T_max + 1): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_loc, z_scale = dmm.combiner(z_prev, rnn_output[:, t - 1, :]) if sample_z: # if we are using normalizing flows, we apply the sequence of transformations # parameterized by self.iafs to the base distribution defined in the previous line # to yield a transformed distribution that we use for q(z_t|...) if len(dmm.iafs) > 0: z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), dmm.iafs) else: z_dist = dist.Normal(z_loc, z_scale) assert z_dist.event_shape == () assert z_dist.batch_shape == (len(mini_batch), dmm.z_q_0.size(0)) # sample z_t from the distribution z_dist annealing_factor = 1.0 with pyro.poutine.scale(scale=annealing_factor): z_t = pyro.sample( "z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1:t]).to_event(1)) else: z_t = z_loc z_t_np = z_t.detach().numpy() z_t_np = z_t_np[:, np.newaxis, :] sequence_z.append(z_t_np) # print("z_{}:".format(t), z_t) # print(z_t.shape) # compute the probabilities that parameterize the bernoulli likelihood emission_probs_t = dmm.emitter(z_t) emission_probs_t_np = emission_probs_t.detach().numpy() sequence_output.append(emission_probs_t_np) # print("x_{}:".format(t), emission_probs_t) # print(emission_probs_t.shape) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t # Run the model another few steps. n_extra_steps = 100 for t in range(1, n_extra_steps + 1): # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1}) z_loc, z_scale = dmm.trans(z_prev) # then sample z_t according to dist.Normal(z_loc, z_scale) # note that we use the reshape method so that the univariate Normal distribution # is treated as a multivariate Normal distribution with a diagonal covariance. annealing_factor = 1.0 with poutine.scale(scale=annealing_factor): z_t = pyro.sample( "z_%d" % t, dist.Normal(z_loc, z_scale) # .mask(mini_batch_mask[:, t - 1:t]) .to_event(1)) z_t_np = z_t.detach().numpy() z_t_np = z_t_np[:, np.newaxis, :] sequence_z.append(z_t_np) # compute the probabilities that parameterize the bernoulli likelihood emission_probs_t = dmm.emitter(z_t) emission_probs_t_np = emission_probs_t.detach().numpy() sequence_output.append(emission_probs_t_np) # the latent sampled at this time step will be conditioned upon # in the next time step so keep track of it z_prev = z_t sequence_z = np.concatenate(sequence_z, axis=1) sequence_output = np.concatenate(sequence_output, axis=1) # print(sequence_output.shape) # n_plots = 5 # fig, axes = plt.subplots(nrows=n_plots, ncols=1) # x = range(sequence_output.shape[1]) # for i in range(n_plots): # input = mini_batch[i, :].numpy().squeeze() # output = sequence_output[i, :] # axes[i].plot(range(input.shape[0]), input) # axes[i].plot(range(len(output)), output) # axes[i].grid() return mini_batch, sequence_z, sequence_output #fig
def test_minibatch(which_mini_batch, shuffled_indices): # compute which sequences in the training set we should grab mini_batch_start = (which_mini_batch * args.mini_batch_size) mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_test_data]) mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end] # grab a fully prepped mini-batch using the helper function in the data loader mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \ = poly.get_mini_batch(mini_batch_indices, test_data_sequences, test_seq_lengths, cuda=args.cuda) # Get the initial RNN state. h_0 = dmm.h_0 h_0_contig = h_0.expand(1, mini_batch.size(0), dmm.rnn.hidden_size).contiguous() # Feed the test sequence into the RNN. rnn_output, rnn_hidden_state = dmm.rnn(mini_batch_reversed, h_0_contig) # Reverse the time ordering of the hidden state and unpack it. rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) print(rnn_output) print(rnn_output.shape) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = dmm.z_q_0.expand(mini_batch.size(0), dmm.z_q_0.size(0)) # sample the latents z one time step at a time T_max = mini_batch.size(1) sequence_output = [] for t in range(1, T_max + 1): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_loc, z_scale = dmm.combiner(z_prev, rnn_output[:, t - 1, :]) # if we are using normalizing flows, we apply the sequence of transformations # parameterized by self.iafs to the base distribution defined in the previous line # to yield a transformed distribution that we use for q(z_t|...) if len(dmm.iafs) > 0: z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), dmm.iafs) else: z_dist = dist.Normal(z_loc, z_scale) assert z_dist.event_shape == () assert z_dist.batch_shape == (len(mini_batch), dmm.z_q_0.size(0)) # sample z_t from the distribution z_dist annealing_factor = 1.0 with pyro.poutine.scale(scale=annealing_factor): z_t = pyro.sample( "z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1:t]).to_event(1)) print("z_{}:".format(t), z_t) print(z_t.shape) # compute the probabilities that parameterize the bernoulli likelihood emission_probs_t = dmm.emitter(z_t) emission_probs_t_np = emission_probs_t.detach().numpy() sequence_output.append(emission_probs_t_np) print("x_{}:".format(t), emission_probs_t) print(emission_probs_t.shape) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t # Run the model another few steps. n_steps = 100 for t in range(1, n_steps + 1): # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1}) z_loc, z_scale = dmm.trans(z_prev) # then sample z_t according to dist.Normal(z_loc, z_scale) # note that we use the reshape method so that the univariate Normal distribution # is treated as a multivariate Normal distribution with a diagonal covariance. with poutine.scale(scale=annealing_factor): z_t = pyro.sample( "z_%d" % t, dist.Normal(z_loc, z_scale).mask( mini_batch_mask[:, t - 1:t]).to_event(1)) # compute the probabilities that parameterize the bernoulli likelihood emission_probs_t = dmm.emitter(z_t) emission_probs_t_np = emission_probs_t.detach().numpy() sequence_output.append(emission_probs_t_np) # # the next statement instructs pyro to observe x_t according to the # # bernoulli distribution p(x_t|z_t) # pyro.sample("obs_x_%d" % t, # # dist.Bernoulli(emission_probs_t) # dist.Normal(emission_probs_t, 0.5) # .mask(mini_batch_mask[:, t - 1:t]) # .to_event(1), # obs=mini_batch[:, t - 1, :]) # the latent sampled at this time step will be conditioned upon # in the next time step so keep track of it z_prev = z_t sequence_output = np.concatenate(sequence_output, axis=1) print(sequence_output.shape) n_plots = 5 fig, axes = plt.subplots(nrows=n_plots, ncols=1) x = range(sequence_output.shape[1]) for i in range(n_plots): input = mini_batch[i, :].numpy().squeeze() output = sequence_output[i, :] axes[i].plot(range(input.shape[0]), input) axes[i].plot(range(len(output)), output) axes[i].grid() # plt.plot(sequence_output[0, :]) plt.show()