Beispiel #1
0
        tf.gather(tf.constant(D['strand_emb'], dtype=tf.float32), x),
    ),
    y,
))

sequence_encoder = InstanceModels.VariantSequence(6,
                                                  4,
                                                  2, [16, 16, 8, 8],
                                                  fusion_dimension=128)
mil = RaggedModels.MIL(instance_encoders=[],
                       sample_encoders=[sequence_encoder.model],
                       output_dim=y_label.shape[-1],
                       output_type='other',
                       mil_hidden=[128, 128],
                       mode='none')
losses = [Losses.CrossEntropy()]
mil.model.compile(
    loss=losses,
    metrics=[Metrics.Accuracy(), Metrics.CrossEntropy()],
    weighted_metrics=[Metrics.Accuracy(),
                      Metrics.CrossEntropy()],
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, ))

callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_weighted_CE',
                                     min_delta=0.001,
                                     patience=10,
                                     mode='min',
                                     restore_best_weights=True)
]
mil.model.fit(
Beispiel #2
0
    strat_dict[(group, event)]
    for group, event in zip(cancer_labels, y_label[:, 1])
])
class_counts = dict(zip(*np.unique(y_strat, return_counts=True)))
y_weights = np.array([1 / class_counts[_] for _ in y_strat])
y_weights /= np.sum(y_weights)

weights = []
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_CE',
                                     min_delta=0.0001,
                                     patience=50,
                                     mode='min',
                                     restore_best_weights=True)
]
losses = [Losses.CrossEntropy(from_logits=False)]
sequence_encoder = InstanceModels.VariantSequence(20, 4, 2, [8, 8, 8, 8])
mil = RaggedModels.MIL(instance_encoders=[sequence_encoder.model],
                       output_dim=2,
                       pooling='sum',
                       mil_hidden=(64, 64, 32, 16),
                       output_type='classification_probability')
mil.model.compile(
    loss=losses,
    metrics=[Metrics.CrossEntropy(from_logits=False),
             Metrics.Accuracy()],
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, clipvalue=10000))
initial_weights = mil.model.get_weights()

##stratified K fold for test
for idx_train, idx_test in StratifiedKFold(n_splits=9,