示例#1
0
    def evaluate(self, adata, batch_key):
        adata = remove_sparsity(adata)

        encoder_labels, _ = label_encoder(adata, self.condition_encoder, batch_key)
        decoder_labels, _ = label_encoder(adata, self.condition_encoder, batch_key)

        encoder_labels = to_categorical(encoder_labels, num_classes=self.n_conditions)
        decoder_labels = to_categorical(decoder_labels, num_classes=self.n_conditions)

        cvae_inputs = [adata.X, encoder_labels, decoder_labels]

        encoded_labels = self.cvae_model.predict(cvae_inputs)[2].argmax(axis=1)

        self._reverse_cell_type_encoder()
        labels = []
        for encoded_label in encoded_labels:
            labels.append(self.inv_cell_type_encoder[encoded_label])

        labels = np.array(labels)
        true_labels = adata.obs[batch_key].values
        accuracy = np.mean(labels == true_labels)

        print(classification_report(true_labels, labels))

        return accuracy, confusion_matrix(true_labels, labels)
示例#2
0
    def get_latent(self, adata, batch_key):
        """ Transforms `adata` in latent space of CVAE and returns the latent
        coordinates in the annotated (adata) format.

        Parameters
        ----------
        adata: :class:`~anndata.AnnData`
            Annotated dataset matrix in Primary space.




        """
        if set(self.gene_names).issubset(set(adata.var_names)):
            adata = adata[:, self.gene_names]
        else:
            raise Exception(
                "set of gene names in train adata are inconsistent with scNet's gene_names"
            )

        encoder_labels, _ = label_encoder(adata, self.condition_encoder,
                                          batch_key)
        encoder_labels = to_categorical(encoder_labels,
                                        num_classes=self.n_conditions)

        return self.get_z_latent(adata, encoder_labels)
示例#3
0
    def get_latent(self, adata, batch_key, return_z=False):
        """ Transforms `adata` in latent space of scNet and returns the latent
        coordinates in the annotated (adata) format.

        Parameters
        ----------
        adata: :class:`~anndata.AnnData`
            Annotated dataset matrix in Primary space.
        batch_key: str
            Name of the column containing the study (batch) names for each sample.
        return_z: bool
            ``False`` by defaul. if ``True``, the output of bottleneck layer of network will be computed.

        Returns
        -------
        adata_pred: `~anndata.AnnData`
            Annotated data of transformed ``adata`` into latent space.
        """
        if set(self.gene_names).issubset(set(adata.var_names)):
            adata = adata[:, self.gene_names]
        else:
            raise Exception("set of gene names in train adata are inconsistent with scNet's gene_names")

        if self.beta == 0:
            return_z = True

        encoder_labels, _ = label_encoder(adata, self.condition_encoder, batch_key)
        encoder_labels = to_categorical(encoder_labels, num_classes=self.n_conditions)

        if return_z or self.beta == 0:
            return self.get_z_latent(adata, encoder_labels)
        else:
            return self.to_mmd_layer(adata, batch_key)
示例#4
0
    def __init__(self,
                 filename: str,
                 adata: anndata.AnnData,
                 batch_key: str,
                 cell_type_key: str,
                 encoder_model: Model,
                 n_per_epoch: int = 5,
                 n_batch_labels: int = 0,
                 n_celltype_labels: int = 0,
                 clustering_scores: list_or_str = 'all'):
        super(ScoreCallback, self).__init__()
        self.adata = remove_sparsity(adata)

        self.batch_labels, _ = label_encoder(adata,
                                             le=None,
                                             condition_key=batch_key)
        self.batch_labels = np.reshape(self.batch_labels, (-1, ))
        self.batch_labels_onehot = to_categorical(self.batch_labels,
                                                  num_classes=n_batch_labels)

        self.celltype_labels, _ = label_encoder(adata,
                                                le=None,
                                                condition_key=cell_type_key)
        self.celltype_labels = np.reshape(self.celltype_labels, (-1, ))
        self.celltype_labels_onehot = to_categorical(
            self.celltype_labels, num_classes=n_celltype_labels)

        self.filename = filename
        self.encoder_model = encoder_model
        self.n_per_epoch = n_per_epoch

        self.n_batch_labels = n_batch_labels
        self.n_celltype_labels = n_celltype_labels

        self.clustering_scores = clustering_scores
        self.score_computers = {
            "asw": self.asw,
            "ari": self.ari,
            "nmi": self.nmi,
            "ebm": self.entropy_of_batch_mixing,
            "knn": self.knn_purity
        }

        self.kmeans_batch = KMeans(self.n_batch_labels, n_init=200)
        self.kmeans_celltype = KMeans(self.n_celltype_labels, n_init=200)
