def forward(self,
                mini_batch,
                mini_batch_reversed,
                mini_batch_mask,
                mini_batch_seq_lengths,
                annealing_factor=1.0):
        T_max = mini_batch.size(1)
        h_0_contig = self.h_0.expand(1, mini_batch.size(0),
                                     self.rnn.hidden_size).contiguous()
        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
        z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))
        z_container = []
        z_loc_container = []
        z_scale_container = []
        for t in range(1, T_max + 1):
            z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :])
            if args.clip != None:
                z_scale = torch.clamp(z_scale, min=args.clip)
            if self.use_cuda:
                eps = torch.randn(z_loc.size()).cuda()
            else:
                eps = torch.randn(z_loc.size())
            z_t = z_loc + z_scale * eps
            z_prev = z_t
            z_container.append(z_t)
            z_loc_container.append(z_loc)
            z_scale_container.append(z_scale)

        z_container = torch.stack(z_container)
        z_loc_container = torch.stack(z_loc_container)
        z_scale_container = torch.stack(z_scale_container)
        return z_container.transpose(0, 1), z_loc_container.transpose(
            0, 1), z_scale_container.transpose(0, 1)
    def forward(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)

        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous()
        # if any(torch.isnan(h_0_contig.reshape(-1))):
        #     for param in self.rnn.parameters():
        #         print(param)
        #     assert False
        # push the observed x's through the rnn;
        # rnn_output contains the hidden state at each time step
        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        # if True:
        #     if any(torch.isnan(rnn_output.data.reshape(-1))):
        #         assert False

        # reverse the time-ordering in the hidden state and un-pack it
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)

        # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
        z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))
        # z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))
        # z_prev = self.z_q_0

        # if any(torch.isnan(z_prev.reshape(-1))):
        #     print("z_prev")

        z_container = []
        z_loc_container = []
        z_scale_container = []
        for t in range(1,T_max+1):
            # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
            z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :])
            if args.clip != None:
                z_scale = torch.clamp(z_scale, min = args.clip)
            # Reparameterization Trick
            if self.use_cuda:
                eps = torch.randn(z_loc.size()).cuda()
            else: eps = torch.randn(z_loc.size())
            z_t = z_loc + z_scale * eps

            # the latent sampled at this time step will be conditioned upon
            # in the next time step so keep track of it
            z_prev = z_t
            z_container.append(z_t)
            z_loc_container.append(z_loc)
            z_scale_container.append(z_scale)
        
        z_container = torch.stack(z_container)
        z_loc_container = torch.stack(z_loc_container)
        z_scale_container = torch.stack(z_scale_container)
        return z_container.transpose(0,1), z_loc_container.transpose(0,1), z_scale_container.transpose(0,1)
Beispiel #3
0
    def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)
        # register all PyTorch (sub)modules with pyro
        pyro.module("dmm", self)

        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous()
        # push the observed x's through the rnn;
        # rnn_output contains the hidden state at each time step
        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        # reverse the time-ordering in the hidden state and un-pack it
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
        # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
        z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))

        # we enclose all the sample statements in the guide in a plate.
        # this marks that each datapoint is conditionally independent of the others.
        with pyro.plate("z_minibatch", len(mini_batch)):
            # sample the latents z one time step at a time
            # we wrap this loop in pyro.markov so that TraceEnum_ELBO can use multiple samples from the guide at each z
            for t in pyro.markov(range(1, T_max + 1)):
                # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
                z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :])

                # if we are using normalizing flows, we apply the sequence of transformations
                # parameterized by self.iafs to the base distribution defined in the previous line
                # to yield a transformed distribution that we use for q(z_t|...)
                if len(self.iafs) > 0:
                    z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs)
                    assert z_dist.event_shape == (self.z_q_0.size(0),)
                    assert z_dist.batch_shape[-1:] == (len(mini_batch),)
                else:
                    z_dist = dist.Normal(z_loc, z_scale)
                    assert z_dist.event_shape == ()
                    assert z_dist.batch_shape[-2:] == (len(mini_batch), self.z_q_0.size(0))

                # sample z_t from the distribution z_dist
                with pyro.poutine.scale(scale=annealing_factor):
                    if len(self.iafs) > 0:
                        # in output of normalizing flow, all dimensions are correlated (event shape is not empty)
                        z_t = pyro.sample("z_%d" % t,
                                          z_dist.mask(mini_batch_mask[:, t - 1]))
                    else:
                        # when no normalizing flow used, ".to_event(1)" indicates latent dimensions are independent
                        z_t = pyro.sample("z_%d" % t,
                                          z_dist.mask(mini_batch_mask[:, t - 1:t])
                                          .to_event(1))
                # the latent sampled at this time step will be conditioned upon in the next time step
                # so keep track of it
                z_prev = z_t
