示例#1
0
    def sample(self, sample_shape=torch.Size([])):
        """
        :param ~torch.Size sample_shape: Sample shape, last dimension must be
            ``num_steps`` and must be broadcastable to
            ``(batch_size, num_steps)``. batch_size must be int not tuple.
        """
        # shape: batch_size x num_steps x categorical_size
        shape = broadcast_shape(
            torch.Size(list(self.batch_shape) + [1, 1]),
            torch.Size(list(sample_shape) + [1]),
            torch.Size((1, 1, self.event_shape[-1])),
        )
        # state: batch_size x state_dim
        state = OneHotCategorical(logits=self.initial_logits).sample()
        # sample: batch_size x num_steps x categorical_size
        sample = torch.zeros(shape)
        for i in range(shape[-2]):
            # batch_size x 1 x state_dim @
            # batch_size x state_dim x categorical_size
            obs_logits = torch.matmul(state.unsqueeze(-2),
                                      self.observation_logits).squeeze(-2)
            sample[:, i, :] = OneHotCategorical(logits=obs_logits).sample()
            # batch_size x 1 x state_dim @
            # batch_size x state_dim x state_dim
            trans_logits = torch.matmul(state.unsqueeze(-2),
                                        self.transition_logits).squeeze(-2)
            state = OneHotCategorical(logits=trans_logits).sample()

        return sample
示例#2
0
 def sample(self, params):
     pi, mean, log_std = params['pi'], params['mean'], params['log_std']
     pi_onehot = OneHotCategorical(pi).sample()
     ac = torch.sum((mean + torch.randn_like(mean) * torch.exp(log_std)) *
                    pi_onehot.unsqueeze(-1), 1)
     return ac
示例#3
0
    def forward(self, input, args, n_particles, test=False):
        """
        n_particles is interpreted as 1 for now to not screw anything up
        """
        if test:
            n_particles = 10
        else:
            n_particles = 1
        T = nn.Softmax(dim=0)(self.T)  # NOTE: not in log-space
        pi = nn.Softmax(dim=0)(self.pi)

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = Variable(
            hidden_states.data.new(batch_sz * n_particles,
                                   self.hidden_size).zero_())

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                             self.z_dim)

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)
        feed = None

        for i in range(seq_len):
            # build the next z sample - not differentiable! we don't train the inference network
            logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach()
            z = OneHotCategorical(logits=logits).sample()

            # this should be batch_sz x x_dim
            feed = self.project(torch.cat([h, z], 1))  # batch_sz x hidden_dim
            scores = torch.mm(feed, self.emit.t())  # batch_sz x x_dim

            NLL = nn.CrossEntropyLoss(reduce=False)(
                scores, input[i].repeat(n_particles))
            KL = (logits.exp() * (logits - (prior_probs + 1e-16).log())).sum(1)
            loss += (NLL + KL)

            nlls[i] = NLL.data

            # set things up for next time
            if i != seq_len - 1:
                prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2)
                h = self.hidden_rnn(emb[i].repeat(n_particles, 1),
                                    h)  # feed the next word into the RNN

        if n_particles != 1:
            loss = -log_sum_exp(-loss.view(n_particles, batch_sz),
                                0) + math.log(n_particles)
            NLL = -log_sum_exp(
                -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log(
                    n_particles)  # not quite accurate, but what can you do
        else:
            NLL = nlls

        # now, we calculate the final log-marginal estimator
        return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0