def testInv(self): m1 = fnn.FlowSequential(fnn.InvertibleMM(NUM_INPUTS)) x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'InvMM Det is not zero.') self.assertTrue((x - z).abs().max() < EPS, 'InvMM is wrong.')
def testSequentialBN(self): m1 = fnn.FlowSequential(fnn.BatchNormFlow(NUM_INPUTS), fnn.InvertibleMM(NUM_INPUTS), fnn.CouplingLayer(NUM_INPUTS, NUM_HIDDEN)) m1.train() x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'Sequential BN Det is not zero.') self.assertTrue((x - z).abs().max() < EPS, 'Sequential BN is wrong.') # Second run. x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'Sequential BN Det is not zero for the second run.') self.assertTrue((x - z).abs().max() < EPS, 'Sequential BN is wrong for the second run.') m1.eval() # Eval run. x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'Sequential BN Det is not zero for the eval run.') self.assertTrue((x - z).abs().max() < EPS, 'Sequential BN is wrong for the eval run.')
def testSequential(self): m1 = fnn.FlowSequential( fnn.ActNorm(NUM_INPUTS), fnn.InvertibleMM(NUM_INPUTS), fnn.CouplingLayer(NUM_INPUTS, NUM_HIDDEN, mask)) x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'Sequential Det is not zero.') self.assertTrue((x - z).abs().max() < EPS, 'Sequential is wrong.') # Second run. x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'Sequential Det is not zero for the second run.') self.assertTrue((x - z).abs().max() < EPS, 'Sequential is wrong for the second run.')
num_cond_inputs, s_act='tanh', t_act='relu'), fnn.BatchNormFlow(num_inputs), fnn.Reverse(num_inputs) ] elif args.flow == 'maf-split-glow': for _ in range(args.num_blocks): modules += [ fnn.MADESplit(num_inputs, num_hidden, num_cond_inputs, s_act='tanh', t_act='relu'), fnn.BatchNormFlow(num_inputs), fnn.InvertibleMM(num_inputs) ] model = fnn.FlowSequential(*modules) for module in model.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight) if hasattr(module, 'bias') and module.bias is not None: module.bias.data.fill_(0) model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6) writer = SummaryWriter(comment=args.flow + "_" + args.dataset)