Example #1
0
ds_test = tf.data.Dataset.from_tensor_slices((idx_test, y_label[idx_test]))
ds_test = ds_test.batch(len(idx_test), drop_remainder=False)
ds_test = ds_test.map(lambda x, y: ((tf.gather(tf.constant(D['seq_5p'], dtype=tf.int32), x),
                                      tf.gather(tf.constant(D['seq_3p'], dtype=tf.int32), x),
                                      tf.gather(tf.constant(D['seq_ref'], dtype=tf.int32), x),
                                      tf.gather(tf.constant(D['seq_alt'], dtype=tf.int32), x),
                                      tf.gather(tf.constant(D['strand_emb'], dtype=tf.float32), x),
                                      tf.gather(tf.constant(D['cds_emb'], dtype=tf.float32), x)
                                       ),
                                       y,
                                      ))



sequence_encoder = InstanceModels.VariantSequence(6, 4, 2, [64, 64, 64, 64], fusion_dimension=128, use_frame=True)
mil = RaggedModels.MIL(instance_encoders=[], sample_encoders=[sequence_encoder.model], output_dim=y_label.shape[-1], output_type='other', mil_hidden=[128, 128, 64, 32], mode='none')
losses = [Losses.CrossEntropy()]
mil.model.compile(loss=losses,
                  metrics=[Metrics.Accuracy(), Metrics.CrossEntropy()],
                  weighted_metrics=[Metrics.Accuracy(), Metrics.CrossEntropy()],
                  optimizer=tf.keras.optimizers.Adam(learning_rate=0.001,
                                                     ))

callbacks = [tf.keras.callbacks.EarlyStopping(monitor='val_weighted_CE', min_delta=0.001, patience=10, mode='min', restore_best_weights=True)]


mil.model.fit(ds_train, steps_per_epoch=50,
              validation_data=ds_valid,
              epochs=10000,
              callbacks=callbacks,
              )
Example #2
0
            cwd / 'figures' / 'tmb' / 'tcga' / 'VICC_01_R2' / 'results' /
            'run_position.pkl', 'rb')),
    pickle.load(
        open(
            cwd / 'figures' / 'tmb' / 'tcga' / 'VICC_01_R2' / 'results' /
            'run_sequence.pkl', 'rb'))
]

results = {}

for encoder, loaders, weights, name in zip(encoders, loaders, all_weights,
                                           ['naive', 'position', 'sequence']):

    mil = RaggedModels.MIL(instance_encoders=[encoder.model],
                           output_dim=1,
                           pooling='sum',
                           mil_hidden=(64, 32, 16),
                           output_type='quantiles',
                           regularization=0)
    mil.model.compile(loss=losses,
                      metrics=metrics,
                      optimizer=tf.keras.optimizers.Adam(learning_rate=0.001))
    ##test eval
    test_idx = []
    predictions = []

    for index, (idx_train, idx_test) in enumerate(
            StratifiedKFold(n_splits=8, random_state=0,
                            shuffle=True).split(y_strat, y_strat)):
        mil.model.set_weights(weights[index])

        ds_test = tf.data.Dataset.from_tensor_slices(
Example #3
0
ds_test = tf.data.Dataset.from_tensor_slices((idx_test, y_label[idx_test]))
ds_test = ds_test.batch(len(idx_test), drop_remainder=False)
ds_test = ds_test.map(lambda x, y: (
    (five_p_loader(x, ragged_output=True), three_p_loader(x,
                                                          ragged_output=True),
     ref_loader(x, ragged_output=True), alt_loader(x, ragged_output=True),
     strand_loader(x, ragged_output=True)), y))

histories = []
evaluations = []
weights = []
for i in range(3):
    tile_encoder = InstanceModels.VariantSequence(6, 4, 2, [16, 16, 8, 8])
    # mil = RaggedModels.MIL(instance_encoders=[tile_encoder.model], output_dim=1, pooling='both', output_type='regression', pooled_layers=[32, ])
    mil = RaggedModels.MIL(instance_encoders=[tile_encoder.model],
                           output_dim=1,
                           pooling='dynamic',
                           output_type='regression')
    losses = ['mse']
    mil.model.compile(loss=losses,
                      metrics=['mse'],
                      optimizer=tf.keras.optimizers.Adam(
                          learning_rate=0.001, ))
    callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor='val_mse',
                                         min_delta=0.001,
                                         patience=20,
                                         mode='min',
                                         restore_best_weights=True)
    ]
    history = mil.model.fit(ds_train,
                            steps_per_epoch=10,
Example #4
0
        x, ragged_output=True), ref_loader(x, ragged_output=True),
     alt_loader(x, ragged_output=True), strand_loader(x, ragged_output=True),
     tf.gather(tf.constant(types), x)), y))

