예제 #1
0
def sample_tree(data, settings, param, cache, cache_tmp):
    p = TreeMCMC(range(data['n_train']), param, settings, cache_tmp)
    grow_nodes = [0]
    while grow_nodes:
        node_id = grow_nodes.pop(0)
        p.depth = max(p.depth, get_depth(node_id))
        log_psplit = np.log(p.compute_psplit(node_id, param))
        train_ids = p.train_ids[node_id]
        (do_not_split_node_id, feat_id_chosen, split_chosen, idx_split_global, log_sis_ratio, logprior_nodeid, \
            train_ids_left, train_ids_right, cache_tmp, loglik_left, loglik_right) \
            = p.precomputed_proposal(data, param, settings, cache, node_id, train_ids, log_psplit)
        if do_not_split_node_id:
            p.do_not_split[node_id] = True
        else:
            p.update_left_right_statistics(cache_tmp, node_id, logprior_nodeid, train_ids_left,\
                train_ids_right, loglik_left, loglik_right, feat_id_chosen, split_chosen, \
                idx_split_global, settings, param, data, cache)
            left, right = get_children_id(node_id)
            grow_nodes.append(left)
            grow_nodes.append(right)
            # create mcmc structures
            p.both_children_terminal.append(node_id)
            parent = get_parent_id(node_id) 
            if (node_id != 0) and (parent in p.non_leaf_nodes):
                p.inner_pc_pairs.append((parent, node_id))
            if node_id != 0:
                try:
                    p.both_children_terminal.remove(parent)
                except ValueError:
                    pass
    if settings.debug:
        print 'sampled new tree:'
        p.print_tree()
    return p
예제 #2
0
 def compute_log_inv_acc_p(self, node_id, param, len_both_children_terminal, loglik, grow_nodes, \
         cache, settings, data):
     # 1/acc for PRUNE is acc for GROW except for corrections to both_children_terminal 
     #       and grow_nodes list
     logprior_children = 0.0
     left, right = get_children_id(node_id)
     if not no_valid_split_exists(data, cache, self.train_ids[left], settings):
         logprior_children += np.log(self.compute_pnosplit(left, param))
     if not no_valid_split_exists(data, cache, self.train_ids[right], settings):
         logprior_children += np.log(self.compute_pnosplit(right, param))
     try:
         check_if_zero(logprior_children - self.logprior[left] - self.logprior[right])
     except AssertionError:
         print 'oh oh ... looks like a bug in compute_log_inv_acc_p'
         print 'term 1 = %s' % logprior_children
         print 'term 2 = %s, 2a = %s, 2b = %s' % (self.logprior[left]+self.logprior[right], \
                  self.logprior[left], self.logprior[right])
         print 'node_id = %s, left = %s, right = %s, logprior = %s' % (node_id, left, right, self.logprior)
         raise AssertionError
     log_inv_acc_prior = np.log(self.compute_psplit(node_id, param)) \
             - np.log(self.compute_pnosplit(node_id, param)) \
             -np.log(len_both_children_terminal) + np.log(len(grow_nodes)) \
             + logprior_children 
     log_inv_acc_loglik = (loglik - self.loglik[node_id])
     log_inv_acc = log_inv_acc_loglik + log_inv_acc_prior
     if settings.verbose >= 2:
         print 'compute_log_inv_acc_p: log_acc_loglik = %s, log_acc_prior = %s' \
                 % (-log_inv_acc_loglik, -log_inv_acc_prior)
     assert(log_inv_acc > -np.inf)
     return log_inv_acc
예제 #3
0
 def compute_log_inv_acc_p(self, node_id, param, len_both_children_terminal, loglik, grow_nodes, \
         cache, settings, data):
     # 1/acc for PRUNE is acc for GROW except for corrections to both_children_terminal 
     #       and grow_nodes list
     logprior_children = 0.0
     left, right = get_children_id(node_id)
     if not no_valid_split_exists(data, cache, self.train_ids[left], settings):
         logprior_children += np.log(self.compute_pnosplit(left, param))
     if not no_valid_split_exists(data, cache, self.train_ids[right], settings):
         logprior_children += np.log(self.compute_pnosplit(right, param))
     try:
         check_if_zero(logprior_children - self.logprior[left] - self.logprior[right])
     except AssertionError:
         print 'oh oh ... looks like a bug in compute_log_inv_acc_p'
         print 'term 1 = %s' % logprior_children
         print 'term 2 = %s, 2a = %s, 2b = %s' % (self.logprior[left]+self.logprior[right], \
                  self.logprior[left], self.logprior[right])
         print 'node_id = %s, left = %s, right = %s, logprior = %s' % (node_id, left, right, self.logprior)
         raise AssertionError
     log_inv_acc_prior = np.log(self.compute_psplit(node_id, param)) \
             - np.log(self.compute_pnosplit(node_id, param)) \
             -np.log(len_both_children_terminal) + np.log(len(grow_nodes)) \
             + logprior_children 
     log_inv_acc_loglik = (loglik - self.loglik[node_id])
     log_inv_acc = log_inv_acc_loglik + log_inv_acc_prior
     if settings.verbose >= 2:
         print 'compute_log_inv_acc_p: log_acc_loglik = %s, log_acc_prior = %s' \
                 % (-log_inv_acc_loglik, -log_inv_acc_prior)
     assert(log_inv_acc > -np.inf)
     return log_inv_acc
