Esempio n. 1
0
    def forward_backward(self, input):
        """
        input: Variable([seq_len x batch_size])
        """
        input = input.long()

        seq_len, batch_size = input.size()
        alpha = [None for i in range(seq_len)]
        beta = [None for i in range(seq_len)]

        T = F.log_softmax(self.T, 0)
        pi = F.log_softmax(self.pi, 0)
        emit = self.calc_emit()

        # forward pass
        alpha[0] = self.log_prob(input[0], (emit, )) + pi.view(1, -1)
        beta[-1] = Variable(torch.zeros(batch_size, self.z_dim))

        if T.is_cuda:
            beta[-1] = beta[-1].cuda()

        for t in range(1, seq_len):
            logprod = alpha[t - 1].unsqueeze(2).expand(
                batch_size, self.z_dim, self.z_dim) + T.t().unsqueeze(0)
            alpha[t] = self.log_prob(input[t],
                                     (emit, )) + log_sum_exp(logprod, 1)

        # keep around for now, but unnecessary in our models
        # for t in range(seq_len - 2, -1, -1):
        #     beta_expand = beta[t + 1].unsqueeze(1).expand(batch_size, self.z_dim, self.z_dim)
        #     beta[t] = log_sum_exp(beta_expand + T.t().unsqueeze(0), 2) + emit[input[t + 1]]

        log_marginal = log_sum_exp(alpha[-1] + beta[-1], dim=-1)

        return alpha, beta, log_marginal
Esempio n. 2
0
    def forward(self, input, args, n_particles, test=False):
        """
        n_particles is interpreted as 1 for now to not screw anything up
        """
        n_particles = 1
        T = nn.Softmax(dim=0)(self.T)  # NOTE: not in log-space
        pi = nn.Softmax(dim=0)(self.pi)
        emit = self.calc_emit()

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        z = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                             self.z_dim)

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)

        for i in range(seq_len):
            # logits = self.logits(torch.cat([hidden_states[i], h], 1))
            # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits)))
            logits = self.logits(hidden_states[i])

            # build the next z sample
            q = RelaxedOneHotCategorical(temperature=Variable(
                torch.Tensor([args.temp]).cuda()),
                                         logits=logits)
            z = q.sample()

            lse = log_sum_exp(logits, dim=1).view(-1, 1)
            log_probs = logits - lse

            # now, compute the log-likelihood of the data given this z-sample
            # so emission is [batch_sz x z_dim], i.e. emission[i, j] is the log-probability of getting this
            # data for element i given choice z
            emission = F.embedding(input[i].repeat(n_particles), emit)

            NLL = -log_sum_exp(emission + log_probs, 1)
            nlls[i] = NLL.data
            KL = (log_probs.exp() * (log_probs -
                                     (prior_probs + 1e-16).log())).sum(1)
            loss += (NLL + KL)

            if i != seq_len - 1:
                prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2)

        # now, we calculate the final log-marginal estimator
        return loss.sum(), nlls.sum(), (seq_len * batch_sz * n_particles), 0
Esempio n. 3
0
    def mutual_info(self, x, lengths):
        """
        *modified from https://github.com/jxhe/vae-lagging-encoder*
        Calculate  the approximate mutual information between z & x under distribution q(z|x).
            I(x, z) =E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z))

        :param x: input sentence. (seq_len, batch_size)
        :param x: (list[int]) length of each sequence in batch.
        :return: (float) approximate mutual information. can be non-negative when n_z > 1.
        """
        mu, logvar = self.encoder(x, lengths)
        x_batch, nz = mu.size()
        neg_entropy = (-0.5 * nz * math.log(2 * math.pi) - 0.5 *
                       (1 + logvar).sum(-1)).mean()

        # [z_batch, 1, nz]
        z, kld = self.reparameterize(mu, logvar)
        z = z.unsqueeze(1)

        # [1, x_batch, nz]
        mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0)
        var = logvar.exp()

        # (z_batch, x_batch, nz)
        dev = z - mu  # dimension broadcast

        # (z_batch, x_batch)
        log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
                      0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))

        # log q(z): aggregate posterior
        # [z_batch]
        log_qz = log_sum_exp(log_density, dim=1) - math.log(x_batch)

        return (neg_entropy - log_qz.mean(-1)).item()