histories = []
evaluations = []
weights = []
for i in range(3):
    sequence_encoder = InstanceModels.VariantSequence(6, 4, 2, [16, 16, 8, 8])
    sample_encoder = SampleModels.Type(shape=(), dim=len(np.unique(types)))
    # mil = RaggedModels.MIL(instance_encoders=[sequence_encoder.model], sample_encoders=[sample_encoder.model], sample_layers=[64, ], output_dim=1, pooling='both', output_type='other', pooled_layers=[32, ])
    mil = RaggedModels.MIL(instance_encoders=[sequence_encoder.model],
                           sample_encoders=[sample_encoder.model],
                           fusion='before',
                           output_dim=1,
                           pooling='both',
                           output_type='other',
                           pooled_layers=[
                               32,
                           ])
    losses = ['mse']
    mil.model.compile(loss=losses,
                      metrics=['mse'],
                      optimizer=tf.keras.optimizers.Adam(
                          learning_rate=0.001, ))
    callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor='val_mse',
                                         min_delta=0.001,
                                         patience=20,
                                         mode='min',
                                         restore_best_weights=True)
Example #5
0
ds_test = tf.data.Dataset.from_tensor_slices((idx_test, y_label[idx_test]))
ds_test = ds_test.batch(len(idx_test), drop_remainder=False)
ds_test = ds_test.map(lambda x, y: (
    (five_p_loader(x, ragged_output=True), three_p_loader(x,
                                                          ragged_output=True),
     ref_loader(x, ragged_output=True), alt_loader(x, ragged_output=True),
     strand_loader(x, ragged_output=True)), y))

histories = []
evaluations = []
weights = []
for i in range(3):
    tile_encoder = InstanceModels.VariantSequence(6, 4, 2, [16, 16, 8, 8])
    mil = RaggedModels.MIL(instance_encoders=[tile_encoder.model],
                           output_dim=2,
                           pooling='dynamic')
    # mil = RaggedModels.MIL(instance_encoders=[tile_encoder.model], output_dim=2, pooling='both', pooled_layers=[32, ])
    losses = [tf.keras.losses.CategoricalCrossentropy(from_logits=True)]
    mil.model.compile(
        loss=losses,
        metrics=[
            'accuracy',
            tf.keras.metrics.CategoricalCrossentropy(from_logits=True)
        ],
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, ))
    callbacks = [
        tf.keras.callbacks.EarlyStopping(
            monitor='val_categorical_crossentropy',
            min_delta=0.00001,
            patience=20,
Example #6
0
                                          tf.gather(tf.constant(y_weights, dtype=tf.float32), x)
                                           ))

    ds_valid = tf.data.Dataset.from_tensor_slices((idx_valid, y_label[idx_valid]))
    ds_valid = ds_valid.batch(len(idx_valid), drop_remainder=False)
    ds_valid = ds_valid.map(lambda x, y: ((pos_loader(x, ragged_output=True),
                                           bin_loader(x, ragged_output=True),
                                           chr_loader(x, ragged_output=True),
                                           ),
                                           y,
                                          tf.gather(tf.constant(y_weights, dtype=tf.float32), x)
                                           ))

    while True:
        position_encoder = InstanceModels.VariantPositionBin(24, 100)
        mil = RaggedModels.MIL(instance_encoders=[position_encoder.model], output_dim=2, pooling='sum', mil_hidden=(64, 32, 16, 8), output_type='anlulogits')

        mil.model.compile(loss=losses,
                          metrics=[Metrics.CrossEntropy(), Metrics.Accuracy()],
                          weighted_metrics=[Metrics.CrossEntropy(), Metrics.Accuracy()],
                          optimizer=tf.keras.optimizers.Adam(learning_rate=0.005,
                                                             clipvalue=10000))
        mil.model.fit(ds_train,
                      steps_per_epoch=20,
                      validation_data=ds_valid,
                      epochs=10000,
                      callbacks=callbacks)


        eval = mil.model.evaluate(ds_valid)
        if eval[2] >= .985:
Example #7
0
     ref_loader(x, ragged_output=True), alt_loader(x, ragged_output=True),
     strand_loader(x, ragged_output=True)), y))

