def __init__(self, img_ch, n_ctx,
            n_hid=64, 
            n_z=10,
            enc_dim=512, 
            share_prior_enc=False, 
            reverse_post=False,
        ):
        super().__init__()

        self.n_ctx = n_ctx
        self.enc_dim = enc_dim

        self.emb_net = nn.ModuleList([
            nn.Conv2d(img_ch, n_hid, 1),  
            ResnetBlock(n_hid, n_hid),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid*1, n_hid*2),
            ResnetBlock(n_hid*2, n_hid*2),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid*2, n_hid*4),
            ResnetBlock(n_hid*4, n_hid*4),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid*4, n_hid*4),
            ResnetBlock(n_hid*4, n_hid*8),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid*8, n_hid*8),
            ResnetBlock(n_hid*8, n_hid*8),
            nn.MaxPool2d(4, 1),
            ResnetBlock(n_hid*8, n_hid*8, norm_ch=1),
            ResnetBlock(n_hid*8, n_hid*8, norm_ch=1),
        ])

        mult = 1
        self.render_net = nn.ModuleList([
            layers.ConvLSTM(n_hid*8 + n_hid*8, n_hid*8),
            layers.DcUpConv(n_hid*8, n_hid*8, 4, 1, 0),
            layers.ConvLSTM(n_hid*8, n_hid*8, norm=True),
            layers.DcUpConv(n_hid*8*mult, n_hid*8, 4, 2, 1),
            layers.ConvLSTM(n_hid*8 + n_hid*8, n_hid*8, norm=True),
            layers.DcUpConv(n_hid*8*mult, n_hid*4, 4, 2, 1),
            layers.ConvLSTM(n_hid*4, n_hid*4, norm=True),
            layers.DcUpConv(n_hid*4*mult, n_hid*2, 4, 2, 1),
            layers.ConvLSTM(n_hid*2 + n_hid*2, n_hid*2, norm=True),
            layers.DcUpConv(n_hid*2*mult, n_hid, 4, 2, 1),
            layers.ConvLSTM(n_hid, n_hid, norm=True),
            layers.DcConv(n_hid, n_hid, 3, 1, 1),
            layers.TemporalConv2d(n_hid, img_ch, 3, 1, 1),
        ])

        self.det_init_net = nn.Sequential(
            layers.DcConv(2*n_hid*8*self.n_ctx, 2*n_hid*8*self.n_ctx, 1),
            layers.TemporalConv2d(2*n_hid*8*self.n_ctx, 2*n_hid*8, 1),
            layers.TemporalNorm2d(1, 2*enc_dim),
        )

        self.prior_init_nets = nn.ModuleDict({
            'layer_16': nn.Sequential(
                layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1),
                layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1),
                layers.TemporalNorm2d(1, 2*n_hid*8),
            ),
            'layer_10': nn.Sequential(
                layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1),
                layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1),
                layers.TemporalNorm2d(16, 2*n_hid*8),
            ),
            'layer_4': nn.Sequential(
                layers.DcConv(n_hid*2*self.n_ctx, n_hid*2*self.n_ctx, 1),
                layers.TemporalConv2d(n_hid*2*self.n_ctx, n_hid*2*2, 1),
                layers.TemporalNorm2d(16, 2*n_hid*2),
            ),
        })

        self.posterior_init_nets = nn.ModuleDict({
            'layer_16': nn.Sequential(
                layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1),
                layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1),
                layers.TemporalNorm2d(1, 2*n_hid*8),
            ),
            'layer_10': nn.Sequential(
                layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1),
                layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1),
                layers.TemporalNorm2d(16, 2*n_hid*8),
            ),
            'layer_4': nn.Sequential(
                layers.DcConv(n_hid*2*self.n_ctx, n_hid*2*self.n_ctx, 1),
                layers.TemporalConv2d(n_hid*2*self.n_ctx, n_hid*2*2, 1),
                layers.TemporalNorm2d(16, 2*n_hid*2),
            ),
        })

        self.posterior_branches = nn.ModuleDict({
            'layer_4': nn.ModuleList([
                layers.TemporalConv2d(n_hid*2, n_hid*2, 1),
                layers.TemporalNorm2d(16, n_hid*2),
                layers.ConvLSTM(n_hid*2, n_hid*2, norm=True),
                layers.TemporalConv2d(n_hid*2 + n_hid*8 + n_hid*8, n_hid*2*2, 1),
                layers.TemporalNorm2d(16, n_hid*2*2),
            ]),
            'layer_10': nn.ModuleList([
                layers.TemporalConv2d(n_hid*8, n_hid*8, 1),
                layers.TemporalNorm2d(16, n_hid*8),
                layers.ConvLSTM(n_hid*8, n_hid*8, norm=True),
                layers.TemporalConv2d(n_hid*8 + n_hid*8, n_hid*8*2, 1),
                layers.TemporalNorm2d(16, n_hid*8*2),
            ]),
            'layer_16': nn.ModuleList([
                layers.TemporalConv2d(n_hid*8, n_hid*8, 1),
                layers.TemporalNorm2d(1, n_hid*8),
                layers.ConvLSTM(n_hid*8, n_hid*8, norm=True),
                layers.TemporalConv2d(n_hid*8, n_hid*8*2, 1),
                layers.TemporalNorm2d(1, n_hid*8*2),
            ]),
        })

        self.prior_branches = nn.ModuleDict({
            'layer_4': nn.ModuleList([
                layers.TemporalConv2d(n_hid*2, n_hid*2, 1),
                layers.TemporalNorm2d(16, n_hid*2),
                layers.ConvLSTM(n_hid*2, n_hid*2, norm=True),
                layers.TemporalConv2d(n_hid*2 + n_hid*8 + n_hid*8, n_hid*2*2, 1),
                layers.TemporalNorm2d(16, n_hid*2*2),
            ]),
            'layer_10': nn.ModuleList([
                layers.TemporalConv2d(n_hid*8, n_hid*8, 1),
                layers.TemporalNorm2d(16, n_hid*8),
                layers.ConvLSTM(n_hid*8, n_hid*8, norm=True),
                layers.TemporalConv2d(n_hid*8 + n_hid*8, n_hid*8*2, 1),
                layers.TemporalNorm2d(16, n_hid*8*2),
            ]),
            'layer_16': nn.ModuleList([
                layers.TemporalConv2d(n_hid*8, n_hid*8, 1),
                layers.TemporalNorm2d(1, n_hid*8),
                layers.ConvLSTM(n_hid*8, n_hid*8, norm=True),
                layers.TemporalConv2d(n_hid*8, n_hid*8*2, 1),
                layers.TemporalNorm2d(1, n_hid*8*2),
            ]),
        })

        # Prior/Posterior branches norm init
        nn.init.constant_(self.posterior_branches['layer_4'][-1].model.weight, 0)
        nn.init.normal_(self.posterior_branches['layer_4'][-1].model.bias, std=1e-3)
        nn.init.constant_(self.posterior_branches['layer_10'][-1].model.weight, 0)
        nn.init.normal_(self.posterior_branches['layer_10'][-1].model.bias, std=1e-3)
        nn.init.constant_(self.posterior_branches['layer_16'][-1].model.weight, 0)
        nn.init.normal_(self.posterior_branches['layer_16'][-1].model.bias, std=1e-3)

        nn.init.constant_(self.prior_branches['layer_4'][-1].model.weight, 0)
        nn.init.normal_(self.prior_branches['layer_4'][-1].model.bias, std=1e-3)
        nn.init.constant_(self.prior_branches['layer_10'][-1].model.weight, 0)
        nn.init.normal_(self.prior_branches['layer_10'][-1].model.bias, std=1e-3)
        nn.init.constant_(self.prior_branches['layer_16'][-1].model.weight, 0)
        nn.init.normal_(self.prior_branches['layer_16'][-1].model.bias, std=1e-3)


        # Connection list
        self.det_init_connections = {
            0: 16,
            2: 13,
            4: 10,
            6: 7,
            8: 4,
            10: 1,
        }

        # Connection branches
        self.det_init_nets = nn.ModuleDict({
            'layer_16': nn.Sequential(
                layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1),
                layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1),
                layers.TemporalNorm2d(1, n_hid*8*2)
            ),
            'layer_13': nn.Sequential(
                layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 3, 1, 1),
                layers.TemporalConv2d(self.n_ctx*n_hid*8, 2*n_hid*8, 1),
                layers.TemporalNorm2d(16, n_hid*8*2)
            ),
            'layer_10': nn.Sequential(
                layers.DcConv(n_hid*8*self.n_ctx, n_hid*8*self.n_ctx, 1),
                layers.TemporalConv2d(n_hid*8*self.n_ctx, n_hid*8*2, 1),
                layers.TemporalNorm2d(16, n_hid*8*2)
            ),
            'layer_7': nn.Sequential(
                layers.DcConv(n_hid*4*self.n_ctx, n_hid*4*self.n_ctx, 1),
                layers.TemporalConv2d(n_hid*4*self.n_ctx, n_hid*4*2, 1),
                layers.TemporalNorm2d(16, n_hid*8)
            ),
            'layer_4': nn.Sequential(
                layers.DcConv(n_hid*2*self.n_ctx, n_hid*2*self.n_ctx, 1),
                layers.TemporalConv2d(n_hid*2*self.n_ctx, n_hid*2*2, 1),
                layers.TemporalNorm2d(16, n_hid*4)
            ),
            'layer_1': nn.Sequential(
                layers.DcConv(n_hid*1*self.n_ctx, n_hid*1*self.n_ctx, 1),
                layers.TemporalConv2d(n_hid*1*self.n_ctx, n_hid*1*2, 1),
                layers.TemporalNorm2d(16, n_hid*2)
            ),
        })

        # Stochastic connection list
        # encoder -> renderer
        self.sto_branches = {
            16: 0,
            10: 4,
            4: 8,
        }
        # renderer -> encoder
        self.rend_sto_branches = {
            0: 0,
            4: 1,
            8: 2,
        }
    def __init__(
        self,
        img_ch,
        n_ctx,
        n_hid=64,
        n_z=10,
        enc_dim=512,
        share_prior_enc=False,
        reverse_post=False,
    ):
        super().__init__()

        self.n_ctx = n_ctx
        self.enc_dim = enc_dim

        self.sto_emb_net = nn.ModuleList([
            layers.DcConv(img_ch, n_hid, 4, 2, 1),
            layers.DcConv(n_hid, n_hid * 2, 4, 2, 1),
            layers.DcConv(n_hid * 2, n_hid * 4, 4, 2, 1),
            layers.DcConv(n_hid * 4, n_hid * 8, 4, 2, 1),
            layers.DcConv(n_hid * 8,
                          enc_dim,
                          4,
                          1,
                          0,
                          norm=partial(nn.GroupNorm, 1)),
        ])

        self.det_emb_net = nn.ModuleList([
            nn.Conv2d(img_ch, n_hid, 1),
            ResnetBlock(n_hid, n_hid),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid * 1, n_hid * 2),
            ResnetBlock(n_hid * 2, n_hid * 2),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid * 2, n_hid * 4),
            ResnetBlock(n_hid * 4, n_hid * 4),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid * 4, n_hid * 4),
            ResnetBlock(n_hid * 4, n_hid * 8),
            nn.MaxPool2d(2, 2),
            ResnetBlock(n_hid * 8, n_hid * 8),
            ResnetBlock(n_hid * 8, n_hid * 8),
            nn.MaxPool2d(4, 1),
            ResnetBlock(n_hid * 8, n_hid * 8, norm_ch=1),
            ResnetBlock(n_hid * 8, n_hid * 8, norm_ch=1),
        ])

        mult = 1
        self.render_net = nn.ModuleList([
            layers.ConvLSTM(n_z + enc_dim, enc_dim),
            layers.DcUpConv(enc_dim, n_hid * 8, 4, 1, 0),
            layers.ConvLSTM(n_hid * 8, n_hid * 8, norm=True),
            layers.DcUpConv(n_hid * 8 * mult, n_hid * 8, 4, 2, 1),
            layers.ConvLSTM(n_hid * 8, n_hid * 8, norm=True),
            layers.DcUpConv(n_hid * 8 * mult, n_hid * 4, 4, 2, 1),
            layers.ConvLSTM(n_hid * 4, n_hid * 4, norm=True),
            layers.DcUpConv(n_hid * 4 * mult, n_hid * 2, 4, 2, 1),
            layers.ConvLSTM(n_hid * 2, n_hid * 2, norm=True),
            layers.DcUpConv(n_hid * 2 * mult, n_hid, 4, 2, 1),
            layers.ConvLSTM(n_hid, n_hid, norm=True),
            layers.DcConv(n_hid, n_hid, 3, 1, 1),
            layers.TemporalConv2d(n_hid, img_ch, 3, 1, 1),
        ])

        self.det_init_net = nn.Sequential(
            layers.DcConv(2 * enc_dim * self.n_ctx, 2 * enc_dim * self.n_ctx,
                          1),
            layers.TemporalConv2d(2 * enc_dim * self.n_ctx, 2 * enc_dim, 1),
            layers.TemporalNorm2d(1, 2 * enc_dim),
        )

        self.prior_init_nets = nn.ModuleDict({
            'layer_4':
            nn.Sequential(
                layers.DcConv(enc_dim * self.n_ctx, enc_dim * self.n_ctx, 1),
                layers.TemporalConv2d(enc_dim * self.n_ctx, enc_dim * 2, 1),
                layers.TemporalNorm2d(1, 2 * enc_dim),
            ),
        })

        self.posterior_init_nets = nn.ModuleDict({
            'layer_4':
            nn.Sequential(
                layers.DcConv(enc_dim * self.n_ctx, enc_dim * self.n_ctx, 1),
                layers.TemporalConv2d(enc_dim * self.n_ctx, enc_dim * 2, 1),
                layers.TemporalNorm2d(1, 2 * enc_dim),
            ),
        })

        self.posterior_branches = nn.ModuleDict({
            'layer_4':
            nn.ModuleList([
                layers.TemporalConv2d(enc_dim, n_z, 1),
                layers.TemporalNorm2d(1, n_z),
                layers.ConvLSTM(n_z, enc_dim, norm=True),
                layers.TemporalConv2d(enc_dim, n_z * 2, 1),
            ])
        })

        self.prior_branches = nn.ModuleDict({
            'layer_4':
            nn.ModuleList([
                layers.TemporalConv2d(enc_dim, n_z, 1),
                layers.TemporalNorm2d(1, n_z),
                layers.ConvLSTM(n_z, enc_dim, norm=True),
                layers.TemporalConv2d(enc_dim, n_z * 2, 1),
            ])
        })

        # Connection list
        self.det_init_connections = {
            0: 16,
            2: 13,
            4: 10,
            6: 7,
            8: 4,
            10: 1,
        }

        # Connection branches
        self.det_init_nets = nn.ModuleDict({
            'layer_16':
            nn.Sequential(
                layers.DcConv(n_hid * 8 * self.n_ctx, n_hid * 8 * self.n_ctx,
                              1),
                layers.TemporalConv2d(n_hid * 8 * self.n_ctx, n_hid * 8 * 2,
                                      1),
                layers.TemporalNorm2d(1, n_hid * 8 * 2)),
            'layer_13':
            nn.Sequential(
                layers.DcConv(n_hid * 8 * self.n_ctx, n_hid * 8 * self.n_ctx,
                              3, 1, 1),
                layers.TemporalConv2d(self.n_ctx * n_hid * 8, 2 * n_hid * 8,
                                      1),
                layers.TemporalNorm2d(16, n_hid * 8 * 2)),
            'layer_10':
            nn.Sequential(
                layers.DcConv(n_hid * 8 * self.n_ctx, n_hid * 8 * self.n_ctx,
                              1),
                layers.TemporalConv2d(n_hid * 8 * self.n_ctx, n_hid * 8 * 2,
                                      1),
                layers.TemporalNorm2d(16, n_hid * 8 * 2)),
            'layer_7':
            nn.Sequential(
                layers.DcConv(n_hid * 4 * self.n_ctx, n_hid * 4 * self.n_ctx,
                              1),
                layers.TemporalConv2d(n_hid * 4 * self.n_ctx,
                                      n_hid * 4 * 2, 1),
                layers.TemporalNorm2d(16, n_hid * 8)),
            'layer_4':
            nn.Sequential(
                layers.DcConv(n_hid * 2 * self.n_ctx, n_hid * 2 * self.n_ctx,
                              1),
                layers.TemporalConv2d(n_hid * 2 * self.n_ctx,
                                      n_hid * 2 * 2, 1),
                layers.TemporalNorm2d(16, n_hid * 4)),
            'layer_1':
            nn.Sequential(
                layers.DcConv(n_hid * 1 * self.n_ctx, n_hid * 1 * self.n_ctx,
                              1),
                layers.TemporalConv2d(n_hid * 1 * self.n_ctx,
                                      n_hid * 1 * 2, 1),
                layers.TemporalNorm2d(16, n_hid * 2)),
        })

        # Stochastic connection list
        # encoder -> renderer
        self.sto_branches = {
            4: 0,
        }
        # renderer -> encoder
        self.rend_sto_branches = {
            0: 0,
        }