예제 #4
0
def sample_tree(data, settings, param, cache, cache_tmp):
    p = TreeMCMC(range(data['n_train']), param, settings, cache_tmp)
    grow_nodes = [0]
    while grow_nodes:
        node_id = grow_nodes.pop(0)
        p.depth = max(p.depth, get_depth(node_id))
        log_psplit = np.log(p.compute_psplit(node_id, param))
        train_ids = p.train_ids[node_id]
        (do_not_split_node_id, feat_id_chosen, split_chosen, idx_split_global, log_sis_ratio, logprior_nodeid, \
            train_ids_left, train_ids_right, cache_tmp, loglik_left, loglik_right) \
            = p.precomputed_proposal(data, param, settings, cache, node_id, train_ids, log_psplit)
        if do_not_split_node_id:
            p.do_not_split[node_id] = True
        else:
            p.update_left_right_statistics(cache_tmp, node_id, logprior_nodeid, train_ids_left,\
                train_ids_right, loglik_left, loglik_right, feat_id_chosen, split_chosen, \
                idx_split_global, settings, param, data, cache)
            left, right = get_children_id(node_id)
            grow_nodes.append(left)
            grow_nodes.append(right)
            # create mcmc structures
            p.both_children_terminal.append(node_id)
            parent = get_parent_id(node_id) 
            if (node_id != 0) and (parent in p.non_leaf_nodes):
                p.inner_pc_pairs.append((parent, node_id))
            if node_id != 0:
                try:
                    p.both_children_terminal.remove(parent)
                except ValueError:
                    pass
    if settings.debug:
        print 'sampled new tree:'
        p.print_tree()
    return p
예제 #5
0
파일: pg.py 프로젝트: Sandy4321/pgbart
def grow_next_pg(p, tree_pg, itr, settings):
    p.log_sis_ratio = 0.
    p.do_not_grow = False
    p.grow_nodes = []
    try:
        nodes_processed = tree_pg.nodes_processed_itr[itr]
        p.nodes_processed_itr.append(nodes_processed[:])
        for node_id in nodes_processed[:-1]:
            assert(tree_pg.do_not_split[node_id])
            p.do_not_split[node_id] = True
        node_id = nodes_processed[-1]
        if node_id in tree_pg.node_info:
            left, right = get_children_id(node_id)
            log_sis_ratio_loglik_new = tree_pg.loglik[left] + tree_pg.loglik[right] - tree_pg.loglik[node_id]
            try:
                log_sis_ratio_loglik_old, log_sis_ratio_prior = tree_pg.log_sis_ratio_d[node_id] 
            except KeyError:
                print 'tree_pg: node_info = %s, log_sis_ratio_d = %s' % (tree_pg.node_info, tree_pg.log_sis_ratio_d)
                raise KeyError
            if settings.verbose >= 2:
                print 'log_sis_ratio_loglik_old = %s' % log_sis_ratio_loglik_old
                print 'log_sis_ratio_loglik_new = %s' % log_sis_ratio_loglik_new
            p.log_sis_ratio = log_sis_ratio_loglik_new + log_sis_ratio_prior
            tree_pg.log_sis_ratio_d[node_id] = (log_sis_ratio_loglik_new, log_sis_ratio_prior)
            p.log_sis_ratio_d[node_id] = tree_pg.log_sis_ratio_d[node_id]
            p.non_leaf_nodes.append(node_id)
            try:
                p.leaf_nodes.remove(node_id)
            except ValueError:
                print 'warning: unable to remove node_id = %s from leaf_nodes = %s' % (node_id, p.leaf_nodes)
                pass
            p.leaf_nodes.append(left)
            p.leaf_nodes.append(right)
            # copying relevant bits
            p.node_info[node_id] = tree_pg.node_info[node_id]
            p.logprior[node_id] = tree_pg.logprior[node_id]
            for node_id_child in [left, right]:
                p.do_not_split[node_id_child] = False   # can look up where node_id_child occurred in nodes_processed_itr
                p.loglik[node_id_child] = tree_pg.loglik[node_id_child]
                p.logprior[node_id_child] = tree_pg.logprior[node_id_child]
                p.train_ids[node_id_child] = tree_pg.train_ids[node_id_child]
                p.sum_y[node_id_child] = tree_pg.sum_y[node_id_child]
                p.sum_y2[node_id_child] = tree_pg.sum_y2[node_id_child]
                p.param_n[node_id_child] = tree_pg.param_n[node_id_child]
                p.n_points[node_id_child] = tree_pg.n_points[node_id_child]
        if settings.verbose >= 2:
            print 'p.leaf_nodes = %s' % p.leaf_nodes
            print 'p.non_leaf_nodes = %s' % p.non_leaf_nodes
            print 'p.node_info.keys() = %s' % sorted(p.node_info.keys())
        try:
            p.grow_nodes = tree_pg.grow_nodes_itr[itr+1]
            p.log_sis_ratio_d = tree_pg.log_sis_ratio_d
            p.depth = tree_pg.depth
        except IndexError:
            p.do_not_grow = True
    except IndexError:
        p.do_not_grow = True
예제 #6
0
def grow_next_pg(p, tree_pg, itr, settings):
    p.log_sis_ratio = 0.
    p.do_not_grow = False
    p.grow_nodes = []
    try:
        nodes_processed = tree_pg.nodes_processed_itr[itr]
        p.nodes_processed_itr.append(nodes_processed[:])
        for node_id in nodes_processed[:-1]:
            assert(tree_pg.do_not_split[node_id])
            p.do_not_split[node_id] = True
        node_id = nodes_processed[-1]
        if node_id in tree_pg.node_info:
            left, right = get_children_id(node_id)
            log_sis_ratio_loglik_new = tree_pg.loglik[left] + tree_pg.loglik[right] - tree_pg.loglik[node_id]
            try:
                log_sis_ratio_loglik_old, log_sis_ratio_prior = tree_pg.log_sis_ratio_d[node_id] 
            except KeyError:
                print('tree_pg: node_info = %s, log_sis_ratio_d = %s' % (tree_pg.node_info, tree_pg.log_sis_ratio_d))
                raise KeyError
            if settings.verbose >= 2:
                print('log_sis_ratio_loglik_old = %s' % log_sis_ratio_loglik_old)
                print('log_sis_ratio_loglik_new = %s' % log_sis_ratio_loglik_new)
            p.log_sis_ratio = log_sis_ratio_loglik_new + log_sis_ratio_prior
            tree_pg.log_sis_ratio_d[node_id] = (log_sis_ratio_loglik_new, log_sis_ratio_prior)
            p.log_sis_ratio_d[node_id] = tree_pg.log_sis_ratio_d[node_id]
            p.non_leaf_nodes.append(node_id)
            try:
                p.leaf_nodes.remove(node_id)
            except ValueError:
                print('warning: unable to remove node_id = %s from leaf_nodes = %s' % (node_id, p.leaf_nodes))
                pass
            p.leaf_nodes.append(left)
            p.leaf_nodes.append(right)
            # copying relevant bits
            p.node_info[node_id] = tree_pg.node_info[node_id]
            p.logprior[node_id] = tree_pg.logprior[node_id]
            for node_id_child in [left, right]:
                p.do_not_split[node_id_child] = False   # can look up where node_id_child occurred in nodes_processed_itr
                p.loglik[node_id_child] = tree_pg.loglik[node_id_child]
                p.logprior[node_id_child] = tree_pg.logprior[node_id_child]
                p.train_ids[node_id_child] = tree_pg.train_ids[node_id_child]
                p.sum_y[node_id_child] = tree_pg.sum_y[node_id_child]
                p.sum_y2[node_id_child] = tree_pg.sum_y2[node_id_child]
                p.param_n[node_id_child] = tree_pg.param_n[node_id_child]
                p.n_points[node_id_child] = tree_pg.n_points[node_id_child]
        if settings.verbose >= 2:
            print('p.leaf_nodes = %s' % p.leaf_nodes)
            print('p.non_leaf_nodes = %s' % p.non_leaf_nodes)
            print('p.node_info.keys() = %s' % sorted(p.node_info.keys()))
        try:
            p.grow_nodes = tree_pg.grow_nodes_itr[itr+1]
            p.log_sis_ratio_d = tree_pg.log_sis_ratio_d
            p.depth = tree_pg.depth
        except IndexError:
            p.do_not_grow = True
    except IndexError:
        p.do_not_grow = True