Esempio n. 4
0
    def forward(self, input, args, n_particles, test=False):
        """
        If n_particles != 1, this the IWAE estimator, which doesn't make sense here
        """
        n_particles = 1
        T = F.log_softmax(self.T, 0)  # NOTE: in log-space
        pi = F.log_softmax(self.pi, 0)  # in log-space, intentionally
        emit = self.calc_emit()  # also in log-space

        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        elbo = 0
        NLL = 0

        # now a logit
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)

        prev_probs = None

        for i in range(seq_len):
            logits = F.log_softmax(self.logits(hidden_states[i]),
                                   1)  # log q(z_i)
            probs = logits.exp()  # q(z_i)
            emission = F.embedding(input[i].repeat(n_particles),
                                   emit)  # log p(x_i | z_i)

            # unary potentials
            elbo += (emission * probs).sum(1)  # E_q[log p(x_i | z_i)]
            NLL += -(emission * probs).sum(1).data

            # binary potentials q(z_t)q(z_{t - 1})log p(z_t | z_{t - 1})
            if i != 0:
                elbo += (prev_probs.unsqueeze(1) * probs.unsqueeze(2) *
                         T.unsqueeze(0)).sum(2).sum(1)
            else:
                # add the log p(z_1) term
                elbo += (probs * prior_logits).sum(1)

            # entropy term - E[-log q]
            elbo -= (logits * probs).sum(1)

            prev_probs = probs

        if n_particles != 1:
            elbo = log_sum_exp(elbo.view(n_particles, batch_sz),
                               0) - math.log(n_particles)
            NLL = NLL.view(n_particles, batch_sz).mean(0)

        # now, we calculate the final log-marginal estimator
        return -elbo.sum(), NLL.sum(), (seq_len * batch_sz), 0
Esempio n. 5
0
    def forward_backward(self, input, speedup=False):
        """
        Modify the forward-backward to compute beta[t], since we need that for checking the sampling in the particle filter case
        """
        input = input.long()

        seq_len, batch_size = input.size()
        alpha = [None for i in range(seq_len)]
        beta = [None for i in range(seq_len)]

        T = F.log_softmax(self.T, 0)
        pi = F.log_softmax(self.pi, 0)
        emit = self.calc_emit()

        # forward pass
        alpha[0] = self.log_prob(input[0], (emit, )) + pi.view(1, -1)
        beta[-1] = Variable(torch.zeros(batch_size, self.z_dim))

        if T.is_cuda:
            beta[-1] = beta[-1].cuda()

        for t in range(1, seq_len):
            logprod = alpha[t - 1].unsqueeze(2).expand(
                batch_size, self.z_dim, self.z_dim) + T.t().unsqueeze(0)
            alpha[t] = self.log_prob(input[t],
                                     (emit, )) + log_sum_exp(logprod, 1)

        log_marginal = log_sum_exp(alpha[-1] + beta[-1], dim=-1)

        if speedup:
            return 0, 0, log_marginal
        else:
            for t in range(seq_len - 2, -1, -1):
                beta[t] = log_sum_exp(
                    T.unsqueeze(0) + beta[t + 1].unsqueeze(2) +
                    F.embedding(input[t + 1], emit).unsqueeze(2), 1)

            return [
                alpha[i] + beta[i] - log_marginal.unsqueeze(1)
                for i in range(seq_len)
            ], 0, log_marginal
Esempio n. 6
0
    def forward(self, input, args, test=False):
        NO_HMM = False

        seq_len, batch_size = input.size()
        # compute the loss as the sum of the forward-backward loss
        if not NO_HMM:
            alpha, _, log_marginal = self.forward_backward(input)
        emb = self.inp_embedding(input)
        T = F.log_softmax(self.T, 0)
        pi = F.log_softmax(self.pi,
                           0).unsqueeze(0).expand(batch_size, self.z_dim)
        if self.separate_opt:
            pi = pi.detach()
            T = T.detach()

        h = (Variable(torch.zeros(batch_size, self.lstm_hidden_size).cuda()),
             Variable(torch.zeros(batch_size, self.lstm_hidden_size).cuda()))

        NLL = 0

        # now, compute the filtered posterior and together with the LSTM feed data into the net-output
        # note that \alpha(t) contains information about the current x, so we need to prop forward
        current_state = None
        for i in range(seq_len):
            if not NO_HMM:
                if i == 0:
                    hmm_post = pi
                else:
                    hmm_post = log_sum_exp(
                        T.unsqueeze(0) + current_state.unsqueeze(1), 2)

            if NO_HMM:
                hmm_post = Variable(torch.zeros(batch_size, self.z_dim).cuda())
            else:
                hmm_post = hmm_post.exp()

            scores = self.project(torch.cat([h[0], hmm_post], 1))
            NLL += nn.CrossEntropyLoss(size_average=False)(scores, input[i])

            # feed information from the current state into the next prediction (i.e. teacher-forcing)
            h = self.lstm(emb[i], h)

            if not NO_HMM:
                current_state = F.log_softmax(alpha[i], 1)
                if self.separate_opt:
                    current_state = current_state.detach()
        if NO_HMM:
            loss = NLL.sum()
        else:
            loss = -log_marginal.sum() + NLL.sum()
        return loss, NLL.data.sum()
