Beispiel #1
0
def entropy_batch_mixing(adata,
                         label_key='batch',
                         n_neighbors=50,
                         n_pools=50,
                         n_samples_per_pool=100,
                         subsample_frac=1.0):
    adata = remove_sparsity(adata)

    n_samples = adata.shape[0]
    keep_idx = np.random.choice(np.arange(n_samples),
                                size=min(n_samples,
                                         int(subsample_frac * n_samples)),
                                replace=False)
    adata = adata[keep_idx, :]

    neighbors = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(adata.X)
    indices = neighbors.kneighbors(adata.X, return_distance=False)[:, 1:]
    batch_indices = np.vectorize(lambda i: adata.obs[label_key].values[i])(
        indices)

    entropies = np.apply_along_axis(__entropy_from_indices,
                                    axis=1,
                                    arr=batch_indices)

    # average n_pools entropy results where each result is an average of n_samples_per_pool random samples.
    if n_pools == 1:
        score = np.mean(entropies)
    else:
        score = np.mean([
            np.mean(entropies[np.random.choice(len(entropies),
                                               size=n_samples_per_pool)])
            for _ in range(n_pools)
        ])

    return score
Beispiel #2
0
    def to_mmd_layer(self, adata, encoder_labels, feed_fake=0):
        """
            Map `data` in to the pn layer after latent layer. This function will feed data
            in encoder part of C-VAE and compute the latent space coordinates
            for each sample in data.
            # Parameters
                data: `~anndata.AnnData`
                    Annotated data matrix to be mapped to latent space. `data.X` has to be in shape [n_obs, n_vars].
                labels: numpy nd-array
                    `numpy nd-array` of labels to be fed as CVAE's condition array.
            # Returns
                latent: numpy nd-array
                    returns array containing latent space encoding of 'data'
        """
        if feed_fake > 0:
            decoder_labels = np.zeros(shape=encoder_labels.shape) + feed_fake
        else:
            decoder_labels = encoder_labels
        adata = remove_sparsity(adata)

        images = np.reshape(adata.X, (-1, *self.x_dim))
        encoder_labels = to_categorical(encoder_labels,
                                        num_classes=self.n_conditions)
        decoder_labels = to_categorical(decoder_labels,
                                        num_classes=self.n_conditions)

        mmd_latent = self.cvae_model.predict(
            [images, encoder_labels, decoder_labels])[1]
        mmd_adata = anndata.AnnData(X=mmd_latent)
        mmd_adata.obs = adata.obs.copy(deep=True)

        return mmd_adata
Beispiel #3
0
    def to_latent(self, adata, encoder_labels, return_adata=True):
        """
            Map `adata` in to the latent space. This function will feed data
            in encoder part of C-VAE and compute the latent space coordinates
            for each sample in data.
            # Parameters
                adata: `~anndata.AnnData`
                    Annotated data matrix to be mapped to latent space. `data.X` has to be in shape [n_obs, n_vars].
                encoder_labels: numpy nd-array
                    `numpy nd-array` of labels to be fed as CVAE's condition array.
                return_adata: boolean
                    if `True`, will output as an `anndata` object or put the results in the `obsm` attribute of `adata`
            # Returns
                output: `~anndata.AnnData`
                    returns `anndata` object containing latent space encoding of 'adata'
        """
        adata = remove_sparsity(adata)

        encoder_labels = to_categorical(encoder_labels,
                                        num_classes=self.n_conditions)
        latent = self.encoder_model.predict([adata.X, encoder_labels])[2]
        latent = np.nan_to_num(latent)

        if return_adata:
            output = anndata.AnnData(X=latent)
            output.obs = adata.obs.copy(deep=True)
        else:
            output = latent

        return output