예제 #7
0
def grow_next_pg(p,tree_pg,iterations,settings):
	p.log_sis_ratio = 0.
	p.do_not_grow = False
	p.grow_nodes = []
	try:
		processed_nodes = tree_pg.processed_nodes_iterations[iterations]
		p.processed_nodes_iterations.append(processed_nodes[:])
		for id_of_node in processed_nodes[:-1]:
			assert(tree_pg.do_not_split[id_of_node])
			p.do_not_split[id_of_node] = True
		id_of_node = processed_nodes[-1]
		if id_of_node in tree_pg.node_info:
			left_node,right_node = get_children_id(id_of_node)
			log_sis_ratio_loglik_new = tree_pg.loglik[left_node] + tree_pg.loglik[right_node] - tree_pg.loglik[id_of_node]
			try:
				log_sis_ratio_loglik_old , log_sis_ratio_prior = tree_pg.log_sis_ratio_d[id_of_node]
			except KeyError:
				print('tree_pg: node_info =%s, log_sis_ratio_d = %s' % (tree_pg.node_info,tree_pg.log_sis_ratio_d))
				raise KeyError
			if settings.verbose >= 2:
				print('log_sis_ratio_loglik_old =%s' % log_sis_ratio_loglik_old)
				print('log_sis_ratio_loglik_new = %s' % log_sis_ratio_loglik_new)
			p.log_sis_ratio = log_sis_ratio_loglik_new + log_sis_ratio_prior
			tree_pg.log_sis_ratio_d[id_of_node] = (log_sis_ratio_loglik_new,log_sis_ratio_prior)
			p.log_sis_ratio_d[id_of_node] = tree_pg.log_sis_ratio_d[id_of_node]
			p.non_leaf_nodes.append(id_of_node)
			try:
				p.leaf_nodes.remove(id_of_node)
			except ValueError:
				print('Warning !!! Unable to remove id_of_node =%s from leaf_nodes = %s' %(id_of_node,p.leaf_nodes))
				pass
			p.leaf_nodes.append(left_node)
			p.leaf_nodes.append(right_node)
			p.node_info[id_of_node] = tree_pg.node_info[id_of_node]
			p.logprior[id_of_node] = tree_pg.logprior[id_of_node]
			for child_node_id in [left_node,right_node]:
				p.do_not_split[child_node_id] = False
				p.loglik[child_node_id] = tree_pg.loglik[child_node_id]
				p.logprior[child_node_id] = tree_pg.logprior[child_node_id]
				p.train_ids[child_node_id] = tree_pg.train_ids[child_node_id]
				p.sum_y[child_node_id] = tree_pg.sum_y[child_node_id]
				p.sum_y2[child_node_id] = tree_pg.sum_y2[child_node_id]
				p.param_n[child_node_id] = tree_pg.param_n[child_node_id]
				p.n_points[child_node_id] = tree_pg.n_points[child_node_id]
		if settings.verbose >= 2:
			print('p.leaf_nodes = %s' % p.leaf_nodes)
			print('p.non_leaf_nodes =%s' % p.non_leaf_nodes)
			print('p.node_info.keys() = %s' % sorted(p.node_info.keys()))
		try:
			p.grow_nodes = tree_pg.grow_nodes_iterations[iterations+1]
			p.log_sis_ratio_d = tree_pg.log_sis_ratio_d
			p.depth = tree_pg.depth
		except IndexError:
			p.do_not_grow = True
	except IndexError:
		p.do_not_grow = True
예제 #8
0
 def get_nodes_subtree(self, node_id):
     # NOTE: current node_id is included in nodes_subtree as well
     node_list = []
     expand = [node_id]
     while len(expand) > 0:
         node = expand.pop(0) 
         node_list.append(node)
         if node not in self.leaf_nodes:
             left, right = get_children_id(node)
             expand.append(left)
             expand.append(right)
     return node_list
예제 #9
0
 def get_nodes_subtree(self, node_id):
     # NOTE: current node_id is included in nodes_subtree as well
     node_list = []
     expand = [node_id]
     while len(expand) > 0:
         node = expand.pop(0) 
         node_list.append(node)
         if node not in self.leaf_nodes:
             left, right = get_children_id(node)
             expand.append(left)
             expand.append(right)
     return node_list
