def train_test_rf_vary_data_size(prefix, motif_scoring_kwargs=None, X_train=None, y_train=None, X_valid=None, y_valid=None, X_test=None, y_test=None, train_set_sizes=None): motif_scores_train = get_motif_scores(X_train, **motif_scoring_kwargs) motif_scores_test = get_motif_scores(X_test, **motif_scoring_kwargs) rf_results = [] for train_set_size in train_set_sizes: ofname_infix = dict2string(motif_scoring_kwargs) ofname_infix = "%s.train_set_size_%s" % (ofname_infix, str(train_set_size)) ofname = "%s.%s.rf.pkl" % (prefix, ofname_infix) try: with open(ofname, 'rb') as fp: rf = pickle.load(fp) except: logger.info("training with %i examples.." % (train_set_size)) rf = RandomForest() rf.train(motif_scores_train[:train_set_size], y_train[:train_set_size].squeeze()) with open(ofname, 'wb') as fid: pickle.dump(rf, fid) rf_results.append(rf.test(motif_scores_test, y_test)) return rf_results
def plot_SequenceDNN_layer_outputs(dnn, simulation_data): # define layer out functions import theano get_conv_output = theano.function([dnn.model.layers[0].input], dnn.model.layers[0].get_output(train=False), allow_input_downcast=True) get_conv_relu_output = theano.function([dnn.model.layers[0].input], dnn.model.layers[1].get_output(train=False), allow_input_downcast=True) get_maxpool_output = theano.function([dnn.model.layers[0].input], dnn.model.layers[-4].get_output(train=False), allow_input_downcast=True) # get layer outputs for a positive simulation example pos_indx = np.where(simulation_data.y_valid==1)[0][0] pos_X = simulation_data.X_valid[pos_indx:(pos_indx+1)] conv_outputs = get_conv_output(pos_X).squeeze() conv_relu_outputs = get_conv_relu_output(pos_X).squeeze() maxpool_outputs = get_maxpool_output(pos_X).squeeze() # plot layer outputs fig = plt.figure(figsize=(15, 12)) ax1 = fig.add_subplot(3, 1, 3) heatmap = ax1.imshow(conv_outputs, aspect='auto', interpolation='None', cmap='seismic') fig.colorbar(heatmap) ax1.set_ylabel("Convolutional Filters") ax1.set_xlabel("Position") ax1.get_yaxis().set_ticks([]) ax1.get_xaxis().set_ticks([]) ax1.set_title("SequenceDNN outputs from convolutional layer.\t\ Locations of motif sites are highlighted in grey.") ax2 = fig.add_subplot(3, 1, 2) heatmap = ax2.imshow(conv_relu_outputs, aspect='auto', interpolation='None', cmap='seismic') fig.colorbar(heatmap) ax2.set_ylabel("Convolutional Filters") ax2.get_yaxis().set_ticks([]) ax2.get_xaxis().set_ticks([]) ax2.set_title("Convolutional outputs after ReLU transformation.\t\ Locations of motif sites are highlighted in grey.") ax3 = fig.add_subplot(3, 1, 1) heatmap = ax3.imshow(maxpool_outputs, aspect='auto', interpolation='None', cmap='seismic') fig.colorbar(heatmap) ax3.set_title("DNN outputs after max pooling") ax3.set_ylabel("Convolutional Filters") ax3.get_yaxis().set_ticks([]) ax3.get_xaxis().set_ticks([]) # highlight motif sites motif_scores = get_motif_scores(pos_X, simulation_data.motif_names) motif_sites = [np.argmax(motif_scores[0, i, :]) for i in [0, 1]] for motif_site in motif_sites: conv_output_start = motif_site - max(dnn.conv_width-10, 0) conv_output_stop = motif_site + max(dnn.conv_width-10, 0) ax1.axvspan(conv_output_start, conv_output_stop, color='grey', alpha=0.5) ax2.axvspan(conv_output_start, conv_output_stop, color='grey', alpha=0.5)
def interpret_data_with_SequenceDNN(dnn, simulation_data): # get a positive and a negative example from the simulation data pos_indx = np.where(simulation_data.y_valid == 1)[0][2] neg_indx = np.where(simulation_data.y_valid == 0)[0][2] pos_X = simulation_data.X_valid[pos_indx:(pos_indx + 1)] neg_X = simulation_data.X_valid[neg_indx:(neg_indx + 1)] # get motif scores, ISM scores, and DeepLIFT scores scores_dict = defaultdict(OrderedDict) scores_dict['Positive']['Motif Scores'] = get_motif_scores( pos_X, simulation_data.motif_names) scores_dict['Positive']['ISM Scores'] = dnn.in_silico_mutagenesis( pos_X).max(axis=-2) scores_dict['Positive']['DeepLIFT Scores'] = dnn.deeplift(pos_X).max( axis=-2) scores_dict['Negative']['Motif Scores'] = get_motif_scores( neg_X, simulation_data.motif_names) scores_dict['Negative']['ISM Scores'] = dnn.in_silico_mutagenesis( neg_X).max(axis=-2) scores_dict['Negative']['DeepLIFT Scores'] = dnn.deeplift(neg_X).max( axis=-2) # get motif site locations motif_sites = {} motif_sites['Positive'] = [ np.argmax(scores_dict['Positive']['Motif Scores'][0, i, :]) for i in range(len(simulation_data.motif_names)) ] motif_sites['Negative'] = [ np.argmax(scores_dict['Negative']['Motif Scores'][0, i, :]) for i in range(len(simulation_data.motif_names)) ] # organize legends motif_label_dict = {} motif_label_dict['Motif Scores'] = simulation_data.motif_names if len(simulation_data.motif_names) == dnn.num_tasks: motif_label_dict['ISM Scores'] = simulation_data.motif_names else: motif_label_dict['ISM Scores'] = [ '_'.join(simulation_data.motif_names) ] motif_label_dict['DeepLIFT Scores'] = motif_label_dict['ISM Scores'] # plot scores and highlight motif site locations seq_length = pos_X.shape[-1] plots_per_row = 2 plots_per_column = 3 ylim_dict = { 'Motif Scores': (-80, 30), 'ISM Scores': (-1.5, 3.0), 'DeepLIFT Scores': (-1.5, 3.0) } motif_colors = ['b', 'r', 'c', 'm', 'g', 'k', 'y'] font_size = 12 num_x_ticks = 5 highlight_width = 5 motif_labels_cache = [] f = plt.figure(figsize=(10, 12)) f.subplots_adjust(hspace=0.15, wspace=0.15) f.set_tight_layout(True) for j, key in enumerate(['Positive', 'Negative']): for i, (score_type, scores) in enumerate(scores_dict[key].iteritems()): ax = f.add_subplot(plots_per_column, plots_per_row, plots_per_row * i + j + 1) ax.set_ylim(ylim_dict[score_type]) ax.set_xlim((0, seq_length)) ax.set_frame_on(False) if j == 0: # put y axis and ticks only on left side xmin, xmax = ax.get_xaxis().get_view_interval() ymin, ymax = ax.get_yaxis().get_view_interval() ax.add_artist( Line2D((xmin, xmin), (ymin, ymax), color='black', linewidth=2)) ax.get_yaxis().tick_left() for tick in ax.yaxis.get_major_ticks(): tick.label.set_fontsize(font_size / 1.5) ax.set_ylabel(score_type) if j > 0: # remove y axes ax.get_yaxis().set_visible(False) if i < (plots_per_column - 1): # remove x axes ax.get_xaxis().set_visible(False) if i == (plots_per_column - 1): # set x axis and ticks on bottom ax.set_xticks(seq_length / num_x_ticks * (np.arange(num_x_ticks + 1))) xmin, xmax = ax.get_xaxis().get_view_interval() ymin, ymax = ax.get_yaxis().get_view_interval() ax.add_artist( Line2D((xmin, xmax), (ymin, ymin), color='black', linewidth=2)) ax.get_xaxis().tick_bottom() for tick in ax.xaxis.get_major_ticks(): tick.label.set_fontsize(font_size / 1.5) ax.set_xlabel("Position") if j > 0 and i < (plots_per_column - 1): # remove all axes ax.axis('off') add_legend = False for _i, motif_label in enumerate(motif_label_dict[score_type]): if score_type == 'Motif Scores': scores_to_plot = scores[0, _i, :] else: scores_to_plot = scores[0, 0, 0, :] if motif_label not in motif_labels_cache: motif_labels_cache.append(motif_label) add_legend = True motif_color = motif_colors[motif_labels_cache.index( motif_label)] ax.plot(scores_to_plot, label=motif_label, c=motif_color) if add_legend: leg = ax.legend(loc=[0, 0.85], frameon=False, fontsize=font_size, ncol=3, handlelength=-0.5) for legobj in leg.legendHandles: legobj.set_color('w') for _j, text in enumerate(leg.get_texts()): text_color = motif_colors[motif_labels_cache.index( motif_label_dict[score_type][_j])] text.set_color(text_color) for motif_site in motif_sites[key]: ax.axvspan(motif_site - highlight_width, motif_site + highlight_width, color='grey', alpha=0.1)
def plot_SequenceDNN_layer_outputs(dnn, simulation_data): # define layer out functions get_conv_output = theano.function( [dnn.model.layers[0].input], dnn.model.layers[0].get_output(train=False), allow_input_downcast=True) get_conv_relu_output = theano.function( [dnn.model.layers[0].input], dnn.model.layers[1].get_output(train=False), allow_input_downcast=True) get_maxpool_output = theano.function( [dnn.model.layers[0].input], dnn.model.layers[-4].get_output(train=False), allow_input_downcast=True) # get layer outputs for a positive simulation example pos_indx = np.where(simulation_data.y_valid == 1)[0][0] pos_X = simulation_data.X_valid[pos_indx:(pos_indx + 1)] conv_outputs = get_conv_output(pos_X).squeeze() conv_relu_outputs = get_conv_relu_output(pos_X).squeeze() maxpool_outputs = get_maxpool_output(pos_X).squeeze() # plot layer outputs fig = plt.figure(figsize=(15, 12)) ax1 = fig.add_subplot(3, 1, 3) heatmap = ax1.imshow(conv_outputs, aspect='auto', interpolation='None', cmap='seismic') fig.colorbar(heatmap) ax1.set_ylabel("Convolutional Filters") ax1.set_xlabel("Position") ax1.get_yaxis().set_ticks([]) ax1.get_xaxis().set_ticks([]) ax1.set_title("SequenceDNN outputs from convolutional layer.\t\ Locations of motif sites are highlighted in grey.") ax2 = fig.add_subplot(3, 1, 2) heatmap = ax2.imshow(conv_relu_outputs, aspect='auto', interpolation='None', cmap='seismic') fig.colorbar(heatmap) ax2.set_ylabel("Convolutional Filters") ax2.get_yaxis().set_ticks([]) ax2.get_xaxis().set_ticks([]) ax2.set_title("Convolutional outputs after ReLU transformation.\t\ Locations of motif sites are highlighted in grey.") ax3 = fig.add_subplot(3, 1, 1) heatmap = ax3.imshow(maxpool_outputs, aspect='auto', interpolation='None', cmap='seismic') fig.colorbar(heatmap) ax3.set_title("DNN outputs after max pooling") ax3.set_ylabel("Convolutional Filters") ax3.get_yaxis().set_ticks([]) ax3.get_xaxis().set_ticks([]) # highlight motif sites motif_scores = get_motif_scores(pos_X, simulation_data.motif_names) motif_sites = [np.argmax(motif_scores[0, i, :]) for i in [0, 1]] for motif_site in motif_sites: conv_output_start = motif_site - max(dnn.conv_width - 10, 0) conv_output_stop = motif_site + max(dnn.conv_width - 10, 0) ax1.axvspan(conv_output_start, conv_output_stop, color='grey', alpha=0.5) ax2.axvspan(conv_output_start, conv_output_stop, color='grey', alpha=0.5)
num_epochs = 100 use_deep_CNN = False use_RNN = False print('Generating sequences...') sequences, labels = simulate_single_motif_detection( 'SPI1_disc1', seq_length, num_positives, num_negatives, GC_fraction) print('One-hot encoding sequences...') encoded_sequences = one_hot_encode(sequences) print('Getting motif scores...') motif_scores = get_motif_scores(encoded_sequences, motif_names=['SPI1_disc1']) print('Partitioning data into training, validation and test sets...') X_train, X_test, y_train, y_test = train_test_split(encoded_sequences, labels, test_size=test_fraction) X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=validation_fraction) print('Adding reverse complements...') X_train = np.concatenate((X_train, reverse_complement(X_train))) y_train = np.concatenate((y_train, y_train)) print('Randomly splitting data into training and test sets...') random_order = np.arange(len(X_train)) np.random.shuffle(random_order)
def multi_method_interpret(model, X, task_idx, deeplift_score_func, motif_names=None, batch_size=200, target_layer_idx=-2, num_refs_per_seq=10, reference="shuffled_ref", one_hot_func=None, pfm=None, GC_fraction=0.4, generate_plots=True): """ Arguments: model -- keras model object X -- numpy array with shape (1, 1, n_bases_in_sample,4) or list of FASTA sequences task_idx -- numerical index (starting with 0) of task to interpet. For a single-tasked model, you should set this to 0 deeplift_score_fun -- scoring function to use with DeepLIFT algorithm. motif_names -- a list of motif name strings to scan for in the input sequence; if this is unknown, keep the default value of None batch_size -- number of samples to interpret at once target_layer_idx -- should be -2 for classification; -1 for regression reference -- one of 'shuffled_ref','gc_ref','zero_ref' num_refs_per_seq -- integer indicating number of references to use for each input sequence \ if the reference is set to 'shuffled_ref';if 'zero_ref' or 'gc_ref' is \ used, this argument is ignored. one_hot_func -- one hot function to use for encoding FASTA string inputs; if the inputs \ are already one-hot-encoded, use the default of None generate_plots -- default True. Flag to indicate whether or not interpretation plots \ should be generated Returns: dictionary with keys 'motif_scan','ism','gradxinput','deeplift' """ outputs = dict() #1) motif scan (if motif_names !=None) if motif_names is not None: print("getting 'motif_scan' value") outputs['motif_scan'] = get_motif_scores(X, motif_names, pfm=pfm, GC_fraction=GC_fraction, return_positions=True) else: outputs['motif_scan'] = None #2) ISM print("getting 'ism' value") outputs['ism'] = in_silico_mutagenesis(model, X, task_idx, target_layer_idx=target_layer_idx) #3) Input_Grad print("getting 'input_grad' value") outputs['input_grad'] = input_grad(model, X, target_layer_idx=target_layer_idx) #4) DeepLIFT print("getting 'deeplift' value") outputs['deeplift'] = deeplift(deeplift_score_func, X, batch_size=batch_size, task_idx=task_idx, num_refs_per_seq=num_refs_per_seq, reference=reference, one_hot_func=one_hot_func) #generate plots if generate_plots == True: plot_all_interpretations([outputs], X) return outputs
def interpret_SequenceDNN_integrative(dnn, simulation_data): # get a positive and a negative example from the simulation data pos_indx = np.where(simulation_data.y_valid==1)[0][0] pos_X = simulation_data.X_valid[pos_indx:(pos_indx+1)] neg_indx = np.where(simulation_data.y_valid==0)[0][0] neg_X = simulation_data.X_valid[neg_indx:(neg_indx+1)] # get motif scores, ISM scores, and DeepLIFT scores scores_dict = defaultdict(OrderedDict) scores_dict['Positive']['Motif Scores'] = get_motif_scores(pos_X, simulation_data.motif_names) scores_dict['Positive']['ISM Scores'] = dnn.in_silico_mutagenesis(pos_X).max(axis=-2) scores_dict['Positive']['DeepLIFT Scores'] = dnn.deeplift(pos_X).max(axis=-2) scores_dict['Negative']['Motif Scores'] = get_motif_scores(neg_X, simulation_data.motif_names) scores_dict['Negative']['ISM Scores'] = dnn.in_silico_mutagenesis(neg_X).max(axis=-2) scores_dict['Negative']['DeepLIFT Scores'] = dnn.deeplift(neg_X).max(axis=-2) # get motif site locations motif_sites = {} motif_sites['Positive'] = [np.argmax(scores_dict['Positive']['Motif Scores'][0, i, :]) for i in range(len(simulation_data.motif_names))] motif_sites['Negative'] = [np.argmax(scores_dict['Negative']['Motif Scores'][0, i, :]) for i in range(len(simulation_data.motif_names))] # organize legends motif_label_dict = {} motif_label_dict['Motif Scores'] = simulation_data.motif_names if len(simulation_data.motif_names) == dnn.num_tasks: motif_label_dict['ISM Scores'] = simulation_data.motif_names else: motif_label_dict['ISM Scores'] = ['_'.join(simulation_data.motif_names)] motif_label_dict['DeepLIFT Scores'] = motif_label_dict['ISM Scores'] # plot scores and highlight motif site locations seq_length = dnn.seq_length plots_per_row = 2 plots_per_column = 3 ylim_dict = {'Motif Scores': (-80, 30), 'ISM Scores': (-1.5, 3.0), 'DeepLIFT Scores': (-1.5, 3.0)} motif_colors = ['b', 'r', 'c', 'm', 'g', 'k', 'y'] font_size = 12 num_x_ticks = 5 highlight_width = 5 motif_labels_cache = [] f = plt.figure(figsize=(10,12)) f.subplots_adjust(hspace=0.15, wspace=0.15) f.set_tight_layout(True) for j, key in enumerate(['Positive', 'Negative']): for i, (score_type, scores) in enumerate(scores_dict[key].iteritems()): ax = f.add_subplot(plots_per_column, plots_per_row, plots_per_row*i+j+1) ax.set_ylim(ylim_dict[score_type]) ax.set_xlim((0, seq_length)) ax.set_frame_on(False) if j == 0: # put y axis and ticks only on left side xmin, xmax = ax.get_xaxis().get_view_interval() ymin, ymax = ax.get_yaxis().get_view_interval() ax.add_artist(Line2D((xmin, xmin), (ymin, ymax), color='black', linewidth=2)) ax.get_yaxis().tick_left() for tick in ax.yaxis.get_major_ticks(): tick.label.set_fontsize(font_size/1.5) ax.set_ylabel(score_type) if j > 0: # remove y axes ax.get_yaxis().set_visible(False) if i < (plots_per_column-1): # remove x axes ax.get_xaxis().set_visible(False) if i == (plots_per_column-1): # set x axis and ticks on bottom ax.set_xticks(seq_length/num_x_ticks*(np.arange(num_x_ticks+1))) xmin, xmax = ax.get_xaxis().get_view_interval() ymin, ymax = ax.get_yaxis().get_view_interval() ax.add_artist(Line2D((xmin, xmax), (ymin, ymin), color='black', linewidth=2)) ax.get_xaxis().tick_bottom() for tick in ax.xaxis.get_major_ticks(): tick.label.set_fontsize(font_size/1.5) ax.set_xlabel("Position") if j>0 and i<(plots_per_column-1): # remove all axes ax.axis('off') add_legend = False for _i, motif_label in enumerate(motif_label_dict[score_type]): if score_type=='Motif Scores': scores_to_plot = scores[0, _i, :] else: scores_to_plot = scores.squeeze(axis=2) if motif_label not in motif_labels_cache: motif_labels_cache.append(motif_label) add_legend = True motif_color = motif_colors[motif_labels_cache.index(motif_label)] ax.plot(scores_to_plot, label=motif_label, c=motif_color) if add_legend: leg = ax.legend(loc=[0,0.85], frameon=False, fontsize=font_size, ncol=3, handlelength=-0.5) for legobj in leg.legendHandles: legobj.set_color('w') for _i, text in enumerate(leg.get_texts()): text.set_color(motif_color) for motif_site in motif_sites[key]: ax.axvspan(motif_site - highlight_width, motif_site + highlight_width, color='grey', alpha=0.1)