Ejemplo n.º 1
0
    def testInv(self):
        m1 = fnn.FlowSequential(fnn.InvertibleMM(NUM_INPUTS))

        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'InvMM Det is not zero.')
        self.assertTrue((x - z).abs().max() < EPS, 'InvMM is wrong.')
Ejemplo n.º 2
0
    def testSequentialBN(self):
        m1 = fnn.FlowSequential(fnn.BatchNormFlow(NUM_INPUTS),
                                fnn.InvertibleMM(NUM_INPUTS),
                                fnn.CouplingLayer(NUM_INPUTS, NUM_HIDDEN))

        m1.train()
        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'Sequential BN Det is not zero.')
        self.assertTrue((x - z).abs().max() < EPS, 'Sequential BN is wrong.')

        # Second run.
        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'Sequential BN Det is not zero for the second run.')
        self.assertTrue((x - z).abs().max() < EPS,
                        'Sequential BN is wrong for the second run.')

        m1.eval()
        # Eval run.
        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'Sequential BN Det is not zero for the eval run.')
        self.assertTrue((x - z).abs().max() < EPS,
                        'Sequential BN is wrong for the eval run.')
Ejemplo n.º 3
0
    def testSequential(self):
        m1 = fnn.FlowSequential(
            fnn.ActNorm(NUM_INPUTS), fnn.InvertibleMM(NUM_INPUTS),
            fnn.CouplingLayer(NUM_INPUTS, NUM_HIDDEN, mask))

        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'Sequential Det is not zero.')
        self.assertTrue((x - z).abs().max() < EPS, 'Sequential is wrong.')

        # Second run.
        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'Sequential Det is not zero for the second run.')
        self.assertTrue((x - z).abs().max() < EPS,
                        'Sequential is wrong for the second run.')
Ejemplo n.º 4
0
                          num_cond_inputs,
                          s_act='tanh',
                          t_act='relu'),
            fnn.BatchNormFlow(num_inputs),
            fnn.Reverse(num_inputs)
        ]
elif args.flow == 'maf-split-glow':
    for _ in range(args.num_blocks):
        modules += [
            fnn.MADESplit(num_inputs,
                          num_hidden,
                          num_cond_inputs,
                          s_act='tanh',
                          t_act='relu'),
            fnn.BatchNormFlow(num_inputs),
            fnn.InvertibleMM(num_inputs)
        ]

model = fnn.FlowSequential(*modules)

for module in model.modules():
    if isinstance(module, nn.Linear):
        nn.init.orthogonal_(module.weight)
        if hasattr(module, 'bias') and module.bias is not None:
            module.bias.data.fill_(0)

model.to(device)

optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6)

writer = SummaryWriter(comment=args.flow + "_" + args.dataset)