コード例 #1
0
def sample_tree(data, settings, param, cache, cache_tmp):
    p = TreeMCMC(range(data['n_train']), param, settings, cache_tmp)
    grow_nodes = [0]
    while grow_nodes:
        node_id = grow_nodes.pop(0)
        p.depth = max(p.depth, get_depth(node_id))
        log_psplit = np.log(p.compute_psplit(node_id, param))
        train_ids = p.train_ids[node_id]
        (do_not_split_node_id, feat_id_chosen, split_chosen, idx_split_global, log_sis_ratio, logprior_nodeid, \
            train_ids_left, train_ids_right, cache_tmp, loglik_left, loglik_right) \
            = p.precomputed_proposal(data, param, settings, cache, node_id, train_ids, log_psplit)
        if do_not_split_node_id:
            p.do_not_split[node_id] = True
        else:
            p.update_left_right_statistics(cache_tmp, node_id, logprior_nodeid, train_ids_left,\
                train_ids_right, loglik_left, loglik_right, feat_id_chosen, split_chosen, \
                idx_split_global, settings, param, data, cache)
            left, right = get_children_id(node_id)
            grow_nodes.append(left)
            grow_nodes.append(right)
            # create mcmc structures
            p.both_children_terminal.append(node_id)
            parent = get_parent_id(node_id) 
            if (node_id != 0) and (parent in p.non_leaf_nodes):
                p.inner_pc_pairs.append((parent, node_id))
            if node_id != 0:
                try:
                    p.both_children_terminal.remove(parent)
                except ValueError:
                    pass
    if settings.debug:
        print 'sampled new tree:'
        p.print_tree()
    return p
コード例 #2
0
ファイル: treemcmc.py プロジェクト: Sandy4321/pgbart
def sample_tree(data, settings, param, cache, cache_tmp):
    p = TreeMCMC(range(data['n_train']), param, settings, cache_tmp)
    grow_nodes = [0]
    while grow_nodes:
        node_id = grow_nodes.pop(0)
        p.depth = max(p.depth, get_depth(node_id))
        log_psplit = np.log(p.compute_psplit(node_id, param))
        train_ids = p.train_ids[node_id]
        (do_not_split_node_id, feat_id_chosen, split_chosen, idx_split_global, log_sis_ratio, logprior_nodeid, \
            train_ids_left, train_ids_right, cache_tmp, loglik_left, loglik_right) \
            = p.precomputed_proposal(data, param, settings, cache, node_id, train_ids, log_psplit)
        if do_not_split_node_id:
            p.do_not_split[node_id] = True
        else:
            p.update_left_right_statistics(cache_tmp, node_id, logprior_nodeid, train_ids_left,\
                train_ids_right, loglik_left, loglik_right, feat_id_chosen, split_chosen, \
                idx_split_global, settings, param, data, cache)
            left, right = get_children_id(node_id)
            grow_nodes.append(left)
            grow_nodes.append(right)
            # create mcmc structures
            p.both_children_terminal.append(node_id)
            parent = get_parent_id(node_id) 
            if (node_id != 0) and (parent in p.non_leaf_nodes):
                p.inner_pc_pairs.append((parent, node_id))
            if node_id != 0:
                try:
                    p.both_children_terminal.remove(parent)
                except ValueError:
                    pass
    if settings.debug:
        print 'sampled new tree:'
        p.print_tree()
    return p
