t2 = np.random.exponential(scale=1 / (np.log(2) / halflife2), size=int(n_samples / 2)) t = np.concatenate((t1, t2)) censtime = np.random.exponential(scale=1 / (np.log(2) / (halflife_cens)), size=n_samples) f = t < censtime t[~f] = censtime[~f] y_train = nnet_survival.make_surv_array(t, f, breaks) x_train = np.zeros(n_samples) x_train[int(n_samples / 2):] = 1 model = Sequential() #Hidden layers would go here. For this example, using simple linear model with no hidden layers. model.add(Dense(1, input_dim=1, use_bias=0, kernel_initializer='zeros')) model.add(nnet_survival.PropHazards(n_intervals)) model.compile(loss=nnet_survival.surv_likelihood(n_intervals), optimizer=optimizers.RMSprop()) #model.summary() early_stopping = EarlyStopping(monitor='loss', patience=2) history = model.fit(x_train, y_train, batch_size=32, epochs=1000, callbacks=[early_stopping]) y_pred = model.predict_proba(x_train, verbose=0) kmf = KaplanMeierFitter() kmf.fit(t[0:int(n_samples / 2)], event_observed=f[0:int(n_samples / 2)]) plt.plot(breaks, np.concatenate(([1], np.cumprod(y_pred[0, :]))), 'bo-') plt.plot(kmf.survival_function_.index.values,
def architecture(self,n_intervals): #mrna_input input_1 = Input(shape = (122,122,1)) mrna_conv_1 = Convolution2D(256, (3, 3), kernel_initializer='glorot_normal')(input_1) mrna_bn_1 = BatchNormalization()(mrna_conv_1) mrna_act_1 = Activation('relu')(mrna_bn_1) mrna_pool_1 = MaxPooling2D(pool_size = (2,2))(mrna_act_1) mrna_conv_2 = Convolution2D(256, (3, 3), kernel_initializer='glorot_normal')(mrna_pool_1) mrna_bn_2 = BatchNormalization()(mrna_conv_2) mrna_act_2 = Activation('relu')(mrna_bn_2) mrna_pool_2 = MaxPooling2D(pool_size = (2,2))(mrna_act_2) flat_1 = Flatten()(mrna_pool_2) #meth_input input_2 = Input(shape = (122,122,1)) meth_conv_1 = Convolution2D(256, (3, 3), kernel_initializer='glorot_normal')(input_2) meth_bn_1 = BatchNormalization()(meth_conv_1) meth_act_1 = Activation('relu')(meth_bn_1) meth_pool_1 = MaxPooling2D(pool_size = (2,2))(meth_act_1) meth_conv_2 = Convolution2D(256, (3, 3), kernel_initializer='glorot_normal')(meth_pool_1) meth_bn_2 = BatchNormalization()(meth_conv_2) meth_act_2 = Activation('relu')(meth_bn_2) meth_pool_2 = MaxPooling2D(pool_size = (2,2))(meth_act_2) flat_2 = Flatten()(meth_pool_2) #mirna_input input_3 = Input(shape = (42,42,1)) mirna_conv_1 = Convolution2D(256, (3, 3), kernel_initializer='glorot_normal')(input_3) mirna_bn_1 = BatchNormalization()(mirna_conv_1) mirna_act_1 = Activation('relu')(mirna_bn_1) mirna_pool_1 = MaxPooling2D(pool_size = (2,2))(mirna_act_1) mirna_conv_2 = Convolution2D(256, (3, 3), kernel_initializer='glorot_normal')(mirna_pool_1) mirna_bn_2 = BatchNormalization()(mirna_conv_2) mirna_act_2 = Activation('relu')(mirna_bn_2) mirna_pool_2 = MaxPooling2D(pool_size = (2,2))(mirna_act_2) flat_3 = Flatten()(mirna_pool_2) #clinical_input input_4 = Input(shape=(22, ), name='clinical') dense = Dense(1, activation='relu', kernel_initializer='glorot_normal')(input_4) #flat4 = Flatten()(dense) if self.omics == 'mrna': if self.clinical: concat = Concatenate()([flat_1, dense]) else: concat = flat_1 dense_1 = Dense(512, activation = 'relu',kernel_initializer='glorot_normal')(concat) dense_1_dropout = Dropout(0.5)(dense_1) dense_2 = Dense(128, activation = 'relu',kernel_initializer='glorot_normal')(dense_1_dropout) dense_2_dropout = Dropout(0.1)(dense_2) if self.PH: dense_3 = Dense(1, use_bias=0, kernel_initializer='zeros')(dense_2_dropout) output = nnet_survival.PropHazards(n_intervals)(dense_3) else: output = Dense(n_intervals, activation='sigmoid', kernel_initializer='he_normal')(dense_2_dropout) if self.clinical: model = Model(inputs=[input_1, input_4], outputs=[output]) else: model = Model(inputs=[input_1], outputs=[output]) if self.omics == 'meth': if self.clinical: concat = Concatenate()([flat_2, dense]) else: concat = flat_2 dense_1 = Dense(512, activation = 'relu',kernel_initializer='glorot_normal')(concat) dense_1_dropout = Dropout(0.5)(dense_1) dense_2 = Dense(128, activation = 'relu',kernel_initializer='glorot_normal')(dense_1_dropout) dense_2_dropout = Dropout(0.1)(dense_2) if self.PH: dense_3 = Dense(1, use_bias=0, kernel_initializer='zeros')(dense_2_dropout) output = nnet_survival.PropHazards(n_intervals)(dense_3) else: output = Dense(n_intervals, activation='sigmoid', kernel_initializer='he_normal')(dense_2_dropout) if self.clinical: model = Model(inputs=[input_2, input_4], outputs=[output]) else: model = Model(inputs=[input_2], outputs=[output]) if self.omics == 'mirna': if self.clinical: concat = Concatenate()([flat_3, dense]) else: concat = flat_3 dense_1 = Dense(512, activation = 'relu',kernel_initializer='glorot_normal')(concat) dense_1_dropout = Dropout(0.5)(dense_1) dense_2 = Dense(128, activation = 'relu',kernel_initializer='glorot_normal')(dense_1_dropout) dense_2_dropout = Dropout(0.1)(dense_2) if self.PH: dense_3 = Dense(1, use_bias=0, kernel_initializer='zeros')(dense_2_dropout) output = nnet_survival.PropHazards(n_intervals)(dense_3) else: output = Dense(n_intervals, activation='sigmoid', kernel_initializer='he_normal')(dense_2_dropout) if self.clinical: model = Model(inputs=[input_3,input_4], outputs=[output]) else: model = Model(inputs=[input_3], outputs=[output]) if self.omics == 'mrna_meth': if self.clinical: concat = Concatenate()([flat_1, flat_2, dense]) else: concat = Concatenate()([flat_1,flat_2]) dense_1 = Dense(512, activation = 'relu',kernel_initializer='glorot_normal')(concat) dense_1_dropout = Dropout(0.5)(dense_1) dense_2 = Dense(128, activation = 'relu',kernel_initializer='glorot_normal')(dense_1_dropout) dense_2_dropout = Dropout(0.1)(dense_2) if self.PH: dense_3 = Dense(1, use_bias=0, kernel_initializer='zeros')(dense_2_dropout) output = nnet_survival.PropHazards(n_intervals)(dense_3) else: output = Dense(n_intervals, activation='sigmoid', kernel_initializer='he_normal')(dense_2_dropout) if self.clinical: model = Model(inputs=[input_1,input_2,input_4], outputs=[output]) else: model = Model(inputs=[input_1, input_2], outputs=[output]) if self.omics == 'mrna_mirna': if self.clinical: concat = Concatenate()([flat_1, flat_3, dense]) else: concat = Concatenate()([flat_1,flat_3]) dense_1 = Dense(512, activation = 'relu',kernel_initializer='glorot_normal')(concat) dense_1_dropout = Dropout(0.5)(dense_1) dense_2 = Dense(128, activation = 'relu',kernel_initializer='glorot_normal')(dense_1_dropout) dense_2_dropout = Dropout(0.1)(dense_2) if self.PH: dense_3 = Dense(1, use_bias=0, kernel_initializer='zeros')(dense_2_dropout) output = nnet_survival.PropHazards(n_intervals)(dense_3) else: output = Dense(n_intervals, activation='sigmoid', kernel_initializer='he_normal')(dense_2_dropout) if self.clinical: model = Model(inputs=[input_1,input_3,input_4], outputs=[output]) else: model = Model(inputs=[input_1, input_3], outputs=[output]) if self.omics == 'mrna_meth_mirna': if self.clinical: concat = Concatenate()([flat_1, flat_2, flat_3, dense]) else: concat = Concatenate()([flat_1, flat_2, flat_3]) dense_1 = Dense(512, activation = 'relu',kernel_initializer='glorot_normal')(concat) dense_1_dropout = Dropout(0.5)(dense_1) dense_2 = Dense(128, activation = 'relu',kernel_initializer='glorot_normal')(dense_1_dropout) dense_2_dropout = Dropout(0.1)(dense_2) if self.PH: dense_3 = Dense(1, use_bias=0, kernel_initializer='zeros')(dense_2_dropout) output = nnet_survival.PropHazards(n_intervals)(dense_3) else: output = Dense(n_intervals, activation='sigmoid', kernel_initializer='he_normal')(dense_2) if self.clinical: model = Model(inputs=[input_1,input_2,input_3,input_4], outputs=[output]) else: model = Model(inputs=[input_1,input_2,input_3], outputs=[output]) return model
def binary_ANN_surviavl(): breaks = np.arange(0, 5000, 50) n_intervals = len(breaks) - 1 timegap = breaks[1:] - breaks[:-1] halflife1 = 200 halflife2 = 400 halflife_cens = 400 n_samples = 5000 np.random.seed(seed=0) t1 = np.random.exponential(scale=1 / (np.log(2) / halflife1), size=int(n_samples / 2)) t2 = np.random.exponential(scale=1 / (np.log(2) / halflife2), size=int(n_samples / 2)) t = np.concatenate((t1, t2)) censtime = np.random.exponential(scale=1 / (np.log(2) / (halflife_cens)), size=n_samples) f = t < censtime t[~f] = censtime[~f] y_train = nnet_survival.make_surv_array(t, f, breaks) x_train = np.zeros(n_samples) x_train[int(n_samples / 2):] = 1 model = Sequential() # Hidden layers would go here. For this example, using simple linear model with no hidden layers. model.add(Dense(1, input_dim=1, use_bias=0, kernel_initializer='zeros')) model.add(nnet_survival.PropHazards(n_intervals)) model.compile(loss=nnet_survival.surv_likelihood(n_intervals), optimizer=optimizers.RMSprop()) # model.summary() early_stopping = EarlyStopping(monitor='loss', patience=2) history = model.fit(x_train, y_train, batch_size=32, epochs=1000, callbacks=[early_stopping]) y_pred = model.predict_proba(x_train, verbose=0) kmf = KaplanMeierFitter() kmf.fit(t[0:int(n_samples / 2)], event_observed=f[0:int(n_samples / 2)]) plt.plot(breaks, np.concatenate(([1], np.cumprod(y_pred[0, :]))), 'bo-') plt.plot(kmf.survival_function_.index.values, kmf.survival_function_.KM_estimate, color='k') kmf.fit(t[int(n_samples / 2) + 1:], event_observed=f[int(n_samples / 2) + 1:]) plt.plot(breaks, np.concatenate(([1], np.cumprod(y_pred[-1, :]))), 'ro-') plt.plot(kmf.survival_function_.index.values, kmf.survival_function_.KM_estimate, color='k') plt.xticks(np.arange(0, 2000.0001, 200)) plt.yticks(np.arange(0, 1.0001, 0.125)) plt.xlim([0, 2000]) plt.ylim([0, 1]) plt.xlabel('Follow-up time (days)') plt.ylabel('Proportion surviving') plt.title('One covariate. Actual=black, predicted=blue/red.') plt.show() myData = pd.DataFrame({'x_train': x_train, 't': t, 'f': f}) cf = CoxPHFitter() cf.fit(myData, 't', event_col='f') # x_train = x_train.astype(np.float64) # cox_coef = cf.hazards_.x_train.values[0] cox_coef = cf.hazards_.x_train nn_coef = model.get_weights()[0][0][0] print('Cox model coefficient:') print(cox_coef) print('Cox model hazard ratio:') print(np.exp(cox_coef)) print('Neural network coefficient:') print(nn_coef) print('Neural network hazard ratio:') print(np.exp(nn_coef))