Beispiel #4
0
    def to_latent(self, adata, encoder_labels):
        """
            Map `data` in to the latent space. This function will feed data
            in encoder part of C-VAE and compute the latent space coordinates
            for each sample in data.
            # Parameters
                data: `~anndata.AnnData`
                    Annotated data matrix to be mapped to latent space. `data.X` has to be in shape [n_obs, n_vars].
                labels: numpy nd-array
                    `numpy nd-array` of labels to be fed as CVAE's condition array.
            # Returns
                latent: numpy nd-array
                    returns array containing latent space encoding of 'data'
        """
        adata = remove_sparsity(adata)

        images = np.reshape(adata.X, (-1, *self.x_dim))
        encoder_labels = to_categorical(encoder_labels,
                                        num_classes=self.n_conditions)

        latent = self.encoder_model.predict([images, encoder_labels])[2]

        latent_adata = anndata.AnnData(X=latent)
        latent_adata.obs = adata.obs.copy(deep=True)

        return latent_adata
Beispiel #5
0
def asw(adata, label_key):
    adata = remove_sparsity(adata)

    labels = adata.obs[label_key].values

    labels_encoded = LabelEncoder().fit_transform(labels)

    return silhouette_score(adata.X, labels_encoded)
Beispiel #6
0
    def predict(self,
                adata,
                encoder_labels,
                decoder_labels,
                return_adata=True):
        """
            Predicts the cell type provided by the user in stimulated condition.
            # Parameters
                adata: `~anndata.AnnData`
                    Annotated data matrix whether in primary space.
                encoder_labels: `numpy nd-array`
                    `numpy nd-array` of labels to be fed as encoder's condition array.
                decoder_labels: `numpy nd-array`
                    `numpy nd-array` of labels to be fed as decoder's condition array.
                return_adata: boolean
                    if `True`, will output as an `anndata` object or put the results in the `obsm` attribute of `adata`
            # Returns
                output: `~anndata.AnnData`
                    `anndata` object of predicted cells in primary space.
            # Example
            ```python
            import scanpy as sc
            import trvae
            train_data = sc.read("train.h5ad")
            valid_adata = sc.read("validation.h5ad")
            n_conditions = len(train_adata.obs['condition'].unique().tolist())
            network = trvae.archs.trVAEMulti(train_adata.shape[1], n_conditions)
            network.train(train_adata, valid_adata, le=None,
                          condition_key="condition", cell_type_key="cell_label",
                          n_epochs=1000, batch_size=256
                          )
            encoder_labels, _ = trvae.utils.label_encoder(train_adata, condition_key="condition")
            decoder_labels, _ = trvae.utils.label_encoder(train_adata, condition_key="condition")
            pred_adata = network.predict(train_adata, encoder_labels, decoder_labels)
            ```
        """
        adata = remove_sparsity(adata)

        encoder_labels = to_categorical(encoder_labels,
                                        num_classes=self.n_conditions)
        decoder_labels = to_categorical(decoder_labels,
                                        num_classes=self.n_conditions)

        reconstructed = self.cvae_model.predict(
            [adata.X, encoder_labels, decoder_labels])[0]
        reconstructed = np.nan_to_num(reconstructed)

        if return_adata:
            output = anndata.AnnData(X=reconstructed)
            output.obs = adata.obs.copy(deep=True)
            output.var_names = adata.var_names
        else:
            output = reconstructed

        return output
Beispiel #7
0
def nmi(adata, label_key):
    adata = remove_sparsity(adata)

    n_labels = len(adata.obs[label_key].unique().tolist())
    kmeans = KMeans(n_labels, n_init=200)

    labels_pred = kmeans.fit_predict(adata.X)
    labels = adata.obs[label_key].values
    labels_encoded = LabelEncoder().fit_transform(labels)

    return normalized_mutual_info_score(labels_encoded, labels_pred)
Beispiel #8
0
    def get_reconstruction_error(self, adata, condition_key):
        adata = remove_sparsity(adata)

        labels_encoded, _ = label_encoder(adata, None, condition_key)
        labels_onehot = to_categorical(labels_encoded,
                                       num_classes=self.n_conditions)

        x = [adata.X, labels_onehot, labels_onehot]
        y = [adata.X, labels_encoded]

        return self.cvae_model.evaluate(x, y, verbose=0)
