def find_genes_CX(drug, model, meta, gdsc_expr, gdsc_dr, test_tcga_expr, save_dir): torch.manual_seed(SEED) np.random.seed(SEED) print('obtaining masked data...') masked_data = get_masked_data_for_CXPlain(model, gdsc_expr) print('obtained masked data...') # get_masked_data_for_CXPlain(model, test_tcga_expr) import tensorflow as tf tf.compat.v1.disable_v2_behavior() tf.keras.backend.clear_session() tf.random.set_seed(SEED) from tensorflow.python.keras.losses import mean_squared_error as loss from cxplain import MLPModelBuilder, CXPlain # from cxplain.backend.masking.zero_masking import FastZeroMasking model_builder = MLPModelBuilder(num_layers=2, num_units=512, batch_size=8, learning_rate=0.001) # masking_operation = FastZeroMasking() print(gdsc_expr.values.shape, gdsc_dr.values.shape) print("Fitting CXPlain model") explainer = CXPlain(model, model_builder, None, loss) explainer.fit(gdsc_expr.values, gdsc_dr.values, masked_data=masked_data) print("Attributing using CXPlain") attr = explainer.explain(test_tcga_expr.values) attr = pd.DataFrame(attr, index=test_tcga_expr.index, columns=dataset.hgnc) borda = get_ranked_list(attr) attr_mean = list(np.abs(attr).mean(axis=0).nlargest(200).index) out = pd.DataFrame(columns=['borda', 'mean']) out['borda'] = borda out['mean'] = attr_mean out.to_csv(save_dir + '/genes.csv', index=False) if not os.path.exists(save_dir + '/explainer/'): os.mkdir(save_dir + '/explainer/') explainer.save(save_dir + '/explainer/')
def cxpl(model_dir, data_dir, results_subdir, random_seed, resolution): np.random.seed(random_seed) tf.set_random_seed(np.random.randint(1 << 31)) session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) sess = tf.Session(graph=tf.get_default_graph(), config=session_conf) set_session(sess) # parser config config_file = model_dir+ "/config.ini" print("Config File Path:", config_file,flush=True) assert os.path.isfile(config_file) cp = ConfigParser() cp.read(config_file) output_dir = os.path.join(results_subdir, "classification_results/test") print("Output Directory:", output_dir,flush=True) if not os.path.isdir(output_dir): os.makedirs(output_dir) # default config image_dimension = cp["TRAIN"].getint("image_dimension") gan_resolution = resolution batch_size = cp["TEST"].getint("batch_size") use_best_weights = cp["TEST"].getboolean("use_best_weights") if use_best_weights: print("** Using BEST weights",flush=True) model_weights_path = os.path.join(results_subdir, "classification_results/train/best_weights.h5") else: print("** Using LAST weights",flush=True) model_weights_path = os.path.join(results_subdir, "classification_results/train/weights.h5") print("** DenseNet Input Resolution:", image_dimension, flush=True) print("** GAN Image Resolution:", gan_resolution, flush=True) # get test sample count test_dir = os.path.join(results_subdir, "inference/test") shutil.copy(test_dir+"/test.csv", output_dir) # Get class names class_names = get_class_names(output_dir,"test") tfrecord_dir_te = os.path.join(data_dir, "test") test_counts, _ = get_sample_counts(output_dir, "test", class_names) # get indicies (all of csv file for validation) print("** test counts:", test_counts, flush=True) # compute steps test_steps = int(np.floor(test_counts / batch_size)) print("** test_steps:", test_steps, flush=True) log2_record = int(np.log2(gan_resolution)) record_file_ending = "*"+ np.str(log2_record)+ ".tfrecords" print("** resolution ", gan_resolution, " corresponds to ", record_file_ending, " TFRecord file.", flush=True) # Get Model # ------------------------------------ input_shape=(image_dimension, image_dimension, 3) img_input = Input(shape=input_shape) base_model = DenseNet121( include_top = False, weights = None, input_tensor = img_input, input_shape = input_shape, pooling = "avg") x = base_model.output predictions = Dense(len(class_names), activation="sigmoid", name="predictions")(x) model = Model(inputs=img_input, outputs = predictions) print(" ** load model from:", model_weights_path, flush=True) model.load_weights(model_weights_path) # ------------------------------------ print("** load test generator **", flush=True) test_seq = TFWrapper( tfrecord_dir=tfrecord_dir_te, record_file_endings = record_file_ending, batch_size = batch_size, model_target_size = (image_dimension, image_dimension), steps = None, augment=False, shuffle=False, prefetch=True, repeat=False) print("** make prediction **", flush=True) test_seq.initialise() x_all, y_all = test_seq.get_all_test_data() print("X-Test Shape:", x_all.shape,flush=True) print("Y-Test Shape:", y_all.shape,flush=True) print("----------------------------------------", flush=True) print("Test Model AUROC", flush=True) y_pred = model.predict(x_all) current_auroc = [] for i in range(len(class_names)): try: score = roc_auc_score(y_all[:, i], y_pred[:, i]) except ValueError: score = 0 current_auroc.append(score) print(i+1,class_names[i],": ", score, flush=True) mean_auroc = np.mean(current_auroc) print("Mean auroc: ", mean_auroc,flush=True) print("----------------------------------------", flush=True) downscale_factor = 8 num_models_to_use = 3 num_test_images = 100 print("Number of Models to use:", num_models_to_use, flush=True) print("Number of Test images:", num_test_images, flush=True) x_tr, y_tr = x_all[num_test_images:], y_all[num_test_images:] x_te, y_te = x_all[0:num_test_images], y_all[0:num_test_images] downsample_factors = (downscale_factor,downscale_factor) print("Downsample Factors:", downsample_factors,flush=True) model_builder = UNetModelBuilder(downsample_factors, num_layers=2, num_units=8, activation="relu", p_dropout=0.0, verbose=0, batch_size=32, learning_rate=0.001) print("Model build done.",flush=True) masking_operation = ZeroMasking() loss = categorical_crossentropy explainer = CXPlain(model, model_builder, masking_operation, loss, num_models=num_models_to_use, downsample_factors=downsample_factors, flatten_for_explained_model=False) print("Explainer build done.",flush=True) explainer.fit(x_tr, y_tr); print("Explainer fit done.",flush=True) try: attr, conf = explainer.explain(x_te, confidence_level=0.80) np.save(output_dir+"/x_cxpl.npy", x_te) np.save(output_dir+"/y_cxpl.npy", y_te) np.save(output_dir+"/attr.npy", attr) np.save(output_dir+"/conf.npy", conf) print("Explainer explain done and saved.",flush=True) except Exception as ef: print(ef,flush=True)