A neural network approach that jointly learns a survival model, which predicts time-to-event outcomes, and a topic model, which captures how features relate. We tested this approach on seven healthcare datasets.
Besides our proposed approach, the survival models listed below are used to establish baselines. Each model is linked to its implementation script. These models may take in data in different formats, as documented in the Data Format column. The next section explains data format in detail.
Model | Descriptions | Type | Data Format |
---|---|---|---|
coxph | Cox regression with lasso regularization | baseline | cox |
coxph_pca | Lasso-regularized cox preceded by PCA | baseline | original |
coxph_unreg | Unregualrized cox regression | baseline | cox |
knnkm | KNN-Kaplan-Meier | baseline | original |
knnkm_pca | KNN-Kaplan-Meier preceded by PCA | baseline | original |
weibull | Weibull regression | baseline | original |
weibull_pca | Weibull regression preceded by PCA | baseline | original |
rsf | Survival random forest | baseline | original |
deepsurv | DeepSurv | baseline | original |
deephit | DeepHit | baseline | original |
lda_cox | Cox regression preceded by LDA | topic | discretize |
survscholar_linear | Supervised cox regression preceded by scholar | topic | discretize |
survscholar_nonlinear | Supervised cox regression preceded by scholar, with nonlinear survival layers | topic | discretize |
We performed random hyperparameter search for all models within each of the search spaces specified in this table; note that for different datasets, our search spaces vary. For each model, we selected the set of hyperparameters that achieved the best cross-validation time-dependent concordance index, as described in the paper.
In the paper, we described how we selected the number of topics for each dataset by looking at training cross-validation c-index vs number of topics. Refer to this figure for a plot of all datasets' training cross-validation c-index vs number of topics.
A list of supported datasets, data preprocessing scripts and details are documented on this page.
Topics learned by our proposed approach are either visualized in heatmaps or listed as top words per topic by convention. Below are examples on the SUPPORT-3 dataset. For all other datasets' outputs, go to here. Refer to the paper for how these outputs should be interpreted.
Follow this section to replicate results in the paper.
Package requirements could be found here. You could set up the required environment in the command line:
>>> python3 -m venv env_npsurvival
>>> source env_npsurvival/bin/activate
>>> pip3 install -r Survival2019/requirements.txt
To run an experiment:
git clone
this repo to a local directory.- Make sure all required packages are installed (see section Required packages).
cd
into the repo directory, replace thedataset/
folder with one that actually contains the data. Data is omitted in this repo because some of our datasets require applying for access.- Make sure hyperparameter search boxes are configured in a
json
file underconfigurations/
. You could find plenty of examples here. - Modify experiment settings in the bash script
run_experiments.sh
, and typesh run_experiments.sh
in the command line. - This will kick off the experiment. Be sure to name experiments properly using the
experiment_id
option, and note that rerunning using the sameexperiment_id
will erase saved outputs from the last experiment with the sameexperiment_id
.
Follow this demo to see how experiments are configured.
- For all experiments, hyperparameter search configuration should be specificied using a
json
file underconfigurations/
. Thejson
file's naming convention should followdataset-model-suffix_identifier.json
. For this demo, we useMETABRIC-survscholar_linear-demo.json
.
{"params": {"n_topics": [1, 10], "survival_loss_weight": [0, 5], "batch_size": [32, 1024]},
"random": {"n_probes": 5}}
By such configuration:
- We search three hyperparameters' values:
n_topics
,survival_loss_weight
,batch_size
- For
n_topics
, we search over the range: [1, 10] - For
survival_loss_weight
, we search over the range: [10^0, 10^5]. The exponentiation is done within the model's implementation, meaning that if the model takes insurvival_loss_weight
as 5, it convertssurvival_loss_weight
into 10^5. This serves as an example that the user should always check the code and make sure to understand how configurations are set. - For
batch_size
, we search over the range: [32, 1024] - We use random sweeping, with only 5 random attemps. (We only try 5 different hyperparameter combinations within the specified ranges.)
- Modify settings in the bash script to specify which dataset and model to use, name the experiment, and specify whether a previously trained model should be loaded etc. Details are documented in the bash script
run_experiments.sh
. For this demo:
dataset=METABRIC
model=survscholar_linear
n_outer_iter=5
tuning_scheme=random
tuning_config=demo # this will locate the configuration json file to be METABRIC-survscholar_linear-demo.json
log_dir=logs # directory where experiment outputs are saved
experiment_id=bootstrap_predictions_demo_explain
saved_experiment_id=None
saved_model_path=saved_models
readme_msg=EnterAnyMessageHere
preset_dir=None
mkdir -p ${log_dir}/${dataset}/${model}/${experiment_id}/${saved_model_path}
python3 experiments.py ${dataset} ${model} ${n_outer_iter} ${tuning_scheme} ${tuning_config} ${experiment_id} ${saved_experiment_id} ${readme_msg} ${preset_dir} --log_dir ${log_dir}
Experiment outputs will be saved to ${log_dir}/${dataset}/${model}/${experiment_id}/
. For this demo, this evaluates to logs/METABRIC/survscholar_linear/bootstrap_predictions_demo_explain/
.
As documented in the experiment transcript here, using only 5 random hyperparameter combinations, we get mean bootstrapped time-dependent c-index 0.66058302 on the test set. (95% confidence interval [0.62127882, 0.70199634]).
For SurvScholar, this notebook demonstrates how to obtain the all-topic heatmaps, single-topic heatmaps, and per topic top-words printouts. Running visualization requires you to specify a directory that contains the saved model outputs, which is usually ${log_dir}/${dataset}/${model}/${experiment_id}/
. In the notebook, we used an experiment on the SUPPORT_Cancer dataset, whose experiment_id
is bootstrap_predictions_3_explain
.