Esempio n. 7
0
    def sampled_elbo(self, input, args, n_particles, emb, hidden_states):
        seq_len, batch_sz = input.size()
        T = nn.Softmax(dim=0)(self.T)  # NOTE: not in log-space
        pi = nn.Softmax(dim=0)(self.pi)
        emit = self.calc_emit()

        hidden_states = hidden_states.repeat(1, n_particles, 1)
        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a value in probability space
        prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                             self.z_dim)

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)

        for i in range(seq_len):
            logits = self.logits(hidden_states[i])

            # build the next z sample
            p = RelaxedOneHotCategorical(temperature=self.temp_prior,
                                         probs=prior_probs)
            q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits)
            z = q.rsample()

            log_probs = F.log_softmax(logits, dim=1)

            # now, compute the log-likelihood of the data given this z-sample
            # so emission is [batch_sz x z_dim], i.e. emission[i, j] is the log-probability of getting this
            # data for element i given choice z
            emission = F.embedding(input[i].repeat(n_particles), emit)

            NLL = -log_sum_exp(emission + log_probs, 1)
            nlls[i] = NLL.data
            KL = q.log_prob(z) - p.log_prob(z)  # pretty inexact
            loss += (NLL + KL)

            if i != seq_len - 1:
                prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2)

        (loss.sum() /
         (seq_len * batch_sz * n_particles)).backward(retain_graph=True)
        return loss, 0, seq_len * batch_sz * n_particles, 0
Esempio n. 8
0
    def sampled_filter(self, input, args, n_particles, emb, hidden_states):
        seq_len, batch_sz = input.size()
        T = F.log_softmax(self.T, 0)  # NOTE: in log-space
        pi = F.log_softmax(self.pi, 0)  # NOTE: in log-space
        emit = self.calc_emit()

        hidden_states = hidden_states.repeat(1, n_particles, 1)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        accumulated_weights = -math.log(
            n_particles)  # will contain log w_{t - 1}
        resamples = 0

        # in log probability space
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)

        for i in range(seq_len):
            # the approximate posterior comes from the same thing as before
            logits = self.logits(hidden_states[i])

            if not self.training:
                # this is crucial!!
                p = OneHotCategorical(logits=prior_logits)
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                p = RelaxedOneHotCategorical(temperature=self.temp_prior,
                                             logits=prior_logits)
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                z = q.rsample()

            # now, compute the log-likelihood of the data given this z-sample
            emission = F.embedding(input[i].repeat(n_particles), emit)
            NLL = -(emission * z).sum(1)
            # NLL = -self.decode(z, input[i].repeat(n_particles), (emit,))  # diff. w.r.t. z
            nlls[i] = NLL.data

            # compute the weight using `reweight` on page (4)
            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            Z = log_sum_exp(wa, dim=0)  # line 7

            loss += Z  # line 8
            accumulated_weights = wa - Z  # F.log_softmax(wa, dim=0)  # line 9

            # sample ancestors, and reindex everything
            if args.filter:
                probs = accumulated_weights.data.exp()
                probs += 0.01
                probs = probs / probs.sum(0, keepdim=True)
                effective_sample_size = 1. / probs.pow(2).sum(0)

                # probs is [n_particles, batch_sz]
                # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]]
                # offsets   [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]]

                # resample / RSAMP
                if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                    resamples += 1
                    ancestors = torch.multinomial(probs.transpose(0, 1),
                                                  n_particles, True)

                    # now, reindex, which is the most important thing
                    offsets = n_particles * torch.arange(batch_sz).unsqueeze(
                        1).repeat(1, n_particles).long()
                    if ancestors.is_cuda:
                        offsets = offsets.cuda()
                    unrolled_idx = Variable(ancestors + offsets).view(-1)
                    z = torch.index_select(z, 0, unrolled_idx)

                    # reset accumulated_weights
                    accumulated_weights = -math.log(
                        n_particles)  # will contain log w_{t - 1}

            if i != seq_len - 1:
                # now in log-probability space
                prior_logits = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2)

        if self.training:
            (-loss.sum() /
             (seq_len * batch_sz * n_particles)).backward(retain_graph=True)
        return -loss.sum(), nlls.sum(), seq_len * batch_sz * n_particles, 0
