コード例 #1
0
 def z_to_x(self, z):
     x = z
     if not self.last_scale:
         x = squeeze(x)
         x, factored_z = x.chunk(2, dim=1)
         x = self.next_scale(x, reverse=True)
         x = torch.cat((x, factored_z), dim=1)
         x = reverse_squeeze(x)
     for step in reversed(self.steps):
         x = step(x, reverse=True)
     return x
コード例 #2
0
    def x_to_z(self, x, log_det=None):
        batch_size = x.shape[0]
        # forward
        if log_det is None:
            log_det = torch.zeros(batch_size).to(DEVICE)

        z = x
        for step in self.steps:
            z, log_det = step(z, log_det)
        if not self.last_scale:
            z = squeeze(z)
            x, factored_z = z.chunk(2, dim=1)
            z, log_det = self.next_scale(x, log_det)
            z = torch.cat((z, factored_z), dim=1)
            z = reverse_squeeze(z)
        return z, log_det
コード例 #3
0
    def z_to_x(self, z):
        # z -> x (inverse of f)
        x = z
        if not self.last_scale:

            x = squeeze(x)
            x, factored_z = x.chunk(2, dim=1)
            x = self.next_scale(x, reverse=True)
            x = torch.cat((x, factored_z), dim=1)

            for coupling in reversed(self.channel_transforms):
                x, _ = coupling(x, reverse=True)
            x = reverse_squeeze(x)
        for coupling in reversed(self.checker_board):
            x, _ = coupling(x, reverse=True)
        return x
コード例 #4
0
    def x_to_z(self, x, log_det=None):
        batch_size = x.shape[0]
        # forward
        if log_det is None:
            log_det = torch.zeros(batch_size).to(DEVICE)
        z = x
        for checker in self.checker_board:
            z, delta_log_det = checker(z)
            log_det += delta_log_det
        if not self.last_scale:
            z = squeeze(z)
            for op in self.channel_transforms:
                z, delta_log_det = op(z)
                log_det += delta_log_det

            x, factored_z = z.chunk(2, dim=1)
            z, log_det = self.next_scale(x, log_det)
            z = torch.cat((z, factored_z), dim=1)

            z = reverse_squeeze(z)
        return z, log_det
コード例 #5
0
 def f(self, x):
     # maps x -> z, and returns the log determinant (reduced)
     z = squeeze(x)
     z, log_det = self.scales(z)
     z = reverse_squeeze(z)
     return z, log_det
コード例 #6
0
 def g(self, z):
     # z -> x (inverse of f)
     x = squeeze(z)
     x = self.scales(x, reverse=True)
     x = reverse_squeeze(x)
     return x