def VAE_2():
    #,batch_size= _BatchSize
    # inputs = Input(shape=(28*28), name='encoder_input',batch_size= _BatchSize) # 使用方法3时 指定batch_size
    inputs = Input(shape=(28 * 28), name='encoder_input')
    x = layers.Dense(128, activation='relu')(inputs)
    z_mean = layers.Dense(2, name='z_mean')(x)
    z_log_var = layers.Dense(2, name='z_log_var')(x)

    # 方法1: 直接把采样嵌入到模型中!
        # 1. 设定一个正太分布
    eps = tf.random.normal((tf.shape(z_mean)[0],tf.shape(z_mean)[1]))
        # 2. 获得标准方差
    std = tf.exp(z_log_var)
        # 3. 通过元素乘法进行采样
    Sample_Z = layers.Add()([z_mean, layers.Multiply()([eps, std])])

    # 方法2: 使用匿名函数Lambda配合sampling函数对层中每一个元素都进行操作
    # Sample_Z = layers.Lambda(sampling, name='z')([z_mean, z_log_var])

    # 方法3: 自定义子类: 抽样层,但是此法和嵌入模型中没有区别,注意 使用此法 需要在两个Input函数中指定 batchsize = _BatchSize
    # Sample_Z = Sample(z_log_var)(z_mean,z_log_var)


    # instantiate encoder model
    encoder = Model(inputs, [z_mean, z_log_var, Sample_Z], name='encoder')
    # encoder.summary()

    # build decoder model
    # latent_inputs = Input(shape=(2), name='z_sampling',batch_size= _BatchSize) # 使用方法3时指定batch_size
    latent_inputs = Input(shape=(2), name='z_sampling')
    x = layers.Dense(128, activation='relu')(latent_inputs)
    outputs = layers.Dense(28*28, activation='sigmoid')(x)

    # instantiate decoder model
    decoder = Model(latent_inputs, outputs, name='decoder')
    # decoder.summary()

    # instantiate VAE model
    outputs = decoder(encoder(inputs)[2])
    vae = Model(inputs = inputs, outputs = outputs, name='vae_mlp')

    # 加入loss
    reconstruction_loss = tf.keras.losses.BinaryCrossentropy(reduction='sum',name='binary_crossentropy')(inputs, outputs)

    kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
    kl_loss =-0.5 *  tf.reduce_mean(kl_loss)
    # 如果这里的 kl_loss =-0.5 *  tf.reduce_sum(kl_loss) 那么就会发生和第一个里面一样的错误
    vae.add_loss(kl_loss)
    vae.add_metric(kl_loss, name='kl_loss',aggregation='mean')
    vae.add_loss(reconstruction_loss)
    vae.add_metric(reconstruction_loss, name='mse_loss',aggregation='mean')

    return vae,encoder,decoder
Example #2
0
    def _build_compile(self, model_input):
        z_mean, z_log_var, z = self.encoder(model_input)
        surved_y_output = self.decoder_y(z)
        surved = Model(model_input, surved_y_output, name='SurVED')

        kl_loss_orig = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) -
                                             tf.exp(z_log_var) + 1)
        kl_loss = kl_loss_orig * self.kl_loss_weight
        surved.add_loss(K.mean(kl_loss))
        surved.add_metric(kl_loss_orig, name='kl_loss', aggregation='mean')
        opt = Adam(lr=self.surved_lr)
        surved.compile(loss=self._get_loss(),
                       optimizer=opt,
                       metrics=[self.cindex, self.surv_mse_loss])
        return surved