ds_test = tf.data.Dataset.from_tensor_slices((idx_test, y_label[idx_test]))
ds_test = ds_test.batch(len(idx_test), drop_remainder=False)
ds_test = ds_test.map(lambda x, y: (
    (five_p_loader(x, ragged_output=True), three_p_loader(x,
                                                          ragged_output=True),
     ref_loader(x, ragged_output=True), alt_loader(x, ragged_output=True),
     strand_loader(x, ragged_output=True)), y))

tile_encoder = InstanceModels.VariantSequence(6, 4, 2, [16, 16, 8, 8])
# mil = RaggedModels.MIL(instance_encoders=[tile_encoder.model], output_dim=2, pooling='dynamic')
mil = RaggedModels.MIL(instance_encoders=[tile_encoder.model],
                       output_dim=2,
                       pooling='both',
                       pooled_layers=[
                           32,
                       ])
attentions = []
evaluations, histories, weights = pickle.load(
    open(
        cwd / 'figures' / 'controls' / 'samples' / 'sim_data' /
        'classification' / 'experiment_4' / 'sample_model_attention_both.pkl',
        'rb'))
for i in range(3):
    mil.model.set_weights(weights[i])
    temp = mil.attention_model.predict(ds_test)
    attentions.append(mil.attention_model.predict(ds_test).to_list())

with open(
        cwd / 'figures' / 'controls' / 'samples' / 'sim_data' /
Example #8
0
losses = [Losses.QuantileLoss()]
metrics = [Metrics.QuantileLoss()]

pass_encoder = InstanceModels.PassThrough(shape=(1, ))
type_encoder = SampleModels.Type(shape=(), dim=max(y_strat) + 1)

weights = pickle.load(
    open(
        cwd / 'figures' / 'tmb' / 'tcga' / 'MDA_409' / 'results' /
        'run_naive_sample_tcga.pkl', 'rb'))

mil = RaggedModels.MIL(
    sample_encoders=[pass_encoder.model, type_encoder.model],
    output_dim=1,
    mil_hidden=(64, 32, 16),
    output_type='quantiles',
    regularization=0,
    mode='none')

##test eval
test_idx = []
predictions = []

for index, (idx_train, idx_test) in enumerate(
        StratifiedKFold(n_splits=8, random_state=0,
                        shuffle=True).split(y_strat, y_strat)):
    mil.model.set_weights(weights[index])

    ds_test = tf.data.Dataset.from_tensor_slices((idx_test, y_label[idx_test]))
    ds_test = ds_test.batch(len(idx_test), drop_remainder=False)
Example #9
0
y_weights = np.array([1 / class_counts[_] for _ in y_strat])
y_weights /= np.sum(y_weights)

weights = []
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_CE',
                                     min_delta=0.0001,
                                     patience=50,
                                     mode='min',
                                     restore_best_weights=True)
]
losses = [Losses.CrossEntropy(from_logits=False)]
sequence_encoder = InstanceModels.VariantSequence(20, 4, 2, [8, 8, 8, 8])
mil = RaggedModels.MIL(instance_encoders=[sequence_encoder.model],
                       output_dim=2,
                       pooling='sum',
                       mil_hidden=(64, 64, 32, 16),
                       output_type='classification_probability')
mil.model.compile(
    loss=losses,
    metrics=[Metrics.CrossEntropy(from_logits=False),
             Metrics.Accuracy()],
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, clipvalue=10000))
initial_weights = mil.model.get_weights()

##stratified K fold for test
for idx_train, idx_test in StratifiedKFold(n_splits=9,
                                           random_state=0,
                                           shuffle=True).split(
                                               y_strat, y_strat):
    ##due to the y_strat levels not being constant this idx_train/idx_valid split is not deterministic
Example #10
0
ds_test = tf.data.Dataset.from_tensor_slices((idx_test, y_label[idx_test]))
ds_test = ds_test.batch(len(idx_test), drop_remainder=False)
ds_test = ds_test.map(lambda x, y: (
    (five_p_loader(x, ragged_output=True), three_p_loader(x,
                                                          ragged_output=True),
     ref_loader(x, ragged_output=True), alt_loader(x, ragged_output=True),
     strand_loader(x, ragged_output=True)), y))

