Example #1
0
t2 = np.random.exponential(scale=1 / (np.log(2) / halflife2),
                           size=int(n_samples / 2))
t = np.concatenate((t1, t2))
censtime = np.random.exponential(scale=1 / (np.log(2) / (halflife_cens)),
                                 size=n_samples)
f = t < censtime
t[~f] = censtime[~f]

y_train = nnet_survival.make_surv_array(t, f, breaks)
x_train = np.zeros(n_samples)
x_train[int(n_samples / 2):] = 1

model = Sequential()
#Hidden layers would go here. For this example, using simple linear model with no hidden layers.
model.add(Dense(1, input_dim=1, use_bias=0, kernel_initializer='zeros'))
model.add(nnet_survival.PropHazards(n_intervals))
model.compile(loss=nnet_survival.surv_likelihood(n_intervals),
              optimizer=optimizers.RMSprop())
#model.summary()
early_stopping = EarlyStopping(monitor='loss', patience=2)
history = model.fit(x_train,
                    y_train,
                    batch_size=32,
                    epochs=1000,
                    callbacks=[early_stopping])
y_pred = model.predict_proba(x_train, verbose=0)

kmf = KaplanMeierFitter()
kmf.fit(t[0:int(n_samples / 2)], event_observed=f[0:int(n_samples / 2)])
plt.plot(breaks, np.concatenate(([1], np.cumprod(y_pred[0, :]))), 'bo-')
plt.plot(kmf.survival_function_.index.values,
Example #2
0
    def architecture(self,n_intervals):
        #mrna_input
        input_1 = Input(shape = (122,122,1))
        mrna_conv_1   = Convolution2D(256, (3, 3), kernel_initializer='glorot_normal')(input_1)
        mrna_bn_1     = BatchNormalization()(mrna_conv_1)
        mrna_act_1    = Activation('relu')(mrna_bn_1)
        mrna_pool_1   = MaxPooling2D(pool_size = (2,2))(mrna_act_1)

        mrna_conv_2   = Convolution2D(256, (3, 3), kernel_initializer='glorot_normal')(mrna_pool_1)
        mrna_bn_2     = BatchNormalization()(mrna_conv_2)
        mrna_act_2    = Activation('relu')(mrna_bn_2)
        mrna_pool_2   = MaxPooling2D(pool_size = (2,2))(mrna_act_2)

        flat_1 = Flatten()(mrna_pool_2)

        #meth_input
        input_2 = Input(shape = (122,122,1))
        meth_conv_1   = Convolution2D(256, (3, 3), kernel_initializer='glorot_normal')(input_2)
        meth_bn_1     = BatchNormalization()(meth_conv_1)
        meth_act_1    = Activation('relu')(meth_bn_1)
        meth_pool_1   = MaxPooling2D(pool_size = (2,2))(meth_act_1)

        meth_conv_2   = Convolution2D(256, (3, 3), kernel_initializer='glorot_normal')(meth_pool_1)
        meth_bn_2     = BatchNormalization()(meth_conv_2)
        meth_act_2    = Activation('relu')(meth_bn_2)
        meth_pool_2   = MaxPooling2D(pool_size = (2,2))(meth_act_2)

        flat_2 = Flatten()(meth_pool_2)

        #mirna_input
        input_3 = Input(shape = (42,42,1))
        mirna_conv_1   = Convolution2D(256, (3, 3), kernel_initializer='glorot_normal')(input_3)
        mirna_bn_1     = BatchNormalization()(mirna_conv_1)
        mirna_act_1    = Activation('relu')(mirna_bn_1)
        mirna_pool_1   = MaxPooling2D(pool_size = (2,2))(mirna_act_1)

        mirna_conv_2   = Convolution2D(256, (3, 3), kernel_initializer='glorot_normal')(mirna_pool_1)
        mirna_bn_2     = BatchNormalization()(mirna_conv_2)
        mirna_act_2    = Activation('relu')(mirna_bn_2)
        mirna_pool_2   = MaxPooling2D(pool_size = (2,2))(mirna_act_2)

        flat_3 = Flatten()(mirna_pool_2)

        #clinical_input
        input_4 = Input(shape=(22, ), name='clinical')
        dense = Dense(1, activation='relu', kernel_initializer='glorot_normal')(input_4)
        #flat4 = Flatten()(dense)

        if self.omics == 'mrna':
            if self.clinical:
                concat = Concatenate()([flat_1, dense])
            else:
                concat = flat_1

            dense_1 = Dense(512, activation = 'relu',kernel_initializer='glorot_normal')(concat)
            dense_1_dropout = Dropout(0.5)(dense_1)
            dense_2 = Dense(128, activation = 'relu',kernel_initializer='glorot_normal')(dense_1_dropout)
            dense_2_dropout = Dropout(0.1)(dense_2)     

            if self.PH:
                dense_3 = Dense(1, use_bias=0, kernel_initializer='zeros')(dense_2_dropout)
                output  = nnet_survival.PropHazards(n_intervals)(dense_3)
            else:
                output = Dense(n_intervals, activation='sigmoid', kernel_initializer='he_normal')(dense_2_dropout)

            if self.clinical:
                model = Model(inputs=[input_1, input_4], outputs=[output])
            else:
                model = Model(inputs=[input_1], outputs=[output])
        
        if self.omics == 'meth':
            if self.clinical:
                concat = Concatenate()([flat_2, dense])
            else:
                concat = flat_2

            dense_1 = Dense(512, activation = 'relu',kernel_initializer='glorot_normal')(concat)
            dense_1_dropout = Dropout(0.5)(dense_1)
            dense_2 = Dense(128, activation = 'relu',kernel_initializer='glorot_normal')(dense_1_dropout)
            dense_2_dropout = Dropout(0.1)(dense_2)
            
            if self.PH:
                dense_3 = Dense(1, use_bias=0, kernel_initializer='zeros')(dense_2_dropout)
                output  = nnet_survival.PropHazards(n_intervals)(dense_3)
            else:
                output = Dense(n_intervals, activation='sigmoid', kernel_initializer='he_normal')(dense_2_dropout)
            
            if self.clinical:
                model = Model(inputs=[input_2, input_4], outputs=[output])
            else:
                model = Model(inputs=[input_2], outputs=[output])
        
        if self.omics == 'mirna':
            if self.clinical:
                concat = Concatenate()([flat_3, dense])
            else:
                concat = flat_3

            dense_1 = Dense(512, activation = 'relu',kernel_initializer='glorot_normal')(concat)
            dense_1_dropout = Dropout(0.5)(dense_1)
            dense_2 = Dense(128, activation = 'relu',kernel_initializer='glorot_normal')(dense_1_dropout)
            dense_2_dropout = Dropout(0.1)(dense_2)
                 
            if self.PH:
                dense_3 = Dense(1, use_bias=0, kernel_initializer='zeros')(dense_2_dropout)
                output  = nnet_survival.PropHazards(n_intervals)(dense_3)
            else:
                output = Dense(n_intervals, activation='sigmoid', kernel_initializer='he_normal')(dense_2_dropout)

            if self.clinical:
                model = Model(inputs=[input_3,input_4], outputs=[output])
            else:
                model = Model(inputs=[input_3], outputs=[output])

        if self.omics == 'mrna_meth':
            if self.clinical:
                concat = Concatenate()([flat_1, flat_2, dense])
            else:
                concat = Concatenate()([flat_1,flat_2])

            dense_1 = Dense(512, activation = 'relu',kernel_initializer='glorot_normal')(concat)
            dense_1_dropout = Dropout(0.5)(dense_1)
            dense_2 = Dense(128, activation = 'relu',kernel_initializer='glorot_normal')(dense_1_dropout)
            dense_2_dropout = Dropout(0.1)(dense_2)    
            
            if self.PH:
                dense_3 = Dense(1, use_bias=0, kernel_initializer='zeros')(dense_2_dropout)
                output  = nnet_survival.PropHazards(n_intervals)(dense_3)
            else:
                output = Dense(n_intervals, activation='sigmoid', kernel_initializer='he_normal')(dense_2_dropout)
            
            if self.clinical:
                model = Model(inputs=[input_1,input_2,input_4], outputs=[output])
            else:
                model = Model(inputs=[input_1, input_2], outputs=[output])

        if self.omics == 'mrna_mirna':
            if self.clinical:
                concat = Concatenate()([flat_1, flat_3, dense])
            else:
                concat = Concatenate()([flat_1,flat_3])

            dense_1 = Dense(512, activation = 'relu',kernel_initializer='glorot_normal')(concat)
            dense_1_dropout = Dropout(0.5)(dense_1)
            dense_2 = Dense(128, activation = 'relu',kernel_initializer='glorot_normal')(dense_1_dropout)
            dense_2_dropout = Dropout(0.1)(dense_2)
            
            if self.PH:
                dense_3 = Dense(1, use_bias=0, kernel_initializer='zeros')(dense_2_dropout)
                output  = nnet_survival.PropHazards(n_intervals)(dense_3)
            else:
                output = Dense(n_intervals, activation='sigmoid', kernel_initializer='he_normal')(dense_2_dropout)

            if self.clinical:
                model = Model(inputs=[input_1,input_3,input_4], outputs=[output])
            else:
                model = Model(inputs=[input_1, input_3], outputs=[output])

        if self.omics == 'mrna_meth_mirna':
            if self.clinical:
                concat = Concatenate()([flat_1, flat_2, flat_3, dense])
            else:
                concat = Concatenate()([flat_1, flat_2, flat_3])

            dense_1 = Dense(512, activation = 'relu',kernel_initializer='glorot_normal')(concat)
            dense_1_dropout = Dropout(0.5)(dense_1)
            dense_2 = Dense(128, activation = 'relu',kernel_initializer='glorot_normal')(dense_1_dropout)
            dense_2_dropout = Dropout(0.1)(dense_2)
            
            if self.PH:
                dense_3 = Dense(1, use_bias=0, kernel_initializer='zeros')(dense_2_dropout)
                output  = nnet_survival.PropHazards(n_intervals)(dense_3)
            else:
                output = Dense(n_intervals, activation='sigmoid', kernel_initializer='he_normal')(dense_2)

            if self.clinical:
                model = Model(inputs=[input_1,input_2,input_3,input_4], outputs=[output])
            else:
                model = Model(inputs=[input_1,input_2,input_3], outputs=[output])

        return model
Example #3
0
def binary_ANN_surviavl():
    breaks = np.arange(0, 5000, 50)
    n_intervals = len(breaks) - 1
    timegap = breaks[1:] - breaks[:-1]

    halflife1 = 200
    halflife2 = 400
    halflife_cens = 400
    n_samples = 5000
    np.random.seed(seed=0)
    t1 = np.random.exponential(scale=1 / (np.log(2) / halflife1),
                               size=int(n_samples / 2))
    t2 = np.random.exponential(scale=1 / (np.log(2) / halflife2),
                               size=int(n_samples / 2))
    t = np.concatenate((t1, t2))
    censtime = np.random.exponential(scale=1 / (np.log(2) / (halflife_cens)),
                                     size=n_samples)
    f = t < censtime
    t[~f] = censtime[~f]

    y_train = nnet_survival.make_surv_array(t, f, breaks)
    x_train = np.zeros(n_samples)
    x_train[int(n_samples / 2):] = 1

    model = Sequential()
    # Hidden layers would go here. For this example, using simple linear model with no hidden layers.
    model.add(Dense(1, input_dim=1, use_bias=0, kernel_initializer='zeros'))
    model.add(nnet_survival.PropHazards(n_intervals))
    model.compile(loss=nnet_survival.surv_likelihood(n_intervals),
                  optimizer=optimizers.RMSprop())
    # model.summary()
    early_stopping = EarlyStopping(monitor='loss', patience=2)
    history = model.fit(x_train,
                        y_train,
                        batch_size=32,
                        epochs=1000,
                        callbacks=[early_stopping])
    y_pred = model.predict_proba(x_train, verbose=0)

    kmf = KaplanMeierFitter()
    kmf.fit(t[0:int(n_samples / 2)], event_observed=f[0:int(n_samples / 2)])
    plt.plot(breaks, np.concatenate(([1], np.cumprod(y_pred[0, :]))), 'bo-')
    plt.plot(kmf.survival_function_.index.values,
             kmf.survival_function_.KM_estimate,
             color='k')
    kmf.fit(t[int(n_samples / 2) + 1:],
            event_observed=f[int(n_samples / 2) + 1:])
    plt.plot(breaks, np.concatenate(([1], np.cumprod(y_pred[-1, :]))), 'ro-')
    plt.plot(kmf.survival_function_.index.values,
             kmf.survival_function_.KM_estimate,
             color='k')
    plt.xticks(np.arange(0, 2000.0001, 200))
    plt.yticks(np.arange(0, 1.0001, 0.125))
    plt.xlim([0, 2000])
    plt.ylim([0, 1])
    plt.xlabel('Follow-up time (days)')
    plt.ylabel('Proportion surviving')
    plt.title('One covariate. Actual=black, predicted=blue/red.')
    plt.show()

    myData = pd.DataFrame({'x_train': x_train, 't': t, 'f': f})
    cf = CoxPHFitter()
    cf.fit(myData, 't', event_col='f')
    # x_train = x_train.astype(np.float64)
    # cox_coef = cf.hazards_.x_train.values[0]
    cox_coef = cf.hazards_.x_train
    nn_coef = model.get_weights()[0][0][0]
    print('Cox model coefficient:')
    print(cox_coef)
    print('Cox model hazard ratio:')
    print(np.exp(cox_coef))
    print('Neural network coefficient:')
    print(nn_coef)
    print('Neural network hazard ratio:')
    print(np.exp(nn_coef))