示例#1
0
    def forward(self, input, targets, args, n_particles, criterion, test=False):
        """
        This version takes the inputs, and does not expose the logits, but instead
        computes the losses directly
        """

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (h, c) = self.encoder(emb, hidden)

        # teacher-forcing
        out_emb = self.dropout(self.dec_embedding(targets))

        # now, we'll replicate it for each particle - it's currently [seq_len x batch_sz x nhid]
        hidden_states = hidden_states.repeat(1, n_particles, 1)
        out_emb = out_emb.repeat(1, n_particles, 1)
        # now [seq_len x (n_particles x batch_sz) x nhid]
        # out_emb, hidden_states should be viewed as (n_particles x batch_sz) - this means that's true for h as well

        # run the z-decoder at this point, evaluating the NLL at each step
        p_h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)  # initially zero
        h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)
        d_h = self.init_hidden(batch_sz * n_particles, self.nhid, squeeze=True)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        accumulated_weights = -math.log(n_particles)  # will contain log w_{t - 1}
        resamples = 0

        for i in range(seq_len):
            h = self.z_decoder(hidden_states[i], h)
            logits = self.logits(h)

            # build the next z sample
            if test:
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits)
                z = q.rsample()
            h = z

            # prior
            if test:
                p = OneHotCategorical(logits=p_h)
            else:
                p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=p_h)

            # now, compute the log-likelihood of the data given this mean, and the input out_emb
            d_h = self.decoder(torch.cat([z, out_emb[i]], 1), d_h)
            decoder_logits = self.out_embedding(d_h)
            NLL = criterion(decoder_logits, input[i].repeat(n_particles))
            nlls[i] = NLL.data

            # compute the weight using `reweight` on page (4)
            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + args.anneal * (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            # sample ancestors, and reindex everything
            Z = log_sum_exp(wa, dim=0)  # line 7
            if (Z.data > 0.1).any():
                pdb.set_trace()

            loss += Z  # line 8
            accumulated_weights = wa - Z  # line 9
            probs = accumulated_weights.data.exp()
            probs += 0.01
            probs = probs / probs.sum(0, keepdim=True)
            effective_sample_size = 1./probs.pow(2).sum(0)

            # resample / RSAMP if 3 batch elements need resampling
            if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                resamples += 1
                ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True)

                # now, reindex, which is the most important thing
                offsets = n_particles * torch.arange(batch_sz).unsqueeze(1).repeat(1, n_particles).long()
                if ancestors.is_cuda:
                    offsets = offsets.cuda()
                unrolled_idx = Variable(ancestors.t().contiguous()+offsets).view(-1)
                h = torch.index_select(h, 0, unrolled_idx)
                p_h = torch.index_select(p_h, 0, unrolled_idx)
                d_h = torch.index_select(d_h, 0, unrolled_idx)

                # reset accumulated_weights
                accumulated_weights = -math.log(n_particles)  # will contain log w_{t - 1}

            if i != seq_len - 1:
                # build the next mean prediction, feeding in the correct ancestor
                p_h = self.ar_prior_logits(torch.cat([h, out_emb[i]], 1), p_h)

        # now, we calculate the final log-marginal estimator
        nll = nlls.view(seq_len, n_particles, batch_sz).mean(1).sum()
        return -loss.sum(), nll, (seq_len * batch_sz), resamples