Beispiel #9
0
def predict_between_conditions(network, adata, pred_adatas, source_condition, source_label, target_label, name,
                               condition_key='condition'):
    adata_source = adata.copy()[adata.obs[condition_key] == source_condition]

    if adata_source.shape[0] == 0:
        adata_source = pred_adatas.copy()[pred_adatas.obs[condition_key] == source_condition]

    source_labels = np.zeros(adata_source.shape[0]) + source_label
    target_labels = np.zeros(adata_source.shape[0]) + target_label

    pred_adata = network.predict(adata_source,
                                 encoder_labels=source_labels,
                                 decoder_labels=target_labels)
    pred_adata.obs[condition_key] = name
    pred_adata = remove_sparsity(pred_adata)

    return pred_adata
Beispiel #10
0
    def to_mmd_layer(self,
                     adata,
                     encoder_labels,
                     feed_fake=0,
                     return_adata=True):
        """
            Map `adata` in to the MMD layer of trVAE network. This function will compute output
            activation of MMD layer in trVAE.
            # Parameters
                adata: `~anndata.AnnData`
                    Annotated data matrix to be mapped to latent space. `data.X` has to be in shape [n_obs, n_vars].
                encoder_labels: numpy nd-array
                    `numpy nd-array` of labels to be fed as CVAE's condition array.
                feed_fake: int
                    if `feed_fake` is non-negative, `decoder_labels` will be identical to `encoder_labels`.
                    if `feed_fake` is not non-negative, `decoder_labels` will be fed with `feed_fake` value.
                return_adata: boolean
                    if `True`, will output as an `anndata` object or put the results in the `obsm` attribute of `adata`
            # Returns
                output: `~anndata.AnnData`
                    returns `anndata` object containing MMD latent space encoding of 'adata'
        """
        if feed_fake >= 0:
            decoder_labels = np.zeros(shape=encoder_labels.shape) + feed_fake
        else:
            decoder_labels = encoder_labels

        encoder_labels = to_categorical(encoder_labels,
                                        num_classes=self.n_conditions)
        decoder_labels = to_categorical(decoder_labels,
                                        num_classes=self.n_conditions)

        adata = remove_sparsity(adata)

        x = [adata.X, encoder_labels, decoder_labels]
        mmd_latent = self.cvae_model.predict(x)[1]
        mmd_latent = np.nan_to_num(mmd_latent)
        if return_adata:
            output = anndata.AnnData(X=mmd_latent)
            output.obs = adata.obs.copy(deep=True)
        else:
            output = mmd_latent

        return output
Beispiel #11
0
    def predict(self, adata, encoder_labels, decoder_labels):
        """
            Predicts the cell type provided by the user in stimulated condition.
            # Parameters
                data: `~anndata.AnnData`
                    Annotated data matrix whether in primary space.
                labels: numpy nd-array
                    `numpy nd-array` of labels to be fed as CVAE's condition array.
            # Returns
                stim_pred: numpy nd-array
                    `numpy nd-array` of predicted cells in primary space.
            # Example
            ```python
            import scanpy as sc
            import scgen
            train_data = sc.read("train_kang.h5ad")
            validation_data = sc.read("./data/validation.h5ad")
            network = scgen.CVAE(train_data=train_data, use_validation=True, validation_data=validation_data, model_path="./saved_models/", conditions={"ctrl": "control", "stim": "stimulated"})
            network.train(n_epochs=20)
            prediction = network.predict('CD4T', obs_key={"cell_type": ["CD8T", "NK"]})
            ```
        """
        adata = remove_sparsity(adata)

        images = np.reshape(adata.X, (-1, *self.x_dim))
        encoder_labels = to_categorical(encoder_labels,
                                        num_classes=self.n_conditions)
        decoder_labels = to_categorical(decoder_labels,
                                        num_classes=self.n_conditions)

        reconstructed = self.cvae_model.predict(
            [images, encoder_labels, decoder_labels])[0]
        reconstructed = np.reshape(reconstructed, (-1, np.prod(self.x_dim)))

        reconstructed_adata = anndata.AnnData(X=reconstructed)
        reconstructed_adata.obs = adata.obs.copy(deep=True)
        reconstructed_adata.var_names = adata.var_names

        return reconstructed_adata
