Exemple #1
0
                                      tf.gather(tf.constant(D['seq_3p'], dtype=tf.int32), x),
                                      tf.gather(tf.constant(D['seq_ref'], dtype=tf.int32), x),
                                      tf.gather(tf.constant(D['seq_alt'], dtype=tf.int32), x),
                                      tf.gather(tf.constant(D['strand_emb'], dtype=tf.float32), x),
                                      tf.gather(tf.constant(D['cds_emb'], dtype=tf.float32), x)
                                       ),
                                       y,
                                      ))



sequence_encoder = InstanceModels.VariantSequence(6, 4, 2, [64, 64, 64, 64], fusion_dimension=128, use_frame=True)
mil = RaggedModels.MIL(instance_encoders=[], sample_encoders=[sequence_encoder.model], output_dim=y_label.shape[-1], output_type='other', mil_hidden=[128, 128, 64, 32], mode='none')
losses = [Losses.CrossEntropy()]
mil.model.compile(loss=losses,
                  metrics=[Metrics.Accuracy(), Metrics.CrossEntropy()],
                  weighted_metrics=[Metrics.Accuracy(), Metrics.CrossEntropy()],
                  optimizer=tf.keras.optimizers.Adam(learning_rate=0.001,
                                                     ))

callbacks = [tf.keras.callbacks.EarlyStopping(monitor='val_weighted_CE', min_delta=0.001, patience=10, mode='min', restore_best_weights=True)]


mil.model.fit(ds_train, steps_per_epoch=50,
              validation_data=ds_valid,
              epochs=10000,
              callbacks=callbacks,
              )


with open(cwd / 'figures' / 'controls' / 'instances' / 'sequence' / 'codons' / 'results' / 'weights_with_frame.pkl', 'wb') as f:
Exemple #2
0
    tf.keras.callbacks.EarlyStopping(monitor='val_CE',
                                     min_delta=0.0001,
                                     patience=50,
                                     mode='min',
                                     restore_best_weights=True)
]
losses = [Losses.CrossEntropy(from_logits=False)]
sequence_encoder = InstanceModels.VariantSequence(20, 4, 2, [8, 8, 8, 8])
mil = RaggedModels.MIL(instance_encoders=[sequence_encoder.model],
                       output_dim=2,
                       pooling='sum',
                       mil_hidden=(64, 64, 32, 16),
                       output_type='classification_probability')
mil.model.compile(
    loss=losses,
    metrics=[Metrics.CrossEntropy(from_logits=False),
             Metrics.Accuracy()],
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, clipvalue=10000))
initial_weights = mil.model.get_weights()

##stratified K fold for test
for idx_train, idx_test in StratifiedKFold(n_splits=9,
                                           random_state=0,
                                           shuffle=True).split(
                                               y_strat, y_strat):
    ##due to the y_strat levels not being constant this idx_train/idx_valid split is not deterministic
    idx_train, idx_valid = [
        idx_train[idx] for idx in list(
            StratifiedShuffleSplit(n_splits=1, test_size=300, random_state=0).
            split(np.zeros_like(y_strat)[idx_train], y_strat[idx_train]))[0]
    ]
Exemple #3
0
))

sequence_encoder = InstanceModels.VariantSequence(6,
                                                  4,
                                                  2, [16, 16, 8, 8],
                                                  fusion_dimension=128)
mil = RaggedModels.MIL(instance_encoders=[],
                       sample_encoders=[sequence_encoder.model],
                       output_dim=y_label.shape[-1],
                       output_type='other',
                       mil_hidden=[128, 128],
                       mode='none')
losses = [Losses.CrossEntropy()]
mil.model.compile(
    loss=losses,
    metrics=[Metrics.Accuracy(), Metrics.CrossEntropy()],
    weighted_metrics=[Metrics.Accuracy(),
                      Metrics.CrossEntropy()],
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, ))

callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_weighted_CE',
                                     min_delta=0.001,
                                     patience=10,
                                     mode='min',
                                     restore_best_weights=True)
]
mil.model.fit(
    ds_train,
    steps_per_epoch=200,
    validation_data=ds_valid,
Exemple #4
0
    ds_train = iter(ds_train).get_next()
    ds_valid_batch = iter(ds_valid).get_next()

    losses = [Losses.CrossEntropy()]

    while True:
        position_encoder = InstanceModels.VariantPositionBin(24, 100)
        mil = RaggedModels.MIL(instance_encoders=[],
                               sample_encoders=[position_encoder.model],
                               output_dim=y_label.shape[-1],
                               output_type='anlulogits',
                               mil_hidden=[32, 16],
                               mode='none')
        mil.model.compile(loss=losses,
                          metrics=[Metrics.CrossEntropy(),
                                   Metrics.Accuracy()],
                          optimizer=tf.keras.optimizers.Adam(
                              learning_rate=0.01, ))

        mil.model.fit(x=ds_train[0],
                      y=ds_train[1],
                      batch_size=len(idx_train) // 4,
                      epochs=1000000,
                      validation_data=ds_valid_batch,
                      shuffle=True,
                      callbacks=callbacks)

        eval = mil.model.evaluate(ds_valid)
        if eval[1] < .0005:
            break