def compute_routing_probabilities_uptonode(self, input, node_idx): """ Compute the routing probabilities up to a node. Return: routing probabilities tensor (tensor) : torch tensor (N, nodes) """ leaves_up_to_node = get_past_leaf_nodes(self.tree_struct, node_idx) # for each leaf predictor, get the list of all nodes (indices) on # their paths to the root and the corresponding lef-child-status # (boolean) on all edges i.e. edge = True if the child is on the left # branch of its parent. Each element in self.paths_list is a tuple # (nodes, edges) which contains these two lists. paths_list_up_to_node = [ get_path_to_root(i, self.tree_struct) for i in leaves_up_to_node ] for i, (nodes, edges) in enumerate(paths_list_up_to_node): # compute probabilities for the given branch # if len(nodes)>1: # prob = 1.0 # else: # if it's just a root node dtype = torch.cuda.FloatTensor if self.cuda_on else torch.FloatTensor prob = Variable(torch.ones(input.size(0)).type(dtype)) output = input.clone() for node, state in zip(nodes[:-1], edges): output = self.tree_modules[node].transform(output) if state: prob = prob * self.tree_modules[node].router(output) else: prob = prob * (1.0 - self.tree_modules[node].router(output)) if not (isinstance(prob, float)): prob = torch.unsqueeze(prob, 1) # account for the split at the last node if self.split and nodes[-1] == self.node_split: node_final = nodes[-1] output = self.tree_modules[node_final].transform(output) prob_last = torch.unsqueeze( self.tree_modules[node_final].router(output), 1) prob = torch.cat((prob_last * prob, (1.0 - prob_last) * prob), dim=1) # concatenate if i == 0: prob_tensor = prob else: prob_tensor = torch.cat((prob_tensor, prob), dim=1) return prob_tensor, leaves_up_to_node
def compute_routing_probability_specificnode(self, input, node_idx): """ Compute the probability of reaching a selected node. If a batch is provided, then the sum of probabilities is computed. """ nodes, edges = get_path_to_root(node_idx, self.tree_struct) prob = 1.0 for node, edge in zip(nodes[:-1], edges): input = self.tree_modules[node].transform(input) if edge: prob = prob * self.tree_modules[node].router(input) else: prob = prob * (1.0 - self.tree_modules[node].router(input)) if not (isinstance(prob, float)): prob = torch.unsqueeze(prob, 1) prob_sum = prob.sum(dim=0) return prob_sum.data[0] else: return prob * input.size(0)
def __init__(self, tree_struct, tree_modules, split=False, node_split=None, child_left=None, child_right=None, extend=False, node_extend=None, child_extension=None, cuda_on=True, breadth_first=True, soft_decision=True): """ Initialise the class. Args: tree_struct (list): List of dictionaries each of which contains meta information about each node of the tree. tree_modules (list): List of dictionaries, each of which contains modules (nn.Module) of each node in the tree and takes the form module = {'transform': transformer_module (nn.Module), 'classifier': solver_module (nn.Module), 'router': router_module (nn.Module) } split (bool): Set True if the model is testing 'split' growth option node_split (int): Index of the node that is being split child_left (dict): Left child of the node node_split and takes the form of {'transform': transformer_module (nn.Module), 'classifier': solver_module (nn.Module), 'router': router_module (nn.Module) } child_right (dict): Right child of the node node_split and takes the form of {'transform': transformer_module (nn.Module), 'classifier': solver_module (nn.Module), 'router': router_module (nn.Module) } extend (bool): Set True if the model is testing 'extend' growth option node_extend (int): Index of the node that is being extended child_extension (dict): The extra node used to extend node node_extend. cuda_on (bool): Set True to train on a GPU. breadth_first (bool): Set True to perform bread-first forward pass. If set to False, depth-first forward pass is performed. soft_decision (bool): Set True to perform multi-path inference, which computes the predictive distribution as the mean of the conditional distributions from all the leaf nodes, weighted by the corresponding reaching probabilities. If set to False, inference based on "hard" decisions is performed. If the routers are defined with stochastic=True, then the stochastic single-path inference is used. Otherwise, the greedy single-path inference is carried out whereby the input sample traverses the tree in the directions of the highest confidence of routers. """ super(Tree, self).__init__() assert not (split and extend) # the node can only be split or extended self.soft_decision = soft_decision self.cuda_on = cuda_on self.split = split self.extend = extend self.tree_struct = tree_struct self.node_split = node_split self.node_extend = node_extend self.breadth_first = breadth_first # get list of leaf nodes: self.leaves_list = get_leaf_nodes(tree_struct) # for each leaf predictor, get the list of all nodes (indices) on # their paths to the root and the corresponding lef-child-status # (boolean) on all edges i.e. edge = True if the child is on the left # branch of its parent. Each element in self.paths_list is a tuple # (nodes, edges) which contains these two lists. self.paths_list = [ get_path_to_root(i, tree_struct) for i in self.leaves_list ] self.tree_modules = nn.ModuleList() for i, node in enumerate(tree_modules): node_modules = nn.Sequential() node_modules.add_module('transform', node["transform"]) node_modules.add_module('classifier', node["classifier"]) node_modules.add_module('router', node["router"]) self.tree_modules.append(node_modules) # add children nodes: # case (1): splitting if split: self.child_left = nn.Sequential() self.child_left.add_module('transform', child_left["transform"]) self.child_left.add_module('classifier', child_left["classifier"]) self.child_left.add_module('router', child_left["router"]) self.child_right = nn.Sequential() self.child_right.add_module('transform', child_right["transform"]) self.child_right.add_module('classifier', child_right["classifier"]) self.child_right.add_module('router', child_right["router"]) # case (2): making deeper if extend: self.child_extension = nn.Sequential() self.child_extension.add_module( 'transform', child_extension["transform"], ) self.child_extension.add_module( 'classifier', child_extension["classifier"], ) self.child_extension.add_module( 'router', child_extension["router"], )