def perm_coupling(L, parity=False, num_blocks=1): modules = [flow.CheckerSplit((L, L))] for _ in range(num_blocks): modules.append( flow.RealNVPPermuteInverseAndLogProb(in_channels=1, hidden_size=16, parity=parity)) modules.append(flow.CheckerConcat((L, L))) net = flow.RealNVPSequential(*modules) return net.to(device)
def __init__(self, latent_shape, flow_depth, hidden_size, flow_std): super().__init__() modules = [flow.CheckerSplit(latent_shape)] for flow_num in range(flow_depth): modules.append( flow.RealNVPPermuteSampleAndLogProb( in_channels=1, hidden_size=hidden_size, # invert mask opposite to prior parity=True if flow_num % 2 == 1 else False)) modules.append(flow.CheckerConcat(latent_shape)) self.q_nu = flow.RealNVPSequential(*modules) self.q_nu_0 = distributions.Normal(loc=0.0, scale=flow_std) self.latent_shape = latent_shape
def __init__(self, latent_shape, flow_depth, flow_std, hidden_size): super().__init__() modules = [flow.CheckerSplit(latent_shape)] for flow_num in range(flow_depth): modules.append( flow.RealNVPPermuteInverseAndLogProb( in_channels=1, hidden_size=hidden_size, # invert mask opposite to prior parity=True if flow_num % 2 == 0 else False)) modules.append(flow.CheckerConcat(latent_shape)) self.r_nu = flow.RealNVPSequential(*modules) self.log_r_nu_0 = BinaryDistribution(latent_shape=latent_shape, scale=flow_std)
def __init__(self, latent_shape, flow_depth, flow_std, hidden_size): super().__init__() self.split = flow.CheckerSplit(latent_shape) self.concat = flow.CheckerConcat(latent_shape) modules = [] for flow_num in range(flow_depth): modules.append( flow.RealNVPPermuteInverseAndLogProb( # input: nu, z_transf (black squares of checkerboard) and z_const (white squares) in_channels=3, hidden_size=hidden_size, # invert mask opposite to prior parity=True if flow_num % 2 == 0 else False)) self.r_nu = flow.SplitSequential(*modules) self.log_r_0 = distributions.StandardNormalLogProb()
def _test_memory_perm(L, num_blocks): p = torch.distributions.Normal(0, 1) x = p.sample((L, L)).to(device) x = x.unsqueeze(0) m0, max0 = _get_memory() modules = [flow.CheckerSplit((L, L))] for _ in range(num_blocks): modules.append( flow.RealNVPPermuteInverseAndLogProb(in_channels=1, hidden_size=16)) modules.append(flow.CheckerConcat((L, L))) net = flow.RealNVPSequential(*modules).to(device) m1, max1 = _get_memory() print('init mem, max:', m1 - m0, max1 - max0) y, log_x = net(x) m2, max2 = _get_memory() print('fwd mem, max:', m2 - m1, max2 - max1)