Esempio n. 1
0
    def to_z_latent(self, adata, batch_key):
        """
            Map `adata` in to the latent space. This function will feed data
            in encoder part of C-VAE and compute the latent space coordinates
            for each sample in data.
            # Parameters
                adata: `~anndata.AnnData`
                    Annotated data matrix to be mapped to latent space. `data.X` has to be in shape [n_obs, n_vars].
                encoder_labels: numpy nd-array
                    `numpy nd-array` of labels to be fed as CVAE's condition array.
                return_adata: boolean
                    if `True`, will output as an `anndata` object or put the results in the `obsm` attribute of `adata`
            # Returns
                output: `~anndata.AnnData`
                    returns `anndata` object containing latent space encoding of 'adata'
        """
        if sparse.issparse(adata.X):
            adata.X = adata.X.A

        encoder_labels, _ = label_encoder(adata, self.condition_encoder, batch_key)
        encoder_labels = to_categorical(encoder_labels, num_classes=self.n_conditions)
        latent = self.encoder_model.predict([adata.X, encoder_labels])[2]
        latent = np.nan_to_num(latent)

        latent_adata = anndata.AnnData(X=latent)
        latent_adata.obs = adata.obs.copy(deep=True)

        return latent_adata
Esempio n. 2
0
    def get_reconstruction_error(self, adata, condition_key):
        adata = remove_sparsity(adata)

        labels_encoded, _ = label_encoder(adata, None, condition_key)
        labels_onehot = to_categorical(labels_encoded,
                                       num_classes=self.n_conditions)

        x = [adata.X, labels_onehot, labels_onehot]
        y = [adata.X, labels_encoded]

        return self.cvae_model.evaluate(x, y, verbose=0)
Esempio n. 3
0
    def to_mmd_layer(self, adata, batch_key):
        """
            Map ``adata`` in to the MMD space. This function will feed data
            in ``mmd_model`` of trVAE and compute the MMD space coordinates
            for each sample in data.

            Parameters
            ----------
            adata: :class:`~anndata.AnnData`
                Annotated data matrix to be mapped to MMD latent space.
                Please note that ``adata.X`` has to be in shape [n_obs, x_dimension]
            encoder_labels: :class:`~numpy.ndarray`
                :class:`~numpy.ndarray` of labels to be fed as trVAE'sencoder condition array.
            decoder_labels: :class:`~numpy.ndarray`
                :class:`~numpy.ndarray` of labels to be fed as trVAE'sdecoder condition array.

            Returns
            -------
            adata_mmd: :class:`~anndata.AnnData`
                returns Annotated data containing MMD latent space encoding of ``adata``
        """
        adata = remove_sparsity(adata)

        encoder_labels, _ = label_encoder(adata, self.condition_encoder, batch_key)
        decoder_labels, _ = label_encoder(adata, self.condition_encoder, batch_key)

        encoder_labels = to_categorical(encoder_labels, num_classes=self.n_conditions)
        decoder_labels = to_categorical(decoder_labels, num_classes=self.n_conditions)

        cvae_inputs = [adata.X, encoder_labels, decoder_labels]

        mmd = self.cvae_model.predict(cvae_inputs)[1]
        mmd = np.nan_to_num(mmd, nan=0.0, posinf=0.0, neginf=0.0)

        adata_mmd = anndata.AnnData(X=mmd)
        adata_mmd.obs = adata.obs.copy(deep=True)

        return adata_mmd
