示例#1
0
    def _train_on_batch(self,
                        adata,
                        condition_key,
                        train_size=0.8,
                        cell_type_key='cell_type',
                        n_epochs=100,
                        batch_size=128,
                        early_stop_limit=10,
                        lr_reducer=8,
                        n_per_epoch=0,
                        score_filename=None,
                        save=True,
                        retrain=True,
                        verbose=3):
        train_adata, valid_adata = train_test_split(adata, train_size)

        if self.gene_names is None:
            self.gene_names = train_adata.var_names.tolist()
        else:
            if set(self.gene_names).issubset(set(train_adata.var_names)):
                train_adata = train_adata[:, self.gene_names]
            else:
                raise Exception(
                    "set of gene names in train adata are inconsistent with class' gene_names"
                )

            if set(self.gene_names).issubset(set(valid_adata.var_names)):
                valid_adata = valid_adata[:, self.gene_names]
            else:
                raise Exception(
                    "set of gene names in valid adata are inconsistent with class' gene_names"
                )

        train_conditions_encoded, self.condition_encoder = label_encoder(
            train_adata,
            le=self.condition_encoder,
            condition_key=condition_key)

        valid_conditions_encoded, self.condition_encoder = label_encoder(
            valid_adata,
            le=self.condition_encoder,
            condition_key=condition_key)

        if not retrain and os.path.exists(
                os.path.join(self.model_path, f"{self.model_name}.h5")):
            self.restore_model_weights()
            return

        train_conditions_onehot = to_categorical(train_conditions_encoded,
                                                 num_classes=self.n_conditions)
        valid_conditions_onehot = to_categorical(valid_conditions_encoded,
                                                 num_classes=self.n_conditions)

        if sparse.issparse(train_adata.X):
            is_sparse = True
        else:
            is_sparse = False

        train_expr = train_adata.X
        valid_expr = valid_adata.X.A if is_sparse else valid_adata.X
        x_valid = [
            valid_expr, valid_conditions_onehot, valid_conditions_onehot
        ]

        if self.loss_fn in ['nb', 'zinb']:
            x_valid.append(valid_adata.obs[self.size_factor_key].values)
            y_valid = valid_adata.raw.X.A if sparse.issparse(
                valid_adata.raw.X) else valid_adata.raw.X
        else:
            y_valid = valid_expr

        es_patience, best_val_loss = 0, 1e10
        for i in range(n_epochs):
            train_loss = train_recon_loss = train_kl_loss = 0.0
            for j in range(min(500, train_adata.shape[0] // batch_size)):
                batch_indices = np.random.choice(train_adata.shape[0],
                                                 batch_size)

                batch_expr = train_expr[
                    batch_indices, :].A if is_sparse else train_expr[
                        batch_indices, :]

                x_train = [
                    batch_expr, train_conditions_onehot[batch_indices],
                    train_conditions_onehot[batch_indices]
                ]

                if self.loss_fn in ['nb', 'zinb']:
                    x_train.append(train_adata.obs[
                        self.size_factor_key].values[batch_indices])
                    y_train = train_adata.raw.X[
                        batch_indices].A if sparse.issparse(
                            train_adata.raw.X[batch_indices]
                        ) else train_adata.raw.X[batch_indices]
                else:
                    y_train = batch_expr

                batch_loss, batch_recon_loss, batch_kl_loss = self.cvae_model.train_on_batch(
                    x_train, y_train)

                train_loss += batch_loss / batch_size
                train_recon_loss += batch_recon_loss / batch_size
                train_kl_loss += batch_kl_loss / batch_size

            valid_loss, valid_recon_loss, valid_kl_loss = self.cvae_model.evaluate(
                x_valid, y_valid, verbose=0)

            if valid_loss < best_val_loss:
                best_val_loss = valid_loss
                es_patience = 0
            else:
                es_patience += 1
                if es_patience == early_stop_limit:
                    print("Training stopped with Early Stopping")
                    break

            logs = {
                "loss": train_loss,
                "recon_loss": train_recon_loss,
                "kl_loss": train_kl_loss,
                "val_loss": valid_loss,
                "val_recon_loss": valid_recon_loss,
                "val_kl_loss": valid_kl_loss
            }
            print_progress(i, logs, n_epochs)

        if save:
            self.update_kwargs()
            self.save(make_dir=True)
示例#2
0
    def _fit(self, adata, condition_key, cell_type_key,
             train_size=0.8,
             n_epochs=25, batch_size=32,
             early_stop_limit=20, lr_reducer=10,
             n_per_epoch=0, score_filename=None,
             save=True, retrain=True, verbose=3):
        """
            Trains scNet with ``n_epochs`` times given ``train_adata``
            and validates the model using ``valid_adata``
            This function is using ``early stopping`` and ``learning rate reduce on plateau``
            techniques to prevent over-fitting.

            Parameters
            ----------
            adata: :class:`~anndata.AnnData`
                Annotated dataset used to train & evaluate scNet.
            condition_key: str
                column name for conditions in the `obs` matrix of `train_adata` and `valid_adata`.
            train_size: float
                fraction of samples used to train scNet.
            n_epochs: int
                number of epochs.
            batch_size: int
                number of samples in the mini-batches used to optimize scNet.
            early_stop_limit: int
                patience of EarlyStopping
            lr_reducer: int
                patience of LearningRateReduceOnPlateau.
            save: bool
                Whether to save scNet after the training or not.
            verbose: int
                Verbose level
            retrain: bool
                ``True`` by default. if ``True`` scNet will be trained regardless of existance of pre-trained scNet in ``model_path``. if ``False`` scNet will not be trained if pre-trained scNet exists in ``model_path``.

        """
        train_adata, valid_adata = train_test_split(adata, train_size)

        if self.gene_names is None:
            self.gene_names = train_adata.var_names.tolist()
        else:
            if set(self.gene_names).issubset(set(train_adata.var_names)):
                train_adata = train_adata[:, self.gene_names]
            else:
                raise Exception("set of gene names in train adata are inconsistent with scNet's gene_names")

            if set(self.gene_names).issubset(set(valid_adata.var_names)):
                valid_adata = valid_adata[:, self.gene_names]
            else:
                raise Exception("set of gene names in valid adata are inconsistent with scNet's gene_names")

        train_conditions_encoded, self.condition_encoder = label_encoder(train_adata, le=self.condition_encoder,
                                                                         condition_key=condition_key)

        valid_conditions_encoded, self.condition_encoder = label_encoder(valid_adata, le=self.condition_encoder,
                                                                         condition_key=condition_key)

        train_cell_types_encoded, encoder = label_encoder(train_adata, le=self.cell_type_encoder,
                                                          condition_key=cell_type_key)

        if self.cell_type_encoder is None:
            self.cell_type_encoder = encoder

        valid_cell_types_encoded, self.cell_type_encoder = label_encoder(valid_adata, le=self.cell_type_encoder,
                                                                         condition_key=cell_type_key)

        if not retrain and self.restore_model_weights():
            return

        train_conditions_onehot = to_categorical(train_conditions_encoded, num_classes=self.n_conditions)
        valid_conditions_onehot = to_categorical(valid_conditions_encoded, num_classes=self.n_conditions)

        train_cell_types_onehot = to_categorical(train_cell_types_encoded, num_classes=self.n_classes)
        valid_cell_types_onehot = to_categorical(valid_cell_types_encoded, num_classes=self.n_classes)

        if self.loss_fn in ['nb', 'zinb']:
            train_raw_expr = train_adata.raw.X.A if sparse.issparse(train_adata.raw.X) else train_adata.raw.X
            valid_raw_expr = valid_adata.raw.X.A if sparse.issparse(valid_adata.raw.X) else valid_adata.raw.X

        train_expr = train_adata.X.A if sparse.issparse(train_adata.X) else train_adata.X
        valid_expr = valid_adata.X.A if sparse.issparse(valid_adata.X) else valid_adata.X

        x_train = [train_expr, train_conditions_onehot, train_conditions_onehot]
        y_train = [train_expr, train_conditions_encoded, train_cell_types_onehot]

        x_valid = [valid_expr, valid_conditions_onehot, valid_conditions_onehot]
        y_valid = [valid_expr, valid_conditions_encoded, valid_cell_types_onehot]

        callbacks = [
            History(),
        ]

        if verbose > 2:
            callbacks.append(
                LambdaCallback(on_epoch_end=lambda epoch, logs: print_progress(epoch, logs, n_epochs)))
            fit_verbose = 0
        else:
            fit_verbose = verbose

        if (n_per_epoch > 0 or n_per_epoch == -1) and not score_filename:
            adata = train_adata.concatenate(valid_adata)

            celltype_labels = np.concatenate([train_cell_types_encoded, valid_cell_types_encoded], axis=0)

            callbacks.append(ScoreCallback(score_filename, adata, condition_key, cell_type_key, self.cvae_model,
                                           n_per_epoch=n_per_epoch, n_batch_labels=self.n_conditions,
                                           n_celltype_labels=len(np.unique(celltype_labels))))

        if early_stop_limit > 0:
            callbacks.append(EarlyStopping(patience=early_stop_limit, monitor='val_loss'))

        if lr_reducer > 0:
            callbacks.append(ReduceLROnPlateau(monitor='val_loss', patience=lr_reducer))

        self.cvae_model.fit(x=x_train,
                            y=y_train,
                            validation_data=(x_valid, y_valid),
                            epochs=n_epochs,
                            batch_size=batch_size,
                            verbose=fit_verbose,
                            callbacks=callbacks,
                            )
        if save:
            self.update_kwargs()
            self.save(make_dir=True)
示例#3
0
    def _fit(self,
             adata,
             condition_key,
             train_size=0.8,
             cell_type_key='cell_type',
             n_epochs=100,
             batch_size=128,
             early_stop_limit=10,
             lr_reducer=8,
             n_per_epoch=0,
             score_filename=None,
             save=True,
             retrain=True,
             verbose=3):
        train_adata, valid_adata = train_test_split(adata, train_size)

        if self.gene_names is None:
            self.gene_names = train_adata.var_names.tolist()
        else:
            if set(self.gene_names).issubset(set(train_adata.var_names)):
                train_adata = train_adata[:, self.gene_names]
            else:
                raise Exception(
                    "set of gene names in train adata are inconsistent with class' gene_names"
                )

            if set(self.gene_names).issubset(set(valid_adata.var_names)):
                valid_adata = valid_adata[:, self.gene_names]
            else:
                raise Exception(
                    "set of gene names in valid adata are inconsistent with class' gene_names"
                )

        if self.loss_fn in ['nb', 'zinb']:
            train_raw_expr = train_adata.raw.X.A if sparse.issparse(
                train_adata.raw.X) else train_adata.raw.X
            valid_raw_expr = valid_adata.raw.X.A if sparse.issparse(
                valid_adata.raw.X) else valid_adata.raw.X

        train_expr = train_adata.X.A if sparse.issparse(
            train_adata.X) else train_adata.X
        valid_expr = valid_adata.X.A if sparse.issparse(
            valid_adata.X) else valid_adata.X

        train_conditions_encoded, self.condition_encoder = label_encoder(
            train_adata,
            le=self.condition_encoder,
            condition_key=condition_key)

        valid_conditions_encoded, self.condition_encoder = label_encoder(
            valid_adata,
            le=self.condition_encoder,
            condition_key=condition_key)

        if not retrain and os.path.exists(
                os.path.join(self.model_path, f"{self.model_name}.h5")):
            self.restore_model_weights()
            return

        callbacks = [
            History(),
        ]

        if verbose > 2:
            callbacks.append(
                LambdaCallback(on_epoch_end=lambda epoch, logs: print_progress(
                    epoch, logs, n_epochs)))
            fit_verbose = 0
        else:
            fit_verbose = verbose

        if (n_per_epoch > 0 or n_per_epoch == -1) and not score_filename:
            adata = train_adata.concatenate(valid_adata)

            train_celltypes_encoded, _ = label_encoder(
                train_adata, le=None, condition_key=cell_type_key)
            valid_celltypes_encoded, _ = label_encoder(
                valid_adata, le=None, condition_key=cell_type_key)
            celltype_labels = np.concatenate(
                [train_celltypes_encoded, valid_celltypes_encoded], axis=0)

            callbacks.append(
                ScoreCallback(score_filename,
                              adata,
                              condition_key,
                              cell_type_key,
                              self.cvae_model,
                              n_per_epoch=n_per_epoch,
                              n_batch_labels=self.n_conditions,
                              n_celltype_labels=len(
                                  np.unique(celltype_labels))))

        if early_stop_limit > 0:
            callbacks.append(
                EarlyStopping(patience=early_stop_limit, monitor='val_loss'))

        if lr_reducer > 0:
            callbacks.append(
                ReduceLROnPlateau(monitor='val_loss', patience=lr_reducer))

        train_conditions_onehot = to_categorical(train_conditions_encoded,
                                                 num_classes=self.n_conditions)
        valid_conditions_onehot = to_categorical(valid_conditions_encoded,
                                                 num_classes=self.n_conditions)

        x_train = [
            train_expr, train_conditions_onehot, train_conditions_onehot
        ]
        x_valid = [
            valid_expr, valid_conditions_onehot, valid_conditions_onehot
        ]

        if self.loss_fn in ['nb', 'zinb']:
            x_train.append(train_adata.obs[self.size_factor_key].values)
            y_train = train_raw_expr

            x_valid.append(valid_adata.obs[self.size_factor_key].values)
            y_valid = valid_raw_expr
        else:
            y_train = train_expr
            y_valid = valid_expr

        self.cvae_model.fit(
            x=x_train,
            y=y_train,
            validation_data=(x_valid, y_valid),
            epochs=n_epochs,
            batch_size=batch_size,
            verbose=fit_verbose,
            callbacks=callbacks,
        )
        if save:
            self.update_kwargs()
            self.save(make_dir=True)
示例#4
0
    def _fit_dataset(self,
                     adata,
                     condition_key,
                     train_size=0.8,
                     n_epochs=100,
                     batch_size=128,
                     steps_per_epoch=100,
                     early_stop_limit=10,
                     lr_reducer=8,
                     save=True,
                     retrain=True,
                     verbose=3):
        train_adata, valid_adata = train_test_split(adata, train_size)

        if self.gene_names is None:
            self.gene_names = train_adata.var_names.tolist()
        else:
            if set(self.gene_names).issubset(set(train_adata.var_names)):
                train_adata = train_adata[:, self.gene_names]
            else:
                raise Exception(
                    "set of gene names in train adata are inconsistent with class' gene_names"
                )

            if set(self.gene_names).issubset(set(valid_adata.var_names)):
                valid_adata = valid_adata[:, self.gene_names]
            else:
                raise Exception(
                    "set of gene names in valid adata are inconsistent with class' gene_names"
                )

        if not retrain and os.path.exists(
                os.path.join(self.model_path, f"{self.model_name}.h5")):
            self.restore_model_weights()
            self.restore_class_config(compile_and_consturct=False)
            return

        callbacks = [
            History(),
        ]

        if verbose > 2:
            callbacks.append(
                LambdaCallback(on_epoch_end=lambda epoch, logs: print_progress(
                    epoch, logs, n_epochs)))
            fit_verbose = 0
        else:
            fit_verbose = verbose

        if early_stop_limit > 0:
            callbacks.append(
                EarlyStopping(patience=early_stop_limit, monitor='val_loss'))

        if lr_reducer > 0:
            callbacks.append(
                ReduceLROnPlateau(monitor='val_loss', patience=lr_reducer))

        train_dataset, self.condition_encoder = make_dataset(
            train_adata,
            condition_key,
            self.condition_encoder,
            batch_size,
            n_epochs,
            is_training=True,
            loss_fn=self.loss_fn,
            n_conditions=self.n_conditions)
        valid_dataset, _ = make_dataset(valid_adata,
                                        condition_key,
                                        self.condition_encoder,
                                        valid_adata.shape[0],
                                        n_epochs,
                                        is_training=False,
                                        loss_fn=self.loss_fn,
                                        n_conditions=self.n_conditions)

        self.log_history = self.fit(
            train_dataset,
            validation_data=valid_dataset,
            epochs=n_epochs,
            batch_size=batch_size,
            verbose=fit_verbose,
            callbacks=callbacks,
            steps_per_epoch=steps_per_epoch,
            validation_steps=1,
        )

        if save:
            self.update_kwargs()
            self.save(make_dir=True)