Exemple #1
0
    [five_p_loader, three_p_loader, ref_loader, alt_loader, strand_loader],
]

# set y label
y_label = np.log(
    sample_df['non_syn_counts'].values /
    (panels.loc[panels['Panel'] == 'Agilent_kit']['cds'].values[0] / 1e6) +
    1)[:, np.newaxis]
y_strat = np.argmax(samples['histology'], axis=-1)

losses = [Losses.QuantileLoss()]
metrics = [Metrics.QuantileLoss()]

encoders = [
    InstanceModels.PassThrough(shape=(1, )),
    InstanceModels.VariantPositionBin(24, 100),
    InstanceModels.VariantSequence(6,
                                   4,
                                   2, [16, 16, 8, 8],
                                   fusion_dimension=32)
]

all_weights = [
    pickle.load(
        open(
            cwd / 'figures' / 'tmb' / 'tcga' / 'VICC_01_R2' / 'results' /
            'run_naive.pkl', 'rb')),
    pickle.load(
        open(
            cwd / 'figures' / 'tmb' / 'tcga' / 'VICC_01_R2' / 'results' /
            'run_position.pkl', 'rb')),
Exemple #2
0
                                           y,
                                          tf.gather(tf.constant(y_weights, dtype=tf.float32), x)
                                           ))

    ds_valid = tf.data.Dataset.from_tensor_slices((idx_valid, y_label[idx_valid]))
    ds_valid = ds_valid.batch(len(idx_valid), drop_remainder=False)
    ds_valid = ds_valid.map(lambda x, y: ((pos_loader(x, ragged_output=True),
                                           bin_loader(x, ragged_output=True),
                                           chr_loader(x, ragged_output=True),
                                           ),
                                           y,
                                          tf.gather(tf.constant(y_weights, dtype=tf.float32), x)
                                           ))

    while True:
        position_encoder = InstanceModels.VariantPositionBin(24, 100)
        mil = RaggedModels.MIL(instance_encoders=[position_encoder.model], output_dim=2, pooling='sum', mil_hidden=(64, 32, 16, 8), output_type='anlulogits')

        mil.model.compile(loss=losses,
                          metrics=[Metrics.CrossEntropy(), Metrics.Accuracy()],
                          weighted_metrics=[Metrics.CrossEntropy(), Metrics.Accuracy()],
                          optimizer=tf.keras.optimizers.Adam(learning_rate=0.005,
                                                             clipvalue=10000))
        mil.model.fit(ds_train,
                      steps_per_epoch=20,
                      validation_data=ds_valid,
                      epochs=10000,
                      callbacks=callbacks)


        eval = mil.model.evaluate(ds_valid)