def set_ldgm_layers(self):
        self.z_dim = self.z_dims[-1]
        neurons = [self.input_size, *self.h_dim]
        encoder_layers = [
            LadderEncoder(neurons[i - 1], neurons[i], self.z_dims[i - 1])
            for i in range(1, len(neurons))
        ]

        e = encoder_layers[-1]
        encoder_layers[-1] = LadderEncoder(e.in_features + self.num_classes,
                                           e.out_features, e.z_dim)

        decoder_layers = [
            LadderDecoder(self.z_dims[i - 1], self.h_dim[i - 1],
                          self.z_dims[i]) for i in range(1, len(self.h_dim))
        ][::-1]

        h_dims = [self.h_dim[0] for _ in range(1)]

        self.classifier = MLP(self.input_size, h_dims, self.num_classes)

        self.encoder = nn.ModuleList(encoder_layers)
        self.decoder = nn.ModuleList(decoder_layers)
        self.reconstruction = Decoder(self.z_dims[0] + self.num_classes,
                                      self.h_dim, self.input_size)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
    def set_adgm_layers(self, h_dims, input_shape, is_hebb_layers=False, use_conv_classifier=False,
                        planes_classifier=None, classifier_kernels=None, classifier_pooling_layers=None):
        if use_conv_classifier:
            self.set_dgm_layers(input_shape=input_shape)
            self.classifier = ConvNet(input_shape=self.input_shape, h_dims=h_dims, num_classes=self.num_classes,
                                      planes=planes_classifier, kernels=classifier_kernels,
                                      pooling_layers=classifier_pooling_layers, a_dim=self.a_dim)
        else:
            self.set_dgm_layers(input_shape=input_shape, num_classes=self.num_classes, is_hebb_layers=is_hebb_layers)
            self.classifier = MLP(self.input_size, self.input_shape, self.indices_names, h_dims, self.num_classes,
                                  a_dim=self.a_dim, is_hebb_layers=is_hebb_layers, gt_input=self.gt_input,
                                  extra_class=False)

        self.aux_encoder = Encoder(self.input_size, self.h_dims, self.a_dim, num_classes=self.num_classes, y_dim=0)

        self.aux_decoder = Encoder(self.input_size + self.z_dim, list(reversed(self.h_dims)),
                                   self.a_dim, num_classes=self.num_classes, y_dim=self.num_classes)

        self.encoder = Encoder(input_size=self.input_size, h_dim=self.h_dims, z_dim=self.z_dim,
                               num_classes=self.num_classes, a_dim=self.a_dim, y_dim=self.num_classes)
        self.decoder = Decoder(self.z_dim, list(reversed(self.h_dims)), self.input_size, num_classes=self.num_classes)
        self.add_flow_auxiliary(self.flow_flavour(in_features=[self.a_dim], n_flows=self.n_flows, h_last_dim=h_dims[-1],
                                auxiliary=True, flow_flavour=self.flavour))

        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
    def set_conv_adgm_layers(self, hs_ae, h_dims, planes_ae, kernels_ae, padding_ae,
                             pooling_layers_ae, planes_classifier=None, classifier_kernels=None,
                             classifier_pooling_layers=None, use_conv_classifier=True, input_shape=None,
                             is_hebb_layers=False):
        self.input_shape = input_shape
        self.input_size = np.prod(input_shape)
        if use_conv_classifier:
            self.set_conv_dgm_layers(hs_ae=hs_ae, hs_class=h_dims, z_dim=self.z_dim, planes_ae=planes_ae, kernels_ae=kernels_ae,
                                     padding_ae=padding_ae, pooling_layers_ae=pooling_layers_ae,
                                     planes_c=planes_classifier, kernels_c=classifier_kernels,
                                     pooling_layers_c=classifier_pooling_layers)
            self.classifier = ConvNet(input_shape=self.input_shape, h_dims=h_dims, num_classes=self.num_classes,
                                      planes=planes_classifier, kernels=classifier_kernels,
                                      pooling_layers=classifier_pooling_layers, a_dim=self.a_dim)
        else:
            self.set_dgm_layers(input_shape=self.input_shape, is_hebb_layers=is_hebb_layers)
            self.classifier = MLP(self.input_size, self.input_shape, h_dims, self.num_classes, is_hebb_layers=is_hebb_layers,
                                  a_dim=self.a_dim, gt=self.gt_input, num_classes=self.num_classes)

        self.aux_encoder = ConvEncoder(h_dim=hs_ae, z_dim=self.a_dim, planes=planes_ae, kernels=kernels_ae,
                                       padding=padding_ae, pooling_layers=pooling_layers_ae, y_size=0, a_size=0)
        self.aux_decoder = ConvEncoder(h_dim=list(reversed(hs_ae)), z_dim=self.a_dim, planes=planes_ae,
                                       kernels=kernels_ae, padding=padding_ae, pooling_layers=pooling_layers_ae,
                                       y_size=self.num_classes, a_size=self.z_dim)
        # self.aux_encoder = Encoder(input_size=self.input_size, h_dim=self.h_dims, z_dim=self.a_dim)
        # self.aux_decoder = Encoder(input_size=self.input_size + self.z_dim + self.num_classes,
        #                           h_dim=list(reversed(self.h_dims)), z_dim=self.a_dim)

        self.encoder = ConvEncoder(hs_ae, z_dim=self.z_dim, planes=planes_ae, kernels=kernels_ae,
                                   padding=padding_ae, pooling_layers=pooling_layers_ae, y_size=self.num_classes,
                                   a_size=self.a_dim)
        self.decoder = ConvDecoder(z_dim=self.z_dim, y_dim=self.num_classes, h_dim=list(reversed(hs_ae)),
                                   input_shape=self.input_shape, planes=planes_ae, kernels=kernels_ae,
                                   padding=padding_ae, unpooling_layers=list(reversed(pooling_layers_ae)))
        self.add_flow_auxiliary(self.flow_flavour(in_features=[self.a_dim], n_flows=self.n_flows, h_last_dim=h_dims[-1],
                                auxiliary=True, flow_flavour=self.flavour))

        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()