コード例 #3
0
    def sample(self, data, settings, param, cache,databc,cachebc):
        if settings.mcmc_type == 'growprune':
            step_id = random.randint(0, 1)  # only grow and prune moves permitted
        elif settings.mcmc_type == 'cgm':
            step_id = random_pick([0,1,2,3], [0.25, 0.25, 0.4, 0.1])
        else:
            raise Exception('invalid mcmc_type')
        log_acc = -np.inf
        log_r = 0
        self.grow_nodes = [n_id for n_id in self.leaf_nodes \
                    if not stop_split(self.train_ids[n_id], settings, data, cache)]
        grow_nodes = self.grow_nodes

        if step_id == 0:
            if not grow_nodes:
                change = False
            else:
                node_id = random.choice(grow_nodes)
                if settings.verbose >= 1:
                    print 'grow_nodes = %s, chosen node_id = %s' % (grow_nodes, node_id)

                do_not_split_node_id, feat_id, split, idx_split_global, logprior_nodeid = \
                        self.sample_split_prior(data, param, settings, cache, node_id,databc,cachebc)

                if do_not_split_node_id==True:
                    change=False

                if do_not_split_node_id==False:

                    if settings.verbose >= 1:
                        print 'grow: do_not_split = %s, feat_id = %s, split = %s' \
                                % (do_not_split_node_id, feat_id, split)
                    train_ids = self.train_ids[node_id]

                    t0 = time.time()
                    if settings.parallelize_compute_statistics==0:
                        (train_ids_left, train_ids_right, cache_tmp, loglik_left, loglik_right) = \
                                compute_left_right_statistics(data, param, cache, train_ids, feat_id, split, settings)
                    else:
                        (cache_tmp, loglik_left, loglik_right) = \
                                compute_left_right_statistics_parallel(data, param, cache, train_ids, feat_id, split, settings, databc)

                        cond = data['x_train'][train_ids, feat_id] <= split
                        train_ids_left = train_ids[cond]
                        train_ids_right = train_ids[~cond]

                    if settings.timer==1:
                        settings.parallelize_compute_statistics_time += time.time()-t0

                    loglik = loglik_left + loglik_right

                    len_both_children_terminal_new = len(self.both_children_terminal)
                    if get_sibling_id(node_id) not in self.leaf_nodes:
                        len_both_children_terminal_new += 1


                    log_acc = self.compute_log_acc_g(node_id, param, len_both_children_terminal_new, \
                                loglik, train_ids_left, train_ids_right, cache, settings, data, grow_nodes)

                    log_r = np.log(np.random.rand(1))
                    if log_r <= log_acc:

                        self.update_left_right_statistics(cache_tmp, node_id, logprior_nodeid, \
                                train_ids_left, train_ids_right, loglik_left, loglik_right, \
                                feat_id, split, idx_split_global, settings, param, data, cache)

                        # MCMC specific data structure updates
                        self.both_children_terminal.append(node_id)
                        parent = get_parent_id(node_id)
                        if (node_id != 0) and (parent in self.non_leaf_nodes):
                            self.inner_pc_pairs.append((parent, node_id))
                        sibling = get_sibling_id(node_id)
                        if sibling in self.leaf_nodes:
                            self.both_children_terminal.remove(parent)
                        change = True
                    else:
                        change = False
        elif step_id == 1:
            if not self.both_children_terminal:
                change = False
            else:
                node_id = random.choice(self.both_children_terminal)
                feat_id = self.node_info[node_id][0]
                if settings.verbose >= 1:
                    print 'prune: node_id = %s, feat_id = %s' % (node_id, feat_id)
                left, right = get_children_id(node_id)
                loglik = self.loglik[left] + self.loglik[right]
                len_both_children_new = len(self.both_children_terminal)
                grow_nodes_tmp = grow_nodes[:]
                grow_nodes_tmp.append(node_id)
                try:
                    grow_nodes_tmp.remove(left)
                except ValueError:
                    pass
                try:
                    grow_nodes_tmp.remove(right)
                except ValueError:
                    pass
                log_acc = - self.compute_log_inv_acc_p(node_id, param, len_both_children_new, \
                                loglik, grow_nodes_tmp, cache, settings, data)
                log_r = np.log(np.random.rand(1))
                if log_r <= log_acc:
                    self.remove_leaf_node_statistics(left, settings)
                    self.remove_leaf_node_statistics(right, settings)
                    self.leaf_nodes.append(node_id)
                    self.non_leaf_nodes.remove(node_id)
                    self.logprior[node_id] = np.log(self.compute_pnosplit(node_id, param))
                    # OK to set logprior as above since we know that a valid split exists
                    # MCMC specific data structure updates
                    self.both_children_terminal.remove(node_id)
                    parent = get_parent_id(node_id) 
                    if (node_id != 0) and (parent in self.non_leaf_nodes):
                        self.inner_pc_pairs.remove((parent, node_id))
                    if node_id != 0:
                        sibling = get_sibling_id(node_id) 
                        if sibling in self.leaf_nodes:
                            if settings.debug == 1:
                                assert(parent not in self.both_children_terminal)
                            self.both_children_terminal.append(parent)
                    change = True
                else:
                    change = False
        elif step_id == 2:
            if not self.non_leaf_nodes:
                change = False
            else:
                node_id = random.choice(self.non_leaf_nodes)

                do_not_split_node_id, feat_id, split, idx_split_global, logprior_nodeid = \
                        self.sample_split_prior(data, param, settings, cache, node_id,databc,cachebc)
                if settings.verbose >= 1:
                    print 'change: node_id = %s, do_not_split = %s, feat_id = %s, split = %s' \
                            % (node_id, do_not_split_node_id, feat_id, split)
                # Note: this just samples a split criterion, not guaranteed to "change" 
                #assert(not do_not_split_node_id)

                if do_not_split_node_id==True:
                    change = False

                if do_not_split_node_id==False:

                    nodes_subtree = self.get_nodes_subtree(node_id)
                    nodes_not_in_subtree = self.get_nodes_not_in_subtree(node_id)
                    if settings.debug == 1:
                        set1 = set(list(nodes_subtree) + list(nodes_not_in_subtree))
                        set2 = set(self.leaf_nodes + self.non_leaf_nodes)
                        assert(sorted(set1) == sorted(set2))
                    self.create_new_statistics(nodes_subtree, nodes_not_in_subtree, node_id, settings)
                    self.node_info_new[node_id] = (feat_id, split, idx_split_global)
                    self.evaluate_new_subtree(data, node_id, param, nodes_subtree, cache, settings, databc, cachebc)
                    # log_acc will be be modified below
                    log_acc_tmp, loglik_diff, logprior_diff = self.compute_log_acc_cs(nodes_subtree, node_id)
                    if settings.debug == 1:
                        self.check_if_same(log_acc_tmp, loglik_diff, logprior_diff)
                    log_acc = log_acc_tmp + self.logprior[node_id] - self.logprior_new[node_id]
                    log_r = np.log(np.random.rand(1))
                    if log_r <= log_acc:
                        self.node_info[node_id] = copy(self.node_info_new[node_id])
                        self.update_subtree(node_id, nodes_subtree, settings)
                        change = True
                    else:
                        change = False
        elif step_id == 3:
            if not self.inner_pc_pairs:
                change = False 
            else:
                node_id, child_id = random.choice(self.inner_pc_pairs)
                nodes_subtree = self.get_nodes_subtree(node_id)
                nodes_not_in_subtree = self.get_nodes_not_in_subtree(node_id)
                if settings.debug == 1:
                    set1 = set(list(nodes_subtree) + list(nodes_not_in_subtree))
                    set2 = set(self.leaf_nodes + self.non_leaf_nodes)
                    assert(sorted(set1) == sorted(set2))
                self.create_new_statistics(nodes_subtree, nodes_not_in_subtree, node_id, settings)
                self.node_info_new[node_id] = copy(self.node_info[child_id])
                self.node_info_new[child_id] = copy(self.node_info[node_id])
                if settings.verbose >= 1:
                    print 'swap: node_id = %s, child_id = %s' % (node_id, child_id)
                    print 'node_info[node_id] = %s, node_info[child_id] = %s' \
                            % (self.node_info[node_id], self.node_info[child_id])
                self.evaluate_new_subtree(data, node_id, param, nodes_subtree, cache, settings, databc, cachebc)
                log_acc, loglik_diff, logprior_diff = self.compute_log_acc_cs(nodes_subtree, node_id)
                if settings.debug == 1:
                    self.check_if_same(log_acc, loglik_diff, logprior_diff)
                log_r = np.log(np.random.rand(1))
                if log_r <= log_acc:
                    self.node_info[node_id] = copy(self.node_info_new[node_id])
                    self.node_info[child_id] = copy(self.node_info_new[child_id])
                    self.update_subtree(node_id, nodes_subtree, settings)
                    change = True
                else:
                    change = False
        if settings.verbose >= 1:
            print 'trying move: step_id = %d, move = %s, log_acc = %s, log_r = %s' \
                    % (step_id, STEP_NAMES[step_id], log_acc, log_r)
        if change:
            self.depth = max([get_depth(node_id) for node_id in \
                    self.leaf_nodes])
            self.loglik_current = sum([self.loglik[node_id] for node_id in \
                    self.leaf_nodes])
            if settings.verbose >= 1:
                print 'accepted move: step_id = %d, move = %s' % (step_id, STEP_NAMES[step_id])
                self.print_stuff()
        if settings.debug == 1:
            both_children_terminal, inner_pc_pairs = self.recompute_mcmc_data_structures()
            print '\nstats from recompute_mcmc_data_structures'
            print 'both_children_terminal = %s' % both_children_terminal
            print 'inner_pc_pairs = %s' % inner_pc_pairs
            assert(sorted(both_children_terminal) == sorted(self.both_children_terminal))
            assert(sorted(inner_pc_pairs) == sorted(self.inner_pc_pairs))
            grow_nodes_new = [n_id for n_id in self.leaf_nodes \
                    if not stop_split(self.train_ids[n_id], settings, data, cache)]
            if change and (step_id == 1):
                print 'grow_nodes_new = %s, grow_nodes_tmp = %s' % (sorted(grow_nodes_new), sorted(grow_nodes_tmp))
                assert(sorted(grow_nodes_new) == sorted(grow_nodes_tmp))
        return (change, step_id)
