def test_actnorm(self): for test in (returns_correct_shape, is_bijective): test(self, flows.ActNorm()) # Test data-dependent initialization inputs = random.uniform(random.PRNGKey(0), (20, 3), minval=-10.0, maxval=10.0) input_dim = inputs.shape[1] init_fun = flows.Serial(flows.ActNorm()) params, direct_fun, inverse_fun = init_fun(random.PRNGKey(0), inputs.shape[1:], init_inputs=inputs) mapped_inputs, _ = direct_fun(params, inputs) self.assertFalse((np.abs(mapped_inputs.mean(0)) > 1e6).any()) self.assertTrue(np.allclose(np.ones(input_dim), mapped_inputs.std(0)))
def __init__(self, input_size, channels_h, K, L, save_memory=False): super(Glow, self).__init__() self.L = L self.save_memory = save_memory self.output_sizes = [] blocks = [] c, h, w = input_size for l in range(L): block = [flows.Squeeze()] c *= 4 h //= 2 w //= 2 # squeeze for _ in range(K): norm_layer = flows.ActNorm(c) if save_memory: perm_layer = flows.RandomRotation( c) # easily inversible ver else: perm_layer = flows.InversibleConv1x1(c) coupling_layer = flows.AffineCouplingLayer(c, channels_h) block += [norm_layer, perm_layer, coupling_layer] blocks.append(flows.FlowSequential(*block)) self.output_sizes.append((c, h, w)) c //= 2 # split self.blocks = nn.ModuleList(blocks)
def build_model(): modules = [] mask = torch.arange(0, num_inputs) % 2 #mask = torch.ones(num_inputs) #mask[round(num_inputs/2):] = 0 mask = mask.to(device).float() # build each modules for _ in range(args.num_blocks): modules += [ fnn.ActNorm(num_inputs), fnn.LUInvertibleMM(num_inputs), fnn.CouplingLayer(num_inputs, num_hidden, mask, num_cond_inputs, s_act='tanh', t_act='relu') ] mask = 1 - mask # build model model = fnn.FlowSequential(*modules) # initialize for module in model.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight) if hasattr(module, 'bias') and module.bias is not None: module.bias.data.fill_(0) model.to(device) return model
def testActNorm(self): m1 = fnn.FlowSequential(fnn.ActNorm(NUM_INPUTS)) x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'ActNorm Det is not zero.') self.assertTrue((x - z).abs().max() < EPS, 'ActNorm is wrong.') # Second run. x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'ActNorm Det is not zero for the second run.') self.assertTrue((x - z).abs().max() < EPS, 'ActNorm is wrong for the second run.')
def testSequential(self): m1 = fnn.FlowSequential(fnn.ActNorm(NUM_INPUTS), fnn.InvertibleMM(NUM_INPUTS), fnn.CouplingLayer(NUM_INPUTS, NUM_HIDDEN)) x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'Sequential Det is not zero.') self.assertTrue((x - z).abs().max() < EPS, 'Sequential is wrong.') # Second run. x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'Sequential Det is not zero for the second run.') self.assertTrue((x - z).abs().max() < EPS, 'Sequential is wrong for the second run.')
def get_modules(flow, num_blocks, normalization, hidden_dim=64): modules = [] if flow == 'realnvp': for _ in range(num_blocks): modules += [ flows.AffineCoupling(get_transform(hidden_dim)), flows.Reverse(), ] if normalization: modules += [ flows.ActNorm(), ] elif flow == 'realnvp-conv': for _ in range(num_blocks): modules += [ MNISTAffineCoupling(get_conv_transform(hidden_dim)), flows.Reverse(), ] if normalization: modules += [ flows.ActNorm(), ] elif flow == 'nice': for _ in range(num_blocks): modules += [ flows.AffineCoupling(get_nice_transform(hidden_dim)), flows.Reverse(), ] if normalization: modules += [ flows.ActNorm(), ] elif flow == 'glow': for _ in range(num_blocks): modules += [ flows.AffineCoupling(get_transform(hidden_dim)), flows.InvertibleLinear(), ] if normalization: modules += [ flows.ActNorm(), ] elif flow == 'maf': for _ in range(num_blocks): modules += [ flows.MADE(get_masked_transform(hidden_dim)), flows.Reverse(), ] if normalization: modules += [ flows.ActNorm(), ] elif flow == 'neural-spline': for _ in range(num_blocks): modules += [ flows.NeuralSplineCoupling(), ] elif flow == 'maf-glow': for _ in range(num_blocks): modules += [ flows.MADE(get_masked_transform(hidden_dim)), flows.InvertibleLinear(), ] if normalization: modules += [ flows.ActNorm(), ] elif flow == 'custom': for _ in range(num_blocks): modules += [ flows.MADE(get_masked_transform(hidden_dim)), flows.InvertibleLinear(), ] modules += [ flows.ActNorm(), ] else: raise Exception('Invalid flow: {}'.format(flow)) return modules