Beispiel #12
0
    def to_mmd_layer(self, adata, batch_key):
        """
            Map ``adata`` in to the MMD space. This function will feed data
            in ``mmd_model`` of trVAE and compute the MMD space coordinates
            for each sample in data.

            Parameters
            ----------
            adata: :class:`~anndata.AnnData`
                Annotated data matrix to be mapped to MMD latent space.
                Please note that ``adata.X`` has to be in shape [n_obs, x_dimension]
            encoder_labels: :class:`~numpy.ndarray`
                :class:`~numpy.ndarray` of labels to be fed as trVAE'sencoder condition array.
            decoder_labels: :class:`~numpy.ndarray`
                :class:`~numpy.ndarray` of labels to be fed as trVAE'sdecoder condition array.

            Returns
            -------
            adata_mmd: :class:`~anndata.AnnData`
                returns Annotated data containing MMD latent space encoding of ``adata``
        """
        adata = remove_sparsity(adata)

        encoder_labels, _ = label_encoder(adata, self.condition_encoder, batch_key)
        decoder_labels, _ = label_encoder(adata, self.condition_encoder, batch_key)

        encoder_labels = to_categorical(encoder_labels, num_classes=self.n_conditions)
        decoder_labels = to_categorical(decoder_labels, num_classes=self.n_conditions)

        cvae_inputs = [adata.X, encoder_labels, decoder_labels]

        mmd = self.cvae_model.predict(cvae_inputs)[1]
        mmd = np.nan_to_num(mmd, nan=0.0, posinf=0.0, neginf=0.0)

        adata_mmd = anndata.AnnData(X=mmd)
        adata_mmd.obs = adata.obs.copy(deep=True)

        return adata_mmd
Beispiel #13
0
    def predict(self, adata, condition_key, target_condition=None):
        """Feeds ``adata`` to trVAE and produces the reconstructed data.

            Parameters
            ----------
            adata: :class:`~anndata.AnnData`
                Annotated data matrix whether in primary space.
            condition_key: str
                :class:`~numpy.ndarray` of labels to be fed as trVAE'sencoder condition array.
            target_condition: str
                :class:`~numpy.ndarray` of labels to be fed as trVAE'sdecoder condition array.

            Returns
            -------
            adata_pred: `~anndata.AnnData`
                Annotated data of predicted cells in primary space.
        """
        adata = remove_sparsity(adata)

        encoder_labels, _ = label_encoder(adata, self.condition_encoder, condition_key)
        if target_condition is not None:
            decoder_labels = np.zeros_like(encoder_labels) + self.condition_encoder[
                target_condition]
        else:
            decoder_labels, _ = label_encoder(adata, self.condition_encoder, condition_key)

        encoder_labels = to_categorical(encoder_labels, num_classes=self.n_conditions)
        decoder_labels = to_categorical(decoder_labels, num_classes=self.n_conditions)

        x_hat = self.cvae_model.predict([adata.X, encoder_labels, decoder_labels])[0]

        adata_pred = anndata.AnnData(X=x_hat)
        adata_pred.obs = adata.obs
        adata_pred.var_names = adata.var_names

        return adata_pred