Esempio n. 4
0
    def predict(self, adata, condition_key, target_condition=None):
        """Feeds ``adata`` to trVAE and produces the reconstructed data.

            Parameters
            ----------
            adata: :class:`~anndata.AnnData`
                Annotated data matrix whether in primary space.
            condition_key: str
                :class:`~numpy.ndarray` of labels to be fed as trVAE'sencoder condition array.
            target_condition: str
                :class:`~numpy.ndarray` of labels to be fed as trVAE'sdecoder condition array.

            Returns
            -------
            adata_pred: `~anndata.AnnData`
                Annotated data of predicted cells in primary space.
        """
        adata = remove_sparsity(adata)

        encoder_labels, _ = label_encoder(adata, self.condition_encoder, condition_key)
        if target_condition is not None:
            decoder_labels = np.zeros_like(encoder_labels) + self.condition_encoder[
                target_condition]
        else:
            decoder_labels, _ = label_encoder(adata, self.condition_encoder, condition_key)

        encoder_labels = to_categorical(encoder_labels, num_classes=self.n_conditions)
        decoder_labels = to_categorical(decoder_labels, num_classes=self.n_conditions)

        x_hat = self.cvae_model.predict([adata.X, encoder_labels, decoder_labels])[0]

        adata_pred = anndata.AnnData(X=x_hat)
        adata_pred.obs = adata.obs
        adata_pred.var_names = adata.var_names

        return adata_pred
Esempio n. 5
0
def create_model(net_train_adata, net_valid_adata, domain_key, label_key,
                 source_domains, target_domains, domain_encoder):
    z_dim_choices = {{choice([20, 40, 50, 60, 80, 100])}}
    mmd_dim_choices = {{choice([64, 128, 256])}}

    alpha_choices = {{choice([0.1, 0.01, 0.001, 0.0001, 0.00001, 0.000001])}}
    beta_choices = {{choice([1, 100, 500, 1000, 1500, 2000, 5000])}}
    gamma_choices = {{choice([1, 10, 100, 1000, 5000, 10000])}}
    eta_choices = {{choice([0.01, 0.1, 1, 5, 10, 50])}}
    batch_size_choices = {{choice([128, 256, 512, 1024, 1500])}}
    dropout_rate_choices = {{choice([0.1, 0.2, 0.5])}}

    n_labels = len(net_train_adata.obs[label_key].unique().tolist())
    n_domains = len(net_train_adata.obs[domain_key].unique().tolist())

    network = trvae.archs.trVAEATAC(
        x_dimension=net_train_adata.shape[1],
        z_dimension=z_dim_choices,
        mmd_dimension=mmd_dim_choices,
        learning_rate=0.001,
        alpha=alpha_choices,
        beta=beta_choices,
        gamma=gamma_choices,
        eta=eta_choices,
        model_path=f"./models/trVAEATAC/hyperopt/Pancreas/",
        n_labels=n_labels,
        n_domains=n_domains,
        output_activation='leaky_relu',
        mmd_computation_way="1",
        dropout_rate=dropout_rate_choices)

    network.train(
        net_train_adata,
        net_valid_adata,
        domain_key,
        label_key,
        source_key=source_domains,
        target_key=target_domains,
        domain_encoder=domain_encoder,
        n_epochs=10000,
        batch_size=batch_size_choices,
        early_stop_limit=500,
        lr_reducer=0,
    )

    target_adata_train = net_train_adata.copy()[
        net_train_adata.obs[domain_key].isin(target_domains)]
    target_adata_valid = net_valid_adata.copy()[
        net_valid_adata.obs[domain_key].isin(target_domains)]

    target_adata = target_adata_train.concatenate(target_adata_valid)
    target_adata = remove_sparsity(target_adata)

    target_adata_domains_encoded, _ = label_encoder(
        target_adata, condition_key=domain_key, label_encoder=domain_encoder)
    target_adata_domains_onehot = to_categorical(target_adata_domains_encoded,
                                                 num_classes=n_domains)

    target_adata_classes_encoded = network.label_enc.transform(
        target_adata.obs[label_key].values)
    target_adata_classes_onehot = to_categorical(target_adata_classes_encoded,
                                                 num_classes=n_labels)

    x_target = [
        target_adata.X, target_adata_domains_onehot,
        target_adata_domains_onehot
    ]
    y_target = target_adata_classes_onehot

    _, target_acc = network.classifier_model.evaluate(x_target,
                                                      y_target,
                                                      verbose=0)
    objective = -target_acc
    print(
        f'alpha = {network.alpha}, beta = {network.beta}, eta={network.eta}, z_dim = {network.z_dim}, mmd_dim = {network.mmd_dim}, batch_size = {batch_size_choices}, dropout_rate = {network.dr_rate}, gamma = {network.gamma}'
    )
    return {'loss': objective, 'status': STATUS_OK}