Example #3
0
def build(
    latent_dim,
    input_shape,
    repeat=1,
    use_inception=True,
    batch_size=1,
    learning_rate=1e-4,
):
    encoder_input, encoder = _build_encoder(
        input_shape, latent_dim, repeat, use_inception
    )
    decoder_input, decoder = _build_decoder(
        latent_dim, input_shape, repeat, use_inception
    )
    z_mean, z_log_var, z = encoder(encoder_input)
    decoder_output = decoder(z)
    model = Model(encoder_input, decoder_output, name="vae")

    print(f"Encoder input: {encoder_input.shape}")
    print(f"Decoder output: {decoder_output.shape}")
    encoder_input.shape.assert_is_compatible_with(decoder_output.shape)
    #     assert encoder_input.shape.as_list() == decoder_output.shape.as_list()

    reconstruction_loss = ReconstructionLoss(mean=True)([encoder_input, decoder_output])
    #     reconstruction_loss = tf.losses.mse(encoder_input, decoder_output)
    #     reconstruction_loss = tf.reduce_sum(reconstruction_loss, axis=[1, 2])
    kl_loss = KLLoss(mean=True)([z, z_mean, z_log_var])
    #     logpz = log_normal_pdf(z, 0.0, 0.0)
    #     logqz_x = log_normal_pdf(z, z_mean, z_log_var)
    #     kl_loss = logqz_x - logpz
    vae_loss = reconstruction_loss + kl_loss
    model.add_loss(vae_loss)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate))
    # model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate), loss=lambda yt, yp: vae_loss)

    model.add_metric(
        reconstruction_loss, aggregation="mean", name="reconstruction_loss"
    )
    model.add_metric(kl_loss, aggregation="mean", name="kl_loss")
    return model, encoder, decoder
Example #4
0
def build_model():

    encoder_input = Input(shape=(time_step, input_dim), name='encoder_input')

    rnn1 = Bidirectional(GRU(rnn_dim, return_sequences=True),
                         name='rnn1')(encoder_input)
    rnn2 = Bidirectional(GRU(rnn_dim), name='rnn2')(rnn1)

    z_mean = Dense(z_dim, name='z_mean')(rnn2)
    z_log_var = Dense(z_dim, name='z_log_var')(rnn2)

    def sampling(args):
        z_mean, z_log_var = args
        batch = K.shape(z_mean)[0]
        dim = K.int_shape(z_mean)[1]
        # by default, random_normal has mean=0 and std=1.0
        epsilon = K.random_normal(shape=(batch, dim))
        return z_mean + K.exp(0.5 * z_log_var) * epsilon

    z = Lambda(sampling, output_shape=(z_dim, ), name='z')([z_mean, z_log_var])

    class kl_beta(tf.keras.layers.Layer):
        def __init__(self):
            super(kl_beta, self).__init__()

            # your variable goes here
            self.beta = tf.Variable(0.0, trainable=False, dtype=tf.float32)

        def call(self, inputs, **kwargs):
            # your mul operation goes here
            return -self.beta * inputs

    beta = kl_beta()
    encoder = Model(encoder_input, z, name='encoder')

    # decoder

    decoder_latent_input = Input(shape=z_dim, name='z_sampling')

    repeated_z = RepeatVector(time_step,
                              name='repeated_z_tension')(decoder_latent_input)

    rnn1_output = GRU(rnn_dim, name='decoder_rnn1',
                      return_sequences=True)(repeated_z)

    rnn2_output = GRU(rnn_dim, name='decoder_rnn2',
                      return_sequences=True)(rnn1_output)

    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss = tf.reduce_mean(kl_loss)

    kl_loss = 0.5 * kl_loss

    kl_loss = beta(kl_loss)
    tensile_middle_output = TimeDistributed(
        Dense(tension_middle_dim, activation='elu'),
        name='tensile_strain_dense1')(rnn2_output)

    tensile_output = TimeDistributed(
        Dense(tension_output_dim, activation='elu'),
        name='tensile_strain_dense2')(tensile_middle_output)

    diameter_middle_output = TimeDistributed(
        Dense(tension_middle_dim, activation='elu'),
        name='diameter_strain_dense1')(rnn2_output)

    diameter_output = TimeDistributed(
        Dense(tension_output_dim, activation='elu'),
        name='diameter_strain_dense2')(diameter_middle_output)

    melody_rhythm_1 = TimeDistributed(Dense(start_middle_dim,
                                            activation='elu'),
                                      name='melody_start_dense1')(rnn2_output)
    melody_rhythm_output = TimeDistributed(
        Dense(melody_note_start_dim, activation='sigmoid'),
        name='melody_start_dense2')(melody_rhythm_1)

    melody_pitch_1 = TimeDistributed(Dense(melody_bass_dense_1_dim,
                                           activation='elu'),
                                     name='melody_pitch_dense1')(rnn2_output)

    melody_pitch_output = TimeDistributed(
        Dense(melody_output_dim, activation='softmax'),
        name='melody_pitch_dense2')(melody_pitch_1)

    bass_rhythm_1 = TimeDistributed(Dense(start_middle_dim, activation='elu'),
                                    name='bass_start_dense1')(rnn2_output)

    bass_rhythm_output = TimeDistributed(
        Dense(bass_note_start_dim, activation='sigmoid'),
        name='bass_start_dense2')(bass_rhythm_1)

    bass_pitch_1 = TimeDistributed(Dense(melody_bass_dense_1_dim,
                                         activation='elu'),
                                   name='bass_pitch_dense1')(rnn2_output)
    bass_pitch_output = TimeDistributed(Dense(bass_output_dim,
                                              activation='softmax'),
                                        name='bass_pitch_dense2')(bass_pitch_1)

    decoder_output = [
        melody_pitch_output, melody_rhythm_output, bass_pitch_output,
        bass_rhythm_output, tensile_output, diameter_output
    ]

    decoder = Model(decoder_latent_input, decoder_output, name='decoder')

    model_input = encoder_input

    vae = Model(model_input,
                decoder(encoder(model_input)),
                name='encoder_decoder')

    vae.add_loss(kl_loss)

    vae.add_metric(kl_loss, name='kl_loss', aggregation='mean')

    optimizer = keras.optimizers.Adam()

    vae.compile(optimizer=optimizer,
                loss=[
                    'categorical_crossentropy', 'binary_crossentropy',
                    'categorical_crossentropy', 'binary_crossentropy', 'mse',
                    'mse'
                ],
                metrics=[[keras.metrics.CategoricalAccuracy()],
                         [keras.metrics.BinaryAccuracy()],
                         [keras.metrics.CategoricalAccuracy()],
                         [keras.metrics.BinaryAccuracy()],
                         [keras.metrics.MeanSquaredError()],
                         [keras.metrics.MeanSquaredError()]])

    return vae