예제 #10
0
 def evaluate_new_subtree(self, data, node_id_start, param, nodes_subtree,
                          cache, settings):
     for i in self.train_ids[node_id_start]:
         x_, y_ = data['x_train'][i, :], data['y_train'][i]
         node_id = copy(node_id_start)
         while True:
             self.sum_y_new[node_id] += y_
             self.sum_y2_new[node_id] += y_**2
             self.n_points_new[node_id] += 1
             self.train_ids_new[node_id] = np.append(
                 self.train_ids_new[node_id], i)
             if node_id in self.leaf_nodes:
                 break
             left, right = get_children_id(node_id)
             feat_id, split, idx_split_global = self.node_info_new[
                 node_id]  # splitting on new criteria
             if x_[feat_id] <= split:
                 node_id = left
             else:
                 node_id = right
     for node_id in nodes_subtree:
         self.loglik_new[node_id] = -np.inf
         if self.n_points_new[node_id] > 0:
             self.loglik_new[node_id], self.param_n_new[node_id] = \
                     compute_normal_normalizer(self.sum_y_new[node_id], self.sum_y2_new[node_id], \
                             self.n_points_new[node_id], param, cache, settings)
         if node_id in self.leaf_nodes:
             if stop_split(self.train_ids_new[node_id], settings, data,
                           cache):
                 # if leaf is empty, logprior_new[node_id] = 0.0 is incorrect; however
                 #      loglik_new[node_id] = -np.inf will reject move to a tree with empty leaves
                 self.logprior_new[node_id] = 0.0
             else:
                 # node with just 1 data point earlier could have more data points now
                 self.logprior_new[node_id] = np.log(
                     self.compute_pnosplit(node_id, param))
         else:
             # split probability might have changed if train_ids have changed
             self.recompute_prob_split(data, param, settings, cache,
                                       node_id)
     if settings.debug == 1:
         try:
             check_if_zero(self.loglik[node_id_start] -
                           self.loglik_new[node_id_start])
         except AssertionError:
             print('train_ids[node_id_start] = %s, train_ids_new[node_id_start] = %s' \
                     % (self.train_ids[node_id_start], self.train_ids_new[node_id_start]))
             raise AssertionError
예제 #11
0
	def process_node_id(self,data,param , settings, cache, id_of_node):
		if self.do_not_split[id_of_node]:
			log_sis_ratio =0.0
		else:
			log_psplit = np.log(self.compute_psplit(id_of_node,param))
			train_ids = self.train_ids[id_of_node]
			left_node,right_node = get_children_id(id_of_node)
			if setings.verbose >= 4:
				print('Train_ids for this node =%s' % train_ids)
			(dont_split_node_id,id_feat_chosen,chosen_split,split_global_idx,log_sis_ratio,logprior_nodeid,training_ids_left,training_ids_right,cache_tmp,loglikelihood_left,loglikelihood_right) =\
			self.prior_proposal(data,param,settings,cache,id_of_node,train_ids,log_psplit)
			if dont_split_node_id:
				self.do_not_split[id_of_node] = True
			else:
				self.update_left_right_statistics(cache_tmp,id_of_node,logprior_nodeid,training_ids_left,training_ids_right,loglikelihood_left,loglikelihood_right,id_feat_chosen,chosen_split,split_global_idx,settings,param,data,cache)
				self.grow_nodes.append(left_node)
				self.grow_nodes.append(right_node)
		return (log_sis_ratio)
예제 #12
0
 def evaluate_new_subtree(self, data, node_id_start, param, nodes_subtree, cache, settings):
     for i in self.train_ids[node_id_start]:
         x_, y_ = data['x_train'][i, :], data['y_train'][i]
         node_id = copy(node_id_start)
         while True:
             self.sum_y_new[node_id] += y_
             self.sum_y2_new[node_id] += y_ ** 2
             self.n_points_new[node_id] += 1
             self.train_ids_new[node_id] = np.append(self.train_ids_new[node_id], i)
             if node_id in self.leaf_nodes:
                 break
             left, right = get_children_id(node_id)
             feat_id, split, idx_split_global = self.node_info_new[node_id]   # splitting on new criteria
             if x_[feat_id] <= split:
                 node_id = left
             else:
                 node_id = right
     for node_id in nodes_subtree:
         self.loglik_new[node_id] = -np.inf
         if self.n_points_new[node_id] > 0:
             self.loglik_new[node_id], self.param_n_new[node_id] = \
                     compute_normal_normalizer(self.sum_y_new[node_id], self.sum_y2_new[node_id], \
                             self.n_points_new[node_id], param, cache, settings)
         if node_id in self.leaf_nodes:
             if stop_split(self.train_ids_new[node_id], settings, data, cache):
             # if leaf is empty, logprior_new[node_id] = 0.0 is incorrect; however
             #      loglik_new[node_id] = -np.inf will reject move to a tree with empty leaves
                 self.logprior_new[node_id] = 0.0
             else:
                 # node with just 1 data point earlier could have more data points now 
                 self.logprior_new[node_id] = np.log(self.compute_pnosplit(node_id, param))
         else:
             # split probability might have changed if train_ids have changed
             self.recompute_prob_split(data, param, settings, cache, node_id)
     if settings.debug == 1:
         try:
             check_if_zero(self.loglik[node_id_start] - self.loglik_new[node_id_start])
         except AssertionError:
             print 'train_ids[node_id_start] = %s, train_ids_new[node_id_start] = %s' \
                     % (self.train_ids[node_id_start], self.train_ids_new[node_id_start])
             raise AssertionError
예제 #13
0
 def compute_log_acc_g(self, node_id, param, len_both_children_terminal, loglik, \
         train_ids_left, train_ids_right, cache, settings, data, grow_nodes):
     # effect of do_not_split does not matter for node_id since it has children
     logprior_children = 0.0
     left, right = get_children_id(node_id)
     if not no_valid_split_exists(data, cache, train_ids_left, settings):
         logprior_children += np.log(self.compute_pnosplit(left, param))
     if not no_valid_split_exists(data, cache, train_ids_right, settings):
         logprior_children += np.log(self.compute_pnosplit(right, param))
     log_acc_prior = np.log(self.compute_psplit(node_id, param)) \
             -np.log(self.compute_pnosplit(node_id, param)) \
         -np.log(len_both_children_terminal) + np.log(len(grow_nodes)) \
         + logprior_children 
     log_acc_loglik = (loglik - self.loglik[node_id])
     log_acc = log_acc_prior + log_acc_loglik
     if settings.verbose >= 2:
         print 'compute_log_acc_g: log_acc_loglik = %s, log_acc_prior = %s' \
                 % (log_acc_loglik, log_acc_prior)
     if loglik == -np.inf:   # just need to ensure that an invalid split is not grown
         log_acc = -np.inf
     return log_acc
