def test(net=None, prior=None, device=None, ckpt=None): assert net is not None or ckpt is not None if net is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("device : %s" % (device)) net = Net(N=4, input_dim=2, hidden_dim=256, device=device).to(device) prior = MultivariateNormal( torch.zeros(2).to(device), torch.eye(2).to(device)) ckpt = torch.load(ckpt) net.load_state_dict(ckpt['net']) print("Load checkpoint at Epoch %d." % (ckpt["epoch"])) with torch.no_grad(): net.eval() d = prior.sample_n(128) pred_x, _ = net.forward(d, reverse=True) draw_plt(d[:, 0], d[:, 1], name="z") draw_plt(pred_x[:, 0], pred_x[:, 1], name="pred_x")
class BijectiveTransform(nn.Module): """Implementation fo bijective transformation in Real NVP. """ def __init__(self, v_size, layer_num, scale_net_hidden_layer_num=1, scale_net_hidden_layer_size=256, translate_net_hidden_layer_num=1, translate_net_hidden_layer_size=256, condition_vector_size=0): """ Parames --- condition_vector_size: size of an additional vector which is concatenated to hidden variables. z = f([x ; cond]) """ super().__init__() self._v_size = v_size self._condition_vector_size = condition_vector_size self._layer_num = layer_num m = torch.cat([torch.ones(v_size//2), torch.zeros(v_size - v_size//2)]) self.register_buffer("_masks", torch.stack([m.clone() if i%2==0 else 1. - m.clone() for i in range(self._layer_num)])) self._s = nn.ModuleList([NN(v_size + condition_vector_size, v_size, scale_net_hidden_layer_num, scale_net_hidden_layer_size, torch.relu, torch.tanh) for _ in range(layer_num)]) self._t = nn.ModuleList([NN(v_size + condition_vector_size, v_size, translate_net_hidden_layer_num, translate_net_hidden_layer_size, torch.relu) for _ in range(layer_num)]) self._prior = MultivariateNormal(torch.zeros(v_size), torch.eye(v_size)) # N(0, 1) def _calc_log_determinant(self, s): """log det(diag( exp( s(x_{1:d}) ) )""" return s.sum(dim=1) def infer(self, x, cond=None): """Inference z = f(x) Return --- z, log_det_J where log_det_J = log det(\frac{\partial f}{\partial x}) """ batch_size = x.size(0) assert x.size() == torch.Size([batch_size, self._v_size]), x.size() if self._condition_vector_size > 0: cond.size() == torch.Size([batch_size, self._condition_vector_size]), cond.size() log_det_J = 0 z = x for i in range(self._layer_num): mask = self._masks[i] z_ = torch.cat([mask * z, cond], dim=1) if cond is not None else mask * z s = self._s[i](z_) * (1. - mask) t = self._t[i](z_) * (1. - mask) z = mask * z + (1. - mask) * z * s.exp() + t log_det_J += self._calc_log_determinant(s) return z, log_det_J def generate(self, z, cond=None): """Generation x = f^-1(z) Return --- x, log_det_inv_J where log_det_inv_J = log det(\frac{\partial f^{-1}}{\partial z}) """ batch_size = z.size(0) assert z.size() == torch.Size([batch_size, self._v_size]), z.size() if self._condition_vector_size > 0: cond.size() == torch.Size([batch_size, self._condition_vector_size]), cond.size() log_det_inv_J = 0 x = z for i in reversed(range(self._layer_num)): mask = self._masks[i] x_ = torch.cat([mask * x, cond], dim=1) if cond is not None else mask * x s = self._s[i](x_) * (1. - mask) t = self._t[i](x_) * (1. - mask) x = mask * x + ((1. - mask) * x - t) * (-s).exp() log_det_inv_J += self._calc_log_determinant(-s) return x, log_det_inv_J def calc_log_likelihood(self, x_batch, cond_batch=None): """Maxmiize log p(x) = log p(f(x)) + log | det(\frac{\partial f}{\partial x}) | """ batch_size = x_batch.size(0) assert x_batch.size() == torch.Size([batch_size, self._v_size]) if self._condition_vector_size > 0: cond_batch.size() == torch.Size([batch_size, self._condition_vector_size]), cond_batch.size() z, log_det_J = self.infer(x_batch, cond_batch) log_pz = self._prior.log_prob(z) log_px = (log_pz + log_det_J).mean() return log_px def sample(self, sample_size, cond=None): if self._condition_vector_size > 0: cond.size() == torch.Size([sample_size, self._condition_vector_size]), cond.size() z = self._prior.sample_n(sample_size) x, _ = self.generate(z, cond) return x.detach()
def sample(self, num_samples): distr = MultivariateNormal(self.means, self.std) return distr.sample_n(num_samples)