コード例 #4
0
ファイル: treemcmc.py プロジェクト: Sandy4321/pgbart
 def sample(self, data, settings, param, cache):
     if settings.mcmc_type == 'growprune':
         step_id = random.randint(0, 1)  # only grow and prune moves permitted
     elif settings.mcmc_type == 'cgm':
         step_id = random.randint(0, 3)  # all 4 moves equally likely (or think of 50% grow/prune, 25% change, 25% swap)
     else:
         raise Exception('invalid mcmc_type')
     log_acc = -np.inf
     log_r = 0
     self.grow_nodes = [n_id for n_id in self.leaf_nodes \
                 if not stop_split(self.train_ids[n_id], settings, data, cache)]
     grow_nodes = self.grow_nodes
     if step_id == 0:        # GROW step
         if not grow_nodes:
             change = False
         else:
             node_id = random.choice(grow_nodes)
             if settings.verbose >= 1:
                 print 'grow_nodes = %s, chosen node_id = %s' % (grow_nodes, node_id)
             do_not_split_node_id, feat_id, split, idx_split_global, logprior_nodeid = \
                     self.sample_split_prior(data, param, settings, cache, node_id)
             assert not do_not_split_node_id
             if settings.verbose >= 1:
                 print 'grow: do_not_split = %s, feat_id = %s, split = %s' \
                         % (do_not_split_node_id, feat_id, split)
             train_ids = self.train_ids[node_id]
             (train_ids_left, train_ids_right, cache_tmp, loglik_left, loglik_right) = \
                 compute_left_right_statistics(data, param, cache, train_ids, \
                     feat_id, split, settings)
             loglik = loglik_left + loglik_right
             len_both_children_terminal_new = len(self.both_children_terminal)
             if get_sibling_id(node_id) not in self.leaf_nodes:
                 len_both_children_terminal_new += 1
             log_acc = self.compute_log_acc_g(node_id, param, len_both_children_terminal_new, \
                         loglik, train_ids_left, train_ids_right, cache, settings, data, grow_nodes)
             log_r = np.log(np.random.rand(1))
             if log_r <= log_acc:
                 self.update_left_right_statistics(cache_tmp, node_id, logprior_nodeid, \
                         train_ids_left, train_ids_right, loglik_left, loglik_right, \
                         feat_id, split, idx_split_global, settings, param, data, cache)
                 # MCMC specific data structure updates
                 self.both_children_terminal.append(node_id)
                 parent = get_parent_id(node_id) 
                 if (node_id != 0) and (parent in self.non_leaf_nodes):
                     self.inner_pc_pairs.append((parent, node_id))
                 sibling = get_sibling_id(node_id)
                 if sibling in self.leaf_nodes:
                     self.both_children_terminal.remove(parent)
                 change = True
             else:
                 change = False
     elif step_id == 1:      # PRUNE step
         if not self.both_children_terminal:
             change = False      # nothing to prune here
         else:
             node_id = random.choice(self.both_children_terminal)
             feat_id = self.node_info[node_id][0]
             if settings.verbose >= 1:
                 print 'prune: node_id = %s, feat_id = %s' % (node_id, feat_id)
             left, right = get_children_id(node_id)
             loglik = self.loglik[left] + self.loglik[right]
             len_both_children_new = len(self.both_children_terminal)
             grow_nodes_tmp = grow_nodes[:]
             grow_nodes_tmp.append(node_id)
             try:
                 grow_nodes_tmp.remove(left)
             except ValueError:
                 pass
             try:
                 grow_nodes_tmp.remove(right)
             except ValueError:
                 pass
             log_acc = - self.compute_log_inv_acc_p(node_id, param, len_both_children_new, \
                             loglik, grow_nodes_tmp, cache, settings, data)
             log_r = np.log(np.random.rand(1))
             if log_r <= log_acc:
                 self.remove_leaf_node_statistics(left, settings)
                 self.remove_leaf_node_statistics(right, settings)
                 self.leaf_nodes.append(node_id)
                 self.non_leaf_nodes.remove(node_id)
                 self.logprior[node_id] = np.log(self.compute_pnosplit(node_id, param))
                 # OK to set logprior as above since we know that a valid split exists
                 # MCMC specific data structure updates
                 self.both_children_terminal.remove(node_id)
                 parent = get_parent_id(node_id) 
                 if (node_id != 0) and (parent in self.non_leaf_nodes):
                     self.inner_pc_pairs.remove((parent, node_id))
                 if node_id != 0:
                     sibling = get_sibling_id(node_id) 
                     if sibling in self.leaf_nodes:
                         if settings.debug == 1:
                             assert(parent not in self.both_children_terminal)
                         self.both_children_terminal.append(parent)
                 change = True
             else:
                 change = False
     elif step_id == 2:      # CHANGE
         if not self.non_leaf_nodes:
             change = False
         else:
             node_id = random.choice(self.non_leaf_nodes)
             do_not_split_node_id, feat_id, split, idx_split_global, logprior_nodeid = \
                     self.sample_split_prior(data, param, settings, cache, node_id)
             if settings.verbose >= 1:
                 print 'change: node_id = %s, do_not_split = %s, feat_id = %s, split = %s' \
                         % (node_id, do_not_split_node_id, feat_id, split)
             # Note: this just samples a split criterion, not guaranteed to "change" 
             assert(not do_not_split_node_id)
             nodes_subtree = self.get_nodes_subtree(node_id)
             nodes_not_in_subtree = self.get_nodes_not_in_subtree(node_id)
             if settings.debug == 1:
                 set1 = set(list(nodes_subtree) + list(nodes_not_in_subtree))
                 set2 = set(self.leaf_nodes + self.non_leaf_nodes)
                 assert(sorted(set1) == sorted(set2))
             self.create_new_statistics(nodes_subtree, nodes_not_in_subtree, node_id, settings)
             self.node_info_new[node_id] = (feat_id, split, idx_split_global)         
             self.evaluate_new_subtree(data, node_id, param, nodes_subtree, cache, settings)
             # log_acc will be be modified below
             log_acc_tmp, loglik_diff, logprior_diff = self.compute_log_acc_cs(nodes_subtree, node_id)
             if settings.debug == 1:
                 self.check_if_same(log_acc_tmp, loglik_diff, logprior_diff)
             log_acc = log_acc_tmp + self.logprior[node_id] - self.logprior_new[node_id]
             log_r = np.log(np.random.rand(1))
             if log_r <= log_acc:
                 self.node_info[node_id] = copy(self.node_info_new[node_id])
                 self.update_subtree(node_id, nodes_subtree, settings)
                 change = True
             else:
                 change = False
     elif step_id == 3:      # SWAP
         if not self.inner_pc_pairs:
             change = False 
         else:
             node_id, child_id = random.choice(self.inner_pc_pairs)
             nodes_subtree = self.get_nodes_subtree(node_id)
             nodes_not_in_subtree = self.get_nodes_not_in_subtree(node_id)
             if settings.debug == 1:
                 set1 = set(list(nodes_subtree) + list(nodes_not_in_subtree))
                 set2 = set(self.leaf_nodes + self.non_leaf_nodes)
                 assert(sorted(set1) == sorted(set2))
             self.create_new_statistics(nodes_subtree, nodes_not_in_subtree, node_id, settings)
             self.node_info_new[node_id] = copy(self.node_info[child_id])
             self.node_info_new[child_id] = copy(self.node_info[node_id])
             if settings.verbose >= 1:
                 print 'swap: node_id = %s, child_id = %s' % (node_id, child_id)
                 print 'node_info[node_id] = %s, node_info[child_id] = %s' \
                         % (self.node_info[node_id], self.node_info[child_id])
             self.evaluate_new_subtree(data, node_id, param, nodes_subtree, cache, settings)
             log_acc, loglik_diff, logprior_diff = self.compute_log_acc_cs(nodes_subtree, node_id)
             if settings.debug == 1:
                 self.check_if_same(log_acc, loglik_diff, logprior_diff)
             log_r = np.log(np.random.rand(1))
             if log_r <= log_acc:
                 self.node_info[node_id] = copy(self.node_info_new[node_id])
                 self.node_info[child_id] = copy(self.node_info_new[child_id])
                 self.update_subtree(node_id, nodes_subtree, settings)
                 change = True
             else:
                 change = False
     if settings.verbose >= 1:
         print 'trying move: step_id = %d, move = %s, log_acc = %s, log_r = %s' \
                 % (step_id, STEP_NAMES[step_id], log_acc, log_r)
     if change:
         self.depth = max([get_depth(node_id) for node_id in \
                 self.leaf_nodes])
         self.loglik_current = sum([self.loglik[node_id] for node_id in \
                 self.leaf_nodes])
         if settings.verbose >= 1:
             print 'accepted move: step_id = %d, move = %s' % (step_id, STEP_NAMES[step_id])
             self.print_stuff()
     if settings.debug == 1:
         both_children_terminal, inner_pc_pairs = self.recompute_mcmc_data_structures()
         print '\nstats from recompute_mcmc_data_structures'
         print 'both_children_terminal = %s' % both_children_terminal
         print 'inner_pc_pairs = %s' % inner_pc_pairs
         assert(sorted(both_children_terminal) == sorted(self.both_children_terminal))
         assert(sorted(inner_pc_pairs) == sorted(self.inner_pc_pairs))
         grow_nodes_new = [n_id for n_id in self.leaf_nodes \
                 if not stop_split(self.train_ids[n_id], settings, data, cache)]
         if change and (step_id == 1):
             print 'grow_nodes_new = %s, grow_nodes_tmp = %s' % (sorted(grow_nodes_new), sorted(grow_nodes_tmp))
             assert(sorted(grow_nodes_new) == sorted(grow_nodes_tmp))
     return (change, step_id)