示例#1
0
文件: main.py 项目: ssumin6/real_nvp
def test(net=None, prior=None, device=None, ckpt=None):
    assert net is not None or ckpt is not None
    if net is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print("device : %s" % (device))
        net = Net(N=4, input_dim=2, hidden_dim=256, device=device).to(device)
        prior = MultivariateNormal(
            torch.zeros(2).to(device),
            torch.eye(2).to(device))
        ckpt = torch.load(ckpt)
        net.load_state_dict(ckpt['net'])
        print("Load checkpoint at Epoch %d." % (ckpt["epoch"]))

    with torch.no_grad():
        net.eval()
        d = prior.sample_n(128)
        pred_x, _ = net.forward(d, reverse=True)

    draw_plt(d[:, 0], d[:, 1], name="z")
    draw_plt(pred_x[:, 0], pred_x[:, 1], name="pred_x")
class BijectiveTransform(nn.Module):
    """Implementation fo bijective transformation in Real NVP.
    """
    def __init__(self, v_size, layer_num, scale_net_hidden_layer_num=1, scale_net_hidden_layer_size=256, 
            translate_net_hidden_layer_num=1, translate_net_hidden_layer_size=256, condition_vector_size=0):
        """
        Parames
        ---
        condition_vector_size: size of an additional vector which is concatenated to hidden variables.
            z = f([x ; cond])
        """
        super().__init__()
        self._v_size = v_size
        self._condition_vector_size = condition_vector_size
        self._layer_num = layer_num
        m = torch.cat([torch.ones(v_size//2), torch.zeros(v_size - v_size//2)])
        self.register_buffer("_masks", torch.stack([m.clone() if i%2==0 else 1. - m.clone() for i in range(self._layer_num)]))
        self._s = nn.ModuleList([NN(v_size + condition_vector_size, v_size, scale_net_hidden_layer_num, scale_net_hidden_layer_size, torch.relu, torch.tanh) for _ in range(layer_num)])
        self._t = nn.ModuleList([NN(v_size + condition_vector_size, v_size, translate_net_hidden_layer_num, translate_net_hidden_layer_size, torch.relu) for _ in range(layer_num)])
        self._prior = MultivariateNormal(torch.zeros(v_size), torch.eye(v_size))  # N(0, 1)

    def _calc_log_determinant(self, s):
        """log det(diag( exp( s(x_{1:d}) ) )"""
        return s.sum(dim=1)

    def infer(self, x, cond=None):
        """Inference z = f(x)

        Return
        ---
        z, log_det_J
        where log_det_J = log det(\frac{\partial f}{\partial x})
        """
        batch_size = x.size(0)
        assert x.size() == torch.Size([batch_size, self._v_size]), x.size()
        if self._condition_vector_size > 0:
            cond.size() == torch.Size([batch_size, self._condition_vector_size]), cond.size()

        log_det_J = 0
        z = x
        for i in range(self._layer_num):
            mask = self._masks[i]
            z_ = torch.cat([mask * z, cond], dim=1) if cond is not None else mask * z
            s = self._s[i](z_) * (1. - mask)
            t = self._t[i](z_) * (1. - mask)
            z = mask * z + (1. - mask) * z * s.exp() + t
            log_det_J += self._calc_log_determinant(s)
        return z, log_det_J

    def generate(self, z, cond=None):
        """Generation x = f^-1(z)

        Return
        ---
        x, log_det_inv_J
        where log_det_inv_J = log det(\frac{\partial f^{-1}}{\partial z})
        """
        batch_size = z.size(0)
        assert z.size() == torch.Size([batch_size, self._v_size]), z.size()
        if self._condition_vector_size > 0:
            cond.size() == torch.Size([batch_size, self._condition_vector_size]), cond.size()

        log_det_inv_J = 0
        x = z
        for i in reversed(range(self._layer_num)):
            mask = self._masks[i]
            x_ = torch.cat([mask * x, cond], dim=1) if cond is not None else mask * x
            s = self._s[i](x_) * (1. - mask)
            t = self._t[i](x_) * (1. - mask)
            x = mask * x + ((1. - mask) * x - t) * (-s).exp()
            log_det_inv_J += self._calc_log_determinant(-s)
        return x, log_det_inv_J

    def calc_log_likelihood(self, x_batch, cond_batch=None):
        """Maxmiize log p(x) = log p(f(x)) + log | det(\frac{\partial f}{\partial x}) |
        """
        batch_size = x_batch.size(0)
        assert x_batch.size() == torch.Size([batch_size, self._v_size])
        if self._condition_vector_size > 0:
            cond_batch.size() == torch.Size([batch_size, self._condition_vector_size]), cond_batch.size()

        z, log_det_J = self.infer(x_batch, cond_batch)
        log_pz = self._prior.log_prob(z)
        log_px = (log_pz + log_det_J).mean()
        return log_px

    def sample(self, sample_size, cond=None):
        if self._condition_vector_size > 0:
            cond.size() == torch.Size([sample_size, self._condition_vector_size]), cond.size()
        z = self._prior.sample_n(sample_size)
        x, _ = self.generate(z, cond)
        return x.detach()
示例#3
0
 def sample(self, num_samples):
     distr = MultivariateNormal(self.means, self.std)
     return distr.sample_n(num_samples)