[five_p_loader, three_p_loader, ref_loader, alt_loader, strand_loader], ] # set y label y_label = np.log( sample_df['non_syn_counts'].values / (panels.loc[panels['Panel'] == 'Agilent_kit']['cds'].values[0] / 1e6) + 1)[:, np.newaxis] y_strat = np.argmax(samples['histology'], axis=-1) losses = [Losses.QuantileLoss()] metrics = [Metrics.QuantileLoss()] encoders = [ InstanceModels.PassThrough(shape=(1, )), InstanceModels.VariantPositionBin(24, 100), InstanceModels.VariantSequence(6, 4, 2, [16, 16, 8, 8], fusion_dimension=32) ] all_weights = [ pickle.load( open( cwd / 'figures' / 'tmb' / 'tcga' / 'VICC_01_R2' / 'results' / 'run_naive.pkl', 'rb')), pickle.load( open( cwd / 'figures' / 'tmb' / 'tcga' / 'VICC_01_R2' / 'results' / 'run_position.pkl', 'rb')),
y, tf.gather(tf.constant(y_weights, dtype=tf.float32), x) )) ds_valid = tf.data.Dataset.from_tensor_slices((idx_valid, y_label[idx_valid])) ds_valid = ds_valid.batch(len(idx_valid), drop_remainder=False) ds_valid = ds_valid.map(lambda x, y: ((pos_loader(x, ragged_output=True), bin_loader(x, ragged_output=True), chr_loader(x, ragged_output=True), ), y, tf.gather(tf.constant(y_weights, dtype=tf.float32), x) )) while True: position_encoder = InstanceModels.VariantPositionBin(24, 100) mil = RaggedModels.MIL(instance_encoders=[position_encoder.model], output_dim=2, pooling='sum', mil_hidden=(64, 32, 16, 8), output_type='anlulogits') mil.model.compile(loss=losses, metrics=[Metrics.CrossEntropy(), Metrics.Accuracy()], weighted_metrics=[Metrics.CrossEntropy(), Metrics.Accuracy()], optimizer=tf.keras.optimizers.Adam(learning_rate=0.005, clipvalue=10000)) mil.model.fit(ds_train, steps_per_epoch=20, validation_data=ds_valid, epochs=10000, callbacks=callbacks) eval = mil.model.evaluate(ds_valid)