예제 #14
0
파일: pg.py 프로젝트: Sandy4321/pgbart
 def process_node_id(self, data, param, settings, cache, node_id):
     if self.do_not_split[node_id]:
         log_sis_ratio = 0.0
     else:
         log_psplit = np.log(self.compute_psplit(node_id, param))
         train_ids = self.train_ids[node_id]
         left, right = get_children_id(node_id)
         if settings.verbose >= 4:
             print 'train_ids for this node = %s' % train_ids
         (do_not_split_node_id, feat_id_chosen, split_chosen, idx_split_global, log_sis_ratio, logprior_nodeid, \
             train_ids_left, train_ids_right, cache_tmp, loglik_left, loglik_right) \
             = self.prior_proposal(data, param, settings, cache, node_id, train_ids, log_psplit)
         if do_not_split_node_id:
             self.do_not_split[node_id] = True
         else:
             self.update_left_right_statistics(cache_tmp, node_id, logprior_nodeid, train_ids_left,\
                 train_ids_right, loglik_left, loglik_right, feat_id_chosen, split_chosen, \
                 idx_split_global, settings, param, data, cache)
             self.grow_nodes.append(left)
             self.grow_nodes.append(right)
     return (log_sis_ratio)
예제 #15
0
 def process_node_id(self, data, param, settings, cache, node_id):
     if self.do_not_split[node_id]:
         log_sis_ratio = 0.0
     else:
         log_psplit = np.log(self.compute_psplit(node_id, param))
         train_ids = self.train_ids[node_id]
         left, right = get_children_id(node_id)
         if settings.verbose >= 4:
             print('train_ids for this node = %s' % train_ids)
         (do_not_split_node_id, feat_id_chosen, split_chosen, idx_split_global, log_sis_ratio, logprior_nodeid, \
             train_ids_left, train_ids_right, cache_tmp, loglik_left, loglik_right) \
             = self.prior_proposal(data, param, settings, cache, node_id, train_ids, log_psplit)
         if do_not_split_node_id:
             self.do_not_split[node_id] = True
         else:
             self.update_left_right_statistics(cache_tmp, node_id, logprior_nodeid, train_ids_left,\
                 train_ids_right, loglik_left, loglik_right, feat_id_chosen, split_chosen, \
                 idx_split_global, settings, param, data, cache)
             self.grow_nodes.append(left)
             self.grow_nodes.append(right)
     return (log_sis_ratio)
예제 #16
0
 def compute_log_acc_g(self, node_id, param, len_both_children_terminal, loglik, \
         train_ids_left, train_ids_right, cache, settings, data, grow_nodes):
     # effect of do_not_split does not matter for node_id since it has children
     logprior_children = 0.0
     left, right = get_children_id(node_id)
     if not no_valid_split_exists(data, cache, train_ids_left, settings):
         logprior_children += np.log(self.compute_pnosplit(left, param))
     if not no_valid_split_exists(data, cache, train_ids_right, settings):
         logprior_children += np.log(self.compute_pnosplit(right, param))
     log_acc_prior = np.log(self.compute_psplit(node_id, param)) \
             -np.log(self.compute_pnosplit(node_id, param)) \
         -np.log(len_both_children_terminal) + np.log(len(grow_nodes)) \
         + logprior_children
     log_acc_loglik = (loglik - self.loglik[node_id])
     log_acc = log_acc_prior + log_acc_loglik
     if settings.verbose >= 2:
         print('compute_log_acc_g: log_acc_loglik = %s, log_acc_prior = %s' \
                 % (log_acc_loglik, log_acc_prior))
     if loglik == -np.inf:  # just need to ensure that an invalid split is not grown
         log_acc = -np.inf
     return log_acc