Beispiel #14
0
def create_model(net_train_adata, net_valid_adata, domain_key, label_key,
                 source_domains, target_domains, domain_encoder):
    z_dim_choices = {{choice([20, 40, 50, 60, 80, 100])}}
    mmd_dim_choices = {{choice([64, 128, 256])}}

    alpha_choices = {{choice([0.1, 0.01, 0.001, 0.0001, 0.00001, 0.000001])}}
    beta_choices = {{choice([1, 100, 500, 1000, 1500, 2000, 5000])}}
    gamma_choices = {{choice([1, 10, 100, 1000, 5000, 10000])}}
    eta_choices = {{choice([0.01, 0.1, 1, 5, 10, 50])}}
    batch_size_choices = {{choice([128, 256, 512, 1024, 1500])}}
    dropout_rate_choices = {{choice([0.1, 0.2, 0.5])}}

    n_labels = len(net_train_adata.obs[label_key].unique().tolist())
    n_domains = len(net_train_adata.obs[domain_key].unique().tolist())

    network = trvae.archs.trVAEATAC(
        x_dimension=net_train_adata.shape[1],
        z_dimension=z_dim_choices,
        mmd_dimension=mmd_dim_choices,
        learning_rate=0.001,
        alpha=alpha_choices,
        beta=beta_choices,
        gamma=gamma_choices,
        eta=eta_choices,
        model_path=f"./models/trVAEATAC/hyperopt/Pancreas/",
        n_labels=n_labels,
        n_domains=n_domains,
        output_activation='leaky_relu',
        mmd_computation_way="1",
        dropout_rate=dropout_rate_choices)

    network.train(
        net_train_adata,
        net_valid_adata,
        domain_key,
        label_key,
        source_key=source_domains,
        target_key=target_domains,
        domain_encoder=domain_encoder,
        n_epochs=10000,
        batch_size=batch_size_choices,
        early_stop_limit=500,
        lr_reducer=0,
    )

    target_adata_train = net_train_adata.copy()[
        net_train_adata.obs[domain_key].isin(target_domains)]
    target_adata_valid = net_valid_adata.copy()[
        net_valid_adata.obs[domain_key].isin(target_domains)]

    target_adata = target_adata_train.concatenate(target_adata_valid)
    target_adata = remove_sparsity(target_adata)

    target_adata_domains_encoded, _ = label_encoder(
        target_adata, condition_key=domain_key, label_encoder=domain_encoder)
    target_adata_domains_onehot = to_categorical(target_adata_domains_encoded,
                                                 num_classes=n_domains)

    target_adata_classes_encoded = network.label_enc.transform(
        target_adata.obs[label_key].values)
    target_adata_classes_onehot = to_categorical(target_adata_classes_encoded,
                                                 num_classes=n_labels)

    x_target = [
        target_adata.X, target_adata_domains_onehot,
        target_adata_domains_onehot
    ]
    y_target = target_adata_classes_onehot

    _, target_acc = network.classifier_model.evaluate(x_target,
                                                      y_target,
                                                      verbose=0)
    objective = -target_acc
    print(
        f'alpha = {network.alpha}, beta = {network.beta}, eta={network.eta}, z_dim = {network.z_dim}, mmd_dim = {network.mmd_dim}, batch_size = {batch_size_choices}, dropout_rate = {network.dr_rate}, gamma = {network.gamma}'
    )
    return {'loss': objective, 'status': STATUS_OK}
