Ejemplo n.º 1
0
                        out_survival = utils.predict_survival_exponient(
                            model, X.to(device), times)
                    elif arg.distribution == "weibull":
                        out_survival = utils.predict_survival_weibull(
                            model, X.to(device), times)
                    elif arg.distribution == "combine":
                        out_survival = utils.predict_survival_multiple_distributions(
                            model, X.to(device), times)
                    surv = pd.DataFrame(out_survival, index=times)
                    durations_test, events_test = T.detach().numpy(), E.detach(
                    ).numpy()
                    ev = EvalSurv(surv,
                                  durations_test,
                                  events_test,
                                  censor_surv='km')
                    c_index_train = ev.concordance_td()

            ##### valid
            model.eval()
            batch_size = 8000
            with torch.no_grad():
                for X, T, E in dataloader.load_data(arg.dataname,
                                                    arg.path_clinical_val,
                                                    'test'):
                    T = T / ratio

                    #pdf = misc.cal_pdf(T)

                    if arg.distribution == "lognormal":
                        out_survival = utils.predict_survival_lognormal(
                            model, X.to(device), times)
Ejemplo n.º 2
0
                            surv_model.interpolate(10).predict_surv_df(
                                fold_X_val_std)
                    else:
                        surv_df = surv_model.predict_surv_df(fold_X_val_std)
                    ev = EvalSurv(surv_df,
                                  fold_y_val[:, 0],
                                  fold_y_val[:, 1],
                                  censor_surv='km')

                    sorted_fold_y_val = np.sort(np.unique(fold_y_val[:, 0]))
                    time_grid = np.linspace(sorted_fold_y_val[0],
                                            sorted_fold_y_val[-1], 100)

                    surv = surv_df.to_numpy().T

                    cindex_scores.append(ev.concordance_td('antolini'))
                    integrated_brier_scores.append(
                        ev.integrated_brier_score(time_grid))
                    print('  c-index (td):', cindex_scores[-1])
                    print('  Integrated Brier score:',
                          integrated_brier_scores[-1])

                cross_val_cindex = np.mean(cindex_scores)
                cross_val_integrated_brier = np.mean(integrated_brier_scores)
                print(hyperparam,
                      ':',
                      cross_val_cindex,
                      cross_val_integrated_brier,
                      flush=True)
                print(hyperparam,
                      ':',
Ejemplo n.º 3
0
    def _survLDA_train_variational_EM(self, word_count_matrix,
                                      survival_or_censoring_times,
                                      censoring_indicator_variables,
                                      val_word_count_matrix,
                                      val_survival_or_censoring_times,
                                      val_censoring_indicator_variables):
        """
        Input:
        - K: number of topics
        - alpha0: hyperparameter for Dirichlet prior

        Output:
        - tau: 2D numpy array; g-th row is for g-th topic and consists of word
            distribution for that topic
        - beta: 1D numpy array with length equal to the number of topics
            (Cox regression coefficients)
        - h0_reformatted: 2D numpy array; first row is sorted unique times and
            second row is the discretized version of log(h0)
        - gamma: 2D numpy array: i-th row is variational parameter for Dirichlet
            distribution for i-th subject (length is number of topics)
        - phi: list of length given by number of subjects; i-th element is a
            2D numpy array with number of rows given by the number of words in
            i-th subject's document and number of columns given by number of
            topics (variational parameter for how much each word of each subject
            belongs to different topics)
        - rmse: estimated median survival time RMSE computed using validation data
        - mae: estimated median survival time MAE computed using validation data
        - cindex: concordance index computed using validation data
        - stop_iter: interation number after stopping
        """
        # *********************************************************************
        # 1. SET UP TOPIC MODEL PARAMETERS BASED ON TRAINING DATA

        self.word_count_matrix = word_count_matrix
        num_subjects, num_words = self.word_count_matrix.shape  # train size and vocabulary size
        # the length of each patient's document (i.e. sum of all words including duplicates)
        doc_length = self.word_count_matrix.sum(axis=1).astype(np.int64)

        # tau tells us what the word distribution is per topic
        # - each row corresponds to a topic
        # - each row is a probability distribution, which we initialize to be a
        #   uniform distribution across the words (so each entry is 1 divided by `num_words`
        if self.random_init_tau:
            self.tau = np.random.rand(self.n_topics, num_words)
            self.tau /= self.tau.sum(axis=1)[:, np.newaxis]  # normalize tau
        else:
            self.tau = np.full((self.n_topics, num_words),
                               1. / num_words,
                               dtype=np.float64)

        # variational distribution parameter gamma tells us the Dirichlet
        # distribution parameter for the topic distribution specific to each subject;
        # we can initialize this to be all ones corresponding to a uniform distribution prior
        self.gamma = np.ones((num_subjects, self.n_topics), dtype=np.float64)

        # variational distribution parameter phi tells us the probabilities of
        # each subject's words coming from each of the K different topics; we can
        # initialize these to be uniform over topics (1/K)
        self.W, self.phi = self._word_count_matrix_to_word_vectors_phi(
            word_count_matrix)
        val_W, _ = self._word_count_matrix_to_word_vectors_phi(
            val_word_count_matrix)

        # *********************************************************************
        # 2. SET UP COX BASELINE HAZARD FUNCTION

        # Cox baseline hazard function h0 can be represented as a finite 1D vector
        death_counter = {}
        for t in survival_or_censoring_times:
            if t not in death_counter:
                death_counter[t] = 1
            else:
                death_counter[t] += 1
        sorted_unique_times = np.sort(list(death_counter.keys()))
        self.sorted_unique_times = sorted_unique_times

        num_unique_times = len(sorted_unique_times)
        self.log_h0_discretized = np.zeros(num_unique_times)
        for r, t in enumerate(sorted_unique_times):
            self.log_h0_discretized[r] = np.log(death_counter[t])

        log_H0 = []
        for r in range(num_unique_times):
            log_H0.append(logsumexp(self.log_h0_discretized[:(r + 1)]))
        log_H0 = np.array(log_H0)

        time_map = {t: r for r, t in enumerate(sorted_unique_times)}
        time_order = np.array(
            [time_map[t] for t in survival_or_censoring_times], dtype=np.int64)

        # *********************************************************************
        # 3. EM MAIN LOOP

        pool = multiprocessing.Pool(4)
        stop_iter = 0
        val_censoring_indicator_variables = val_censoring_indicator_variables.astype(
            np.bool)
        for EM_iter_idx in range(self.max_iter):
            # ------------------------------------------------------------------
            # E-step (update gamma, phi; uses helper variables psi, xi)
            if self.verbose:
                print('[Variational EM iteration %d]' % (EM_iter_idx + 1))
                print('  Running E-step...', end='', flush=True)
            tic = time()

            # update gamma (dimensions: `num_subjects` by `K`)
            for i, phi_i in enumerate(self.phi):
                self.gamma[i] = self.alpha + phi_i.sum(axis=0)

            # compute psi (dimensions: `num_subjects` by `K`)
            psi = digamma(self.gamma) - digamma(
                self.gamma.sum(axis=1))[:, np.newaxis]
            # update phi, this normalizes already
            self.phi = pool.map(SurvLDA._compute_updated_phi_i_pmap_helper,
                                [(i, self.phi[i], psi[i], self.W[i], self.tau,
                                  self.beta, np.exp(log_H0[time_order[i]]),
                                  censoring_indicator_variables[i],
                                  doc_length[i], self.n_topics)
                                 for i in range(num_subjects)])

            toc = time()
            if self.verbose:
                print(' Done. Time elapsed: %f second(s).' % (toc - tic),
                      flush=True)

            # --------------------------------------------------------------------
            # M-step (update tau, beta, h0, H0; uses helper variable phi_bar)
            if self.verbose:
                print('  Running M-step...', end='', flush=True)
            tic = time()

            # update tau
            tau = np.zeros((self.n_topics, num_words), dtype=np.float64)
            for i in range(num_subjects):
                for j in range(doc_length[i]):
                    word = self.W[i][j]
                    for k in range(self.n_topics):
                        tau[k][word] += self.phi[i][j, k]
            # normalize tau
            self.tau /= tau.sum(axis=1)[:, np.newaxis]

            phi_bar = np.zeros((num_subjects, self.n_topics), dtype=np.float64)
            for i in range(num_subjects):
                phi_bar[i] = self.phi[i].sum(axis=0) / doc_length[i]

            y_r = np.vstack(
                (survival_or_censoring_times, censoring_indicator_variables)).T
            beta_phi_bar = np.dot(phi_bar, self.beta)
            log_h0_discretized_ = np.zeros(num_unique_times)
            log_H0_ = []
            for r, t in enumerate(sorted_unique_times):
                R_t = np.where(survival_or_censoring_times >= t, True, False)
                log_h0_discretized_[r] = np.log(death_counter[t]) - logsumexp(
                    beta_phi_bar[R_t])
                log_H0_.append(logsumexp(log_h0_discretized_[:(r + 1)]))

            # update beta
            def obj_fun(beta_):
                beta_phi_bar_ = np.dot(phi_bar, beta_)
                fun_val = np.dot(censoring_indicator_variables, beta_phi_bar_)
                # finally, we add in the third term
                for i in range(num_subjects):
                    product_terms = np.dot(self.phi[i],
                                           np.exp(beta_ / doc_length[i]))
                    fun_val -= np.exp(log_H0_[time_order[i]] +
                                      np.sum(np.log(product_terms)))
                return -fun_val

            self.beta = minimize(obj_fun, self.beta).x

            # compute dot product <beta, phi_bar[i]> for each patient i
            # (resulting in a 1D array of length `num_subjects`)
            beta_phi_bar = np.dot(phi_bar, self.beta)
            # print("phi bar is \n", phi_bar)

            # update h0
            for r, t in enumerate(sorted_unique_times):
                R_t = np.where(survival_or_censoring_times >= t, True, False)
                self.log_h0_discretized[r] = np.log(
                    death_counter[t]) - logsumexp(beta_phi_bar[R_t])
            log_H0 = []
            for r in range(num_unique_times):
                log_H0.append(logsumexp(self.log_h0_discretized[:(r + 1)]))
            log_H0 = np.array(log_H0)

            toc = time()
            if self.verbose:
                print(' Done. Time elapsed: %f second(s).' % (toc - tic),
                      flush=True)
                print('  Beta:', self.beta)

            # convergence criteria to decide whether to break out of the for loop early
            surv_estimates = []
            median_estimates = []
            for i in range(len(val_W)):
                preds = self._predict_survival(val_W[i])
                surv_estimates.append(preds[1])
                median_estimates.append(preds[0])
            surv_estimates = np.array(surv_estimates)
            median_estimates = np.array(median_estimates)

            val_ev = EvalSurv(pd.DataFrame(np.transpose(surv_estimates),
                                           index=sorted_unique_times),
                              val_survival_or_censoring_times,
                              val_censoring_indicator_variables,
                              censor_surv='km')
            val_cindex = val_ev.concordance_td('antolini')

            val_rmse = np.sqrt(np.mean((median_estimates[val_censoring_indicator_variables] - \
                            val_survival_or_censoring_times[val_censoring_indicator_variables])**2))
            val_mae = np.mean(np.abs(median_estimates[val_censoring_indicator_variables] - \
                            val_survival_or_censoring_times[val_censoring_indicator_variables]))

            if self.verbose:
                print('  Validation set survival time c-index: %f' %
                      val_cindex)
                print('  Validation set survival time rmse: %f' % val_rmse)
                print('  Validation set survival time mae: %f' % val_mae)

            stop_iter += 1

        pool.close()
        if self.verbose:
            print('  EM algorithm completes training at iteration ', stop_iter)

        return
Ejemplo n.º 4
0
            verbose = True
            log = model.fit(x_train,
                            y_train,
                            batch_size,
                            epochs,
                            callbacks,
                            verbose,
                            val_data=val.repeat(10).cat())
            # Evaluation ===================================================================

            _ = model.compute_baseline_hazards()
            surv = model.predict_surv_df(x_test)
            ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

            # ctd = ev.concordance_td()
            ctd = ev.concordance_td()
            time_grid = np.linspace(durations_test.min(), durations_test.max(),
                                    100)

            ibs = ev.integrated_brier_score(time_grid)
            nbll = ev.integrated_nbll(time_grid)
            val_loss = min(log.monitors['val_'].scores['loss']['score'])

            wandb.log({
                'val_loss': val_loss,
                'ctd': ctd,
                'ibs': ibs,
                'nbll': nbll
            })
            wandb.finish()
            fold_ctd.append(ctd)
Ejemplo n.º 5
0
    y_pred_train_surv = np.cumprod((1 - ypred_train_NN), axis=1)
    oneyr_surv_train = y_pred_train_surv[:, 50]
    oneyr_surv_valid = y_pred_valid_surv[:, 50]
    surv_valid = pd.DataFrame(np.transpose(y_pred_valid_surv))
    surv_valid.index = interval_l
    surv_train = pd.DataFrame(np.transpose(y_pred_train_surv))
    surv_train.index = interval_l
    dict_cv_cindex_train[key] = concordance_index(dataTrain.time,
                                                  oneyr_surv_train)
    dict_cv_cindex_valid[key] = concordance_index(dataValid.time,
                                                  oneyr_surv_valid)
    ev_valid = EvalSurv(surv_valid,
                        dataValid['time'].values,
                        dataValid['dead'].values,
                        censor_surv='km')
    scores_test += ev_valid.concordance_td()
    ev_train = EvalSurv(surv_train,
                        dataTrain['time'].values,
                        dataTrain['dead'].values,
                        censor_surv='km')
    scores_train += ev_train.concordance_td()
    cta.append(concordance_index(dataTrain.time, oneyr_surv_train))
    cte.append(concordance_index(dataValid.time, oneyr_surv_valid))
    ctda.append(ev_train.concordance_td())
    ctde.append(ev_valid.concordance_td())
    #scores_train += concordance_index(dataTrain.time,oneyr_surv_train)
    #scores_test += concordance_index(dataValid.time,oneyr_surv_valid)

save_loss_png = "loss_cv_" + str(h) + "_KIRC_vK.png"
import seaborn as sns
fs = 20
def main(data_root, cancer_type, anatomical_location, fold):

    # Import the RDF graph for PPI network
    f = open('seen.pkl', 'rb')
    seen = pickle.load(f)
    f.close()
    #####################

    f = open('ei.pkl', 'rb')
    ei = pickle.load(f)
    f.close()

    global device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')

    cancer_type_vector = np.zeros((33, ), dtype=np.float32)
    cancer_type_vector[cancer_type] = 1

    cancer_subtype_vector = np.zeros((25, ), dtype=np.float32)
    for i in CANCER_SUBTYPES[cancer_type]:
        cancer_subtype_vector[i] = 1

    anatomical_location_vector = np.zeros((52, ), dtype=np.float32)
    anatomical_location_vector[anatomical_location] = 1
    cell_type_vector = np.zeros((10, ), dtype=np.float32)
    cell_type_vector[CELL_TYPES[cancer_type]] = 1

    pt_tensor_cancer_type = torch.FloatTensor(cancer_type_vector).to(device)
    pt_tensor_cancer_subtype = torch.FloatTensor(cancer_subtype_vector).to(
        device)
    pt_tensor_anatomical_location = torch.FloatTensor(
        anatomical_location_vector).to(device)
    pt_tensor_cell_type = torch.FloatTensor(cell_type_vector).to(device)
    edge_index = torch.LongTensor(ei).to(device)

    # Import a dictionary that maps protiens to their coresponding genes by Ensembl database
    f = open('ens_dic.pkl', 'rb')
    dicty = pickle.load(f)
    f.close()
    dic = {}
    for d in dicty:
        key = dicty[d]
        if key not in dic:
            dic[key] = {}
        dic[key][d] = 1

    # Build a dictionary from ENSG -- ENST
    d = {}
    with open('data1/prot_names1.txt') as f:
        for line in f:
            tok = line.split()
            d[tok[1]] = tok[0]

    clin = [
    ]  # for clinical data (i.e. number of days to survive, days to death for dead patients and days to last followup for alive patients)
    feat_vecs = [
    ]  # list of lists ([[patient1],[patient2],.....[patientN]]) -- [patientX] = [gene_expression_value, diff_gene_expression_value, methylation_value, diff_methylation_value, VCF_value, CNV_value]
    suv_time = [
    ]  # list that include wheather a patient is alive or dead (i.e. 0 for dead and 1 for alive)
    can_types = ["BRCA_v2"]
    data_root = '/ibex/scratch/projects/c2014/sara/'
    for i in range(len(can_types)):
        # file that contain patients ID with their coressponding 6 differnt files names (i.e. files names for gene_expression, diff_gene_expression, methylation, diff_methylation, VCF and CNV)
        f = open(data_root + can_types[i] + '.txt')
        lines = f.read().splitlines()
        f.close()
        lines = lines[1:]
        count = 0
        feat_vecs = np.zeros((len(lines), 17186 * 6), dtype=np.float32)
        i = 0
        for l in tqdm(lines):
            l = l.split('\t')
            clinical_file = l[6]
            surv_file = l[2]
            myth_file = 'myth/' + l[3]
            diff_myth_file = 'diff_myth/' + l[1]
            exp_norm_file = 'exp_count/' + l[-1]
            diff_exp_norm_file = 'diff_exp/' + l[0]
            cnv_file = 'cnv/' + l[4] + '.txt'
            vcf_file = 'vcf/' + 'OutputAnnoFile_' + l[
                5] + '.hg38_multianno.txt.dat'
            # Check if all 6 files are exist for a patient (that's because for some patients, their survival time not reported)
            all_files = [
                myth_file, diff_exp_norm_file, diff_myth_file, exp_norm_file,
                cnv_file, vcf_file
            ]
            for fname in all_files:
                if not os.path.exists(fname):
                    print('File ' + fname + ' does not exist!')
                    sys.exit(1)
    #         f = open(clinical_file)
    #         content = f.read().strip()
    #         f.close()
            clin.append(clinical_file)
            #         f = open(surv_file)
            #         content = f.read().strip()
            #         f.close()
            suv_time.append(surv_file)
            temp_myth = myth_data(myth_file, seen, d, dic)
            vec = np.array(get_data(exp_norm_file, diff_exp_norm_file,
                                    diff_myth_file, cnv_file, vcf_file,
                                    temp_myth, seen, dic),
                           dtype=np.float32)
            vec = vec.flatten()
            #         vec = np.concatenate([
            #             vec, cancer_type_vector, cancer_subtype_vector,
            #             anatomical_location_vector, cell_type_vector])
            feat_vecs[i, :] = vec
            i += 1

    min_max_scaler = MinMaxScaler(clip=True)
    labels_days = []
    labels_surv = []
    for days, surv in zip(clin, suv_time):
        #     if days.replace("-", "") != "":
        #         days = float(days)
        #     else:
        #         days = 0.0
        labels_days.append(float(days))
        labels_surv.append(float(surv))

    # Train by batch
    dataset = feat_vecs
    #print(dataset.shape)
    labels_days = np.array(labels_days)
    labels_surv = np.array(labels_surv)

    censored_index = []
    uncensored_index = []
    for i in range(len(dataset)):
        if labels_surv[i] == 1:
            censored_index.append(i)
        else:
            uncensored_index.append(i)
    model = CoxPH(MyNet(edge_index).to(device), tt.optim.Adam(0.0001))

    censored_index = np.array(censored_index)
    uncensored_index = np.array(uncensored_index)

    # Each time test on a specific cancer type
    # total_cancers = ["TCGA-BRCA"]
    # for i in range(len(total_cancers)):
    # test_set = [d for t, d in zip(total_cancers, dataset) if t == total_cancers[i]]
    # train_set = [d for t, d in zip(total_cancers, dataset) if t != total_cancers[i]]

    # Censored split
    n = len(censored_index)
    index = np.arange(n)
    i = n // 5
    np.random.seed(seed=0)
    np.random.shuffle(index)
    if fold < 4:
        ctest_idx = index[fold * i:fold * i + i]
        ctrain_idx = index[:fold * i] + index[fold * i + i:]
    else:
        ctest_idx = index[fold * i:]
        ctrain_idx = index[:fold * i]
    ctrain_n = len(ctrain_idx)
    cvalid_n = ctrain_n // 10
    cvalid_idx = ctrain_idx[:cvalid_n]
    ctrain_idx = ctrain_idx[cvalid_n:]

    # Uncensored split
    n = len(uncensored_index)
    index = np.arange(n)
    i = n // 5
    np.random.seed(seed=0)
    np.random.shuffle(index)
    if fold < 4:
        utest_idx = index[fold * i:fold * i + i]
        utrain_idx = index[:fold * i] + index[fold * i + i:]
    else:
        utest_idx = index[fold * i:]
        utrain_idx = index[:fold * i]
    utrain_n = len(utrain_idx)
    uvalid_n = utrain_n // 10
    uvalid_idx = utrain_idx[:uvalid_n]
    utrain_idx = utrain_idx[uvalid_n:]

    train_idx = np.concatenate(censored_index[ctrain_idx],
                               uncensored_index[utrain_idx])
    np.random.seed(seed=0)
    np.random.shuffle(train_idx)
    valid_idx = np.concatenate(censored_index[cvalid_idx],
                               uncensored_index[uvalid_idx])
    np.random.seed(seed=0)
    np.random.shuffle(valid_idx)
    test_idx = np.concatenate(censored_index[ctest_idx],
                              uncensored_index[utest_idx])
    np.random.seed(seed=0)
    np.random.shuffle(test_idx)

    train_data = dataset[train_idx]
    train_data = min_max_scaler.fit_transform(train_data)
    train_labels_days = labels_days[train_idx]
    train_labels_surv = labels_surv[train_idx]
    train_labels = (train_labels_days, train_labels_surv)

    val_data = dataset[valid_idx]
    val_data = min_max_scaler.transform(val_data)
    val_labels_days = labels_days[valid_idx]
    val_labels_surv = labels_surv[valid_idx]
    test_data = dataset[test_idx]
    test_data = min_max_scaler.transform(test_data)
    test_labels_days = labels_days[test_idx]
    test_labels_surv = labels_surv[test_idx]
    val_labels = (val_labels_days, val_labels_surv)
    print(val_labels)

    callbacks = [tt.callbacks.EarlyStopping()]
    batch_size = 16
    epochs = 100
    val = (val_data, val_labels)
    log = model.fit(train_data,
                    train_labels,
                    batch_size,
                    epochs,
                    callbacks,
                    True,
                    val_data=val,
                    val_batch_size=batch_size)
    log.plot()
    plt.show()
    # print(model.partial_log_likelihood(*val).mean())
    train = train_data, train_labels
    # Compute the evaluation measurements
    model.compute_baseline_hazards(*train)
    surv = model.predict_surv_df(test_data)
    print(surv)
    ev = EvalSurv(surv, test_labels_days, test_labels_surv)
    print(ev.concordance_td())
Ejemplo n.º 7
0
                       show_progress=False,
                       step_size=.1)
        elapsed = time.time() - tic
        print('Time elapsed: %f second(s)' % elapsed)
        np.savetxt(time_elapsed_filename, np.array(elapsed).reshape(1, -1))

        # ---------------------------------------------------------------------
        # evaluation
        #

        sorted_y_test = np.unique(y_test[:, 0])
        surv_df = surv_model.predict_survival_function(X_test_std,
                                                       sorted_y_test)
        surv = surv_df.values.T
        ev = EvalSurv(surv_df, y_test[:, 0], y_test[:, 1], censor_surv='km')
        cindex_td = ev.concordance_td('antolini')
        print('c-index (td):', cindex_td)

        linear_predictors = \
            surv_model.predict_log_partial_hazard(X_test_std)
        cindex = concordance_index(y_test[:, 0], -linear_predictors, y_test[:,
                                                                            1])
        print('c-index:', cindex)

        time_grid = np.linspace(sorted_y_test[0], sorted_y_test[-1], 100)
        integrated_brier = ev.integrated_brier_score(time_grid)
        print('Integrated Brier score:', integrated_brier, flush=True)

        test_set_metrics = [cindex_td, integrated_brier]

        rng = np.random.RandomState(bootstrap_random_seed)