示例#2
0
    def sampled_filter(self, input, args, n_particles, emb, hidden_states):
        seq_len, batch_sz = input.size()
        T = F.log_softmax(self.T, 0)  # NOTE: in log-space
        pi = F.log_softmax(self.pi, 0)  # NOTE: in log-space
        emit = self.calc_emit()

        hidden_states = hidden_states.repeat(1, n_particles, 1)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        accumulated_weights = -math.log(
            n_particles)  # will contain log w_{t - 1}
        resamples = 0

        # in log probability space
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)

        for i in range(seq_len):
            # the approximate posterior comes from the same thing as before
            logits = self.logits(hidden_states[i])

            if not self.training:
                # this is crucial!!
                p = OneHotCategorical(logits=prior_logits)
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                p = RelaxedOneHotCategorical(temperature=self.temp_prior,
                                             logits=prior_logits)
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                z = q.rsample()

            # now, compute the log-likelihood of the data given this z-sample
            emission = F.embedding(input[i].repeat(n_particles), emit)
            NLL = -(emission * z).sum(1)
            # NLL = -self.decode(z, input[i].repeat(n_particles), (emit,))  # diff. w.r.t. z
            nlls[i] = NLL.data

            # compute the weight using `reweight` on page (4)
            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            Z = log_sum_exp(wa, dim=0)  # line 7

            loss += Z  # line 8
            accumulated_weights = wa - Z  # F.log_softmax(wa, dim=0)  # line 9

            # sample ancestors, and reindex everything
            if args.filter:
                probs = accumulated_weights.data.exp()
                probs += 0.01
                probs = probs / probs.sum(0, keepdim=True)
                effective_sample_size = 1. / probs.pow(2).sum(0)

                # probs is [n_particles, batch_sz]
                # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]]
                # offsets   [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]]

                # resample / RSAMP
                if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                    resamples += 1
                    ancestors = torch.multinomial(probs.transpose(0, 1),
                                                  n_particles, True)

                    # now, reindex, which is the most important thing
                    offsets = n_particles * torch.arange(batch_sz).unsqueeze(
                        1).repeat(1, n_particles).long()
                    if ancestors.is_cuda:
                        offsets = offsets.cuda()
                    unrolled_idx = Variable(ancestors + offsets).view(-1)
                    z = torch.index_select(z, 0, unrolled_idx)

                    # reset accumulated_weights
                    accumulated_weights = -math.log(
                        n_particles)  # will contain log w_{t - 1}

            if i != seq_len - 1:
                # now in log-probability space
                prior_logits = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2)

        if self.training:
            (-loss.sum() /
             (seq_len * batch_sz * n_particles)).backward(retain_graph=True)
        return -loss.sum(), nlls.sum(), seq_len * batch_sz * n_particles, 0
示例#3
0
    def forward(self, input, args, n_particles, test=False):
        T = nn.Softmax(dim=0)(self.T)  # NOTE: not in log-space
        pi = nn.Softmax(dim=0)(self.pi)
        emit = self.calc_emit()

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        z = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                             self.z_dim)

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)

        for i in range(seq_len):
            # logits = self.logits(torch.cat([hidden_states[i], h], 1))
            # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits)))
            logits = self.logits(
                nn.functional.relu(
                    self.z_decoder(torch.cat([hidden_states[i], z], 1),
                                   logits)))

            # build the next z sample
            if test:
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                z = q.rsample()

            lse = log_sum_exp(logits, dim=1).view(-1, 1)
            log_probs = logits - lse

            # now, compute the log-likelihood of the data given this z-sample
            # so emission is [batch_sz x z_dim], i.e. emission[i, j] is the log-probability of getting this
            # data for element i given choice z
            emission = F.embedding(input[i].repeat(n_particles), emit)

            NLL = -log_sum_exp(emission + log_probs, 1)
            nlls[i] = NLL.data
            KL = (log_probs.exp() * (log_probs -
                                     (prior_probs + 1e-16).log())).sum(1)
            loss += (NLL + KL)

            if i != seq_len - 1:
                prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2)

        # now, we calculate the final log-marginal estimator
        return loss.sum(), nlls.sum(), (seq_len * batch_sz * n_particles), 0