示例#5
0
    def annotate(self, adata, batch_key, cell_type_key):
        adata = remove_sparsity(adata)

        encoder_labels, _ = label_encoder(adata, self.condition_encoder, batch_key)
        decoder_labels, _ = label_encoder(adata, self.condition_encoder, batch_key)

        encoder_labels = to_categorical(encoder_labels, num_classes=self.n_conditions)
        decoder_labels = to_categorical(decoder_labels, num_classes=self.n_conditions)

        cvae_inputs = [adata.X, encoder_labels, decoder_labels]

        encoded_labels = self.cvae_model.predict(cvae_inputs)[2].argmax(axis=1)

        self._reverse_cell_type_encoder()
        labels = []
        for encoded_label in encoded_labels:
            labels.append(self.inv_cell_type_encoder[encoded_label])

        adata.obs[f'pred_{cell_type_key}'] = np.array(labels)
示例#6
0
    def to_mmd_layer(self, adata, batch_key):
        """
            Map ``adata`` in to the MMD space. This function will feed data
            in ``mmd_model`` of scArches and compute the MMD space coordinates
            for each sample in data.

            Parameters
            ----------
            adata: :class:`~anndata.AnnData`
                Annotated data matrix to be mapped to MMD latent space.
                Please note that ``adata.X`` has to be in shape [n_obs, x_dimension]
            encoder_labels: :class:`~numpy.ndarray`
                :class:`~numpy.ndarray` of labels to be fed as scArches' encoder condition array.
            decoder_labels: :class:`~numpy.ndarray`
                :class:`~numpy.ndarray` of labels to be fed as scArches' decoder condition array.

            Returns
            -------
            adata_mmd: :class:`~anndata.AnnData`
                returns Annotated data containing MMD latent space encoding of ``adata``
        """
        adata = remove_sparsity(adata)

        encoder_labels, _ = label_encoder(adata, self.condition_encoder,
                                          batch_key)
        decoder_labels, _ = label_encoder(adata, self.condition_encoder,
                                          batch_key)

        encoder_labels = to_categorical(encoder_labels,
                                        num_classes=self.n_conditions)
        decoder_labels = to_categorical(decoder_labels,
                                        num_classes=self.n_conditions)

        cvae_inputs = [adata.X, encoder_labels, decoder_labels]

        mmd = self.cvae_model.predict(cvae_inputs)[1]
        mmd = np.nan_to_num(mmd, nan=0.0, posinf=0.0, neginf=0.0)

        adata_mmd = anndata.AnnData(X=mmd)
        adata_mmd.obs = adata.obs.copy(deep=True)

        return adata_mmd
示例#7
0
    def get_latent(self, adata, batch_key, return_mean=False):
        """ Transforms `adata` in latent space of CVAE and returns the latent
        coordinates in the annotated (adata) format.

        Parameters
        ----------
        adata: :class:`~anndata.AnnData`
            Annotated dataset matrix in Primary space.
        batch_key: str
            key for the observation that has batch labels in adata.obs.

        return_mean: bool
            if False, z will be sampled. Set to `True` if want a fix z value every time you call
             get_latent.

        Returns
        -------
        latent_adata: :class:`~anndata.AnnData`
            Annotated dataset matrix in Latent space.



        """
        if set(self.gene_names).issubset(set(adata.var_names)):
            adata = adata[:, self.gene_names]
        else:
            raise Exception(
                "set of gene names in train adata are inconsistent with scNet's gene_names"
            )

        encoder_labels, _ = label_encoder(adata, self.condition_encoder,
                                          batch_key)
        encoder_labels = to_categorical(encoder_labels,
                                        num_classes=self.n_conditions)

        return self.get_z_latent(adata, encoder_labels, return_mean)