Example #5
0
class WANN(object):
    """
    WANN: Weighting Adversarial Neural Network is an instance-based domain adaptation
    method suited for regression tasks. It supposes the supervised setting where some
    labeled target data are available.
    
    The goal of WANN is to compute a source instances reweighting which correct
    "shifts" between source and target domain. This is done by minimizing the
    Y-discrepancy distance between source and target distributions
    
    WANN involves three networks:
        - the weighting network which learns the source weights.
        - the task network which learns the task.
        - the discrepancy network which is used to estimate a distance 
          between the reweighted source and target distributions: the Y-discrepancy
    
    Parameters
    ----------
    get_base_model: callable, optional
        Constructor for the two networks: task and discrepancer.
        The constructor should take the four following
        arguments:
        - shape: the input shape
        - C: the projecting constant
        - activation: the last layer activation function
        - name: the model name
        If None, get_default_model is used.
        
    get_weighting_model: callable, optional
        Constructor for the weightig network.
        The constructor should take the same arguments 
        as get_base_model.
        If None, get_base_model is used.
        
    C: float, optional (default=1.)
        Projecting constant: networks should be
        regularized by projecting the weights of each layer
        on the ball of radius C.
        
    C_w: float, optional (default=None)
        Projecting constant of the weighting network.
        If None C_w = C.
        
    optimizer: tf.keras Optimizer, optional (default="adam")
        Optimizer of WANN
        
    save_hist: boolean, optional (default=False)
        Wether to save the predicted weights and labels
        at each epochs or not
    """
    
    def __init__(self, get_base_model=None, get_weighting_model=None,
                 C=1., C_w=None, optimizer='adam', save_hist=False):
        
        self.get_base_model = get_base_model
        if self.get_base_model is None:
            self.get_base_model = _get_default_model
        
        self.get_weighting_model = get_weighting_model
        if self.get_weighting_model is None:
            self.get_weighting_model = get_base_model
        
        self.C = C
        self.C_w = C_w
        if self.C_w is None:
            self.C_w = C
        
        self.save_hist = save_hist
        self.optimizer = optimizer
        

    def fit(self, X, y, index=None, weights_target=None, **fit_params):
        """
        Fit WANN
        
        Parameters
        ----------
        X, y: numpy arrays
            Input data
            
        index: iterable
            Index should contains 2 lists or 1D-arrays
            corresponding to:
            index[0]: indexes of source labeled data in X, y
            index[1]: indexes of target labeled data in X, y
            
        weights_target: numpy array, optional (default=None)
            Weights for target sample.
            If None, all weights are set to 1.
            
        fit_params: key, value arguments
            Arguments to pass to the fit method (epochs, batch_size...)
            
        Returns
        -------
        self 
        """
        self.fit_params = fit_params
        assert hasattr(index, "__iter__"), "index should be iterable"
        assert len(index) == 2, "index length should be 2"
        src_index = index[0]
        tgt_index = index[1]
        self._fit(X, y, src_index, tgt_index, weights_target)        
        return self


    def _fit(self, X, y, src_index, tgt_index, weights_target):
        # Resize source and target index to the same length
        max_size = max((len(src_index), len(tgt_index)))
        resize_src_ind = np.array([src_index[i%len(src_index)]
                                   for i in range(max_size)])
        resize_tgt_ind = np.array([tgt_index[i%len(tgt_index)]
                                   for i in range(max_size)])
        
        # If no target weights, all are set to one 
        if weights_target is None:
             resize_weights_target = np.ones(max_size)
        else:
            resize_weights_target = np.array([weights_target[i%len(weights_target)]
                                              for i in range(max_size)])
                     
        # Create WANN model
        if not hasattr(self, "model"):
            self._create_wann(shape=X.shape[1])

        # Callback to save predicted weights and labels
        callbacks = []
        if "callbacks" in self.fit_params:
            callbacks = self.fit_params["callbacks"]
            del self.fit_params["callbacks"]
            
        # Initialize weighting network
        self.weights_predictor.compile(optimizer=copy.deepcopy(self.optimizer), loss="mse") #copy.deepcopy(self.optimizer)
        self.weights_predictor.fit(X[src_index], np.ones(len(src_index)), **self.fit_params)
        
        # Fit
        self.model.fit([X[resize_src_ind], X[resize_tgt_ind],
                        y[resize_src_ind], y[resize_tgt_ind],
                        resize_weights_target],
                       callbacks = callbacks,
                       **self.fit_params)
        return self
            
            
    def _create_wann(self, shape):
        # Build task, weights_predictor and discrepancer network
        # Weights_predictor should end with a relu activation
        self.weights_predictor = self.get_weighting_model(
                shape, activation='relu', C=self.C_w, name="weights")
        self.task = self.get_base_model(
                shape, activation=None, C=self.C, name="task")
        self.discrepancer = self.get_base_model(
                shape, activation=None, C=self.C, name="discrepancer")
        
        # Create input layers for Xs, Xt, ys, yt and target weights
        input_source = Input(shape=(shape,))
        input_target = Input(shape=(shape,))
        output_source = Input(shape=(1,))
        output_target = Input(shape=(1,))
        weights_target = Input(shape=(1,))
        Flip = _GradReverse()
        
        # Get networks output for both source and target
        weights_source = self.weights_predictor(input_source)      
        output_task_s = self.task(input_source)
        output_task_t = self.task(input_target)
        output_disc_s = self.discrepancer(input_source)
        output_disc_t = self.discrepancer(input_target)
        
        # Reversal layer at the end of discrepancer
        output_disc_s = Flip(output_disc_s)
        output_disc_t = Flip(output_disc_t)

        # Create model and define loss
        self.model = Model([input_source, input_target, output_source, output_target, weights_target],
                           [output_task_s, output_task_t, output_disc_s, output_disc_t, weights_source],
                           name='WANN')
            
        loss_task_s = K.mean(multiply([weights_source, K.square(output_source - output_task_s)]))
        loss_task_t = K.mean(multiply([weights_target, K.square(output_target - output_task_t)]))
            
        loss_disc_s = K.mean(multiply([weights_source, K.square(output_source - output_disc_s)]))
        loss_disc_t = K.mean(multiply([weights_target, K.square(output_target - output_disc_t)]))
            
        loss_task = loss_task_s #+ loss_task_t
        loss_disc = loss_disc_t - loss_disc_s
                         
        loss = loss_task + loss_disc
   
        self.model.add_loss(loss)
        self.model.add_metric(tf.reduce_sum(K.mean(weights_source)), name="weights", aggregation="mean")
        self.model.add_metric(tf.reduce_sum(loss_task_s), name="task_s", aggregation="mean")
        self.model.add_metric(tf.reduce_sum(loss_task_t), name="task_t", aggregation="mean")
        self.model.add_metric(tf.reduce_sum(loss_disc), name="disc", aggregation="mean")
        self.model.add_metric(tf.reduce_sum(loss_disc_s), name="disc_s", aggregation="mean")
        self.model.add_metric(tf.reduce_sum(loss_disc_t), name="disc_t", aggregation="mean")
        self.model.compile(optimizer=self.optimizer)
        return self
    
    
    def predict(self, X):
        """
        Predict method: return the prediction of task network
        
        Parameters
        ----------
        X: array
            input data
            
        Returns
        -------
        y_pred: array
            prediction of task network
        """
        return self.task.predict(X)
    
    
    def get_weight(self, X):
        """
        Return the predictions of weighting network
        
        Parameters
        ----------
        X: array
            input data
            
        Returns
        -------
        array:
            weights
        """
        return self.weights_predictor.predict(X)
    
    
    def save(self, path):
        """
        Save task network
        
        Parameters
        ----------
        path: str
            path where to save the model
        """
        self.task.save(path)
        self.weights_predictor.save(path + "_weights")
        return self
