def sample(self, sample_shape=torch.Size([])): """ :param ~torch.Size sample_shape: Sample shape, last dimension must be ``num_steps`` and must be broadcastable to ``(batch_size, num_steps)``. batch_size must be int not tuple. """ # shape: batch_size x num_steps x categorical_size shape = broadcast_shape( torch.Size(list(self.batch_shape) + [1, 1]), torch.Size(list(sample_shape) + [1]), torch.Size((1, 1, self.event_shape[-1])), ) # state: batch_size x state_dim state = OneHotCategorical(logits=self.initial_logits).sample() # sample: batch_size x num_steps x categorical_size sample = torch.zeros(shape) for i in range(shape[-2]): # batch_size x 1 x state_dim @ # batch_size x state_dim x categorical_size obs_logits = torch.matmul(state.unsqueeze(-2), self.observation_logits).squeeze(-2) sample[:, i, :] = OneHotCategorical(logits=obs_logits).sample() # batch_size x 1 x state_dim @ # batch_size x state_dim x state_dim trans_logits = torch.matmul(state.unsqueeze(-2), self.transition_logits).squeeze(-2) state = OneHotCategorical(logits=trans_logits).sample() return sample
def sample(self, params): pi, mean, log_std = params['pi'], params['mean'], params['log_std'] pi_onehot = OneHotCategorical(pi).sample() ac = torch.sum((mean + torch.randn_like(mean) * torch.exp(log_std)) * pi_onehot.unsqueeze(-1), 1) return ac
def forward(self, input, args, n_particles, test=False): """ n_particles is interpreted as 1 for now to not screw anything up """ if test: n_particles = 10 else: n_particles = 1 T = nn.Softmax(dim=0)(self.T) # NOTE: not in log-space pi = nn.Softmax(dim=0)(self.pi) # run the input and teacher-forcing inputs through the embedding layers here seq_len, batch_sz = input.size() emb = self.inp_embedding(input) hidden = self.init_hidden(batch_sz, self.nhid, 2) # bidirectional hidden_states, (_, _) = self.encoder(emb, hidden) hidden_states = hidden_states.repeat(1, n_particles, 1) # run the z-decoder at this point, evaluating the NLL at each step h = Variable( hidden_states.data.new(batch_sz * n_particles, self.hidden_size).zero_()) nlls = hidden_states.data.new(seq_len, batch_sz * n_particles) loss = 0 # now a log-prob prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles, self.z_dim) logits = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True) feed = None for i in range(seq_len): # build the next z sample - not differentiable! we don't train the inference network logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach() z = OneHotCategorical(logits=logits).sample() # this should be batch_sz x x_dim feed = self.project(torch.cat([h, z], 1)) # batch_sz x hidden_dim scores = torch.mm(feed, self.emit.t()) # batch_sz x x_dim NLL = nn.CrossEntropyLoss(reduce=False)( scores, input[i].repeat(n_particles)) KL = (logits.exp() * (logits - (prior_probs + 1e-16).log())).sum(1) loss += (NLL + KL) nlls[i] = NLL.data # set things up for next time if i != seq_len - 1: prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2) h = self.hidden_rnn(emb[i].repeat(n_particles, 1), h) # feed the next word into the RNN if n_particles != 1: loss = -log_sum_exp(-loss.view(n_particles, batch_sz), 0) + math.log(n_particles) NLL = -log_sum_exp( -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log( n_particles) # not quite accurate, but what can you do else: NLL = nlls # now, we calculate the final log-marginal estimator return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0