示例#8
0
    def _train_on_batch(self,
                        adata,
                        condition_key,
                        train_size=0.8,
                        cell_type_key='cell_type',
                        n_epochs=100,
                        batch_size=128,
                        early_stop_limit=10,
                        lr_reducer=8,
                        n_per_epoch=0,
                        score_filename=None,
                        save=True,
                        retrain=True,
                        verbose=3):
        train_adata, valid_adata = train_test_split(adata, train_size)

        if self.gene_names is None:
            self.gene_names = train_adata.var_names.tolist()
        else:
            if set(self.gene_names).issubset(set(train_adata.var_names)):
                train_adata = train_adata[:, self.gene_names]
            else:
                raise Exception(
                    "set of gene names in train adata are inconsistent with class' gene_names"
                )

            if set(self.gene_names).issubset(set(valid_adata.var_names)):
                valid_adata = valid_adata[:, self.gene_names]
            else:
                raise Exception(
                    "set of gene names in valid adata are inconsistent with class' gene_names"
                )

        train_conditions_encoded, self.condition_encoder = label_encoder(
            train_adata,
            le=self.condition_encoder,
            condition_key=condition_key)

        valid_conditions_encoded, self.condition_encoder = label_encoder(
            valid_adata,
            le=self.condition_encoder,
            condition_key=condition_key)

        if not retrain and os.path.exists(
                os.path.join(self.model_path, f"{self.model_name}.h5")):
            self.restore_model_weights()
            return

        train_conditions_onehot = to_categorical(train_conditions_encoded,
                                                 num_classes=self.n_conditions)
        valid_conditions_onehot = to_categorical(valid_conditions_encoded,
                                                 num_classes=self.n_conditions)

        if sparse.issparse(train_adata.X):
            is_sparse = True
        else:
            is_sparse = False

        train_expr = train_adata.X
        valid_expr = valid_adata.X.A if is_sparse else valid_adata.X
        x_valid = [
            valid_expr, valid_conditions_onehot, valid_conditions_onehot
        ]

        if self.loss_fn in ['nb', 'zinb']:
            x_valid.append(valid_adata.obs[self.size_factor_key].values)
            y_valid = valid_adata.raw.X.A if sparse.issparse(
                valid_adata.raw.X) else valid_adata.raw.X
        else:
            y_valid = valid_expr

        es_patience, best_val_loss = 0, 1e10
        for i in range(n_epochs):
            train_loss = train_recon_loss = train_kl_loss = 0.0
            for j in range(min(500, train_adata.shape[0] // batch_size)):
                batch_indices = np.random.choice(train_adata.shape[0],
                                                 batch_size)

                batch_expr = train_expr[
                    batch_indices, :].A if is_sparse else train_expr[
                        batch_indices, :]

                x_train = [
                    batch_expr, train_conditions_onehot[batch_indices],
                    train_conditions_onehot[batch_indices]
                ]

                if self.loss_fn in ['nb', 'zinb']:
                    x_train.append(train_adata.obs[
                        self.size_factor_key].values[batch_indices])
                    y_train = train_adata.raw.X[
                        batch_indices].A if sparse.issparse(
                            train_adata.raw.X[batch_indices]
                        ) else train_adata.raw.X[batch_indices]
                else:
                    y_train = batch_expr

                batch_loss, batch_recon_loss, batch_kl_loss = self.cvae_model.train_on_batch(
                    x_train, y_train)

                train_loss += batch_loss / batch_size
                train_recon_loss += batch_recon_loss / batch_size
                train_kl_loss += batch_kl_loss / batch_size

            valid_loss, valid_recon_loss, valid_kl_loss = self.cvae_model.evaluate(
                x_valid, y_valid, verbose=0)

            if valid_loss < best_val_loss:
                best_val_loss = valid_loss
                es_patience = 0
            else:
                es_patience += 1
                if es_patience == early_stop_limit:
                    print("Training stopped with Early Stopping")
                    break

            logs = {
                "loss": train_loss,
                "recon_loss": train_recon_loss,
                "kl_loss": train_kl_loss,
                "val_loss": valid_loss,
                "val_recon_loss": valid_recon_loss,
                "val_kl_loss": valid_kl_loss
            }
            print_progress(i, logs, n_epochs)

        if save:
            self.update_kwargs()
            self.save(make_dir=True)
示例#9
0
    def _fit(self,
             adata,
             condition_key,
             train_size=0.8,
             cell_type_key='cell_type',
             n_epochs=100,
             batch_size=128,
             early_stop_limit=10,
             lr_reducer=8,
             n_per_epoch=0,
             score_filename=None,
             save=True,
             retrain=True,
             verbose=3):
        train_adata, valid_adata = train_test_split(adata, train_size)

        if self.gene_names is None:
            self.gene_names = train_adata.var_names.tolist()
        else:
            if set(self.gene_names).issubset(set(train_adata.var_names)):
                train_adata = train_adata[:, self.gene_names]
            else:
                raise Exception(
                    "set of gene names in train adata are inconsistent with class' gene_names"
                )

            if set(self.gene_names).issubset(set(valid_adata.var_names)):
                valid_adata = valid_adata[:, self.gene_names]
            else:
                raise Exception(
                    "set of gene names in valid adata are inconsistent with class' gene_names"
                )

        if self.loss_fn in ['nb', 'zinb']:
            train_raw_expr = train_adata.raw.X.A if sparse.issparse(
                train_adata.raw.X) else train_adata.raw.X
            valid_raw_expr = valid_adata.raw.X.A if sparse.issparse(
                valid_adata.raw.X) else valid_adata.raw.X

        train_expr = train_adata.X.A if sparse.issparse(
            train_adata.X) else train_adata.X
        valid_expr = valid_adata.X.A if sparse.issparse(
            valid_adata.X) else valid_adata.X

        train_conditions_encoded, self.condition_encoder = label_encoder(
            train_adata,
            le=self.condition_encoder,
            condition_key=condition_key)

        valid_conditions_encoded, self.condition_encoder = label_encoder(
            valid_adata,
            le=self.condition_encoder,
            condition_key=condition_key)

        if not retrain and os.path.exists(
                os.path.join(self.model_path, f"{self.model_name}.h5")):
            self.restore_model_weights()
            return

        callbacks = [
            History(),
        ]

        if verbose > 2:
            callbacks.append(
                LambdaCallback(on_epoch_end=lambda epoch, logs: print_progress(
                    epoch, logs, n_epochs)))
            fit_verbose = 0
        else:
            fit_verbose = verbose

        if (n_per_epoch > 0 or n_per_epoch == -1) and not score_filename:
            adata = train_adata.concatenate(valid_adata)

            train_celltypes_encoded, _ = label_encoder(
                train_adata, le=None, condition_key=cell_type_key)
            valid_celltypes_encoded, _ = label_encoder(
                valid_adata, le=None, condition_key=cell_type_key)
            celltype_labels = np.concatenate(
                [train_celltypes_encoded, valid_celltypes_encoded], axis=0)

            callbacks.append(
                ScoreCallback(score_filename,
                              adata,
                              condition_key,
                              cell_type_key,
                              self.cvae_model,
                              n_per_epoch=n_per_epoch,
                              n_batch_labels=self.n_conditions,
                              n_celltype_labels=len(
                                  np.unique(celltype_labels))))

        if early_stop_limit > 0:
            callbacks.append(
                EarlyStopping(patience=early_stop_limit, monitor='val_loss'))

        if lr_reducer > 0:
            callbacks.append(
                ReduceLROnPlateau(monitor='val_loss', patience=lr_reducer))

        train_conditions_onehot = to_categorical(train_conditions_encoded,
                                                 num_classes=self.n_conditions)
        valid_conditions_onehot = to_categorical(valid_conditions_encoded,
                                                 num_classes=self.n_conditions)

        x_train = [
            train_expr, train_conditions_onehot, train_conditions_onehot
        ]
        x_valid = [
            valid_expr, valid_conditions_onehot, valid_conditions_onehot
        ]

        if self.loss_fn in ['nb', 'zinb']:
            x_train.append(train_adata.obs[self.size_factor_key].values)
            y_train = train_raw_expr

            x_valid.append(valid_adata.obs[self.size_factor_key].values)
            y_valid = valid_raw_expr
        else:
            y_train = train_expr
            y_valid = valid_expr

        self.cvae_model.fit(
            x=x_train,
            y=y_train,
            validation_data=(x_valid, y_valid),
            epochs=n_epochs,
            batch_size=batch_size,
            verbose=fit_verbose,
            callbacks=callbacks,
        )
        if save:
            self.update_kwargs()
            self.save(make_dir=True)
示例#10
0
    def _fit(self, adata, condition_key, cell_type_key,
             train_size=0.8,
             n_epochs=25, batch_size=32,
             early_stop_limit=20, lr_reducer=10,
             n_per_epoch=0, score_filename=None,
             save=True, retrain=True, verbose=3):
        """
            Trains scNet with ``n_epochs`` times given ``train_adata``
            and validates the model using ``valid_adata``
            This function is using ``early stopping`` and ``learning rate reduce on plateau``
            techniques to prevent over-fitting.

            Parameters
            ----------
            adata: :class:`~anndata.AnnData`
                Annotated dataset used to train & evaluate scNet.
            condition_key: str
                column name for conditions in the `obs` matrix of `train_adata` and `valid_adata`.
            train_size: float
                fraction of samples used to train scNet.
            n_epochs: int
                number of epochs.
            batch_size: int
                number of samples in the mini-batches used to optimize scNet.
            early_stop_limit: int
                patience of EarlyStopping
            lr_reducer: int
                patience of LearningRateReduceOnPlateau.
            save: bool
                Whether to save scNet after the training or not.
            verbose: int
                Verbose level
            retrain: bool
                ``True`` by default. if ``True`` scNet will be trained regardless of existance of pre-trained scNet in ``model_path``. if ``False`` scNet will not be trained if pre-trained scNet exists in ``model_path``.

        """
        train_adata, valid_adata = train_test_split(adata, train_size)

        if self.gene_names is None:
            self.gene_names = train_adata.var_names.tolist()
        else:
            if set(self.gene_names).issubset(set(train_adata.var_names)):
                train_adata = train_adata[:, self.gene_names]
            else:
                raise Exception("set of gene names in train adata are inconsistent with scNet's gene_names")

            if set(self.gene_names).issubset(set(valid_adata.var_names)):
                valid_adata = valid_adata[:, self.gene_names]
            else:
                raise Exception("set of gene names in valid adata are inconsistent with scNet's gene_names")

        train_conditions_encoded, self.condition_encoder = label_encoder(train_adata, le=self.condition_encoder,
                                                                         condition_key=condition_key)

        valid_conditions_encoded, self.condition_encoder = label_encoder(valid_adata, le=self.condition_encoder,
                                                                         condition_key=condition_key)

        train_cell_types_encoded, encoder = label_encoder(train_adata, le=self.cell_type_encoder,
                                                          condition_key=cell_type_key)

        if self.cell_type_encoder is None:
            self.cell_type_encoder = encoder

        valid_cell_types_encoded, self.cell_type_encoder = label_encoder(valid_adata, le=self.cell_type_encoder,
                                                                         condition_key=cell_type_key)

        if not retrain and self.restore_model_weights():
            return

        train_conditions_onehot = to_categorical(train_conditions_encoded, num_classes=self.n_conditions)
        valid_conditions_onehot = to_categorical(valid_conditions_encoded, num_classes=self.n_conditions)

        train_cell_types_onehot = to_categorical(train_cell_types_encoded, num_classes=self.n_classes)
        valid_cell_types_onehot = to_categorical(valid_cell_types_encoded, num_classes=self.n_classes)

        if self.loss_fn in ['nb', 'zinb']:
            train_raw_expr = train_adata.raw.X.A if sparse.issparse(train_adata.raw.X) else train_adata.raw.X
            valid_raw_expr = valid_adata.raw.X.A if sparse.issparse(valid_adata.raw.X) else valid_adata.raw.X

        train_expr = train_adata.X.A if sparse.issparse(train_adata.X) else train_adata.X
        valid_expr = valid_adata.X.A if sparse.issparse(valid_adata.X) else valid_adata.X

        x_train = [train_expr, train_conditions_onehot, train_conditions_onehot]
        y_train = [train_expr, train_conditions_encoded, train_cell_types_onehot]

        x_valid = [valid_expr, valid_conditions_onehot, valid_conditions_onehot]
        y_valid = [valid_expr, valid_conditions_encoded, valid_cell_types_onehot]

        callbacks = [
            History(),
        ]

        if verbose > 2:
            callbacks.append(
                LambdaCallback(on_epoch_end=lambda epoch, logs: print_progress(epoch, logs, n_epochs)))
            fit_verbose = 0
        else:
            fit_verbose = verbose

        if (n_per_epoch > 0 or n_per_epoch == -1) and not score_filename:
            adata = train_adata.concatenate(valid_adata)

            celltype_labels = np.concatenate([train_cell_types_encoded, valid_cell_types_encoded], axis=0)

            callbacks.append(ScoreCallback(score_filename, adata, condition_key, cell_type_key, self.cvae_model,
                                           n_per_epoch=n_per_epoch, n_batch_labels=self.n_conditions,
                                           n_celltype_labels=len(np.unique(celltype_labels))))

        if early_stop_limit > 0:
            callbacks.append(EarlyStopping(patience=early_stop_limit, monitor='val_loss'))

        if lr_reducer > 0:
            callbacks.append(ReduceLROnPlateau(monitor='val_loss', patience=lr_reducer))

        self.cvae_model.fit(x=x_train,
                            y=y_train,
                            validation_data=(x_valid, y_valid),
                            epochs=n_epochs,
                            batch_size=batch_size,
                            verbose=fit_verbose,
                            callbacks=callbacks,
                            )
        if save:
            self.update_kwargs()
            self.save(make_dir=True)
示例#11
0
def make_dataset(adata,
                 condition_key,
                 le,
                 batch_size,
                 n_epochs,
                 is_training,
                 loss_fn,
                 n_conditions,
                 size_factor_key=None,
                 use_mmd=False):
    if sparse.issparse(adata.X):
        expressions = adata.X.A
    else:
        expressions = adata.X

    encoded_conditions, le = label_encoder(adata, le, condition_key)
    if loss_fn == 'nb':
        if sparse.issparse(adata.raw.X):
            raw_expressions = adata.raw.X.A
        else:
            raw_expressions = adata.raw.X
        dataset = tf.data.Dataset.from_tensor_slices(({
            "expression":
            expressions,
            "encoder_label":
            encoded_conditions,
            "decoder_label":
            encoded_conditions,
            "size_factor":
            adata.obs[size_factor_key].values
        }, {
            "reconstruction":
            raw_expressions
        }))
    elif loss_fn == 'zinb':
        if sparse.issparse(adata.raw.X):
            raw_expressions = adata.raw.X.A
        else:
            raw_expressions = adata.raw.X
        dataset = tf.data.Dataset.from_tensor_slices(({
            "expression":
            expressions,
            "encoder_label":
            encoded_conditions,
            "decoder_label":
            encoded_conditions,
            'size_factor':
            adata.obs[size_factor_key].values
        }, {
            "reconstruction":
            raw_expressions
        }))
    else:
        if use_mmd:
            dataset = tf.data.Dataset.from_tensor_slices(({
                "expression":
                expressions,
                "encoder_label":
                encoded_conditions,
                "decoder_label":
                encoded_conditions
            }, {
                "reconstruction":
                expressions,
                "mmd":
                encoded_conditions
            }))
        else:
            dataset = tf.data.Dataset.from_tensor_slices(({
                "expression":
                expressions,
                "encoder_label":
                encoded_conditions,
                "decoder_label":
                encoded_conditions
            }, {
                "reconstruction":
                expressions
            }))
    if is_training:
        dataset = dataset.shuffle(1000)
    dataset = dataset.map(preprocess_cvae_input(n_conditions),
                          num_parallel_calls=4,
                          deterministic=None)
    if is_training:
        dataset = dataset.repeat(n_epochs)
    else:
        dataset = dataset.repeat()

    dataset = dataset.batch(batch_size,
                            drop_remainder=True if is_training else False)
    dataset = dataset.prefetch(buffer_size=5)

    return dataset, le