예제 #17
0
 def sample(self, data, settings, param, cache):
     if settings.mcmc_type == 'growprune':
         step_id = random.randint(0, 1)  # only grow and prune moves permitted
     elif settings.mcmc_type == 'cgm':
         step_id = random.randint(0, 3)  # all 4 moves equally likely (or think of 50% grow/prune, 25% change, 25% swap)
     else:
         raise Exception('invalid mcmc_type')
     log_acc = -np.inf
     log_r = 0
     self.grow_nodes = [n_id for n_id in self.leaf_nodes \
                 if not stop_split(self.train_ids[n_id], settings, data, cache)]
     grow_nodes = self.grow_nodes
     if step_id == 0:        # GROW step
         if not grow_nodes:
             change = False
         else:
             node_id = random.choice(grow_nodes)
             if settings.verbose >= 1:
                 print 'grow_nodes = %s, chosen node_id = %s' % (grow_nodes, node_id)
             do_not_split_node_id, feat_id, split, idx_split_global, logprior_nodeid = \
                     self.sample_split_prior(data, param, settings, cache, node_id)
             assert not do_not_split_node_id
             if settings.verbose >= 1:
                 print 'grow: do_not_split = %s, feat_id = %s, split = %s' \
                         % (do_not_split_node_id, feat_id, split)
             train_ids = self.train_ids[node_id]
             (train_ids_left, train_ids_right, cache_tmp, loglik_left, loglik_right) = \
                 compute_left_right_statistics(data, param, cache, train_ids, \
                     feat_id, split, settings)
             loglik = loglik_left + loglik_right
             len_both_children_terminal_new = len(self.both_children_terminal)
             if get_sibling_id(node_id) not in self.leaf_nodes:
                 len_both_children_terminal_new += 1
             log_acc = self.compute_log_acc_g(node_id, param, len_both_children_terminal_new, \
                         loglik, train_ids_left, train_ids_right, cache, settings, data, grow_nodes)
             log_r = np.log(np.random.rand(1))
             if log_r <= log_acc:
                 self.update_left_right_statistics(cache_tmp, node_id, logprior_nodeid, \
                         train_ids_left, train_ids_right, loglik_left, loglik_right, \
                         feat_id, split, idx_split_global, settings, param, data, cache)
                 # MCMC specific data structure updates
                 self.both_children_terminal.append(node_id)
                 parent = get_parent_id(node_id) 
                 if (node_id != 0) and (parent in self.non_leaf_nodes):
                     self.inner_pc_pairs.append((parent, node_id))
                 sibling = get_sibling_id(node_id)
                 if sibling in self.leaf_nodes:
                     self.both_children_terminal.remove(parent)
                 change = True
             else:
                 change = False
     elif step_id == 1:      # PRUNE step
         if not self.both_children_terminal:
             change = False      # nothing to prune here
         else:
             node_id = random.choice(self.both_children_terminal)
             feat_id = self.node_info[node_id][0]
             if settings.verbose >= 1:
                 print 'prune: node_id = %s, feat_id = %s' % (node_id, feat_id)
             left, right = get_children_id(node_id)
             loglik = self.loglik[left] + self.loglik[right]
             len_both_children_new = len(self.both_children_terminal)
             grow_nodes_tmp = grow_nodes[:]
             grow_nodes_tmp.append(node_id)
             try:
                 grow_nodes_tmp.remove(left)
             except ValueError:
                 pass
             try:
                 grow_nodes_tmp.remove(right)
             except ValueError:
                 pass
             log_acc = - self.compute_log_inv_acc_p(node_id, param, len_both_children_new, \
                             loglik, grow_nodes_tmp, cache, settings, data)
             log_r = np.log(np.random.rand(1))
             if log_r <= log_acc:
                 self.remove_leaf_node_statistics(left, settings)
                 self.remove_leaf_node_statistics(right, settings)
                 self.leaf_nodes.append(node_id)
                 self.non_leaf_nodes.remove(node_id)
                 self.logprior[node_id] = np.log(self.compute_pnosplit(node_id, param))
                 # OK to set logprior as above since we know that a valid split exists
                 # MCMC specific data structure updates
                 self.both_children_terminal.remove(node_id)
                 parent = get_parent_id(node_id) 
                 if (node_id != 0) and (parent in self.non_leaf_nodes):
                     self.inner_pc_pairs.remove((parent, node_id))
                 if node_id != 0:
                     sibling = get_sibling_id(node_id) 
                     if sibling in self.leaf_nodes:
                         if settings.debug == 1:
                             assert(parent not in self.both_children_terminal)
                         self.both_children_terminal.append(parent)
                 change = True
             else:
                 change = False
     elif step_id == 2:      # CHANGE
         if not self.non_leaf_nodes:
             change = False
         else:
             node_id = random.choice(self.non_leaf_nodes)
             do_not_split_node_id, feat_id, split, idx_split_global, logprior_nodeid = \
                     self.sample_split_prior(data, param, settings, cache, node_id)
             if settings.verbose >= 1:
                 print 'change: node_id = %s, do_not_split = %s, feat_id = %s, split = %s' \
                         % (node_id, do_not_split_node_id, feat_id, split)
             # Note: this just samples a split criterion, not guaranteed to "change" 
             assert(not do_not_split_node_id)
             nodes_subtree = self.get_nodes_subtree(node_id)
             nodes_not_in_subtree = self.get_nodes_not_in_subtree(node_id)
             if settings.debug == 1:
                 set1 = set(list(nodes_subtree) + list(nodes_not_in_subtree))
                 set2 = set(self.leaf_nodes + self.non_leaf_nodes)
                 assert(sorted(set1) == sorted(set2))
             self.create_new_statistics(nodes_subtree, nodes_not_in_subtree, node_id, settings)
             self.node_info_new[node_id] = (feat_id, split, idx_split_global)         
             self.evaluate_new_subtree(data, node_id, param, nodes_subtree, cache, settings)
             # log_acc will be be modified below
             log_acc_tmp, loglik_diff, logprior_diff = self.compute_log_acc_cs(nodes_subtree, node_id)
             if settings.debug == 1:
                 self.check_if_same(log_acc_tmp, loglik_diff, logprior_diff)
             log_acc = log_acc_tmp + self.logprior[node_id] - self.logprior_new[node_id]
             log_r = np.log(np.random.rand(1))
             if log_r <= log_acc:
                 self.node_info[node_id] = copy(self.node_info_new[node_id])
                 self.update_subtree(node_id, nodes_subtree, settings)
                 change = True
             else:
                 change = False
     elif step_id == 3:      # SWAP
         if not self.inner_pc_pairs:
             change = False 
         else:
             node_id, child_id = random.choice(self.inner_pc_pairs)
             nodes_subtree = self.get_nodes_subtree(node_id)
             nodes_not_in_subtree = self.get_nodes_not_in_subtree(node_id)
             if settings.debug == 1:
                 set1 = set(list(nodes_subtree) + list(nodes_not_in_subtree))
                 set2 = set(self.leaf_nodes + self.non_leaf_nodes)
                 assert(sorted(set1) == sorted(set2))
             self.create_new_statistics(nodes_subtree, nodes_not_in_subtree, node_id, settings)
             self.node_info_new[node_id] = copy(self.node_info[child_id])
             self.node_info_new[child_id] = copy(self.node_info[node_id])
             if settings.verbose >= 1:
                 print 'swap: node_id = %s, child_id = %s' % (node_id, child_id)
                 print 'node_info[node_id] = %s, node_info[child_id] = %s' \
                         % (self.node_info[node_id], self.node_info[child_id])
             self.evaluate_new_subtree(data, node_id, param, nodes_subtree, cache, settings)
             log_acc, loglik_diff, logprior_diff = self.compute_log_acc_cs(nodes_subtree, node_id)
             if settings.debug == 1:
                 self.check_if_same(log_acc, loglik_diff, logprior_diff)
             log_r = np.log(np.random.rand(1))
             if log_r <= log_acc:
                 self.node_info[node_id] = copy(self.node_info_new[node_id])
                 self.node_info[child_id] = copy(self.node_info_new[child_id])
                 self.update_subtree(node_id, nodes_subtree, settings)
                 change = True
             else:
                 change = False
     if settings.verbose >= 1:
         print 'trying move: step_id = %d, move = %s, log_acc = %s, log_r = %s' \
                 % (step_id, STEP_NAMES[step_id], log_acc, log_r)
     if change:
         self.depth = max([get_depth(node_id) for node_id in \
                 self.leaf_nodes])
         self.loglik_current = sum([self.loglik[node_id] for node_id in \
                 self.leaf_nodes])
         if settings.verbose >= 1:
             print 'accepted move: step_id = %d, move = %s' % (step_id, STEP_NAMES[step_id])
             self.print_stuff()
     if settings.debug == 1:
         both_children_terminal, inner_pc_pairs = self.recompute_mcmc_data_structures()
         print '\nstats from recompute_mcmc_data_structures'
         print 'both_children_terminal = %s' % both_children_terminal
         print 'inner_pc_pairs = %s' % inner_pc_pairs
         assert(sorted(both_children_terminal) == sorted(self.both_children_terminal))
         assert(sorted(inner_pc_pairs) == sorted(self.inner_pc_pairs))
         grow_nodes_new = [n_id for n_id in self.leaf_nodes \
                 if not stop_split(self.train_ids[n_id], settings, data, cache)]
         if change and (step_id == 1):
             print 'grow_nodes_new = %s, grow_nodes_tmp = %s' % (sorted(grow_nodes_new), sorted(grow_nodes_tmp))
             assert(sorted(grow_nodes_new) == sorted(grow_nodes_tmp))
     return (change, step_id)