Esempio n. 9
0
 def sampled_iwae(self, input, args, n_particles, loss, tokens):
     seq_len, batch_sz = input.size()
     loss = -log_sum_exp(-loss.view(n_particles, batch_sz),
                         0) + math.log(n_particles)
     (loss.sum() / tokens).backward(retain_graph=True)
Esempio n. 10
0
    def forward(self, input, args, n_particles, test=False):
        """
        n_particles is interpreted as 1 for now to not screw anything up
        """
        if test:
            n_particles = 10
        else:
            n_particles = 1
        T = nn.Softmax(dim=0)(self.T)  # NOTE: not in log-space
        pi = nn.Softmax(dim=0)(self.pi)

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = Variable(
            hidden_states.data.new(batch_sz * n_particles,
                                   self.hidden_size).zero_())

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                             self.z_dim)

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)
        feed = None

        for i in range(seq_len):
            # build the next z sample - not differentiable! we don't train the inference network
            logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach()
            z = OneHotCategorical(logits=logits).sample()

            # this should be batch_sz x x_dim
            feed = self.project(torch.cat([h, z], 1))  # batch_sz x hidden_dim
            scores = torch.mm(feed, self.emit.t())  # batch_sz x x_dim

            NLL = nn.CrossEntropyLoss(reduce=False)(
                scores, input[i].repeat(n_particles))
            KL = (logits.exp() * (logits - (prior_probs + 1e-16).log())).sum(1)
            loss += (NLL + KL)

            nlls[i] = NLL.data

            # set things up for next time
            if i != seq_len - 1:
                prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2)
                h = self.hidden_rnn(emb[i].repeat(n_particles, 1),
                                    h)  # feed the next word into the RNN

        if n_particles != 1:
            loss = -log_sum_exp(-loss.view(n_particles, batch_sz),
                                0) + math.log(n_particles)
            NLL = -log_sum_exp(
                -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log(
                    n_particles)  # not quite accurate, but what can you do
        else:
            NLL = nlls

        # now, we calculate the final log-marginal estimator
        return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
Esempio n. 11
0
    def forward(self, input, args, n_particles, test=False):
        """
        evaluation is the IWAE-10 bound
        """
        pi = F.log_softmax(self.pi, 0)

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = (Variable(
            hidden_states.data.new(batch_sz * n_particles,
                                   self.hidden_size).zero_()),
             Variable(
                 hidden_states.data.new(batch_sz * n_particles,
                                        self.hidden_size).zero_()))

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)
        prior_h = (Variable(torch.zeros(batch_sz * n_particles, 50).cuda()),
                   Variable(torch.zeros(batch_sz * n_particles, 50).cuda()))

        accumulated_weights = -math.log(
            n_particles)  # will contain log w_{t - 1}
        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)
        feed = None

        x_emb = self.lockdrop(emb, self.dropout_x)

        if test:
            pdb.set_trace()

        for i in range(seq_len):
            # build the next z sample - not differentiable! we don't train the inference network
            logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach()

            # if test:
            q = OneHotCategorical(logits=logits)
            p = OneHotCategorical(logits=prior_logits)
            a = q.sample()
            # else:
            #     q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits)
            #     p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits)
            #     a = q.rsample()

            # to guard against being too crazy
            b = a + 1e-16
            z = b / b.sum(1, keepdim=True)

            # this should be batch_sz x x_dim
            scores = torch.mm(self.project(torch.cat([h[0], z], 1)),
                              self.emit.t())

            NLL = nn.CrossEntropyLoss(reduce=False)(
                scores, input[i].repeat(n_particles))
            nlls[i] = NLL.data

            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            Z = log_sum_exp(wa, dim=0)  # line 7

            loss += Z  # line 8
            accumulated_weights = wa - Z  # F.log_softmax(wa, dim=0)  # line 9

            probs = accumulated_weights.data.exp()
            probs += 0.01
            probs = probs / probs.sum(0, keepdim=True)
            effective_sample_size = 1. / probs.pow(2).sum(0)

            if any_nans(probs):
                pdb.set_trace()

            # probs is [n_particles, batch_sz]
            # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]]
            # offsets   [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]]

            # resample / RSAMP
            if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                ancestors = torch.multinomial(probs.transpose(0, 1),
                                              n_particles, True)

                # now, reindex, which is the most important thing
                offsets = n_particles * torch.arange(batch_sz).unsqueeze(
                    1).repeat(1, n_particles).long()
                if ancestors.is_cuda:
                    offsets = offsets.cuda()
                unrolled_idx = Variable(ancestors + offsets).view(-1)

                # shuffle!
                z = torch.index_select(z, 0, unrolled_idx)
                a, b = h
                h = torch.index_select(a, 0, unrolled_idx), torch.index_select(
                    b, 0, unrolled_idx)
                a, b = prior_h
                prior_h = torch.index_select(a, 0,
                                             unrolled_idx), torch.index_select(
                                                 b, 0, unrolled_idx)

                # reset accumulated_weights
                accumulated_weights = -math.log(
                    n_particles)  # will contain log w_{t - 1}

            # set things up for next time
            if i != seq_len - 1:
                feed = torch.cat(
                    [emb[i].repeat(n_particles, 1),
                     self.z_emb(z)], 1)
                prior_h = self.z_decoder(feed, prior_h)
                prior_logits = F.log_softmax(self.project_z(prior_h[0]), 1)
                h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1),
                                    h)  # feed the next word into the RNN

        NLL = nlls

        # now, we calculate the final log-marginal estimator
        return loss.sum(), NLL.sum(), (seq_len * batch_sz * n_particles), 0