Example #6
0
class CVAE():

    def __init__(self, 
                 x_input_size,
                 b_input_size,
                 lb_input_size,
                 sf_input_size = 1,
                 enc = (256, 256, 128),
                 dec = (128, 256, 256),
                 latent_k = 30,
                 alpha = 0.01,
                 input_dropout = 0.,
                 encoder_dropout = 0.1,
                 nonmissing_indicator = None,
                 init = tf.keras.initializers.Orthogonal(),
                 optimizer = None,
                 lr = 0.001,
                 clipvalue = 5,
                 clipnorm = 1,
                 theta_min = 1e-6,
                 theta_max = 1e2):

        self.x_input_size = x_input_size
        self.b_input_size = b_input_size
        self.lb_input_size = lb_input_size
        self.z_input_size = latent_k
        self.sf_input_size = sf_input_size
        self.disp_input_size = b_input_size
        self.enc = enc
        self.dec = dec
        self.latent_k = latent_k
        self.alpha = alpha
        self.input_dropout = input_dropout
        self.encoder_dropout = encoder_dropout
        self.init = init
        self.lr = lr
        self.clipvalue = clipvalue
        self.clipnorm = clipnorm
        self.theta_min = theta_min
        self.theta_max = theta_max



        if optimizer is None:
            self.optimizer = tf.keras.optimizers.Adam(learning_rate = lr, 
                                                  clipnorm = clipnorm, clipvalue = clipvalue)
        else:
            self.optimizer = optimizer

        
        self.extra_models = {}
        self.model = None


    
    def build(self, print_model = False):


        """ Inputs. """
        self.x_input = Input(shape = (self.x_input_size, ), name = 'x_input')
        self.b_input = Input(shape = (self.b_input_size, ), name = 'B')
        self.sf_input = Input(shape = (self.sf_input_size, ), name = 'sf_input')
        self.z_input = Input(shape = (self.z_input_size, ), name = 'z_input')
        self.disp_input = Input(shape = (self.disp_input_size, ), name = 'nb_input')
        self.x_raw_input = Input(shape = (self.x_input_size, ), name = 'x_raw_input')
        self.lb_input = Input(shape = (self.lb_input_size, ), name = 'lb_input')



        """ Build the encoder. """
        self.z = keras.layers.concatenate([self.x_input, self.b_input])

        for i, hid_size in enumerate(self.enc):
            dense_layer_name = 'e%s' % i
            bn_layer_name = 'be%s' % i
            self.z = Dense(hid_size, activation = None, use_bias = True, 
                        kernel_initializer = self.init, name = dense_layer_name)(self.z)
            self.z = LeakyReLU(alpha = 0.01)(self.z)
            self.z = BatchNormalization(center = False, scale = True, name = bn_layer_name)(self.z)
            if i == 0:
                self.z = Dropout(self.encoder_dropout)(self.z)
            
        self.z_mean = Dense(self.latent_k, activation = None, use_bias = True, 
                            kernel_initializer = self.init, name = 'z_mean_dense')(self.z)
        self.z_mean = LeakyReLU(alpha = 0.01, name = 'z_mean_act')(self.z_mean)
        self.z_mean = BatchNormalization(center = False, scale = True, name = 'bz')(self.z_mean)
        self.z_log_var = Dense(self.latent_k, activation = None, use_bias = True, 
                            kernel_initializer = tf.keras.initializers.Orthogonal(gain = 0.01), 
                            name = 'z_log_var')(self.z)

        # Sampling latent space
        self.z_out = Lambda(sample_z, output_shape = (self.latent_k, ))([self.z_mean, self.z_log_var])

        self.extra_models['mean_out'] = Model([self.x_input, self.b_input], self.z_mean, name = 'mean_out')
        self.extra_models['var_out'] = Model([self.x_input, self.b_input], self.z_log_var, name = 'var_out')
        self.extra_models['samp_out'] = Model([self.x_input, self.b_input], self.z_out, name = 'samp_out')


        """ Build the prediction network. """
        self.lb_pred = Dense(self.latent_k, activation = 'sigmoid', use_bias = True, 
                            kernel_initializer = self.init, name = 'pred_sigmoid')(self.z_mean)
        self.lb_pred = BatchNormalization(center = False, scale = True, name = 'lz1')(self.lb_pred)
        self.lb_pred = Dense(int(0.5*self.latent_k), activation = 'sigmoid', use_bias = True, 
                            kernel_initializer = self.init, name = 'pred_sigmoid2')(self.lb_pred)
        self.lb_pred = BatchNormalization(center = False, scale = True, name = 'lz2')(self.lb_pred)
        self.lb_pred = Dense(self.lb_input_size, activation = 'softmax', use_bias = True, 
                            kernel_initializer = self.init, name = 'pred_softmax')(self.lb_pred)
        self.extra_models['lb_pred'] = Model([self.x_input, self.b_input], self.lb_pred, name = 'lb_pred')


        """ Build the decoder. """
        #### decoder network
        self.decoder_dense_layers = []
        self.decoder_leaky_layers = []
        for i, hid_size in enumerate(self.dec):
            dense_layer_name = 'd%s' % i
            self.decoder_dense_layers.append ( Dense(hid_size, activation = None, use_bias = True, 
                                                kernel_initializer = self.init, name = dense_layer_name) )
            self.decoder_leaky_layers.append ( LeakyReLU(alpha = 0.01) )
        self.last_layer_mu = Dense(self.x_input_size, activation = None, use_bias = True, 
                                kernel_initializer = self.init, name = 'mu_out')


        #### start from sampled latent values
        self.decoder11 = keras.layers.concatenate([self.z_out, self.b_input])
        for i, hid_size in enumerate(self.dec):
            self.decoder11 = self.decoder_dense_layers[i](self.decoder11)
            self.decoder11 = self.decoder_leaky_layers[i](self.decoder11)
        self.mu_hat = self.last_layer_mu(self.decoder11)
        self.mu_hat_sf = AddLayer(name = 'mu_hat_sf')([self.mu_hat, self.sf_input])
        self.mu_hat_exp_sf = ExpLayer(name = 'mu_hat_exp_sf')(self.mu_hat_sf)
        self.mu_hat_exp = ExpLayer(name = 'mu_hat_exp')(self.mu_hat)


        #### start from zeroed latent values
        self.decoder12_mean = keras.layers.concatenate([self.z_input, self.b_input])
        for i, hid_size in enumerate(self.dec):
            self.decoder12_mean = self.decoder_dense_layers[i](self.decoder12_mean)
            self.decoder12_mean = self.decoder_leaky_layers[i](self.decoder12_mean)
        self.mu_hat_mean = self.last_layer_mu(self.decoder12_mean)
        self.mu_hat_mean_sf = AddLayer(name = 'mu_hat_mean_sf')([self.mu_hat_mean, self.sf_input])
        self.mu_hat_mean_exp_sf = ExpLayer(name = 'mu_hat_mean_exp_sf')(self.mu_hat_mean_sf)
        self.mu_hat_mean_exp = ExpLayer(name = 'mu_hat_mean_exp')(self.mu_hat_mean)

        self.extra_models['decoder_mean'] = Model([self.z_input, self.b_input], [self.mu_hat_mean_exp], name = 'decoder_mean')



        """ Build the dispersion network. """
        self.last_layer_theta = Dense(self.x_input_size, activation = None, use_bias = True, 
                                      kernel_initializer = self.init, name = 'theta_out')

        #### start from sampled latent values
        self.theta_hat = self.last_layer_theta(self.disp_input)
        self.theta_hat = ClipLayer(name = 'clip_theta_hat')(self.theta_hat)
        self.theta_hat_exp = ExpLayer(name = 'theta_hat_exp')(self.theta_hat)

        #### start from zeroed latent values
        self.theta_hat_mean = self.last_layer_theta(self.disp_input)
        self.theta_hat_mean = ClipLayer(name = 'clip_theta_hat_mean')(self.theta_hat_mean)
        self.theta_hat_mean_exp = ExpLayer(name = 'theta_hat_mean_exp')(self.theta_hat_mean)

        self.extra_models['disp_model'] = Model(self.disp_input, self.theta_hat_mean_exp, name = 'disp_model')



        """ Build the whole network. """
        # decoder output
        self.out_hat = keras.layers.concatenate([self.mu_hat_sf, self.theta_hat], name = 'out')
        self.out_hat_mean = keras.layers.concatenate([self.mu_hat_mean_sf, self.theta_hat_mean], name = 'out_mean')
        # the whole model
        self.model = Model(inputs = [self.z_input, self.x_input, self.b_input, self.sf_input, self.disp_input, self.x_raw_input, self.lb_input], 
                           outputs = [self.out_hat, self.out_hat_mean, self.lb_pred], 
                           name = 'model')

        if print_model:
            self.model.summary()


        self.pred_loss = K.sum( tf.keras.losses.categorical_crossentropy(self.lb_input, self.lb_pred), axis = -1)
        self.kl_loss = -0.5 * K.sum(1 + self.z_log_var - K.square(self.z_mean) - K.exp(self.z_log_var), axis = -1)
        self.recon_loss = ((1 - self.alpha) * self.nb_loss_func(self.x_raw_input, self.mu_hat_exp_sf) 
                            + self.alpha * self.nb_loss0_func(self.x_raw_input, self.mu_hat_mean_exp_sf))
        


    def add_loss(self, pred_weight, kl_weight=1):
        self.final_loss = kl_weight * self.kl_loss + self.recon_loss + pred_weight * self.pred_loss
        self.model.add_loss(self.final_loss)
        self.model.add_metric(self.pred_loss, name='pred_loss')
        self.model.add_metric(self.kl_loss, name='kl_loss')
        self.model.add_metric(self.recon_loss, name='recon_loss')



    def compile_model(self, pred_weight, kl_weight=1, optimizer = None):

        self.add_loss(pred_weight, kl_weight)

        if optimizer is not None:
            self.optimizer = optimizer

        self.model.compile(optimizer = self.optimizer)



    def kl_loss_func(self):

        kl_loss = -0.5 * K.sum(1 + self.z_log_var - K.square(self.z_mean) - K.exp(self.z_log_var), axis = -1)

        return kl_loss



    def nb_loss_func(self, y_true, y_pred):
    
        log_mu = self.mu_hat_sf
        log_theta = self.theta_hat
        mu = self.mu_hat_exp_sf
        theta = self.theta_hat_exp
        f0 = -1 * tf.math.lgamma(y_true + 1)
        f1 = -1 * tf.math.lgamma(theta)
        f2 = tf.math.lgamma(y_true + theta)
        f3 = - (y_true + theta) * tf.math.log(theta + mu)
        f4 = theta * log_theta
        f5 = y_true * log_mu
        final = - K.sum(f0 + f1 + f2 + f3 + f4 + f5, axis = 1)

        return final



    def nb_loss0_func(self, y_true, y_pred):
        
        log_mu = self.mu_hat_mean_sf
        log_theta = self.theta_hat_mean
        mu = self.mu_hat_mean_exp_sf
        theta = self.theta_hat_mean_exp
        f0 = -1 * tf.math.lgamma(y_true + 1)
        f1 = -1 * tf.math.lgamma(theta)
        f2 = tf.math.lgamma(y_true + theta)
        f3 = - (y_true + theta) * tf.math.log(theta + mu)
        f4 = theta * log_theta
        f5 = y_true * log_mu
        final = - K.sum(f0 + f1 + f2 + f3 + f4 + f5, axis = 1)

        return final


    def load_weights(self, filename):

        self.model.load_weights(filename)


    def save_weights(self, filename, save_extra = False, extra_filenames = None):

        self.model.save_weights(filename)

        if save_extra:
            self.extra_models['mean_out'].save_weights(extra_filenames["mean_out"])
            self.extra_models['var_out'].save_weights(extra_filenames["var_out"])
            self.extra_models['samp_out'].save_weights(extra_filenames["samp_out"])
            self.extra_models['disp_model'].save_weights(extra_filenames["disp_model"])
            self.extra_models['decoder_mean'].save_weights(extra_filenames["decoder_mean"])


    def predict_latent(self, X, B):

        latent_mean = self.extra_models['mean_out'].predict([X, B])

        return latent_mean


    
    def predict_beta(self, X, B, sf):

        zmean = self.extra_models['mean_out'].predict([X, B])
        X_lambda = self.extra_models['decoder_mean'].predict([zmean, B])
        X_theta = self.extra_models['disp_model'].predict(B)
        X_lambda = (X_lambda.T * sf).T

        return X_lambda, X_theta



    def model_initialize(self, adata, 
                         epochs=300, batch_size=64, 
                         validation_split=0.1, 
                         shuffle=True, fit_verbose=1, 
                         lr_patience=1, 
                         lr_factor=0.1, 
                         lr_verbose=True,
                         es_patience=2,
                         es_verbose=True):

        callbacks = []
        lr_cb = ReduceLROnPlateau(monitor='val_pred_loss', patience=lr_patience, factor=lr_factor, verbose=lr_verbose)
        callbacks.append(lr_cb)
        es_cb = EarlyStopping(monitor='val_pred_loss', patience=es_patience, verbose=es_verbose)
        callbacks.append(es_cb)

        z_blank = np.zeros((adata.n_obs, self.latent_k), dtype=np.float32)
        inputs = [z_blank, 
                  adata.X, 
                  adata.obsm['saver_batch'], 
                  np.log(adata.obs.size_factors), 
                  adata.obsm['saver_batch'], 
                  adata.raw.X, 
                  adata.obsm['saver_targetL']]
        outputs = [adata.raw.X, 
                  adata.raw.X, 
                  adata.obsm['saver_targetL']]

        loss = self.model.fit(inputs, outputs,
                              epochs=epochs,
                              batch_size=batch_size,
                              shuffle=shuffle,
                              callbacks=callbacks,
                              validation_split=validation_split,
                              verbose=fit_verbose)


        return loss




    def model_finetune(self, adata, 
                         epochs=300, batch_size=64, 
                         validation_split=0.1, 
                         shuffle=True, fit_verbose=1, 
                         lr_patience=4, 
                         lr_factor=0.1, 
                         lr_verbose=True,
                         es_patience=6,
                         es_verbose=True):

        callbacks = []
        lr_cb = ReduceLROnPlateau(monitor='val_loss', patience=lr_patience, factor=lr_factor, verbose=lr_verbose)
        callbacks.append(lr_cb)
        es_cb = EarlyStopping(monitor='val_loss', patience=es_patience, verbose=es_verbose)
        callbacks.append(es_cb)

        z_blank = np.zeros((adata.n_obs, self.latent_k), dtype=np.float32)
        inputs = [z_blank, 
                  adata.X, 
                  adata.obsm['saver_batch'], 
                  np.log(adata.obs.size_factors), 
                  adata.obsm['saver_batch'], 
                  adata.raw.X, 
                  adata.obsm['saver_targetL']]
        outputs = [adata.raw.X, 
                  adata.raw.X, 
                  adata.obsm['saver_targetL']]

        loss = self.model.fit(inputs, outputs,
                              epochs=epochs,
                              batch_size=batch_size,
                              shuffle=shuffle,
                              callbacks=callbacks,
                              validation_split=validation_split,
                              verbose=fit_verbose)

        return loss