示例#1
0
    def compute_routing_probabilities_uptonode(self, input, node_idx):
        """ Compute the routing probabilities up to a node.

        Return:
            routing probabilities tensor (tensor) : torch tensor (N, nodes)

        """
        leaves_up_to_node = get_past_leaf_nodes(self.tree_struct, node_idx)

        # for each leaf predictor, get the list of all nodes (indices) on
        # their paths to the root and the corresponding lef-child-status
        # (boolean) on all edges i.e. edge = True if the child is on the left
        # branch of its parent. Each element in self.paths_list is a tuple
        # (nodes, edges) which contains these two lists.
        paths_list_up_to_node = [
            get_path_to_root(i, self.tree_struct) for i in leaves_up_to_node
        ]

        for i, (nodes, edges) in enumerate(paths_list_up_to_node):
            # compute probabilities for the given branch
            # if len(nodes)>1:
            #     prob = 1.0
            # else: # if it's just a root node
            dtype = torch.cuda.FloatTensor if self.cuda_on else torch.FloatTensor
            prob = Variable(torch.ones(input.size(0)).type(dtype))
            output = input.clone()

            for node, state in zip(nodes[:-1], edges):
                output = self.tree_modules[node].transform(output)
                if state:
                    prob = prob * self.tree_modules[node].router(output)
                else:
                    prob = prob * (1.0 -
                                   self.tree_modules[node].router(output))

            if not (isinstance(prob, float)):
                prob = torch.unsqueeze(prob, 1)

            # account for the split at the last node
            if self.split and nodes[-1] == self.node_split:
                node_final = nodes[-1]
                output = self.tree_modules[node_final].transform(output)
                prob_last = torch.unsqueeze(
                    self.tree_modules[node_final].router(output), 1)
                prob = torch.cat((prob_last * prob, (1.0 - prob_last) * prob),
                                 dim=1)

            # concatenate
            if i == 0:
                prob_tensor = prob
            else:
                prob_tensor = torch.cat((prob_tensor, prob), dim=1)

        return prob_tensor, leaves_up_to_node
示例#2
0
def visualise_routers_behaviours(model,
                                 data_loader,
                                 no_classes=10,
                                 objects=('0', '1', '2', '3', '4', '5', '6',
                                          '7', '8', '9'),
                                 fig_scale=None,
                                 title='',
                                 title_font=20,
                                 subtitle_font=20,
                                 axis_font=14,
                                 cuda_on=False,
                                 plot_on=True,
                                 save_as=''):
    """
    Visualise the probability of reachine a leaf node for different classes.

    Args:
        node (int) : node index. This function gets the list of all the peripheral nodes when the given
                     node is added to the tree, and computest the class probabilities.
        model (nn.Module) : your tree model
        dataloader (data loader):
    Return:

    """
    if cuda_on:
        model.cuda()
    else:
        model.cpu()

    # get the list of edge nodes on respective levels:
    tree_struct = model.tree_struct
    edge_nodes = []
    e_n = 0
    max_level = 0
    while e_n >= 0:
        edge_nodes.append(e_n)
        max_level += 1
        e_n = find_edgenode(tree_struct, max_level)

    # set up the figure size and dimension:
    num_rows = 2 * len(
        edge_nodes)  # first row for showing the class fistribution
    num_cols = len(ops.get_past_leaf_nodes(
        tree_struct, edge_nodes[-1]))  # get the list of leaf nodes
    if fig_scale == None:
        fig = plt.figure(figsize=(num_cols, num_rows))
    else:
        fig = plt.figure(figsize=(fig_scale * num_cols, fig_scale * num_rows))
    plt.suptitle(title, fontsize=title_font)

    # -------------- compute and plot stuff ------------------------------
    for level, node in enumerate(edge_nodes):
        print('Computuing histograms for level {}/{}'.format(
            level,
            len(edge_nodes) - 1))
        # compute stuff:
        y_list, p_list = [], []
        for x, y in data_loader:
            x, y = Variable(x, volatile=True), Variable(y)
            if cuda_on:
                x, y = x.cuda(), y.cuda()
            p, nodes_list = model.compute_routing_probabilities_uptonode(
                x, node)

            if cuda_on:
                p, y = p.cpu(), y.cpu()

            p_list.append(p.data.numpy())
            y_list.append(y.data.numpy())

        # compute class-specific probabilities for reaching a peripheral node
        c_list = list(range(no_classes))  # [0,1,2,3,4,5,6,7,8,9] # class list
        y_full = np.concatenate(y_list)
        p_full = np.concatenate(p_list)  # N x number of peripheral nodes

        node_class_probs = []
        for c in c_list:
            leaf_c = p_full[y_full == c].mean(axis=0)
            node_class_probs.append(leaf_c)

        # C x number of peripheral nodes
        node_class_probs = np.vstack(node_class_probs)

        # Bar chart for node-wise class distributions
        # average probabilitiy of images from a specific class routed to each node
        y_pos = np.arange(len(objects))
        for i, node_idx in enumerate(nodes_list):
            performance = node_class_probs[:, i]
            # print(num_rows, num_cols, 2*num_cols*level+i+1, num_cols*(2*level+1)+i+1 )

            ax1 = fig.add_subplot(num_rows, num_cols,
                                  2 * num_cols * level + i + 1)
            ax1.bar(y_pos, performance, align='center', alpha=0.5, color='r')
            plt.xticks(y_pos, objects, rotation='vertical', fontsize=axis_font)
            plt.ylim((0, 1))
            if i == 0:
                ax1.set_ylabel("reaching prob. per class", fontsize=axis_font)
            ax1.set_title('Node ' + str(node_idx), fontsize=subtitle_font)

        # Histogram of reaching probabilities for respective peripheral nodes:
        for i, node_idx in enumerate(nodes_list):
            ax1 = fig.add_subplot(num_rows, num_cols,
                                  num_cols * (2 * level + 1) + i + 1)
            ax1.hist(p_full[:, i], normed=False, bins=25, range=(0, 1.0))
            if i == 0:
                ax1.set_ylabel("histogram of \n reaching prob. dist.",
                               fontsize=axis_font)

    plt.subplots_adjust(wspace=0.25, hspace=0.25)

    if plot_on:
        plt.show()

    if save_as:
        # Save the full figure
        print('Save the histogram of splitting as ' + save_as)
        fig.savefig(save_as, format='png', dpi=300)