def posterior(self, x, ctx, use_mean=False):
        """
        Encode the posterior.
        """
        dists = []

        out = x
        for layer_idx, layer in enumerate(self.posterior_net):

            if isinstance(layer, layers.ConvLSTM):
                # Process Condition
                cur_ctx = ctx.view(ctx.shape[0], -1, ctx.shape[-2],
                                   ctx.shape[-1]).unsqueeze(1)
                cur_ctx = self.posterior_init_net(cur_ctx)
                cur_ctx = cur_ctx.squeeze(1)

                # Run LSTM with the inputs from the previous layer
                out = layer(out, torch.chunk(cur_ctx, 2, 1))

            else:
                out = layer(out)

        # Tanh the activations
        out = torch.tanh(out)

        # Compute distribution stats
        mean, logvar = torch.chunk(out, 2, 2)

        # Generate sample from this distribution
        z0 = flows.gaussian_rsample(mean, logvar, use_mean=use_mean)

        dists.append([mean, logvar, z0, z0, None])

        return dists
    def prior(self, emb, q_dists, use_mean=False, scale_var=1.):
        dists = []

        for net_idx in range(len(self.prior_nets)):

            # print('[NET] PriorNet {}'.format(net_idx))

            # Find the corresopnding activations
            ctx_idx = self.arch['latent']['ctx_idx'][net_idx]
            out = emb[ctx_idx][:, self.n_ctx - 1:-1].contiguous()
            cur_ctx = emb[ctx_idx][:, :self.n_ctx].contiguous()
            branch_layers = self.prior_nets[net_idx]

            # Process the current branch
            for branch_layer_idx, layer in enumerate(branch_layers):

                # print('[NET] PriorNet {}/{}'.format(net_idx, branch_layer_idx))

                if isinstance(layer, ConvLSTM):
                    # Get initial condition
                    cur_ctx = cur_ctx.view(cur_ctx.shape[0], -1,
                                           cur_ctx.shape[-2],
                                           cur_ctx.shape[-1])
                    cur_ctx = cur_ctx.unsqueeze(1)
                    cur_ctx = self.prior_init_nets[net_idx](cur_ctx)
                    cur_ctx = cur_ctx.squeeze(1)

                    # Forward LSTM
                    out = layer(out, torch.chunk(cur_ctx, 2, 1))

                else:
                    out = layer(out)

            # Compute distribution stats
            mean, var = torch.chunk(out, 2, 2)

            # Scale the variance
            var = var * scale_var

            # Softplus var
            logvar = F.softplus(var).log()

            # Generate sample from this distribution
            z0 = flows.gaussian_rsample(mean, logvar, use_mean=use_mean)

            dists.append([mean, logvar, z0, z0, None])

        return dists
    def posterior(self, x, ctx, use_mean=False):
        dists = []
        sto_branches = sorted(self.sto_branches.keys(), reverse=True)

        for layer_idx in sto_branches:

            # print(layer_idx)

            # Find the corresopnding activations
            out = x[layer_idx][:, self.n_ctx:].contiguous()
            cur_ctx = ctx[layer_idx][:, :self.n_ctx].contiguous()
            branch_layers = self.posterior_branches['layer_{}'.format(
                layer_idx)]

            # Process the current branch
            for branch_layer_idx, layer in enumerate(branch_layers):

                # print(branch_layer_idx)

                if isinstance(layer, layers.ConvLSTM):
                    # Get initial condition
                    cur_ctx = cur_ctx.view(cur_ctx.shape[0], -1,
                                           cur_ctx.shape[-2],
                                           cur_ctx.shape[-1])
                    cur_ctx = cur_ctx.unsqueeze(1)
                    cur_ctx = self.posterior_init_nets['layer_{}'.format(
                        layer_idx)](cur_ctx)
                    cur_ctx = cur_ctx.squeeze(1)

                    # Forward LSTM
                    out = layer(out, torch.chunk(cur_ctx, 2, 1))

                else:
                    out = layer(out)

            # Compute distribution stats
            mean, var = torch.chunk(out, 2, 2)

            # Softplus var
            logvar = F.softplus(var).log()

            # Generate sample from this distribution
            z0 = flows.gaussian_rsample(mean, logvar, use_mean=use_mean)

            dists.append([mean, logvar, z0, z0, None])

        return dists
    def posterior(self, x, ctx, use_mean=False, scale_var=1.):
        dists = []
        sto_branches = sorted(self.sto_branches.keys(), reverse=True)

        for layer_idx in sto_branches:

            # print(layer_idx)

            # Find the corresopnding activations
            out = x[layer_idx][:, self.n_ctx:].contiguous()
            cur_ctx = ctx[layer_idx][:, :self.n_ctx].contiguous()
            branch_layers = self.posterior_branches['layer_{}'.format(layer_idx)]

            # Process the current branch
            for branch_layer_idx, layer in enumerate(branch_layers):

                # print(branch_layer_idx)

                if isinstance(layer, layers.ConvLSTM):
                    # Get initial condition
                    cur_ctx = cur_ctx.view(cur_ctx.shape[0], -1, cur_ctx.shape[-2], cur_ctx.shape[-1])
                    cur_ctx = cur_ctx.unsqueeze(1)
                    cur_ctx = self.posterior_init_nets['layer_{}'.format(layer_idx)](cur_ctx)
                    cur_ctx = cur_ctx.squeeze(1)

                    # Forward LSTM
                    out = layer(out, torch.chunk(cur_ctx, 2, 1))

                # Handcrafted rules for integrating the different z's
                elif branch_layer_idx == 3:

                    if layer_idx == 16:
                        out = layer(out)

                    elif layer_idx == 10:
                        z1 = dists[0][-2]
                        b, t, c, h, w = z1.shape
                        z1 = z1.view(b*t, c, h, w)
                        z1 = F.interpolate(z1, scale_factor=8)
                        z1 = z1.view(b, t, c, z1.shape[-2], z1.shape[-1])
                        out = torch.cat([out, z1], 2)
                        out = layer(out)

                    elif layer_idx == 4:
                        z1 = dists[0][-2]
                        z2 = dists[1][-2]

                        b, t, c, h, w = z1.shape
                        z1 = z1.view(b*t, c, h, w)
                        z1 = F.interpolate(z1, scale_factor=32)
                        z1 = z1.view(b, t, c, z1.shape[-2], z1.shape[-1])

                        b, t, c, h, w = z2.shape
                        z2 = z2.view(b*t, c, h, w)
                        z2 = F.interpolate(z2, scale_factor=4)
                        z2 = z2.view(b, t, c, z2.shape[-2], z2.shape[-1])
                        out = torch.cat([out, z1, z2], 2)
                        out = layer(out)

                else:
                    out = layer(out)

            # Compute distribution stats
            mean, var = torch.chunk(out, 2, 2)

            # Softplus var
            scaled_var = F.softplus(var)*scale_var
            logvar = scaled_var.log()

            # Generate sample from this distribution
            z0 = flows.gaussian_rsample(mean, logvar, use_mean=use_mean)

            dists.append([mean, logvar, z0, z0, None])

        return dists
    def posterior(self, emb, use_mean=False, scale_var=1.):
        dists = []

        for net_idx in range(len(self.posterior_nets)):

            # print('[NET] PosteriorNet {}'.format(net_idx))

            # Find the corresopnding activations
            ctx_idx = self.arch['latent']['ctx_idx'][net_idx]
            out = emb[ctx_idx][:, self.n_ctx:].contiguous()
            cur_ctx = emb[ctx_idx][:, :self.n_ctx].contiguous()
            branch_layers = self.posterior_nets[net_idx]

            # print('CTX IDX: ', ctx_idx, ' shape: ', cur_ctx.shape)
            # print(branch_layers)

            # Process the current branch
            for branch_layer_idx, layer in enumerate(branch_layers):

                # print('[NET] PosteriorNet {}/{}'.format(net_idx, branch_layer_idx))

                if isinstance(layer, ConvLSTM):

                    # Get initial condition
                    cur_ctx = cur_ctx.view(cur_ctx.shape[0], -1,
                                           cur_ctx.shape[-2],
                                           cur_ctx.shape[-1])
                    cur_ctx = cur_ctx.unsqueeze(1)
                    cur_ctx = self.posterior_init_nets[net_idx](cur_ctx)
                    cur_ctx = cur_ctx.squeeze(1)

                    # Forward LSTM
                    out = layer(out, torch.chunk(cur_ctx, 2, 1))

                # Dense connectivity latent
                elif branch_layer_idx == 3:

                    # print('THIRD LAYER')

                    # Get current latent resolution
                    cur_res = self.arch['latent']['resolution'][net_idx]

                    # Accumulate previous z
                    prev_zs = []

                    for prev_z_idx in range(net_idx):

                        # Get previous z resolution
                        prev_res = self.arch['latent']['resolution'][
                            prev_z_idx]

                        # Compute scaling factor
                        scaling_factor = cur_res // prev_res

                        # Interpolate previous z
                        z_prev = dists[prev_z_idx][-2]
                        b, t, c, h, w = z_prev.shape
                        z_prev = z_prev.view(b * t, c, h, w)
                        z_prev = F.interpolate(z_prev,
                                               scale_factor=scaling_factor)
                        z_prev = z_prev.view(b, t, c, z_prev.shape[-2],
                                             z_prev.shape[-1])
                        prev_zs.append(z_prev)

                    # Concatenate zs
                    prev_zs = torch.cat(prev_zs + [out], 2)

                    # Forward through layer
                    out = layer(prev_zs)

                else:
                    out = layer(out)

            # Compute distribution stats
            mean, var = torch.chunk(out, 2, 2)

            # Scale the variance
            var = var * scale_var

            # Softplus var
            logvar = F.softplus(var).log()

            # Generate sample from this distribution
            z0 = flows.gaussian_rsample(mean, logvar, use_mean=use_mean)

            dists.append([mean, logvar, z0, z0, None])

        return dists