histories = []
evaluations = []
weights = []
for i in range(3):
    tile_encoder = InstanceModels.VariantSequence(6, 4, 2, [16, 16, 8, 8])
    # mil = RaggedModels.MIL(instance_encoders=[tile_encoder.model], output_dim=2, pooling='both', pooled_layers=[32, ])
    mil = RaggedModels.MIL(instance_encoders=[tile_encoder.model],
                           output_dim=2,
                           pooling='sum',
                           mode='none')
    losses = [tf.keras.losses.CategoricalCrossentropy(from_logits=True)]
    mil.model.compile(
        loss=losses,
        metrics=[
            'accuracy',
            tf.keras.metrics.CategoricalCrossentropy(from_logits=True)
        ],
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, ))
    callbacks = [
        tf.keras.callbacks.EarlyStopping(
            monitor='val_categorical_crossentropy',
            min_delta=0.00001,
            patience=20,
            mode='min',
Example #11
0
ds_test = ds_test.batch(len(idx_test), drop_remainder=False)
ds_test = ds_test.map(lambda x, y: (
    (five_p_loader(x, ragged_output=True), three_p_loader(x,
                                                          ragged_output=True),
     ref_loader(x, ragged_output=True), alt_loader(x, ragged_output=True),
     strand_loader(x, ragged_output=True)), ))

from model.Sample_MIL import InstanceModels, RaggedModels
evaluations, histories, weights = pickle.load(
    open(
        cwd / 'figures' / 'controls' / 'samples' / 'sim_data' / 'regression' /
        'experiment_1' / 'sample_model_mean.pkl', 'rb'))

predictions = []
attentions = []
for i in range(3):
    tile_encoder = InstanceModels.VariantSequence(6, 4, 2, [16, 16, 8, 8])
    # mil = RaggedModels.MIL(instance_encoders=[tile_encoder.model], output_dim=1, pooling='both', output_type='regression', pooled_layers=[32, ])
    mil = RaggedModels.MIL(instance_encoders=[tile_encoder.model],
                           output_dim=1,
                           pooling='mean',
                           output_type='regression',
                           mode='none')
    mil.model.set_weights(weights[i])
    predictions.append(mil.model.predict(ds_test))
    # attentions.append(mil.attention_model.predict(ds_test).to_list())
#
with open(
        cwd / 'figures' / 'controls' / 'samples' / 'sim_data' / 'regression' /
        'experiment_1' / 'sample_model_mean_predictions.pkl', 'wb') as f:
    pickle.dump([idx_test, predictions], f)
Example #12
0
            tf.gather(tf.constant(D['pos_bin'], dtype=tf.float32), x),
            tf.gather(tf.constant(D['chr'], dtype=tf.int32), x),
        ),
        y,
    ))

    ds_train = iter(ds_train).get_next()
    ds_valid_batch = iter(ds_valid).get_next()

    losses = [Losses.CrossEntropy()]

    while True:
        position_encoder = InstanceModels.VariantPositionBin(24, 100)
        mil = RaggedModels.MIL(instance_encoders=[],
                               sample_encoders=[position_encoder.model],
                               output_dim=y_label.shape[-1],
                               output_type='anlulogits',
                               mil_hidden=[32, 16],
                               mode='none')
        mil.model.compile(loss=losses,
                          metrics=[Metrics.CrossEntropy(),
                                   Metrics.Accuracy()],
                          optimizer=tf.keras.optimizers.Adam(
                              learning_rate=0.01, ))

        mil.model.fit(x=ds_train[0],
                      y=ds_train[1],
                      batch_size=len(idx_train) // 4,
                      epochs=1000000,
                      validation_data=ds_valid_batch,
                      shuffle=True,
                      callbacks=callbacks)
Example #13
0
 ds_test = ds_test.batch(len(idx_test), drop_remainder=False)
 ds_test = ds_test.map(lambda x, y: (
     (five_p_loader(x, ragged_output=True),
      three_p_loader(x, ragged_output=True),
      ref_loader(x, ragged_output=True), alt_loader(x, ragged_output=True),
      strand_loader(x, ragged_output=True)), y))
 X = False
 while X == False:
     print(index)
     try:
         tile_encoder = InstanceModels.VariantSequence(
             6, 4, 2, [16, 16, 8, 8])
         # mil = RaggedModels.MIL(instance_encoders=[tile_encoder.model], output_dim=1, pooling='both', output_type='other', pooled_layers=[32, 128, 64])
         mil = RaggedModels.MIL(instance_encoders=[tile_encoder.model],
                                output_dim=1,
                                pooling='dynamic',
                                output_type='other',
                                pooled_layers=[128, 64])
         losses = [Losses.CoxPH()]
         mil.model.compile(loss=losses,
                           metrics=[Losses.CoxPH()],
                           optimizer=tf.keras.optimizers.Adam(
                               learning_rate=0.001, ))
         callbacks = [
             tf.keras.callbacks.EarlyStopping(monitor='val_coxph',
                                              min_delta=0.0001,
                                              patience=20,
                                              mode='min',
                                              restore_best_weights=True)
         ]
         history = mil.model.fit(ds_train,