def test_logdet(input_size=(12, 4, 8, 8)):
    check_logdet(LearnableLeakyRelu().to('cuda'), input_size)
    check_logdet(SplineActivation(input_size).to('cuda'), input_size)
    check_logdet(SmoothLeakyRelu().to('cuda'), input_size)
    check_logdet(SmoothTanh().to('cuda'), input_size)
    check_logdet(Identity().to('cuda'), input_size)
    check_logdet(ActNorm(input_size[1]).to('cuda'), input_size)
    check_logdet(Coupling(input_size[1:]).to('cuda'), input_size)
    check_logdet(
        Normalization(translation=-1e-6, scale=1 / (1 - 2 * 1e-6)).to('cuda'),
        input_size)
    check_logdet(Squeeze().to('cuda'), input_size)
    check_logdet(UnSqueeze().to('cuda'), input_size)
    test_snf_logdet(input_size)

    print("All log-det tests passed")
def test_inverses(input_size=(12, 4, 16, 16)):
    check_inverse(LearnableLeakyRelu().to('cuda'), input_size)
    check_inverse(SplineActivation(input_size).to('cuda'), input_size)
    check_inverse(SmoothLeakyRelu().to('cuda'), input_size)
    check_inverse(SmoothTanh().to('cuda'), input_size)
    check_inverse(Identity().to('cuda'), input_size)
    check_inverse(ActNorm(input_size[1]).to('cuda'), input_size)
    check_inverse(Conv1x1(input_size[1]).to('cuda'), input_size)
    check_inverse(Conv1x1Householder(input_size[1], 10).to('cuda'), input_size)
    check_inverse(Coupling(input_size[1:]).to('cuda'), input_size)
    check_inverse(
        Normalization(translation=-1e-6, scale=1 / (1 - 2 * 1e-6)).to('cuda'),
        input_size)
    check_inverse(Squeeze().to('cuda'), input_size)
    check_inverse(UnSqueeze().to('cuda'), input_size)
    test_snf_layer_inverses(input_size)

    print("All inverse tests passed")
from torch import optim
from torch.optim.lr_scheduler import StepLR

from snf.layers import Dequantization, Normalization
from snf.layers.distributions.uniform import UniformDistribution
from snf.layers.flowsequential import FlowSequential
from snf.layers.convexp.convexp_module import ConvExp
from snf.layers.activations import SmoothLeakyRelu, SplineActivation, LearnableLeakyRelu, Identity
from snf.layers.squeeze import Squeeze
from snf.layers.transforms import LogitTransform
from snf.train.losses import NegativeGaussianLoss
from snf.train.experiment import Experiment
from snf.datasets.mnist import load_data

activations = {
    'SLR':lambda size: SmoothLeakyRelu(alpha=0.3),
    'LLR': lambda size: LearnableLeakyRelu(),
    'Spline': lambda size: SplineActivation(size, tail_bound=10, individual_weights=True),
    'SELU': lambda size: SELU(alpha=1.6733, lamb=1.0507)
}

def create_model(num_layers=100, sym_recon_grad=False, 
                 activation='Spline', recon_loss_weight=1.0,
                 num_blocks=3):
    block_size = int(num_layers / num_blocks)
    act = activations[activation]

    alpha = 1e-6
    layers = [
        Dequantization(UniformDistribution(size=(1, 28, 28))),
        Normalization(translation=0, scale=256),