示例#4
0
    def forward(self, input, args, n_particles, test=False):
        """
        evaluation is the IWAE-10 bound
        """
        if test:
            n_particles = 10
        else:
            n_particles = 1
        pi = F.log_softmax(self.pi, 0)

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = (Variable(
            hidden_states.data.new(batch_sz * n_particles,
                                   self.hidden_size).zero_()),
             Variable(
                 hidden_states.data.new(batch_sz * n_particles,
                                        self.hidden_size).zero_()))

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)
        prior_h = (Variable(torch.zeros(batch_sz * n_particles, 50).cuda()),
                   Variable(torch.zeros(batch_sz * n_particles, 50).cuda()))

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)
        feed = None

        x_emb = self.lockdrop(emb, self.dropout_x)

        for i in range(seq_len):
            # build the next z sample - not differentiable! we don't train the inference network
            logits = F.log_softmax(self.logits(hidden_states[i]), 1)

            if test:
                q = OneHotCategorical(logits=logits)
                # p = OneHotCategorical(logits=prior_logits)
                z = q.sample()
            else:
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                # p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits)
                z = q.rsample()

            # this should be batch_sz x x_dim
            scores = torch.mm(self.project(torch.cat([h[0], z], 1)),
                              self.emit.t())

            NLL = nn.CrossEntropyLoss(reduce=False)(
                scores, input[i].repeat(n_particles))
            # KL = q.log_prob(z) - p.log_prob(z)
            KL = (logits.exp() * (logits - prior_logits)).sum(1)
            loss += (NLL + KL)
            # else:
            #     loss += (NLL + args.anneal * KL)

            nlls[i] = NLL.data

            # set things up for next time
            if i != seq_len - 1:
                feed = torch.cat(
                    [emb[i].repeat(n_particles, 1),
                     self.z_emb(z)], 1)
                prior_h = self.z_decoder(feed, prior_h)
                prior_logits = F.log_softmax(self.project_z(prior_h[0]), 1)
                h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1),
                                    h)  # feed the next word into the RNN

        if n_particles != 1:
            loss = -log_sum_exp(-loss.view(n_particles, batch_sz),
                                0) + math.log(n_particles)
            NLL = -log_sum_exp(
                -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log(
                    n_particles)  # not quite accurate, but what can you do
        else:
            NLL = nlls

        # now, we calculate the final log-marginal estimator
        return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
示例#5
0
    def forward(self, input, args, n_particles, test=False):
        T = F.log_softmax(self.T, 0)  # NOTE: in log-space
        pi = F.log_softmax(self.pi, 0)  # NOTE: in log-space
        emit = self.calc_emit()

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        accumulated_weights = -math.log(
            n_particles)  # will contain log w_{t - 1}
        resamples = 0

        # in log probability space
        prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                             self.z_dim)

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)

        for i in range(seq_len):
            # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits)))
            logits = self.logits(
                nn.functional.relu(
                    self.z_decoder(torch.cat([hidden_states[i], h], 1),
                                   logits)))

            # build the next z sample
            if any_nans(logits):
                pdb.set_trace()
            if test:
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                z = q.rsample()
            h = z

            # prior
            if any_nans(prior_probs):
                pdb.set_trace()
            if test:
                p = OneHotCategorical(logits=prior_probs)
            else:
                p = RelaxedOneHotCategorical(temperature=self.temp_prior,
                                             logits=prior_probs)

            if any_nans(prior_probs):
                pdb.set_trace()
            if any_nans(logits):
                pdb.set_trace()

            # now, compute the log-likelihood of the data given this z-sample
            NLL = -self.decode(z, input[i].repeat(n_particles),
                               (emit, ))  # diff. w.r.t. z
            nlls[i] = NLL.data

            # compute the weight using `reweight` on page (4)
            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            # sample ancestors, and reindex everything
            Z = log_sum_exp(wa, dim=0)  # line 7

            loss += Z  # line 8
            accumulated_weights = wa - Z  # line 9

            if args.filter:
                probs = accumulated_weights.data.exp()
                probs += 0.01
                probs = probs / probs.sum(0, keepdim=True)
                effective_sample_size = 1. / probs.pow(2).sum(0)

                # probs is [n_particles, batch_sz]
                # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]]
                # offsets   [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]]

                # resample / RSAMP
                if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                    resamples += 1
                    ancestors = torch.multinomial(probs.transpose(0, 1),
                                                  n_particles, True)

                    # now, reindex, which is the most important thing
                    offsets = n_particles * torch.arange(batch_sz).unsqueeze(
                        1).repeat(1, n_particles).long()
                    if ancestors.is_cuda:
                        offsets = offsets.cuda()
                    unrolled_idx = Variable(ancestors + offsets).view(-1)
                    h = torch.index_select(h, 0, unrolled_idx)

                    # reset accumulated_weights
                    accumulated_weights = -math.log(
                        n_particles)  # will contain log w_{t - 1}

            if i != seq_len - 1:
                # now in probability space
                prior_probs = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2)

                # let's normalize things - slower, but safer
                # prior_probs += 0.01
                # prior_probs = prior_probs / prior_probs.sum(1, keepdim=True)

            # # if ((prior_probs.sum(1) - 1) > 1e-3).any()[0]:
            #     pdb.set_trace()

        if any_nans(loss):
            pdb.set_trace()

        # now, we calculate the final log-marginal estimator
        return -loss.sum(), nlls.sum(), (seq_len * batch_sz *
                                         n_particles), resamples