def compute_routing_probabilities_uptonode(self, input, node_idx): """ Compute the routing probabilities up to a node. Return: routing probabilities tensor (tensor) : torch tensor (N, nodes) """ leaves_up_to_node = get_past_leaf_nodes(self.tree_struct, node_idx) # for each leaf predictor, get the list of all nodes (indices) on # their paths to the root and the corresponding lef-child-status # (boolean) on all edges i.e. edge = True if the child is on the left # branch of its parent. Each element in self.paths_list is a tuple # (nodes, edges) which contains these two lists. paths_list_up_to_node = [ get_path_to_root(i, self.tree_struct) for i in leaves_up_to_node ] for i, (nodes, edges) in enumerate(paths_list_up_to_node): # compute probabilities for the given branch # if len(nodes)>1: # prob = 1.0 # else: # if it's just a root node dtype = torch.cuda.FloatTensor if self.cuda_on else torch.FloatTensor prob = Variable(torch.ones(input.size(0)).type(dtype)) output = input.clone() for node, state in zip(nodes[:-1], edges): output = self.tree_modules[node].transform(output) if state: prob = prob * self.tree_modules[node].router(output) else: prob = prob * (1.0 - self.tree_modules[node].router(output)) if not (isinstance(prob, float)): prob = torch.unsqueeze(prob, 1) # account for the split at the last node if self.split and nodes[-1] == self.node_split: node_final = nodes[-1] output = self.tree_modules[node_final].transform(output) prob_last = torch.unsqueeze( self.tree_modules[node_final].router(output), 1) prob = torch.cat((prob_last * prob, (1.0 - prob_last) * prob), dim=1) # concatenate if i == 0: prob_tensor = prob else: prob_tensor = torch.cat((prob_tensor, prob), dim=1) return prob_tensor, leaves_up_to_node
def visualise_routers_behaviours(model, data_loader, no_classes=10, objects=('0', '1', '2', '3', '4', '5', '6', '7', '8', '9'), fig_scale=None, title='', title_font=20, subtitle_font=20, axis_font=14, cuda_on=False, plot_on=True, save_as=''): """ Visualise the probability of reachine a leaf node for different classes. Args: node (int) : node index. This function gets the list of all the peripheral nodes when the given node is added to the tree, and computest the class probabilities. model (nn.Module) : your tree model dataloader (data loader): Return: """ if cuda_on: model.cuda() else: model.cpu() # get the list of edge nodes on respective levels: tree_struct = model.tree_struct edge_nodes = [] e_n = 0 max_level = 0 while e_n >= 0: edge_nodes.append(e_n) max_level += 1 e_n = find_edgenode(tree_struct, max_level) # set up the figure size and dimension: num_rows = 2 * len( edge_nodes) # first row for showing the class fistribution num_cols = len(ops.get_past_leaf_nodes( tree_struct, edge_nodes[-1])) # get the list of leaf nodes if fig_scale == None: fig = plt.figure(figsize=(num_cols, num_rows)) else: fig = plt.figure(figsize=(fig_scale * num_cols, fig_scale * num_rows)) plt.suptitle(title, fontsize=title_font) # -------------- compute and plot stuff ------------------------------ for level, node in enumerate(edge_nodes): print('Computuing histograms for level {}/{}'.format( level, len(edge_nodes) - 1)) # compute stuff: y_list, p_list = [], [] for x, y in data_loader: x, y = Variable(x, volatile=True), Variable(y) if cuda_on: x, y = x.cuda(), y.cuda() p, nodes_list = model.compute_routing_probabilities_uptonode( x, node) if cuda_on: p, y = p.cpu(), y.cpu() p_list.append(p.data.numpy()) y_list.append(y.data.numpy()) # compute class-specific probabilities for reaching a peripheral node c_list = list(range(no_classes)) # [0,1,2,3,4,5,6,7,8,9] # class list y_full = np.concatenate(y_list) p_full = np.concatenate(p_list) # N x number of peripheral nodes node_class_probs = [] for c in c_list: leaf_c = p_full[y_full == c].mean(axis=0) node_class_probs.append(leaf_c) # C x number of peripheral nodes node_class_probs = np.vstack(node_class_probs) # Bar chart for node-wise class distributions # average probabilitiy of images from a specific class routed to each node y_pos = np.arange(len(objects)) for i, node_idx in enumerate(nodes_list): performance = node_class_probs[:, i] # print(num_rows, num_cols, 2*num_cols*level+i+1, num_cols*(2*level+1)+i+1 ) ax1 = fig.add_subplot(num_rows, num_cols, 2 * num_cols * level + i + 1) ax1.bar(y_pos, performance, align='center', alpha=0.5, color='r') plt.xticks(y_pos, objects, rotation='vertical', fontsize=axis_font) plt.ylim((0, 1)) if i == 0: ax1.set_ylabel("reaching prob. per class", fontsize=axis_font) ax1.set_title('Node ' + str(node_idx), fontsize=subtitle_font) # Histogram of reaching probabilities for respective peripheral nodes: for i, node_idx in enumerate(nodes_list): ax1 = fig.add_subplot(num_rows, num_cols, num_cols * (2 * level + 1) + i + 1) ax1.hist(p_full[:, i], normed=False, bins=25, range=(0, 1.0)) if i == 0: ax1.set_ylabel("histogram of \n reaching prob. dist.", fontsize=axis_font) plt.subplots_adjust(wspace=0.25, hspace=0.25) if plot_on: plt.show() if save_as: # Save the full figure print('Save the histogram of splitting as ' + save_as) fig.savefig(save_as, format='png', dpi=300)