Esempio n. 6
0
    def _train_on_batch(self, adata,
                        condition_key, train_size=0.8,
                        n_epochs=300, batch_size=512,
                        early_stop_limit=10, lr_reducer=7,
                        save=True, retrain=True, verbose=3):
        train_adata, valid_adata = train_test_split(adata, train_size)

        if self.gene_names is None:
            self.gene_names = train_adata.var_names.tolist()
        else:
            if set(self.gene_names).issubset(set(train_adata.var_names)):
                train_adata = train_adata[:, self.gene_names]
            else:
                raise Exception("set of gene names in train adata are inconsistent with class' gene_names")

            if set(self.gene_names).issubset(set(valid_adata.var_names)):
                valid_adata = valid_adata[:, self.gene_names]
            else:
                raise Exception("set of gene names in valid adata are inconsistent with class' gene_names")

        train_conditions_encoded, self.condition_encoder = label_encoder(train_adata, le=self.condition_encoder,
                                                                         condition_key=condition_key)

        valid_conditions_encoded, self.condition_encoder = label_encoder(valid_adata, le=self.condition_encoder,
                                                                         condition_key=condition_key)

        if not retrain and os.path.exists(os.path.join(self.model_path, f"{self.model_name}.h5")):
            self.restore_model_weights()
            return

        train_conditions_onehot = to_categorical(train_conditions_encoded, num_classes=self.n_conditions)
        valid_conditions_onehot = to_categorical(valid_conditions_encoded, num_classes=self.n_conditions)

        if sparse.issparse(train_adata.X):
            is_sparse = True
        else:
            is_sparse = False

        train_expr = train_adata.X
        valid_expr = valid_adata.X.A if is_sparse else valid_adata.X
        x_valid = [valid_expr, valid_conditions_onehot, valid_conditions_onehot]

        if self.loss_fn in ['nb', 'zinb']:
            x_valid.append(valid_adata.obs[self.size_factor_key].values)
            y_valid = [valid_adata.raw.X.A if sparse.issparse(valid_adata.raw.X) else valid_adata.raw.X,
                       valid_conditions_encoded]
        else:
            y_valid = [valid_expr, valid_conditions_encoded]

        es_patience, best_val_loss = 0, 1e10
        for i in range(n_epochs):
            train_loss = train_recon_loss = train_mmd_loss = 0.0
            for j in range(min(200, train_adata.shape[0] // batch_size)):
                batch_indices = np.random.choice(train_adata.shape[0], batch_size)

                batch_expr = train_expr[batch_indices, :].A if is_sparse else train_expr[batch_indices, :]

                x_train = [batch_expr, train_conditions_onehot[batch_indices], train_conditions_onehot[batch_indices]]

                if self.loss_fn in ['nb', 'zinb']:
                    x_train.append(train_adata.obs[self.size_factor_key].values[batch_indices])
                    y_train = [train_adata.raw.X[batch_indices].A if sparse.issparse(
                        train_adata.raw.X[batch_indices]) else train_adata.raw.X[batch_indices],
                               train_conditions_encoded[batch_indices]]
                else:
                    y_train = [batch_expr, train_conditions_encoded[batch_indices]]

                batch_loss, batch_recon_loss, batch_kl_loss = self.cvae_model.train_on_batch(x_train, y_train)

                train_loss += batch_loss / batch_size
                train_recon_loss += batch_recon_loss / batch_size
                train_mmd_loss += batch_kl_loss / batch_size

            valid_loss, valid_recon_loss, valid_mmd_loss = self.cvae_model.evaluate(x_valid, y_valid, verbose=0)

            if valid_loss < best_val_loss:
                best_val_loss = valid_loss
                es_patience = 0
            else:
                es_patience += 1
                if es_patience == early_stop_limit:
                    print("Training stopped with Early Stopping")
                    break

            logs = {"loss": train_loss, "recon_loss": train_recon_loss, "mmd_loss": train_mmd_loss,
                    "val_loss": valid_loss, "val_recon_loss": valid_recon_loss, "val_mmd_loss": valid_mmd_loss}
            print_progress(i, logs, n_epochs)

        if save:
            self.save(make_dir=True)
Esempio n. 7
0
    def _fit(self, adata,
             condition_key, train_size=0.8,
             n_epochs=300, batch_size=512,
             early_stop_limit=10, lr_reducer=7,
             save=True, retrain=True, verbose=3):
        train_adata, valid_adata = train_test_split(adata, train_size)

        if self.gene_names is None:
            self.gene_names = train_adata.var_names.tolist()
        else:
            if set(self.gene_names).issubset(set(train_adata.var_names)):
                train_adata = train_adata[:, self.gene_names]
            else:
                raise Exception("set of gene names in train adata are inconsistent with class' gene_names")

            if set(self.gene_names).issubset(set(valid_adata.var_names)):
                valid_adata = valid_adata[:, self.gene_names]
            else:
                raise Exception("set of gene names in valid adata are inconsistent with class' gene_names")

        train_expr = train_adata.X.A if sparse.issparse(train_adata.X) else train_adata.X
        valid_expr = valid_adata.X.A if sparse.issparse(valid_adata.X) else valid_adata.X

        train_conditions_encoded, self.condition_encoder = label_encoder(train_adata, le=self.condition_encoder,
                                                                         condition_key=condition_key)

        valid_conditions_encoded, self.condition_encoder = label_encoder(valid_adata, le=self.condition_encoder,
                                                                         condition_key=condition_key)

        if not retrain and os.path.exists(os.path.join(self.model_path, f"{self.model_name}.h5")):
            self.restore_model_weights()
            return

        callbacks = [
            History(),
        ]

        if verbose > 2:
            callbacks.append(
                LambdaCallback(on_epoch_end=lambda epoch, logs: print_progress(epoch, logs, n_epochs)))
            fit_verbose = 0
        else:
            fit_verbose = verbose

        if early_stop_limit > 0:
            callbacks.append(EarlyStopping(patience=early_stop_limit, monitor='val_loss'))

        if lr_reducer > 0:
            callbacks.append(ReduceLROnPlateau(monitor='val_loss', patience=lr_reducer))

        train_conditions_onehot = to_categorical(train_conditions_encoded, num_classes=self.n_conditions)
        valid_conditions_onehot = to_categorical(valid_conditions_encoded, num_classes=self.n_conditions)

        x_train = [train_expr, train_conditions_onehot, train_conditions_onehot]
        x_valid = [valid_expr, valid_conditions_onehot, valid_conditions_onehot]

        y_train = [train_expr, train_conditions_encoded]
        y_valid = [valid_expr, valid_conditions_encoded]

        self.cvae_model.fit(x=x_train,
                            y=y_train,
                            validation_data=(x_valid, y_valid),
                            epochs=n_epochs,
                            batch_size=batch_size,
                            verbose=fit_verbose,
                            callbacks=callbacks,
                            )
        if save:
            self.save(make_dir=True)
Esempio n. 8
0
    def train(self,
              train_adata,
              valid_adata=None,
              condition_encoder=None,
              condition_key='condition',
              n_epochs=10000,
              batch_size=1024,
              early_stop_limit=100,
              lr_reducer=80,
              threshold=0.0,
              monitor='val_loss',
              shuffle=True,
              verbose=0,
              save=True,
              monitor_best=True):
        """
            Trains the network `n_epochs` times with given `train_data`
            and validates the model using validation_data if it was given
            in the constructor function. This function is using `early stopping`
            technique to prevent overfitting.
            # Parameters
                train_adata: `~anndata.AnnData`
                    `AnnData` object for training trVAE
                valid_adata: `~anndata.AnnData`
                    `AnnData` object for validating trVAE (if None, trVAE will automatically split the data with
                    fraction of 80%/20%.
                condition_encoder: dict
                    dictionary of encoded conditions (if None, trVAE will make one for data)
                condition_key: str
                    name of conditions (domains) column in obs matrix
                cell_type_key: str
                    name of cell_types (labels) column in obs matrix
                n_epochs: int
                    number of epochs to iterate and optimize network weights
                batch_size: int
                    number of samples to be used in each batch for network weights optimization
                early_stop_limit: int
                    number of consecutive epochs in which network loss is not going lower.
                    After this limit, the network will stop training.
                threshold: float
                    Threshold for difference between consecutive validation loss values
                    if the difference is upper than this `threshold`, this epoch will not
                    considered as an epoch in early stopping.
                monitor: str
                    metric to be monitored for early stopping.
                shuffle: boolean
                    if `True`, `train_adata` will be shuffled before training.
                verbose: int
                    level of verbosity
                save: boolean
                    if `True`, the model will be saved in the specified path after training.
            # Returns
                Nothing will be returned
            # Example
            ```python
            import scanpy as sc
            import trvae
            train_data = sc.read("train.h5ad")
            valid_adata = sc.read("validation.h5ad")
            n_conditions = len(train_adata.obs['condition'].unique().tolist())
            network = trvae.archs.trVAEMulti(train_adata.shape[1], n_conditions)
            network.train(train_adata, valid_adata, le=None,
                          condition_key="condition", cell_type_key="cell_label",
                          n_epochs=1000, batch_size=256
                          )
            ```
        """
        train_labels_encoded, _ = label_encoder(train_adata, condition_encoder,
                                                condition_key)
        train_labels_onehot = to_categorical(train_labels_encoded,
                                             num_classes=self.n_conditions)

        callbacks = [
            History(),
            CSVLogger(filename="./csv_logger.log"),
        ]

        if early_stop_limit > 0:
            callbacks.append(
                EarlyStopping(patience=early_stop_limit,
                              monitor=monitor,
                              min_delta=threshold))

        if lr_reducer > 0:
            callbacks.append(
                ReduceLROnPlateau(monitor=monitor,
                                  patience=lr_reducer,
                                  verbose=verbose))

        if verbose > 2:
            callbacks.append(
                LambdaCallback(on_epoch_end=lambda epoch, logs: print_message(
                    epoch, logs, n_epochs, verbose)))
            fit_verbose = 0
        else:
            fit_verbose = verbose
        if monitor_best:
            os.makedirs(self.model_to_use, exist_ok=True)
            callbacks.append(
                ModelCheckpoint(filepath=os.path.join(self.model_to_use,
                                                      "best_model.h5"),
                                save_best_only=True,
                                monitor=monitor,
                                period=50))

        if sparse.issparse(train_adata.X):
            train_adata.X = train_adata.X.A

        x = [train_adata.X, train_labels_onehot, train_labels_onehot]
        y = [train_adata.X, train_labels_encoded]

        if valid_adata is not None:
            if sparse.issparse(valid_adata.X):
                valid_adata.X = valid_adata.X.A

            valid_labels_encoded, _ = label_encoder(valid_adata,
                                                    condition_encoder,
                                                    condition_key)
            valid_labels_onehot = to_categorical(valid_labels_encoded,
                                                 num_classes=self.n_conditions)

            x_valid = [valid_adata.X, valid_labels_onehot, valid_labels_onehot]
            y_valid = [valid_adata.X, valid_labels_encoded]

            history = self.cvae_model.fit(x=x,
                                          y=y,
                                          epochs=n_epochs,
                                          batch_size=batch_size,
                                          validation_data=(x_valid, y_valid),
                                          shuffle=shuffle,
                                          callbacks=callbacks,
                                          verbose=fit_verbose)
        else:
            history = self.cvae_model.fit(x=x,
                                          y=y,
                                          epochs=n_epochs,
                                          batch_size=batch_size,
                                          validation_split=0.2,
                                          shuffle=shuffle,
                                          callbacks=callbacks,
                                          verbose=fit_verbose)
            if monitor_best:
                self.restore_model()
            elif save and not monitor_best:
                self.save_model()
Esempio n. 9
0
 def train(self,
           train_data,
           use_validation=False,
           valid_data=None,
           n_epochs=25,
           batch_size=32,
           early_stop_limit=20,
           threshold=0.0025,
           initial_run=True,
           shuffle=True):
     """
         Trains the network `n_epochs` times with given `train_data`
         and validates the model using validation_data if it was given
         in the constructor function. This function is using `early stopping`
         technique to prevent overfitting.
         # Parameters
             n_epochs: int
                 number of epochs to iterate and optimize network weights
             early_stop_limit: int
                 number of consecutive epochs in which network loss is not going lower.
                 After this limit, the network will stop training.
             threshold: float
                 Threshold for difference between consecutive validation loss values
                 if the difference is upper than this `threshold`, this epoch will not
                 considered as an epoch in early stopping.
             full_training: bool
                 if `True`: Network will be trained with all batches of data in each epoch.
                 if `False`: Network will be trained with a random batch of data in each epoch.
             initial_run: bool
                 if `True`: The network will initiate training and log some useful initial messages.
                 if `False`: Network will resume the training using `restore_model` function in order
                     to restore last model which has been trained with some training dataset.
         # Returns
             Nothing will be returned
         # Example
         ```python
         import scanpy as sc
         import scgen
         train_data = sc.read(train_katrain_kang.h5ad           >>> validation_data = sc.read(valid_kang.h5ad)
         network = scgen.CVAE(train_data=train_data, use_validation=True, validation_data=validation_data, model_path="./saved_models/", conditions={"ctrl": "control", "stim": "stimulated"})
         network.train(n_epochs=20)
         ```
     """
     if initial_run:
         log.info("----Training----")
         assign_step_zero = tensorflow.assign(self.global_step, 0)
         _init_step = self.sess.run(assign_step_zero)
     if not initial_run:
         self.saver.restore(self.sess, self.model_to_use)
     train_labels, le = label_encoder(train_data)
     if use_validation and valid_data is None:
         raise Exception("valid_data is None but use_validation is True.")
     if use_validation:
         valid_labels, _ = label_encoder(valid_data)
     loss_hist = []
     patience = early_stop_limit
     min_delta = threshold
     patience_cnt = 0
     for it in range(n_epochs):
         increment_global_step_op = tensorflow.assign(
             self.global_step, self.global_step + 1)
         _step = self.sess.run(increment_global_step_op)
         current_step = self.sess.run(self.global_step)
         train_loss = 0
         for lower in range(0, train_data.shape[0], batch_size):
             upper = min(lower + batch_size, train_data.shape[0])
             if sparse.issparse(train_data.X):
                 x_mb = train_data[lower:upper, :].X.A
             else:
                 x_mb = train_data[lower:upper, :].X
             y_mb = train_labels[lower:upper]
             _, current_loss_train = self.sess.run(
                 [self.solver, self.vae_loss],
                 feed_dict={
                     self.x: x_mb,
                     self.y: y_mb,
                     self.time_step: current_step,
                     self.size: len(x_mb),
                     self.is_training: True
                 })
             train_loss += current_loss_train
         print(
             f"iteration {it}: {train_loss // (train_data.shape[0] // batch_size)}"
         )
         if use_validation:
             valid_loss = 0
             for lower in range(0, valid_data.shape[0], batch_size):
                 upper = min(lower + batch_size, valid_data.shape[0])
                 if sparse.issparse(valid_data.X):
                     x_mb = valid_data[lower:upper, :].X.A
                 else:
                     x_mb = valid_data[lower:upper, :].X
                 y_mb = valid_labels[lower:upper]
                 current_loss_valid = self.sess.run(self.vae_loss,
                                                    feed_dict={
                                                        self.x: x_mb,
                                                        self.y: y_mb,
                                                        self.time_step:
                                                        current_step,
                                                        self.size:
                                                        len(x_mb),
                                                        self.is_training:
                                                        False
                                                    })
                 valid_loss += current_loss_valid
             loss_hist.append(valid_loss / valid_data.shape[0])
             if it > 0 and loss_hist[it - 1] - loss_hist[it] > min_delta:
                 patience_cnt = 0
             else:
                 patience_cnt += 1
             if patience_cnt > patience:
                 os.makedirs(self.model_to_use, exist_ok=True)
                 save_path = self.saver.save(self.sess, self.model_to_use)
                 break
     else:
         os.makedirs(self.model_to_use, exist_ok=True)
         save_path = self.saver.save(self.sess, self.model_to_use)
     log.info(f"Model saved in file: {save_path}. Training finished")
Esempio n. 10
0
    def train(self,
              train_adata,
              valid_adata=None,
              condition_encoder=None,
              condition_key='condition',
              n_epochs=25,
              batch_size=32,
              early_stop_limit=20,
              lr_reducer=10,
              threshold=0.0025,
              monitor='val_loss',
              shuffle=True,
              verbose=2,
              save=True):
        """
            Trains the network `n_epochs` times with given `train_data`
            and validates the model using validation_data if it was given
            in the constructor function. This function is using `early stopping`
            technique to prevent overfitting.
            # Parameters
                n_epochs: int
                    number of epochs to iterate and optimize network weights
                early_stop_limit: int
                    number of consecutive epochs in which network loss is not going lower.
                    After this limit, the network will stop training.
                threshold: float
                    Threshold for difference between consecutive validation loss values
                    if the difference is upper than this `threshold`, this epoch will not
                    considered as an epoch in early stopping.
                full_training: bool
                    if `True`: Network will be trained with all batches of data in each epoch.
                    if `False`: Network will be trained with a random batch of data in each epoch.
                initial_run: bool
                    if `True`: The network will initiate training and log some useful initial messages.
                    if `False`: Network will resume the training using `restore_model` function in order
                        to restore last model which has been trained with some training dataset.
            # Returns
                Nothing will be returned
            # Example
            ```python
            import scanpy as sc
            import scgen
            train_data = sc.read(train_katrain_kang.h5ad           >>> validation_data = sc.read(valid_kang.h5ad)
            network = scgen.CVAE(train_data=train_data, use_validation=True, validation_data=validation_data, model_path="./saved_models/", conditions={"ctrl": "control", "stim": "stimulated"})
            network.train(n_epochs=20)
            ```
        """
        train_adata = remove_sparsity(train_adata)

        train_labels_encoded, self.condition_encoder = label_encoder(
            train_adata, condition_encoder, condition_key)
        train_labels_onehot = to_categorical(train_labels_encoded,
                                             num_classes=self.n_conditions)

        callbacks = [
            History(),
            CSVLogger(filename="./csv_logger.log"),
        ]
        if early_stop_limit > 0:
            callbacks.append(
                EarlyStopping(patience=early_stop_limit,
                              monitor=monitor,
                              min_delta=threshold))

        if lr_reducer > 0:
            callbacks.append(
                ReduceLROnPlateau(monitor=monitor,
                                  patience=lr_reducer,
                                  verbose=verbose))

        if verbose > 2:
            callbacks.append(
                LambdaCallback(on_epoch_end=lambda epoch, logs: print_message(
                    epoch, logs, n_epochs, verbose)))
            fit_verbose = 0
        else:
            fit_verbose = verbose

        train_images = np.reshape(train_adata.X, (-1, *self.x_dim))

        x = [train_images, train_labels_onehot, train_labels_onehot]
        y = [train_images, train_labels_encoded]

        if valid_adata is not None:
            valid_adata = remove_sparsity(valid_adata)

            valid_labels_encoded, _ = label_encoder(valid_adata,
                                                    condition_encoder,
                                                    condition_key)
            valid_labels_onehot = to_categorical(valid_labels_encoded,
                                                 num_classes=self.n_conditions)

            valid_images = np.reshape(valid_adata.X, (-1, *self.x_dim))

            x_valid = [valid_images, valid_labels_onehot, valid_labels_onehot]
            y_valid = [valid_images, valid_labels_encoded]

            self.cvae_model.fit(x=x,
                                y=y,
                                epochs=n_epochs,
                                batch_size=batch_size,
                                validation_data=(x_valid, y_valid),
                                shuffle=shuffle,
                                callbacks=callbacks,
                                verbose=fit_verbose)
        else:
            self.cvae_model.fit(x=x,
                                y=y,
                                epochs=n_epochs,
                                batch_size=batch_size,
                                shuffle=shuffle,
                                callbacks=callbacks,
                                verbose=fit_verbose)
        if save:
            self.save_model()