def feature_importance(crispr_model, x_len, xs, y, unique_test_index): logger.debug("Getting features ranks") names = [] names += ["for_" + str(i) for i in range(x_len)] names += config.extra_categorical_features + config.extra_numerical_features ranker = feature_imp.InputPerturbationRank(names) feature_ranks = ranker.rank(20, y[unique_test_index], crispr_model, [data[unique_test_index] for data in xs]) feature_ranks_df = pd.DataFrame(feature_ranks) feature_ranks_df.to_csv(config.feature_importance_path, index=False) logger.debug("Get features ranks successfully")
elif attention_setting.output_FF_layers[-1] == 1: crispr_model = attention_model.get_OT_model(data_pre.feature_length_map) regressor_training(crispr_model, X, y, cv_splitter) else: crispr_model_classifier = attention_model.get_OT_model(data_pre.feature_length_map, classifier=True) best_crispr_model, train_index = classifier_training(crispr_model_classifier, X, y_binary, cv_splitter) if config.check_feature_importance: logger.debug("Getting features ranks") names = [] names += ["src_" + str(i) for i in range(data_pre.feature_length_map[0][1])] if data_pre.feature_length_map[1] is not None: names += ["trg_" + str(i) for i in range( data_pre.feature_length_map[1][1] - data_pre.feature_length_map[1][0])] if data_pre.feature_length_map[ 2] is not None: names += config.extra_categorical_features + config.extra_numerical_features ranker = feature_imp.InputPerturbationRank(names) feature_ranks = ranker.rank(2, y_binary[train_index], best_crispr_model, [X[train_index, :]]) feature_ranks_df = pd.DataFrame(feature_ranks) feature_ranks_df.to_csv(config.feature_importance_path, index=False) logger.debug("Get features ranks successfully")
def run(): data_pre = OT_crispr_attn.data_preparer() print(data_pre.get_crispr_preview()) X = data_pre.prepare_x() y = data_pre.get_labels() data_pre.persist_data() print(X.head()) print(data_pre.feature_length_map) train_index, test_index = data_pre.train_test_split() if config.log2fc_filter: train_filter_1 = (y.loc[train_index, :] > 10) y[train_filter_1] = 10 train_filter_2 = (y.loc[train_index, :] < -10) y[train_filter_2] = -10 # train_filter = (y.loc[train_index, :] > 10) # train_index = train_index[train_filter.iloc[:,0]] print(train_index, test_index) torch.manual_seed(0) # X = X.values.astype(np.float32) # y = y.values.astype(np.float32) # y_binary = (y > np.quantile(y, 0.8)).astype(int).reshape(-1, ) #y[y.loc[train_index, config.y] < 0] = 0 # essen_filter = list(data_pre.crispr[data_pre.crispr['log2fc'] > 0].index) # train_index = list(set(train_index).intersection(essen_filter)) # test_index = list(set(test_index).intersection(essen_filter)) std_scaler = StandardScaler() m_m = MinMaxScaler((0, 100)) if config.y_transform: std_scaler.fit(y.loc[train_index, :]) new_y = std_scaler.transform(y) * 100 y = pd.DataFrame(new_y, columns=y.columns, index=y.index) m_m.fit(y.loc[train_index, :]) new_y = m_m.transform(y) y = pd.DataFrame(new_y, columns=y.columns, index=y.index) if config.test_cellline is not None: test_cellline_index = data_pre.crispr[data_pre.crispr['cellline'].isin( config.test_cellline)].index test_index = test_cellline_index.intersection(test_index) if config.train_cellline is not None: train_cellline_index = data_pre.crispr[ data_pre.crispr['cellline'].isin(config.train_cellline)].index train_index = train_cellline_index.intersection(train_index) logger.debug("training data amounts: %s, testing data amounts: %s" % (len(train_index), len(test_index))) x_train, x_test, y_train, y_test = \ X.loc[train_index, :], X.loc[test_index, :], \ y.loc[train_index, :], y.loc[test_index, :] _, unique_train_index = np.unique(pd.concat([x_train, y_train], axis=1), return_index=True, axis=0) _, unique_test_index = np.unique(pd.concat([x_test, y_test], axis=1), return_index=True, axis=0) logger.debug( "after deduplication, training data amounts: %s, testing data amounts: %s" % (len(unique_train_index), len(unique_test_index))) train_index = train_index[unique_train_index] test_index = test_index[unique_test_index] x_concat = pd.concat([X, y], axis=1) _, unique_index = np.unique(x_concat, return_index=True, axis=0) logger.debug("{0!r}, {1!r}".format( (len(x_concat.loc[train_index, :].drop_duplicates())), str(len(x_concat.loc[train_index, :])))) logger.debug("Splitted dataset successfully") train_eval_index_list = list(train_index) random.shuffle(train_eval_index_list) sep = int(len(train_eval_index_list) * 0.9) train_index_list = train_eval_index_list[:sep] eval_index_list = train_eval_index_list[sep:] partition = { 'train': train_eval_index_list, 'eval': list(test_index), 'test': list(test_index) } labels = { key: value for key, value in zip(range(len(y)), list(y.values.reshape(-1))) } train_params = {'batch_size': config.batch_size, 'shuffle': True} eval_params = {'batch_size': len(test_index), 'shuffle': False} test_params = {'batch_size': len(test_index), 'shuffle': False} logger.debug("Preparing datasets ... ") training_set = my_data.MyDataset(partition['train'], labels) training_generator = data.DataLoader(training_set, **train_params) train_bg_params = { 'batch_size': len(train_eval_index_list) // 6 + 5, 'shuffle': False } training_bg_set = my_data.MyDataset(partition['train'], labels) training_bg_generator = data.DataLoader(training_bg_set, **train_bg_params) validation_set = my_data.MyDataset(partition['eval'], labels) validation_generator = data.DataLoader(validation_set, **eval_params) test_set = my_data.MyDataset(partition['test'], labels) test_generator = data.DataLoader(test_set, **test_params) logger.debug("I might need to augment data") logger.debug("Building the scaled dot product attention model") for_input_len = data_pre.feature_length_map[0][ 1] - data_pre.feature_length_map[0][0] extra_input_len = 0 if not data_pre.feature_length_map[2] \ else data_pre.feature_length_map[2][1] - data_pre.feature_length_map[2][0] crispr_model = attention_model.get_OT_model(data_pre.feature_length_map) best_crispr_model = attention_model.get_OT_model( data_pre.feature_length_map) crispr_model.to(device2) best_crispr_model.to(device2) # crispr_model = attention_model.get_model(d_input=for_input_len) # best_crispr_model = attention_model.get_model(d_input=for_input_len) # #crispr_model = attention_model.get_OT_model(data_pre.feature_length_map, classifier=True) # crispr_model.to(device2) # best_crispr_model.to(device2) best_cv_spearman_score = 0 if config.retraining: logger.debug("I need to load a old trained model") crispr_model = load(config.retraining_model) crispr_model.load_state_dict(load(config.retraining_model_state)) crispr_model.to(device2) logger.debug("I might need to freeze some of the weights") logger.debug("Built the Crispr model successfully") optimizer = torch.optim.Adam(crispr_model.parameters(), lr=config.start_lr, weight_decay=config.lr_decay, betas=(0.9, 0.98), eps=1e-9) try: if config.training: logger.debug("Training the model") mse_visualizer = torch_visual.VisTorch(env_name='MSE') pearson_visualizer = torch_visual.VisTorch(env_name='Pearson') spearman_visualizer = torch_visual.VisTorch(env_name='Spearman') for epoch in range(config.n_epochs): start = time() cur_epoch_train_loss = [] train_total_loss = 0 i = 0 # Training for local_batch, local_labels in training_generator: i += 1 # Transfer to GPU local_batch, local_labels = local_batch.float().to( device2), local_labels.float().to(device2) # seq_local_batch = local_batch.narrow(dim=1, start=0, length=for_input_len).long() # extra_local_batch = local_batch.narrow(dim=1, start=for_input_len, length=extra_input_len) # Model computations preds = crispr_model(local_batch).contiguous().view(-1) ys = local_labels.contiguous().view(-1) optimizer.zero_grad() assert preds.size(-1) == ys.size(-1) #loss = F.nll_loss(preds, ys) crispr_model.train() loss = F.mse_loss(preds, ys) loss.backward() optimizer.step() train_total_loss += loss.item() n_iter = 2 if i % n_iter == 0: sample_size = len(train_index) p = int(100 * i * config.batch_size / sample_size) avg_loss = train_total_loss / n_iter if config.y_inverse_transform: avg_loss = \ std_scaler.inverse_transform(np.array(avg_loss / 100).reshape(-1, 1)).reshape(-1)[0] logger.debug(" %dm: epoch %d [%s%s] %d%% loss = %.3f" % \ ((time() - start) // 60, epoch, "".join('#' * (p // 5)), "".join(' ' * (20 - (p // 5))), p, avg_loss)) train_total_loss = 0 cur_epoch_train_loss.append(avg_loss) ### Evaluation val_i = 0 val_total_loss = 0 val_loss = [] val_preds = [] val_ys = [] val_pearson = 0 val_spearman = 0 with torch.set_grad_enabled(False): crispr_model.eval() for local_batch, local_labels in training_bg_generator: val_i += 1 local_labels_on_cpu = np.array(local_labels).reshape( -1) sample_size = local_labels_on_cpu.shape[-1] local_labels_on_cpu = local_labels_on_cpu[:sample_size] # Transfer to GPU local_batch, local_labels = local_batch.float().to( device2), local_labels.float().to(device2) # seq_local_batch = local_batch.narrow(dim=1, start=0, length=for_input_len).long() # extra_local_batch = local_batch.narrow(dim=1, start=for_input_len, length=extra_input_len) preds = crispr_model(local_batch).contiguous().view(-1) assert preds.size(-1) == local_labels.size(-1) prediction_on_cpu = preds.cpu().numpy().reshape(-1) # mean_prediction_on_cpu = np.mean([prediction_on_cpu[:sample_size], # prediction_on_cpu[sample_size:]], axis=0) mean_prediction_on_cpu = prediction_on_cpu[: sample_size] if config.y_inverse_transform: local_labels_on_cpu, mean_prediction_on_cpu = \ std_scaler.inverse_transform(local_labels_on_cpu.reshape(-1, 1) / 100), \ std_scaler.inverse_transform(mean_prediction_on_cpu.reshape(-1, 1) / 100) loss = mean_squared_error(local_labels_on_cpu, mean_prediction_on_cpu) val_preds.append(mean_prediction_on_cpu) val_ys.append(local_labels_on_cpu) val_total_loss += loss n_iter = 1 if val_i % n_iter == 0: avg_loss = val_total_loss / n_iter val_loss.append(avg_loss) val_total_loss = 0 mean_prediction_on_cpu = np.concatenate(tuple(val_preds)) local_labels_on_cpu = np.concatenate(tuple(val_ys)) val_pearson = pearsonr(mean_prediction_on_cpu.reshape(-1), local_labels_on_cpu.reshape(-1))[0] val_spearman = spearmanr( mean_prediction_on_cpu.reshape(-1), local_labels_on_cpu.reshape(-1))[0] logger.debug( "Validation mse is {0}, Validation pearson correlation is {1!r} and Validation " "spearman correlation is {2!r}".format( np.mean(val_loss), val_pearson, val_spearman)) mse_visualizer.plot_loss(epoch, np.mean(cur_epoch_train_loss), np.mean(val_loss), loss_type='mse') pearson_visualizer.plot_loss(epoch, val_pearson, loss_type='pearson_loss', ytickmin=0, ytickmax=1) spearman_visualizer.plot_loss(epoch, val_spearman, loss_type='spearman_loss', ytickmin=0, ytickmax=1) ### Evaluation val_i = 0 val_total_loss = 0 val_loss = [] val_preds = [] val_ys = [] val_pearson = 0 val_spearman = 0 with torch.set_grad_enabled(False): crispr_model.eval() for local_batch, local_labels in validation_generator: val_i += 1 local_labels_on_cpu = np.array(local_labels).reshape( -1) sample_size = local_labels_on_cpu.shape[-1] local_labels_on_cpu = local_labels_on_cpu[:sample_size] # Transfer to GPU local_batch, local_labels = local_batch.float().to( device2), local_labels.float().to(device2) # seq_local_batch = local_batch.narrow(dim=1, start=0, length=for_input_len).long() # extra_local_batch = local_batch.narrow(dim=1, start=for_input_len, length=extra_input_len) preds = crispr_model(local_batch).contiguous().view(-1) assert preds.size(-1) == local_labels.size(-1) prediction_on_cpu = preds.cpu().numpy().reshape(-1) # mean_prediction_on_cpu = np.mean([prediction_on_cpu[:sample_size], # prediction_on_cpu[sample_size:]], axis=0) mean_prediction_on_cpu = prediction_on_cpu[: sample_size] if config.y_inverse_transform: local_labels_on_cpu, mean_prediction_on_cpu = \ std_scaler.inverse_transform(local_labels_on_cpu.reshape(-1, 1) / 100), \ std_scaler.inverse_transform(mean_prediction_on_cpu.reshape(-1, 1) / 100) loss = mean_squared_error(local_labels_on_cpu, mean_prediction_on_cpu) val_preds.append(mean_prediction_on_cpu) val_ys.append(local_labels_on_cpu) val_total_loss += loss n_iter = 1 if val_i % n_iter == 0: avg_loss = val_total_loss / n_iter val_loss.append(avg_loss) val_total_loss = 0 mean_prediction_on_cpu = np.concatenate(tuple(val_preds)) local_labels_on_cpu = np.concatenate(tuple(val_ys)) val_pearson = pearsonr(mean_prediction_on_cpu.reshape(-1), local_labels_on_cpu.reshape(-1))[0] val_spearman = spearmanr( mean_prediction_on_cpu.reshape(-1), local_labels_on_cpu.reshape(-1))[0] if best_cv_spearman_score < val_spearman: best_cv_spearman_score = val_spearman best_crispr_model.load_state_dict( crispr_model.state_dict()) logger.debug( "Test mse is {0}, Test pearson correlation is {1!r} and Test " "spearman correlation is {2!r}".format( np.mean(val_loss), val_pearson, val_spearman)) mse_visualizer.plot_loss(epoch, np.mean(cur_epoch_train_loss), np.mean(val_loss), loss_type='mse') pearson_visualizer.plot_loss(epoch, val_pearson, loss_type='pearson_loss', ytickmin=0, ytickmax=1) spearman_visualizer.plot_loss(epoch, val_spearman, loss_type='spearman_loss', ytickmin=0, ytickmax=1) logger.debug("Saving training history") logger.debug("Saved training history successfully") logger.debug("Trained crispr model successfully") else: logger.debug("loading in old model") logger.debug("Load in model successfully") except KeyboardInterrupt as e: logger.debug("Loading model") logger.debug("loading some intermediate step's model") logger.debug("Load in model successfully") logger.debug("Persisting model") # serialize weights to HDF5 save(best_crispr_model, config.hdf5_path) save(best_crispr_model.state_dict(), config.hdf5_path_state) logger.debug("Saved model to disk") save_output = [] if 'essentiality' in config.extra_numerical_features: save_output.append(data_pre.crispr.loc[test_index, 'essentiality']) test_model(best_crispr_model, test_generator, save_output, std_scaler) if config.check_feature_importance: logger.debug("Getting features ranks") names = [] names += [ "src_" + str(i) for i in range(data_pre.feature_length_map[0][1]) ] if data_pre.feature_length_map[1] is not None: names += [ "trg_" + str(i) for i in range(data_pre.feature_length_map[1][1] - data_pre.feature_length_map[1][0]) ] if data_pre.feature_length_map[2] is not None: names += config.extra_categorical_features + config.extra_numerical_features ranker = feature_imp.InputPerturbationRank(names) feature_ranks = ranker.rank( 2, y.loc[train_eval_index_list, :].values, best_crispr_model, [torch.FloatTensor(X.loc[train_eval_index_list, :].values)], torch=True) feature_ranks_df = pd.DataFrame(feature_ranks) feature_ranks_df.to_csv(config.feature_importance_path, index=False) logger.debug("Get features ranks successfully")
def run(): logger.debug("Reading in the crispr dataset %s" % config.input_dataset) crispr = pd.read_csv(config.input_dataset) crispr['PAM'] = crispr['sequence'].str[-3:] if config.log_cen: crispr['essentiality'] = np.log(crispr['essentiality'] * 100 + 1) if config.with_pam: pam_code = 8 else: pam_code = 0 # scale_features process_features.scale_features(crispr) process_features.scale_output(crispr) logger.debug("Read in data successfully") logger.debug("Transforming data") X_for = crispr.loc[:, 'sequence'].apply( lambda seq: utils.split_seqs(seq[:config.seq_len])) X_rev = crispr.loc[:, 'sequence'].apply( lambda seq: utils.split_seqs(seq[config.seq_len - 1::-1])) X_cnn = crispr.loc[:, 'sequence'].apply( lambda seq: utils.split_seqs(seq[:config.seq_len], nt=1)) X = pd.concat([X_for, X_rev, X_cnn], axis=1) logger.debug("Get sequence sucessfully") off_target_X = pd.DataFrame(np.empty(shape=[X_for.shape[0], 0])) # off_target_X = crispr.loc[:, 'sequence'].apply(lambda seq: utils.map_to_matrix(seq, 1, 22)) # y = pd.DataFrame(np.abs(crispr[config.y].copy()) * 10) y = pd.DataFrame(crispr[config.y].copy() * 8) logger.debug("Transformed data successfully") logger.debug( "Starting to prepare for splitting dataset to training dataset and testing dataset based on genes" ) logger.debug("Generating groups based on gene names") if config.group: crispr.loc[:, "group"] = pd.Categorical(crispr.loc[:, config.group]) logger.debug("Generated groups information successfully") logger.debug("Splitting dataset") if os.path.exists(config.train_index) and os.path.exists( config.test_index): train_index = pickle.load(open(config.train_index, "rb")) test_index = pickle.load(open(config.test_index, "rb")) else: train_test_split = getattr(process_features, config.split_method + "_split", process_features.regular_split) train_index, test_index = train_test_split(crispr, group_col=config.group_col, n_split=max( len(crispr) / 100, 10), rd_state=7) with open(config.train_index, 'wb') as train_file: pickle.dump(train_index, train_file) with open(config.test_index, 'wb') as test_file: pickle.dump(test_index, test_file) if config.test_cellline: test_cellline_index = crispr[crispr['cellline'] == config.test_cellline].index test_index = test_cellline_index.intersection(test_index) test_index_list = [ x.index for _, x in crispr.loc[test_index, :].reset_index().groupby('group') if len(x) ] if config.test_method == 'group' else [] logger.debug("Splitted data successfully") logger.debug("training data amounts: %s, testing data amounts: %s" % (len(train_index), len(test_index))) x_train, x_test, y_train, y_test, off_target_X_train, off_target_X_test = \ X.loc[train_index, :], X.loc[test_index, :], \ y.loc[train_index, :], y.loc[test_index, :], \ off_target_X.loc[train_index, :], off_target_X.loc[test_index, :] _, unique_train_index = np.unique(pd.concat([x_train, y_train], axis=1), return_index=True, axis=0) _, unique_test_index = np.unique(pd.concat([x_test, y_test], axis=1), return_index=True, axis=0) logger.debug( "after deduplication, training data amounts: %s, testing data amounts: %s" % (len(unique_train_index), len(unique_test_index))) logger.debug("Splitted dataset successfully") logger.debug("Generating one hot vector for categorical data") extra_crispr_df = crispr[config.extra_categorical_features + config.extra_numerical_features] n_values = [pam_code] + ([2] * (len(config.extra_categorical_features) - 1) ) if config.with_pam else [2] * len( config.extra_categorical_features) process_features.process_categorical_features(extra_crispr_df, n_values) extra_x_train, extra_x_test = extra_crispr_df.loc[ train_index, :].values, extra_crispr_df.loc[test_index, :].values logger.debug("Generating on hot vector for categorical data successfully") logger.debug("Seperate forward and reverse seq") x_train = x_train.values for_input_len = config.seq_len - config.word_len + 1 for_input, rev_input, for_cnn = x_train[:, : for_input_len], x_train[:, for_input_len: 2 * for_input_len], x_train[:, 2 * for_input_len:] x_test = x_test.values for_x_test, rev_x_test, for_cnn_test = x_test[:, : for_input_len], x_test[:, for_input_len: 2 * for_input_len], x_test[:, 2 * for_input_len:] off_target_X_train = off_target_X_train.values off_target_X_test = off_target_X_test.values if not config.off_target: off_target_X_train, off_target_X_test = np.empty( shape=[off_target_X_train.shape[0], 0]), np.empty( shape=[off_target_X_test.shape[0], 0]) if (not config.rev_seq) or (config.model_type == 'mixed'): rev_input, rev_x_test = np.empty( shape=[rev_input.shape[0], 0]), np.empty( shape=[rev_x_test.shape[0], 0]) y_train = y_train.values filter = y_train.flatten() > 0 y_test = y_test.values if config.ml_train: try: ml_train(X, extra_crispr_df, y, train_index, test_index) except: logger.debug("Fail to use random forest") finally: h2o.cluster().shutdown() return logger.debug("Building the RNN graph") weight_matrix = [utils.get_weight_matrix() ] if config.word2vec_weight_matrix else None for_seq_input = Input(shape=(for_input.shape[1], )) rev_seq_input = Input(shape=(rev_input.shape[1], )) for_cnn_input = Input(shape=(for_cnn.shape[1], )) bio_features = Input(shape=(extra_x_train.shape[1], )) off_target_features = Input(shape=(off_target_X_train.shape[1], )) all_features = Input(shape=(for_input.shape[1] + rev_input.shape[1] + extra_x_train.shape[1] + off_target_X_train.shape[1], )) if not config.ensemble: crispr_model = models.CrisprCasModel( bio_features=bio_features, for_seq_input=for_seq_input, rev_seq_input=rev_seq_input, weight_matrix=weight_matrix, off_target_features=off_target_features, all_features=all_features).get_model() else: crispr_model = models.CrisprCasModel( bio_features=bio_features, for_seq_input=for_seq_input, rev_seq_input=rev_seq_input, for_cnn_input=for_cnn_input, weight_matrix=weight_matrix, off_target_features=off_target_features, all_features=all_features).get_model() if config.retraining: loaded_model = load_model(config.retraining_model, custom_objects={ 'revised_mse_loss': utils.revised_mse_loss, 'tf': tf }) for layer in loaded_model.layers: print(layer.name) if config.model_type == 'cnn': for_layer = loaded_model.get_layer(name='embedding_1') for_layer.trainable = config.fine_tune_trainable full_connected = loaded_model.get_layer(name='sequential_6') elif (config.model_type == 'mixed') or (config.model_type == 'ensemble'): for_layer = loaded_model.get_layer(name='sequential_5') if config.frozen_embedding_only: for_layer = for_layer.get_layer(name='embedding_1') for_layer.trainable = config.fine_tune_trainable cnn_layer = loaded_model.get_layer(name='embedding_2') cnn_layer.trainable = config.fine_tune_trainable if not config.frozen_embedding_only: cnn_layer_1 = loaded_model.get_layer(name='sequential_3') cnn_layer_2 = loaded_model.get_layer(name='sequential_4') cnn_layer_1.trainable = config.fine_tune_trainable cnn_layer_2.trainable = config.fine_tune_trainable full_connected = loaded_model.get_layer(name='sequential_6') else: for_layer = loaded_model.get_layer(name='sequential_5') if config.frozen_embedding_only: for_layer = for_layer.get_layer(name='embedding_1') for_layer.trainable = config.fine_tune_trainable if config.rev_seq: rev_layer = loaded_model.get_layer(name='sequential_2') if config.frozen_embedding_only: rev_layer = rev_layer.get_layer(name='embedding_2') rev_layer.trainable = config.fine_tune_trainable full_connected = loaded_model.get_layer(name='sequential_3') else: full_connected = loaded_model.get_layer(name='sequential_6') for i in range( int((len(full_connected.layers) / 4) * (1 - config.fullly_connected_train_fraction))): dense_layer = full_connected.get_layer(name='dense_' + str(i + 1)) dense_layer.trainable = config.fine_tune_trainable crispr_model = models.CrisprCasModel.compile_transfer_learning_model( loaded_model) utils.output_model_info(crispr_model) logger.debug("Built the RNN model successfully") try: if config.training: logger.debug("Training the model") # x_train = x_train.values.astype('int32').reshape((-1, 21, 200)) checkpoint = ModelCheckpoint(config.temp_hdf5_path, verbose=1, save_best_only=True, period=1) reduce_lr = LearningRateScheduler(utils.cosine_decay_lr) logger.debug("augmenting data") processed_for_input = utils.augment_data( for_input, filter=filter, is_seq=True) if config.augment_data else for_input if config.augment_data: if rev_input.shape[0] and rev_input.shape[1]: processed_rev_input = utils.augment_data(rev_input, filter=filter, is_seq=True, is_rev=True) else: processed_rev_input = utils.augment_data(rev_input, filter=filter) else: processed_rev_input = rev_input processed_off_target_X_train = utils.augment_data( off_target_X_train, filter=filter) if config.augment_data else off_target_X_train processed_extra_x_train = utils.augment_data( extra_x_train, filter=filter) if config.augment_data else extra_x_train processed_y_train = utils.augment_data( y_train, filter=filter) if config.augment_data else y_train logger.debug("augmented data successfully") logger.debug("selecting %d data for training" % (config.retraining_datasize * len(processed_y_train))) index_range = list(range(len(processed_y_train))) np.random.shuffle(index_range) selected_index = index_range[:int(config.retraining_datasize * len(processed_y_train))] logger.debug("selecting %d data for training" % (config.retraining_datasize * len(processed_y_train))) features_list = [ processed_for_input[selected_index], processed_rev_input[selected_index], processed_off_target_X_train[selected_index], processed_extra_x_train[selected_index] ] if config.ensemble: processed_for_cnn = utils.augment_data( for_cnn, filter=filter, is_seq=True) if config.augment_data else for_cnn features_list.append(processed_for_cnn[selected_index]) print("ensemble") print(len(features_list)) training_history = utils.print_to_training_log(crispr_model.fit)( x=features_list, validation_split=0.05, y=processed_y_train[selected_index], epochs=config.n_epochs, batch_size=config.batch_size, verbose=2, callbacks=[checkpoint, reduce_lr]) logger.debug("Saving history") with open(config.training_history, 'wb') as history_file: pickle.dump(training_history.history, history_file) logger.debug("Saved training history successfully") logger.debug("Trained crispr model successfully") else: logger.debug("Logging in old model") loaded_model = load_model(config.old_model_hdf5, custom_objects={ 'revised_mse_loss': utils.revised_mse_loss, 'tf': tf }) crispr_model = models.CrisprCasModel.compile_transfer_learning_model( loaded_model) crispr_model.save(config.temp_hdf5_path) logger.debug("Load in model successfully") except KeyboardInterrupt as e: logger.debug("Loading model") loaded_model = load_model(config.temp_hdf5_path, custom_objects={ 'revised_mse_loss': utils.revised_mse_loss, 'tf': tf }) crispr_model = models.CrisprCasModel.compile_transfer_learning_model( loaded_model) logger.debug("Load in model successfully") logger.debug("Persisting model") # serialize weights to HDF5 crispr_model.save(config.hdf5_path) print("Saved model to disk") logger.debug("Loading best model for testing") loaded_model = load_model(config.temp_hdf5_path, custom_objects={ 'revised_mse_loss': utils.revised_mse_loss, 'tf': tf }) crispr_model = models.CrisprCasModel.compile_transfer_learning_model( loaded_model) logger.debug("Load in model successfully") logger.debug("Predicting data with best model") train_list = [ for_input[unique_train_index], rev_input[unique_train_index], off_target_X_train[unique_train_index], extra_x_train[unique_train_index] ] if config.ensemble: train_list.append(for_cnn[unique_train_index]) train_prediction = crispr_model.predict(x=train_list) train_performance = spearmanr(train_prediction, y_train[unique_train_index]) logger.debug( "GRU model spearman correlation coefficient for training dataset is: %s" % str(train_performance)) get_prediction = getattr(sys.modules[__name__], "get_prediction_" + config.test_method, get_prediction_group) test_list = [for_x_test, rev_x_test, off_target_X_test, extra_x_test] if config.ensemble: test_list.append(for_cnn_test) performance, prediction = get_prediction(crispr_model, test_index_list, unique_test_index, y_test, test_list) logger.debug("GRU model spearman correlation coefficient: %s" % str(performance)) logger.debug("Loading last model for testing") loaded_model = load_model(config.hdf5_path, custom_objects={ 'revised_mse_loss': utils.revised_mse_loss, 'tf': tf }) crispr_model = models.CrisprCasModel.compile_transfer_learning_model( loaded_model) logger.debug("Load in model successfully") logger.debug("Predicting data with last model") last_train_prediction = crispr_model.predict(x=train_list) last_train_performance = spearmanr(last_train_prediction, y_train[unique_train_index]) utils.output_config_info() logger.debug( "GRU model spearman correlation coefficient for training dataset is: %s" % str(last_train_performance)) last_performance, last_prediction = get_prediction(crispr_model, test_index_list, unique_test_index, y_test, test_list) logger.debug("GRU model spearman correlation coefficient: %s" % str(last_performance)) logger.debug("Saving test and prediction data plot") if last_performance > performance: prediction = last_prediction utils.ytest_and_prediction_output(y_test[unique_test_index], prediction) logger.debug("Saved test and prediction data plot successfully") if config.check_feature_importance: if performance > last_performance: loaded_model = load_model(config.temp_hdf5_path, custom_objects={ 'revised_mse_loss': utils.revised_mse_loss, 'tf': tf }) crispr_model = models.CrisprCasModel.compile_transfer_learning_model( loaded_model) logger.debug("Getting features ranks") names = [] names += ["for_" + str(i) for i in range(for_input.shape[1])] names += ["rev_" + str(i) for i in range(rev_input.shape[1])] names += ["off_" + str(i) for i in range(off_target_X_train.shape[1])] names += config.extra_categorical_features + config.extra_numerical_features ranker = feature_imp.InputPerturbationRank(names) feature_ranks = ranker.rank( 20, y_test[unique_test_index], crispr_model, [data[unique_test_index] for data in test_list]) feature_ranks_df = pd.DataFrame(feature_ranks) feature_ranks_df.to_csv(config.feature_importance_path, index=False) logger.debug("Get features ranks successfully")
def run(): data_pre = OT_crispr_attn.data_preparer() print(data_pre.get_crispr_preview()) X = data_pre.prepare_x(mismatch=True, trg_seq_col='Target sequence') y = data_pre.get_labels(binary=True) data_pre.persist_data() print(X.head()) logger.debug("{0!r}".format(data_pre.feature_length_map)) train_index, test_index = data_pre.train_test_split(n_split=5) logger.debug("{0!r}".format( set(data_pre.crispr.loc[test_index, :]['sequence']))) logger.debug("{0!r}".format( set(data_pre.crispr.loc[train_index, :]['sequence']))) #assert len(set(data_pre.crispr.loc[test_index, :]['sequence']) & set(data_pre.crispr.loc[train_index, :]['sequence'])) == 0 logger.debug("{0!r}".format(train_index)) logger.debug("{0!r}".format(test_index)) logger.debug("training data amounts: %s, testing data amounts: %s" % (len(train_index), len(test_index))) torch.manual_seed(0) if config.test_cellline: test_cellline_index = data_pre.crispr[data_pre.crispr['cellline'] == config.test_cellline].index test_index = test_cellline_index.intersection(test_index) ros = RandomOverSampler(random_state=42) _ = ros.fit_resample(X.loc[train_index, :], y.loc[train_index, :]) new_train_index = train_index[ros.sample_indices_] oversample_train_index = list(new_train_index) random.shuffle(oversample_train_index) # sep = int(len(train_eval_index_list) * 0.9) # train_index_list = train_eval_index_list[:sep] # eval_index_list = train_eval_index_list[sep:] assert len(set(oversample_train_index) & set(test_index)) == 0 assert len(set(oversample_train_index) & set(train_index)) == len( set(train_index)) partition = { 'train': oversample_train_index, 'train_val': train_index, 'eval': list(test_index), 'test': list(test_index) } labels = { key: value for key, value in zip(list(range(len(y))), list(y.values.reshape(-1))) } train_params = {'batch_size': config.batch_size, 'shuffle': True} train_bg_params = {'batch_size': config.batch_size, 'shuffle': True} eval_params = {'batch_size': len(test_index), 'shuffle': False} test_params = {'batch_size': len(test_index), 'shuffle': False} logger.debug("Preparing datasets ... ") training_set = my_data.MyDataset(partition['train'], labels) training_generator = data.DataLoader(training_set, **train_params) training_bg_set = my_data.MyDataset(partition['train_val'], labels) training_bg_generator = data.DataLoader(training_bg_set, **train_bg_params) validation_set = my_data.MyDataset(partition['eval'], labels) validation_generator = data.DataLoader(validation_set, **eval_params) test_set = my_data.MyDataset(partition['test'], labels) test_generator = data.DataLoader(test_set, **test_params) logger.debug("I might need to augment data") logger.debug("Building the scaled dot product attention model") crispr_model = attention_model.get_OT_model(data_pre.feature_length_map, classifier=True) best_crispr_model = attention_model.get_OT_model( data_pre.feature_length_map, classifier=True) crispr_model.to(device2) best_crispr_model.to(device2) best_cv_roc_auc_scores = 0 optimizer = torch.optim.Adam(crispr_model.parameters(), lr=config.start_lr, weight_decay=config.lr_decay, betas=(0.9, 0.98), eps=1e-9) if config.retraining: logger.debug("I need to load a old trained model") logger.debug("I might need to freeze some of the weights") logger.debug("Built the RNN model successfully") try: if config.training: logger.debug("Training the model") nllloss_visualizer = torch_visual.VisTorch(env_name='NLLLOSS') roc_auc_visualizer = torch_visual.VisTorch(env_name='ROC_AUC') pr_auc_visualizer = torch_visual.VisTorch(env_name='PR_AUC') for epoch in range(config.n_epochs): crispr_model.train() start = time() cur_epoch_train_loss = [] train_total_loss = 0 i = 0 # Training for local_batch, local_labels in training_generator: i += 1 # Transfer to GPU local_batch, local_labels = local_batch.float().to( device2), local_labels.long().to(device2) # seq_local_batch = local_batch.narrow(dim=1, start=0, length=for_input_len).long() # extra_local_batch = local_batch.narrow(dim=1, start=for_input_len, length=extra_input_len) # Model computations preds = crispr_model(local_batch) ys = local_labels.contiguous().view(-1) optimizer.zero_grad() assert preds.size(0) == ys.size(0) #loss = F.nll_loss(preds, ys) criterion = torch.nn.CrossEntropyLoss() loss = criterion(preds, ys) loss.backward() optimizer.step() train_total_loss += loss.item() n_iter = 50 if i % n_iter == 0: sample_size = len(train_index) p = int(100 * i * config.batch_size / sample_size) avg_loss = train_total_loss / n_iter logger.debug(" %dm: epoch %d [%s%s] %d%% loss = %.3f" % \ ((time() - start) // 60, epoch, "".join('#' * (p // 5)), "".join(' ' * (20 - (p // 5))), p, avg_loss)) train_total_loss = 0 cur_epoch_train_loss.append(avg_loss) ### Evaluation train_val_i = 0 train_val_total_loss = 0 train_val_loss = [] # train_val_roc_auc = 0 # train_val_pr_auc = 0 train_val_preds = [] train_val_ys = [] n_pos, n_neg = 0, 0 with torch.set_grad_enabled(False): crispr_model.eval() for local_batch, local_labels in training_bg_generator: train_val_i += 1 local_labels_on_cpu = np.array(local_labels).reshape( -1) train_val_ys.append(local_labels_on_cpu) n_pos += sum(local_labels_on_cpu) n_neg += len(local_labels_on_cpu) - sum( local_labels_on_cpu) # Transfer to GPU local_batch, local_labels = local_batch.float().to( device2), local_labels.long().to(device2) preds = crispr_model(local_batch) assert preds.size(0) == local_labels.size(0) criterion = torch.nn.CrossEntropyLoss() nllloss_val = criterion(preds, local_labels).item() # nllloss_val = F.nll_loss(preds, local_labels).item() train_val_total_loss += nllloss_val prediction_on_cpu = preds.cpu().numpy()[:, 1] train_val_preds.append(prediction_on_cpu) n_iter = 10 if train_val_i % n_iter == 0: avg_loss = train_val_total_loss / n_iter train_val_loss.append(avg_loss) train_val_total_loss = 0 preds = np.concatenate(tuple(train_val_preds)) ys = np.concatenate(tuple(train_val_ys)) train_val_roc_auc = roc_auc_score(ys, preds) train_val_pr_auc = average_precision_score(ys, preds) logger.debug( "{0!r} positive samples and {1!r} negative samples".format( n_pos, n_neg)) logger.debug( "Validation nllloss is {0}, Validation roc_auc is {1!r} and Validation " "pr_auc correlation is {2!r}".format( np.mean(train_val_loss), train_val_roc_auc, train_val_pr_auc)) # nllloss_visualizer.plot_loss(epoch, np.mean(cur_epoch_train_loss), np.mean(train_val_loss), loss_type='nllloss') # roc_auc_visualizer.plot_loss(epoch, train_val_roc_auc, loss_type='roc_auc', ytickmin=0, ytickmax=1) # pr_auc_visualizer.plot_loss(epoch, train_val_pr_auc, loss_type='pr_auc', ytickmin=0, ytickmax=1) ### Evaluation val_i = 0 val_total_loss = 0 val_loss = [] val_preds = [] val_ys = [] with torch.set_grad_enabled(False): crispr_model.eval() for local_batch, local_labels in validation_generator: val_i += 1 local_labels_on_cpu = np.array(local_labels).reshape( -1) val_ys.append(local_labels_on_cpu) n_pos = sum(local_labels_on_cpu) n_neg = len(local_labels_on_cpu) - sum( local_labels_on_cpu) logger.debug( "{0!r} positive samples and {1!r} negative samples" .format(n_pos, n_neg)) local_batch, local_labels = local_batch.float().to( device2), local_labels.long().to(device2) preds = crispr_model(local_batch) assert preds.size(0) == local_labels.size(0) criterion = torch.nn.CrossEntropyLoss() nllloss_val = criterion(preds, local_labels).item() #nllloss_val = F.nll_loss(preds, local_labels).item() prediction_on_cpu = preds.cpu().numpy()[:, 1] val_preds.append(prediction_on_cpu) val_total_loss += nllloss_val n_iter = 1 if val_i % n_iter == 0: avg_loss = val_total_loss / n_iter val_loss.append(avg_loss) val_total_loss = 0 preds = np.concatenate(tuple(val_preds)) ys = np.concatenate(tuple(val_ys)) val_roc_auc = roc_auc_score(ys, preds) val_pr_auc = average_precision_score(ys, preds) logger.debug( "Test NLLloss is {0}, Test roc_auc is {1!r} and Test " "pr_auc is {2!r}".format(np.mean(val_loss), val_roc_auc, val_pr_auc)) nllloss_visualizer.plot_loss(epoch, np.mean(cur_epoch_train_loss), np.mean(val_loss), loss_type='nllloss') roc_auc_visualizer.plot_loss(epoch, train_val_roc_auc, val_roc_auc, loss_type='roc_auc', ytickmin=0, ytickmax=1) pr_auc_visualizer.plot_loss(epoch, train_val_pr_auc, val_pr_auc, loss_type='pr_auc', ytickmin=0, ytickmax=1) if best_cv_roc_auc_scores < val_roc_auc: best_cv_roc_auc_scores = val_roc_auc best_crispr_model.load_state_dict( crispr_model.state_dict()) logger.debug("Saving training history") logger.debug("Saved training history successfully") logger.debug("Trained crispr model successfully") else: logger.debug("loading in old model") logger.debug("Load in model successfully") except KeyboardInterrupt as e: logger.debug("Loading model") logger.debug("loading some intermediate step's model") logger.debug("Load in model successfully") logger.debug("Persisting model") # serialize weights to HDF5 save(best_crispr_model, config.hdf5_path) save(best_crispr_model.state_dict(), config.hdf5_path_state) logger.debug("Saved model to disk") save_output = [] if 'essentiality' in config.extra_numerical_features: save_output.append(data_pre.crispr.loc[test_index, 'essentiality']) test_model(best_crispr_model, test_generator, save_output) if config.check_feature_importance: logger.debug("Getting features ranks") names = [] names += [ "src_" + str(i) for i in range(data_pre.feature_length_map[0][1]) ] if data_pre.feature_length_map[1] is not None: names += [ "trg_" + str(i) for i in range(data_pre.feature_length_map[1][1] - data_pre.feature_length_map[1][0]) ] if data_pre.feature_length_map[2] is not None: names += config.extra_categorical_features + config.extra_numerical_features ranker = feature_imp.InputPerturbationRank(names) feature_ranks = ranker.rank( 2, y.loc[train_index, :].values, best_crispr_model, [torch.FloatTensor(X.loc[train_index, :].values)], torch=True, classifier=True) feature_ranks_df = pd.DataFrame(feature_ranks) feature_ranks_df.to_csv(config.feature_importance_path, index=False) logger.debug("Get features ranks successfully")