예제 #18
0
    def sample(self, data, settings, param, cache,databc,cachebc):
        if settings.mcmc_type == 'growprune':
            step_id = random.randint(0, 1)  # only grow and prune moves permitted
        elif settings.mcmc_type == 'cgm':
            step_id = random_pick([0,1,2,3], [0.25, 0.25, 0.4, 0.1])
        else:
            raise Exception('invalid mcmc_type')
        log_acc = -np.inf
        log_r = 0
        self.grow_nodes = [n_id for n_id in self.leaf_nodes \
                    if not stop_split(self.train_ids[n_id], settings, data, cache)]
        grow_nodes = self.grow_nodes

        if step_id == 0:
            if not grow_nodes:
                change = False
            else:
                node_id = random.choice(grow_nodes)
                if settings.verbose >= 1:
                    print 'grow_nodes = %s, chosen node_id = %s' % (grow_nodes, node_id)

                do_not_split_node_id, feat_id, split, idx_split_global, logprior_nodeid = \
                        self.sample_split_prior(data, param, settings, cache, node_id,databc,cachebc)

                if do_not_split_node_id==True:
                    change=False

                if do_not_split_node_id==False:

                    if settings.verbose >= 1:
                        print 'grow: do_not_split = %s, feat_id = %s, split = %s' \
                                % (do_not_split_node_id, feat_id, split)
                    train_ids = self.train_ids[node_id]

                    t0 = time.time()
                    if settings.parallelize_compute_statistics==0:
                        (train_ids_left, train_ids_right, cache_tmp, loglik_left, loglik_right) = \
                                compute_left_right_statistics(data, param, cache, train_ids, feat_id, split, settings)
                    else:
                        (cache_tmp, loglik_left, loglik_right) = \
                                compute_left_right_statistics_parallel(data, param, cache, train_ids, feat_id, split, settings, databc)

                        cond = data['x_train'][train_ids, feat_id] <= split
                        train_ids_left = train_ids[cond]
                        train_ids_right = train_ids[~cond]

                    if settings.timer==1:
                        settings.parallelize_compute_statistics_time += time.time()-t0

                    loglik = loglik_left + loglik_right

                    len_both_children_terminal_new = len(self.both_children_terminal)
                    if get_sibling_id(node_id) not in self.leaf_nodes:
                        len_both_children_terminal_new += 1


                    log_acc = self.compute_log_acc_g(node_id, param, len_both_children_terminal_new, \
                                loglik, train_ids_left, train_ids_right, cache, settings, data, grow_nodes)

                    log_r = np.log(np.random.rand(1))
                    if log_r <= log_acc:

                        self.update_left_right_statistics(cache_tmp, node_id, logprior_nodeid, \
                                train_ids_left, train_ids_right, loglik_left, loglik_right, \
                                feat_id, split, idx_split_global, settings, param, data, cache)

                        # MCMC specific data structure updates
                        self.both_children_terminal.append(node_id)
                        parent = get_parent_id(node_id)
                        if (node_id != 0) and (parent in self.non_leaf_nodes):
                            self.inner_pc_pairs.append((parent, node_id))
                        sibling = get_sibling_id(node_id)
                        if sibling in self.leaf_nodes:
                            self.both_children_terminal.remove(parent)
                        change = True
                    else:
                        change = False
        elif step_id == 1:
            if not self.both_children_terminal:
                change = False
            else:
                node_id = random.choice(self.both_children_terminal)
                feat_id = self.node_info[node_id][0]
                if settings.verbose >= 1:
                    print 'prune: node_id = %s, feat_id = %s' % (node_id, feat_id)
                left, right = get_children_id(node_id)
                loglik = self.loglik[left] + self.loglik[right]
                len_both_children_new = len(self.both_children_terminal)
                grow_nodes_tmp = grow_nodes[:]
                grow_nodes_tmp.append(node_id)
                try:
                    grow_nodes_tmp.remove(left)
                except ValueError:
                    pass
                try:
                    grow_nodes_tmp.remove(right)
                except ValueError:
                    pass
                log_acc = - self.compute_log_inv_acc_p(node_id, param, len_both_children_new, \
                                loglik, grow_nodes_tmp, cache, settings, data)
                log_r = np.log(np.random.rand(1))
                if log_r <= log_acc:
                    self.remove_leaf_node_statistics(left, settings)
                    self.remove_leaf_node_statistics(right, settings)
                    self.leaf_nodes.append(node_id)
                    self.non_leaf_nodes.remove(node_id)
                    self.logprior[node_id] = np.log(self.compute_pnosplit(node_id, param))
                    # OK to set logprior as above since we know that a valid split exists
                    # MCMC specific data structure updates
                    self.both_children_terminal.remove(node_id)
                    parent = get_parent_id(node_id) 
                    if (node_id != 0) and (parent in self.non_leaf_nodes):
                        self.inner_pc_pairs.remove((parent, node_id))
                    if node_id != 0:
                        sibling = get_sibling_id(node_id) 
                        if sibling in self.leaf_nodes:
                            if settings.debug == 1:
                                assert(parent not in self.both_children_terminal)
                            self.both_children_terminal.append(parent)
                    change = True
                else:
                    change = False
        elif step_id == 2:
            if not self.non_leaf_nodes:
                change = False
            else:
                node_id = random.choice(self.non_leaf_nodes)

                do_not_split_node_id, feat_id, split, idx_split_global, logprior_nodeid = \
                        self.sample_split_prior(data, param, settings, cache, node_id,databc,cachebc)
                if settings.verbose >= 1:
                    print 'change: node_id = %s, do_not_split = %s, feat_id = %s, split = %s' \
                            % (node_id, do_not_split_node_id, feat_id, split)
                # Note: this just samples a split criterion, not guaranteed to "change" 
                #assert(not do_not_split_node_id)

                if do_not_split_node_id==True:
                    change = False

                if do_not_split_node_id==False:

                    nodes_subtree = self.get_nodes_subtree(node_id)
                    nodes_not_in_subtree = self.get_nodes_not_in_subtree(node_id)
                    if settings.debug == 1:
                        set1 = set(list(nodes_subtree) + list(nodes_not_in_subtree))
                        set2 = set(self.leaf_nodes + self.non_leaf_nodes)
                        assert(sorted(set1) == sorted(set2))
                    self.create_new_statistics(nodes_subtree, nodes_not_in_subtree, node_id, settings)
                    self.node_info_new[node_id] = (feat_id, split, idx_split_global)
                    self.evaluate_new_subtree(data, node_id, param, nodes_subtree, cache, settings, databc, cachebc)
                    # log_acc will be be modified below
                    log_acc_tmp, loglik_diff, logprior_diff = self.compute_log_acc_cs(nodes_subtree, node_id)
                    if settings.debug == 1:
                        self.check_if_same(log_acc_tmp, loglik_diff, logprior_diff)
                    log_acc = log_acc_tmp + self.logprior[node_id] - self.logprior_new[node_id]
                    log_r = np.log(np.random.rand(1))
                    if log_r <= log_acc:
                        self.node_info[node_id] = copy(self.node_info_new[node_id])
                        self.update_subtree(node_id, nodes_subtree, settings)
                        change = True
                    else:
                        change = False
        elif step_id == 3:
            if not self.inner_pc_pairs:
                change = False 
            else:
                node_id, child_id = random.choice(self.inner_pc_pairs)
                nodes_subtree = self.get_nodes_subtree(node_id)
                nodes_not_in_subtree = self.get_nodes_not_in_subtree(node_id)
                if settings.debug == 1:
                    set1 = set(list(nodes_subtree) + list(nodes_not_in_subtree))
                    set2 = set(self.leaf_nodes + self.non_leaf_nodes)
                    assert(sorted(set1) == sorted(set2))
                self.create_new_statistics(nodes_subtree, nodes_not_in_subtree, node_id, settings)
                self.node_info_new[node_id] = copy(self.node_info[child_id])
                self.node_info_new[child_id] = copy(self.node_info[node_id])
                if settings.verbose >= 1:
                    print 'swap: node_id = %s, child_id = %s' % (node_id, child_id)
                    print 'node_info[node_id] = %s, node_info[child_id] = %s' \
                            % (self.node_info[node_id], self.node_info[child_id])
                self.evaluate_new_subtree(data, node_id, param, nodes_subtree, cache, settings, databc, cachebc)
                log_acc, loglik_diff, logprior_diff = self.compute_log_acc_cs(nodes_subtree, node_id)
                if settings.debug == 1:
                    self.check_if_same(log_acc, loglik_diff, logprior_diff)
                log_r = np.log(np.random.rand(1))
                if log_r <= log_acc:
                    self.node_info[node_id] = copy(self.node_info_new[node_id])
                    self.node_info[child_id] = copy(self.node_info_new[child_id])
                    self.update_subtree(node_id, nodes_subtree, settings)
                    change = True
                else:
                    change = False
        if settings.verbose >= 1:
            print 'trying move: step_id = %d, move = %s, log_acc = %s, log_r = %s' \
                    % (step_id, STEP_NAMES[step_id], log_acc, log_r)
        if change:
            self.depth = max([get_depth(node_id) for node_id in \
                    self.leaf_nodes])
            self.loglik_current = sum([self.loglik[node_id] for node_id in \
                    self.leaf_nodes])
            if settings.verbose >= 1:
                print 'accepted move: step_id = %d, move = %s' % (step_id, STEP_NAMES[step_id])
                self.print_stuff()
        if settings.debug == 1:
            both_children_terminal, inner_pc_pairs = self.recompute_mcmc_data_structures()
            print '\nstats from recompute_mcmc_data_structures'
            print 'both_children_terminal = %s' % both_children_terminal
            print 'inner_pc_pairs = %s' % inner_pc_pairs
            assert(sorted(both_children_terminal) == sorted(self.both_children_terminal))
            assert(sorted(inner_pc_pairs) == sorted(self.inner_pc_pairs))
            grow_nodes_new = [n_id for n_id in self.leaf_nodes \
                    if not stop_split(self.train_ids[n_id], settings, data, cache)]
            if change and (step_id == 1):
                print 'grow_nodes_new = %s, grow_nodes_tmp = %s' % (sorted(grow_nodes_new), sorted(grow_nodes_tmp))
                assert(sorted(grow_nodes_new) == sorted(grow_nodes_tmp))
        return (change, step_id)