def perm_coupling(L, parity=False, num_blocks=1):
    modules = [flow.CheckerSplit((L, L))]
    for _ in range(num_blocks):
        modules.append(
            flow.RealNVPPermuteInverseAndLogProb(in_channels=1,
                                                 hidden_size=16,
                                                 parity=parity))
    modules.append(flow.CheckerConcat((L, L)))
    net = flow.RealNVPSequential(*modules)
    return net.to(device)
Beispiel #2
0
 def __init__(self, latent_shape, flow_depth, hidden_size, flow_std):
     super().__init__()
     modules = [flow.CheckerSplit(latent_shape)]
     for flow_num in range(flow_depth):
         modules.append(
             flow.RealNVPPermuteSampleAndLogProb(
                 in_channels=1,
                 hidden_size=hidden_size,
                 # invert mask opposite to prior
                 parity=True if flow_num % 2 == 1 else False))
     modules.append(flow.CheckerConcat(latent_shape))
     self.q_nu = flow.RealNVPSequential(*modules)
     self.q_nu_0 = distributions.Normal(loc=0.0, scale=flow_std)
     self.latent_shape = latent_shape
 def __init__(self, latent_shape, flow_depth, flow_std, hidden_size):
     super().__init__()
     modules = [flow.CheckerSplit(latent_shape)]
     for flow_num in range(flow_depth):
         modules.append(
             flow.RealNVPPermuteInverseAndLogProb(
                 in_channels=1,
                 hidden_size=hidden_size,
                 # invert mask opposite to prior
                 parity=True if flow_num % 2 == 0 else False))
     modules.append(flow.CheckerConcat(latent_shape))
     self.r_nu = flow.RealNVPSequential(*modules)
     self.log_r_nu_0 = BinaryDistribution(latent_shape=latent_shape,
                                          scale=flow_std)
 def __init__(self, latent_shape, flow_depth, flow_std, hidden_size):
     super().__init__()
     self.split = flow.CheckerSplit(latent_shape)
     self.concat = flow.CheckerConcat(latent_shape)
     modules = []
     for flow_num in range(flow_depth):
         modules.append(
             flow.RealNVPPermuteInverseAndLogProb(
                 # input: nu, z_transf (black squares of checkerboard) and z_const (white squares)
                 in_channels=3,
                 hidden_size=hidden_size,
                 # invert mask opposite to prior
                 parity=True if flow_num % 2 == 0 else False))
     self.r_nu = flow.SplitSequential(*modules)
     self.log_r_0 = distributions.StandardNormalLogProb()
def _test_memory_perm(L, num_blocks):
    p = torch.distributions.Normal(0, 1)
    x = p.sample((L, L)).to(device)
    x = x.unsqueeze(0)
    m0, max0 = _get_memory()
    modules = [flow.CheckerSplit((L, L))]
    for _ in range(num_blocks):
        modules.append(
            flow.RealNVPPermuteInverseAndLogProb(in_channels=1,
                                                 hidden_size=16))
    modules.append(flow.CheckerConcat((L, L)))
    net = flow.RealNVPSequential(*modules).to(device)
    m1, max1 = _get_memory()
    print('init mem, max:', m1 - m0, max1 - max0)
    y, log_x = net(x)
    m2, max2 = _get_memory()
    print('fwd mem, max:', m2 - m1, max2 - max1)