Beispiel #4
0
    def forward(self,
                mini_batch,
                mini_batch_reversed,
                mini_batch_mask,
                mini_batch_seq_lengths,
                annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)

        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        h_0_contig = self.h_0.expand(1, mini_batch.size(0),
                                     self.rnn.hidden_size).contiguous()
        # if any(torch.isnan(h_0_contig.reshape(-1))):
        #     print("h_0_contig")
        # push the observed x's through the rnn;
        # rnn_output contains the hidden state at each time step
        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        if self.rnn_check:
            if any(torch.isnan(rnn_output.data.reshape(-1))):
                # print("rnn_output First")
                # print(self.rnn.state_dict().items())
                # print(rnn_output)
                # torch.save(rnn_output, "out")
                # torch.save(self.rnn.state_dict().items, "dic")
                # torch.save(mini_batch_reversed, "mini_batch_reversed")
                # torch.save(h_0_contig, "h_0_contig")
                assert False

        # reverse the time-ordering in the hidden state and un-pack it
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
        # print(rnn_output.size())
        # assert False
        # if any(torch.isnan(rnn_output.reshape(-1))):
        #     print("rnn_output")
        # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
        z_prev = self.z_0[:mini_batch.size(0)]
        # if any(torch.isnan(z_prev.reshape(-1))):
        #     print("z_prev")

        x_container = []
        for t in range(1, T_max + 1):
            # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
            z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :])

            # Reparameterization Trick
            if self.use_cuda:
                eps = torch.randn(z_loc.size()).cuda()
            else:
                eps = torch.randn(z_loc.size())
            z_t = z_loc + z_scale * eps

            # compute the probabilities that parameterize the bernoulli likelihood
            emission_probs_t = self.emitter(z_t)

            # the next statement instructs pyro to observe x_t according to the
            # bernoulli distribution p(x_t|z_t)

            #Reparameterization Trick
            # eps = torch.rand(88)
            # # assert len(emission_probs_t) == 88
            # appxm = torch.log(eps) - torch.log(1-eps) + torch.log(probs) - torch.log(1-probs)
            # x = torch.sigmoid(args)

            # No Reparameterization Trick
            x = emission_probs_t

            #Reparameterization Trick
            if self.rpt:
                if self.use_cuda:
                    eps = torch.rand(88).cuda()
                else:
                    eps = torch.rand(88)
                # assert len(emission_probs_t) == 88
                appxm = torch.log(eps + 1e-20) - torch.log(
                    1 - eps + 1e-20) + torch.log(x + 1e-20) - torch.log(1 - x +
                                                                        1e-20)
                # appxm = torch.log(eps) - torch.log(1-eps) + torch.log(x) - torch.log(1-x)
                x = torch.sigmoid(appxm)

            # the latent sampled at this time step will be conditioned upon
            # in the next time step so keep track of it
            z_prev = z_t
            x_container.append(x)

        x_container = torch.stack(x_container)
        return x_container.transpose(0, 1)