Beispiel #15
0
def create_model(train_adata,
                 net_train_adata, net_valid_adata,
                 condition_key, cell_type_key,
                 cell_type, condition_encoder,
                 data_name, source_condition, target_condition):
    n_conditions = len(net_train_adata.obs[condition_key].unique().tolist())

    z_dim_choices = {{choice([10, 20, 40, 50, 60, 80, 100])}}
    mmd_dim_choices = {{choice([64, 128, 256])}}

    alpha_choices = {{choice([0.001, 0.0001, 0.00001, 0.000001])}}
    beta_choices = {{choice([1, 5, 10, 20, 40, 50, 100])}}
    eta_choices = {{choice([1, 2, 5, 10, 50, 100])}}
    batch_size_choices = {{choice([128, 256, 512, 1024, 1500])}}
    dropout_rate_choices = {{choice([0.1, 0.2, 0.5])}}

    network = trvae.archs.trVAETaskSpecific(x_dimension=net_train_adata.shape[1],
                                            z_dimension=z_dim_choices,
                                            n_conditions=n_conditions,
                                            mmd_dimension=mmd_dim_choices,
                                            alpha=alpha_choices,
                                            beta=beta_choices,
                                            eta=eta_choices,
                                            kernel='multi-scale-rbf',
                                            learning_rate=0.001,
                                            clip_value=1e6,
                                            loss_fn='mse',
                                            model_path=f"./models/trVAETaskSpecific/hyperopt/{data_name}/{cell_type}/{target_condition}/",
                                            dropout_rate=dropout_rate_choices,
                                            output_activation="relu",
                                            )

    network.train(net_train_adata,
                  net_valid_adata,
                  condition_encoder,
                  condition_key,
                  n_epochs=10000,
                  batch_size=batch_size_choices,
                  verbose=2,
                  early_stop_limit=100,
                  lr_reducer=80,
                  monitor='val_loss',
                  shuffle=True,
                  save=False)

    cell_type_adata = train_adata.copy()[train_adata.obs[cell_type_key] == cell_type]

    sc.tl.rank_genes_groups(cell_type_adata,
                            key_added='up_reg_genes',
                            groupby=condition_key,
                            groups=[target_condition],
                            reference=source_condition,
                            n_genes=10)

    sc.tl.rank_genes_groups(cell_type_adata,
                            key_added='down_reg_genes',
                            groupby=condition_key,
                            groups=[source_condition],
                            reference=target_condition,
                            n_genes=10)

    up_genes = cell_type_adata.uns['up_reg_genes']['names'][target_condition].tolist()
    down_genes = cell_type_adata.uns['down_reg_genes']['names'][source_condition].tolist()

    top_genes = up_genes + down_genes

    source_adata = cell_type_adata.copy()[cell_type_adata.obs[condition_key] == source_condition]

    source_label = condition_encoder[source_condition]
    target_label = condition_encoder[target_condition]

    source_labels = np.zeros(source_adata.shape[0]) + source_label
    target_labels = np.zeros(source_adata.shape[0]) + target_label

    pred_target = network.predict(source_adata,
                                  encoder_labels=source_labels,
                                  decoder_labels=target_labels)

    real_target = cell_type_adata.copy()[cell_type_adata.obs[condition_key] == target_condition]

    real_target = remove_sparsity(real_target)

    pred_target = pred_target[:, top_genes]
    real_target = real_target[:, top_genes]

    x_var = np.var(pred_target.X, axis=0)
    y_var = np.var(real_target.X, axis=0)
    m, b, r_value_var, p_value, std_err = stats.linregress(x_var, y_var)
    r_value_var = r_value_var ** 2

    x_mean = np.mean(pred_target.X, axis=0)
    y_mean = np.mean(real_target.X, axis=0)
    m, b, r_value_mean, p_value, std_err = stats.linregress(x_mean, y_mean)
    r_value_mean = r_value_mean ** 2

    best_mean_diff = np.abs(np.mean(x_mean - y_mean))
    best_var_diff = np.abs(np.var(x_var - y_var))
    objective = r_value_mean + r_value_var
    print(f'Reg_mean_diff: {r_value_mean}, Reg_var_all: {r_value_var})')
    print(f'Mean diff: {best_mean_diff}, Var_diff: {best_var_diff}')
    print(
        f'alpha = {network.alpha}, beta = {network.beta}, eta={network.eta}, z_dim = {network.z_dim}, mmd_dim = {network.mmd_dim}, batch_size = {batch_size_choices}, dropout_rate = {network.dr_rate}, lr = {network.lr}')
    return {'loss': -objective, 'status': STATUS_OK}