Esempio n. 12
0
    def forward(self, input, args, n_particles, test=False):
        """
        evaluation is the IWAE-10 bound
        """
        if test:
            n_particles = 10
        else:
            n_particles = 1
        pi = F.log_softmax(self.pi, 0)

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = (Variable(
            hidden_states.data.new(batch_sz * n_particles,
                                   self.hidden_size).zero_()),
             Variable(
                 hidden_states.data.new(batch_sz * n_particles,
                                        self.hidden_size).zero_()))

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)
        prior_h = (Variable(torch.zeros(batch_sz * n_particles, 50).cuda()),
                   Variable(torch.zeros(batch_sz * n_particles, 50).cuda()))

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)
        feed = None

        x_emb = self.lockdrop(emb, self.dropout_x)

        for i in range(seq_len):
            # build the next z sample - not differentiable! we don't train the inference network
            logits = F.log_softmax(self.logits(hidden_states[i]), 1)

            if test:
                q = OneHotCategorical(logits=logits)
                # p = OneHotCategorical(logits=prior_logits)
                z = q.sample()
            else:
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                # p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits)
                z = q.rsample()

            # this should be batch_sz x x_dim
            scores = torch.mm(self.project(torch.cat([h[0], z], 1)),
                              self.emit.t())

            NLL = nn.CrossEntropyLoss(reduce=False)(
                scores, input[i].repeat(n_particles))
            # KL = q.log_prob(z) - p.log_prob(z)
            KL = (logits.exp() * (logits - prior_logits)).sum(1)
            loss += (NLL + KL)
            # else:
            #     loss += (NLL + args.anneal * KL)

            nlls[i] = NLL.data

            # set things up for next time
            if i != seq_len - 1:
                feed = torch.cat(
                    [emb[i].repeat(n_particles, 1),
                     self.z_emb(z)], 1)
                prior_h = self.z_decoder(feed, prior_h)
                prior_logits = F.log_softmax(self.project_z(prior_h[0]), 1)
                h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1),
                                    h)  # feed the next word into the RNN

        if n_particles != 1:
            loss = -log_sum_exp(-loss.view(n_particles, batch_sz),
                                0) + math.log(n_particles)
            NLL = -log_sum_exp(
                -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log(
                    n_particles)  # not quite accurate, but what can you do
        else:
            NLL = nlls

        # now, we calculate the final log-marginal estimator
        return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
Esempio n. 13
0
    def forward(self, input, args, n_particles, test=False):
        """
        The major difference is that now we use a GRU to predict the prior z logits, instead of using a linear map
        T. I think trying to fit this GRU is really hard, I'm kind of concerned
        """
        if test:
            n_particles = 10
        else:
            n_particles = 1
        pi = F.log_softmax(self.pi, 0)

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = Variable(
            hidden_states.data.new(batch_sz * n_particles,
                                   self.hidden_size).zero_())

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)
        prior_h = Variable(torch.zeros(batch_sz * n_particles, 50).cuda())

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)
        feed = None

        # use dropout on the teacher-forcing
        x_emb = self.lockdrop(emb, self.dropout_x)

        for i in range(seq_len):
            # build the next z sample - not differentiable! we don't train the inference network
            logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach()
            z = OneHotCategorical(logits=logits).sample()

            # this should be batch_sz x x_dim
            scores = torch.mm(self.project(torch.cat([h, z], 1)),
                              self.emit.t())

            NLL = nn.CrossEntropyLoss(reduce=False)(
                scores, input[i].repeat(n_particles))
            KL = (logits.exp() * (logits - prior_logits)).sum(1)
            loss += (NLL + KL)

            nlls[i] = NLL.data

            # set things up for next time
            if i != seq_len - 1:
                feed = torch.cat(
                    [emb[i].repeat(n_particles, 1),
                     self.z_emb(z)], 1)
                prior_h = self.z_decoder(feed, prior_h)
                prior_logits = F.log_softmax(self.project_z(prior_h), 1)
                h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1),
                                    h)  # feed the next word into the RNN

        if n_particles != 1:
            loss = -log_sum_exp(-loss.view(n_particles, batch_sz),
                                0) + math.log(n_particles)
            NLL = -log_sum_exp(
                -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log(
                    n_particles)  # not quite accurate, but what can you do
        else:
            NLL = nlls

        # now, we calculate the final log-marginal estimator
        return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
Esempio n. 14
0
    def forward(self, input, args, n_particles, test=False):
        T = F.log_softmax(self.T, 0)  # NOTE: in log-space
        pi = F.log_softmax(self.pi, 0)  # NOTE: in log-space
        emit = self.calc_emit()

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        accumulated_weights = -math.log(
            n_particles)  # will contain log w_{t - 1}
        resamples = 0

        # in log probability space
        prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                             self.z_dim)

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)

        for i in range(seq_len):
            # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits)))
            logits = self.logits(
                nn.functional.relu(
                    self.z_decoder(torch.cat([hidden_states[i], h], 1),
                                   logits)))

            # build the next z sample
            if any_nans(logits):
                pdb.set_trace()
            if test:
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                z = q.rsample()
            h = z

            # prior
            if any_nans(prior_probs):
                pdb.set_trace()
            if test:
                p = OneHotCategorical(logits=prior_probs)
            else:
                p = RelaxedOneHotCategorical(temperature=self.temp_prior,
                                             logits=prior_probs)

            if any_nans(prior_probs):
                pdb.set_trace()
            if any_nans(logits):
                pdb.set_trace()

            # now, compute the log-likelihood of the data given this z-sample
            NLL = -self.decode(z, input[i].repeat(n_particles),
                               (emit, ))  # diff. w.r.t. z
            nlls[i] = NLL.data

            # compute the weight using `reweight` on page (4)
            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            # sample ancestors, and reindex everything
            Z = log_sum_exp(wa, dim=0)  # line 7

            loss += Z  # line 8
            accumulated_weights = wa - Z  # line 9

            if args.filter:
                probs = accumulated_weights.data.exp()
                probs += 0.01
                probs = probs / probs.sum(0, keepdim=True)
                effective_sample_size = 1. / probs.pow(2).sum(0)

                # probs is [n_particles, batch_sz]
                # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]]
                # offsets   [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]]

                # resample / RSAMP
                if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                    resamples += 1
                    ancestors = torch.multinomial(probs.transpose(0, 1),
                                                  n_particles, True)

                    # now, reindex, which is the most important thing
                    offsets = n_particles * torch.arange(batch_sz).unsqueeze(
                        1).repeat(1, n_particles).long()
                    if ancestors.is_cuda:
                        offsets = offsets.cuda()
                    unrolled_idx = Variable(ancestors + offsets).view(-1)
                    h = torch.index_select(h, 0, unrolled_idx)

                    # reset accumulated_weights
                    accumulated_weights = -math.log(
                        n_particles)  # will contain log w_{t - 1}

            if i != seq_len - 1:
                # now in probability space
                prior_probs = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2)

                # let's normalize things - slower, but safer
                # prior_probs += 0.01
                # prior_probs = prior_probs / prior_probs.sum(1, keepdim=True)

            # # if ((prior_probs.sum(1) - 1) > 1e-3).any()[0]:
            #     pdb.set_trace()

        if any_nans(loss):
            pdb.set_trace()

        # now, we calculate the final log-marginal estimator
        return -loss.sum(), nlls.sum(), (seq_len * batch_sz *
                                         n_particles), resamples