Beispiel #16
0
    def train(self,
              train_adata,
              valid_adata=None,
              condition_encoder=None,
              condition_key='condition',
              n_epochs=25,
              batch_size=32,
              early_stop_limit=20,
              lr_reducer=10,
              threshold=0.0025,
              monitor='val_loss',
              shuffle=True,
              verbose=2,
              save=True):
        """
            Trains the network `n_epochs` times with given `train_data`
            and validates the model using validation_data if it was given
            in the constructor function. This function is using `early stopping`
            technique to prevent overfitting.
            # Parameters
                n_epochs: int
                    number of epochs to iterate and optimize network weights
                early_stop_limit: int
                    number of consecutive epochs in which network loss is not going lower.
                    After this limit, the network will stop training.
                threshold: float
                    Threshold for difference between consecutive validation loss values
                    if the difference is upper than this `threshold`, this epoch will not
                    considered as an epoch in early stopping.
                full_training: bool
                    if `True`: Network will be trained with all batches of data in each epoch.
                    if `False`: Network will be trained with a random batch of data in each epoch.
                initial_run: bool
                    if `True`: The network will initiate training and log some useful initial messages.
                    if `False`: Network will resume the training using `restore_model` function in order
                        to restore last model which has been trained with some training dataset.
            # Returns
                Nothing will be returned
            # Example
            ```python
            import scanpy as sc
            import scgen
            train_data = sc.read(train_katrain_kang.h5ad           >>> validation_data = sc.read(valid_kang.h5ad)
            network = scgen.CVAE(train_data=train_data, use_validation=True, validation_data=validation_data, model_path="./saved_models/", conditions={"ctrl": "control", "stim": "stimulated"})
            network.train(n_epochs=20)
            ```
        """
        train_adata = remove_sparsity(train_adata)

        train_labels_encoded, self.condition_encoder = label_encoder(
            train_adata, condition_encoder, condition_key)
        train_labels_onehot = to_categorical(train_labels_encoded,
                                             num_classes=self.n_conditions)

        callbacks = [
            History(),
            CSVLogger(filename="./csv_logger.log"),
        ]
        if early_stop_limit > 0:
            callbacks.append(
                EarlyStopping(patience=early_stop_limit,
                              monitor=monitor,
                              min_delta=threshold))

        if lr_reducer > 0:
            callbacks.append(
                ReduceLROnPlateau(monitor=monitor,
                                  patience=lr_reducer,
                                  verbose=verbose))

        if verbose > 2:
            callbacks.append(
                LambdaCallback(on_epoch_end=lambda epoch, logs: print_message(
                    epoch, logs, n_epochs, verbose)))
            fit_verbose = 0
        else:
            fit_verbose = verbose

        train_images = np.reshape(train_adata.X, (-1, *self.x_dim))

        x = [train_images, train_labels_onehot, train_labels_onehot]
        y = [train_images, train_labels_encoded]

        if valid_adata is not None:
            valid_adata = remove_sparsity(valid_adata)

            valid_labels_encoded, _ = label_encoder(valid_adata,
                                                    condition_encoder,
                                                    condition_key)
            valid_labels_onehot = to_categorical(valid_labels_encoded,
                                                 num_classes=self.n_conditions)

            valid_images = np.reshape(valid_adata.X, (-1, *self.x_dim))

            x_valid = [valid_images, valid_labels_onehot, valid_labels_onehot]
            y_valid = [valid_images, valid_labels_encoded]

            self.cvae_model.fit(x=x,
                                y=y,
                                epochs=n_epochs,
                                batch_size=batch_size,
                                validation_data=(x_valid, y_valid),
                                shuffle=shuffle,
                                callbacks=callbacks,
                                verbose=fit_verbose)
        else:
            self.cvae_model.fit(x=x,
                                y=y,
                                epochs=n_epochs,
                                batch_size=batch_size,
                                shuffle=shuffle,
                                callbacks=callbacks,
                                verbose=fit_verbose)
        if save:
            self.save_model()