Exemplo n.º 1
0
def add_nontriv_splits_attr(tm, all_taxa_bitmask):
    all_spl = tm.tree.split_edges.keys()
    non_triv = [
        i for i in all_spl if not is_trivial_split(i, all_taxa_bitmask)
    ]
    non_triv.sort()
    tm.splits = tuple(non_triv)
Exemplo n.º 2
0
def add_nontriv_splits_attr(tree, all_taxa_bitmask):
    all_spl = tree.split_edges.keys()
    non_triv = []
    for i in all_spl:
        if not is_trivial_split(i, all_taxa_bitmask):
            if i & 1:
                non_triv.append(i)
            else:
                non_triv.append((~i) & all_taxa_bitmask)
    non_triv.sort()
    tree.splits = tuple(non_triv)
    tree.split_set = set(non_triv)
def add_nontriv_splits_attr(tree, all_taxa_bitmask):
    all_spl = tree.split_edges.keys()
    non_triv = []
    for i in all_spl:
        if not is_trivial_split(i, all_taxa_bitmask):
            if i & 1:
                non_triv.append(i)
            else:
                non_triv.append((~i)&all_taxa_bitmask)
    non_triv.sort()
    tree.splits = tuple(non_triv)
    tree.split_set = set(non_triv)
Exemplo n.º 4
0
    def testSplits(self):
        unrooted = True
        for tc in test_cases:
            tree_filepaths = [dendropy.tests.data_source_path(tc[0])]
            taxa_filepath = dendropy.tests.data_source_path(tc[1])
            paup_sd = paup.get_split_distribution(tree_filepaths,
                                                  taxa_filepath,
                                                  unrooted=unrooted,
                                                  burnin=0)
            taxa_block = paup_sd.taxa_block
            dp_sd = splits.SplitDistribution(taxa_block=taxa_block)
            dp_sd.ignore_edge_lengths = True
            dp_sd.ignore_node_ages = True
            dp_sd.unrooted = unrooted

            taxa_mask = taxa_block.all_taxa_bitmask()
            taxa_block.lock()
            for tree_filepath in tree_filepaths:
                for tree in nexus.iterate_over_trees(open(tree_filepath, "rU"),
                                                     taxa_block=taxa_block):
                    #_LOG.debug("tree = %s" % str(tree))
                    splits.encode_splits(tree)
                    dp_sd.count_splits_on_tree(tree)

            self.assertEqual(dp_sd.total_trees_counted,
                             paup_sd.total_trees_counted)

            # SplitsDistribution counts trivial splits, whereas PAUP*
            # contree does not, so the following will not work
            #             assert len(dp_sd.splits) == len(paup_sd.splits),\
            #                 "dp = %d, sd = %d" % (len(dp_sd.splits), len(paup_sd.splits))

            taxa_mask = taxa_block.all_taxa_bitmask()
            for split in dp_sd.splits:
                if not splits.is_trivial_split(split, taxa_mask):
                    self.assertTrue(split in paup_sd.splits)
                    self.assertEqual(dp_sd.split_counts[split],
                                     paup_sd.split_counts[split])
                    paup_sd.splits.remove(split)

            # if any splits remain here, they were not
            # in dp_sd
            assert len(paup_sd.splits) == 0
def main_cli():

    description =  '%s %s ' % (_program_name, _program_version)
    usage = "%prog [options] <TREES FILE> [<TREES FILE> [<TREES FILE> [...]]"

    parser = OptionParser(usage=usage, add_help_option=True, version = _program_version, description=description)
    parser.add_option('-r','--reference',
                  dest='reference_tree_filepath',
                  default=None,
                  help="path to file containing the reference (true) tree")
    parser.add_option('-v', '--verbose',
                      action='store_false',
                      dest='quiet',
                      default=True,
                      help="Verbose mode")

    (opts, args) = parser.parse_args()

    ###################################################
    # Support file idiot checking

    sampled_filepaths = []
    missing = False
    for fpath in args:
        fpath = os.path.expanduser(os.path.expandvars(fpath))
        if not os.path.exists(fpath):
            sys.exit('Sampled trees file not found: "%s"' % fpath)
        sampled_filepaths.append(fpath)
    if not sampled_filepaths:
        sys.exit("Expecting arguments indicating files that contain sampled trees")

    sampled_file_objs = [open(f, "rU") for f in sampled_filepaths]

    ###################################################
    # Lots of other idiot-checking ...

    # target tree
    if opts.reference_tree_filepath is None:
        sys.exit("A reference tree must be specified (use -h to see all options)")
    reference_tree_filepath = os.path.expanduser(os.path.expandvars(opts.reference_tree_filepath))
    if not os.path.exists(reference_tree_filepath):
        sys.exit('Reference tree file not found: "%s"\n' % reference_tree_filepath)

    d = Dataset()
    ref_trees  = d.read_trees(open(reference_tree_filepath, 'ru'), schema="NEXUS")

    if len(ref_trees) != 1:
        sys.exit("Expecting one reference tree")
    ref_tree = ref_trees[0]
    splits.encode_splits(ref_tree)
    assert(len(d.taxa_blocks) == 1)
    taxa = d.taxa_blocks[0]


    ###################################################
    # Main work begins here: Count the splits

    start_time = datetime.datetime.now()

    comments = []
    tsum = treesum.TreeSummarizer()
    tsum.burnin = 0
    if opts.quiet:
        tsum.verbose = False
        tsum.write_message = None
    else:
        tsum.verbose = True
        tsum.write_message = sys.stderr.write




    _LOG.debug("### COUNTING SPLITS ###\n")
    split_distribution = splits.SplitDistribution(taxa_block=taxa)
    tree_source = MultiFileTreeIterator(filepaths=sampled_filepaths, core_iterator=nexus.iterate_over_trees)
    tsum.count_splits_on_trees(tree_source, split_distribution)

    report = []
    report.append("%d trees read from %d files." % (tsum.total_trees_read, len(sampled_filepaths)))
    report.append("%d trees ignored in total." % (tree_source.total_trees_ignored))
    report.append("%d trees considered in total for split support assessment." % (tsum.total_trees_counted))
    report.append("%d unique taxa across all trees." % len(split_distribution.taxa_block))
    num_splits, num_unique_splits, num_nt_splits, num_nt_unique_splits = split_distribution.splits_considered()
    report.append("%d unique splits out of %d total splits counted." % (num_unique_splits, num_splits))
    report.append("%d unique non-trivial splits out of %d total non-trivial splits counted." % (num_nt_unique_splits, num_nt_splits))

    _LOG.debug("\n".join(report))


    con_tree = treegen.star_tree(taxa)
    taxa_mask = taxa.all_taxa_bitmask()
    splits.encode_splits(con_tree)
    leaves = con_tree.leaf_nodes()

    to_leaf_dict = {}
    for leaf in leaves:
        to_leaf_dict[leaf.edge.clade_mask] = leaf
    unrooted = True
    n_read = float(tsum.total_trees_read)
    sp_list = []
    for split, count in split_distribution.split_counts.iteritems():
        freq = count/n_read
        if not splits.is_trivial_split(split, taxa_mask):
            m = split & taxa_mask
            if (m != taxa_mask) and ((m-1) & m): # if not root (i.e., all "1's") and not singleton (i.e., one "1")
                if unrooted:
                    c = (~m) & taxa_mask
                    if (c-1) & c: # not singleton (i.e., one "0")
                        if 1 & m:
                            k = c
                        else:
                            k = m
                        sp_list.append((freq, k, m))
                else:
                    sp_list.append((freq, m, m))
    sp_list.sort(reverse=True)

    root = con_tree.seed_node
    root_edge = root.edge

    curr_freq = 1.1
    curr_all_splits_list = []
    curr_compat_splits_list = []
    all_splits_by_freq = []
    compat_splits_by_freq = []

    # Now when we add splits in order, we will do a greedy, extended majority-rule consensus tree
    for freq, split_to_add, split_in_dict in sp_list:
        if abs(curr_freq-freq) > 0.000001:
            # dropping down to the next lowest freq
            curr_l = [freq, []]
            curr_all_splits_list = curr_l[1]
            all_splits_by_freq.append(curr_l)
            curr_l = [freq, []]
            curr_compat_splits_list = curr_l[1]
            compat_splits_by_freq.append(curr_l)
            curr_freq = freq

        curr_all_splits_list.append(split_to_add)

        if (split_to_add & root_edge.clade_mask) != split_to_add:
            continue
        lb = splits.lowest_bit_only(split_to_add)
        one_leaf = to_leaf_dict[lb]
        parent_node = one_leaf
        while (split_to_add & parent_node.edge.clade_mask) != split_to_add:
            parent_node = parent_node.parent_node
        if parent_node is None or parent_node.edge.clade_mask == split_to_add:
            continue # split is not in tree, or already in tree.

        new_node = trees.Node()
        new_node_children = []
        new_edge = new_node.edge
        new_edge.clade_mask = 0
        for child in parent_node.child_nodes():
            # might need to modify the following if rooted splits
            # are used
            cecm = child.edge.clade_mask
            if (cecm & split_to_add ):
                assert cecm != split_to_add
                new_edge.clade_mask |= cecm
                new_node_children.append(child)
        # Check to see if we have accumulated all of the bits that we
        #   needed, but none that we don't need.
        if new_edge.clade_mask == split_to_add:
            for child in new_node_children:
                parent_node.remove_child(child)
                new_node.add_child(child)
            parent_node.add_child(new_node)
            con_tree.split_edges[split_to_add] = new_edge
            curr_compat_splits_list.append(split_to_add)
    ref_set = set()
    for s in ref_tree.split_edges.iterkeys():
        m = s & taxa_mask
        if 1 & m:
            k = (~m) & taxa_mask
        else:
            k = m
        if not splits.is_trivial_split(k, taxa_mask):
            ref_set.add(k)

    all_set = set()
    compat_set = set()

    _LOG.debug("%d edges is the reference tree" % (len(ref_set)))

    print "freq\tcompatFP\tcompatFN\tcompatSD\tallFP\tallFN\tallSD"
    for all_el, compat_el in itertools.izip(all_splits_by_freq, compat_splits_by_freq):
        freq = all_el[0]
        all_sp = all_el[1]
        all_set.update(all_sp)
        all_fn = len(ref_set - all_set)
        all_fp = len(all_set - ref_set)
        compat_sp = compat_el[1]
        compat_set.update(compat_sp)
        compat_fn = len(ref_set - compat_set)
        compat_fp = len(compat_set - ref_set)

        print "%f\t%d\t%d\t%d\t%d\t%d\t%d" % (freq, compat_fp, compat_fn, compat_fp + compat_fn, all_fp, all_fn, all_fp + all_fn )
Exemplo n.º 6
0
def add_nontriv_splits_attr(tm, all_taxa_bitmask):
    all_spl = tm.tree.split_edges.keys()
    non_triv = [i for i in all_spl if not is_trivial_split(i, all_taxa_bitmask)]
    non_triv.sort()
    tm.splits = tuple(non_triv)
Exemplo n.º 7
0
    def select_trees_for_next_round(self, culled, curr_results):
        all_taxa_bitmask = curr_results[0].tree.seed_node.edge.clade_mask
        assert all_taxa_bitmask == ((1 << self.curr_n_taxa) - 1)
        ########################################
        # First, we make sure that there are not duplicate topologies
        # Because we reverse sort, we'll be retaining the tree with the
        #   best score
        #####
        curr_results.sort(reverse=True)
        set_of_split_sets = set()
        unique_topos = []
        for tm in curr_results:
            add_nontriv_splits_attr(tm, all_taxa_bitmask)
            if tm.splits not in set_of_split_sets:
                unique_topos.append(tm)
                set_of_split_sets.add(tm.splits)
        curr_results = unique_topos
        set_of_split_sets.clear()
        _LOG.info('There were %d unique result topologiesfor ntax = %d ' % (len(curr_results), self.curr_n_taxa))

        ########################################
        # the trees can be hefty, so lets eliminate unneeded references
        #####
        del unique_topos


        ########################################
        # Make sure to keep the next_round_trees list (and a set that helps us
        #   identify unique trees.
        #####
        ml_est = curr_results[0]
        curr_results.pop(0)
        next_round_trees = [ml_est]
        
        ########################################
        # Now we identify best trees that LACK the splits in the ML tree
        #####
        ml_split_dict = ml_est.tree.split_edges
        unanimous_splits = []
        best_disagreeing_index_set = set()
        for split in ml_split_dict.iterkeys():
            if not is_trivial_split(split, all_taxa_bitmask):
                found = False
                for n, tm in enumerate(curr_results):
                    if split not in tm.tree.split_edges:
                        best_disagreeing_index_set.add(n)
                        found = True
                        break
                if not found:
                    unanimous_splits.append(split)

        ########################################
        # now we add the trees that "must" be included because they do NOT have
        #   a split that is in the current ML tree.
        # We do this with a reverse sorted list so that we can pop them off of
        #   the curr_results list without invalidating the list of indices to move
        #####
        bdis_list = list(best_disagreeing_index_set)
        bdis_list.sort(reverse=True)
        for tree_ind in bdis_list:
            tm = curr_results.pop(tree_ind)
            next_round_trees.append(tm)
        
        ########################################
        # Now we have to augment our list of trees such that we have exemplar trees
        #   that conflict with every split in the ML tree
        # We'll do this by starting from a version of the ML tree that has been
        #   collapsed so that it does not conflict with the split
        #####
        _LOG.info('There were %d unanimous splits in the curr_results for ntax = %d ' % (len(unanimous_splits), self.curr_n_taxa))
        for split in unanimous_splits:
            best_conflicting = self.find_best_conflicting(starting_tree=ml_est, split=split, dataset=culled)
            for b in best_conflicting:
                add_nontriv_splits_attr(b, all_taxa_bitmask)

            best_conflicting.sort(reverse=True)
            tm = best_conflicting[0]
            next_round_trees.append(tm)
            curr_results.extend(best_conflicting[1:])
        
        
        ########################################
        # To keep the remaining trees in next_round_trees diverse we will
        #   try to add trees that maximize a score which is:
        #       lambda*tree_split_rarity + lnL
        #   where lambda is a tuning parameter and tree_split_rarity is:
        #       n_tree_times_splits = num_trees_in_next_round_trees * num_splits_per_tree
        #       split_occurrence = num_trees_in_next_round_trees_that_have_split
        #       tree_split_rarity = n_tree_times_splits - SUM split_occurrence
        #   in which the summation is taken over all splits in the tree
        #####
        max_len = self.max_trees_carried_over
        num_trees_to_add = max_len - len(next_round_trees)
        if num_trees_to_add > len(curr_results):
            split_count = {}
            n_tree_times_splits = 0
            for tm in next_round_trees:
                for k in tm.splits:
                    split_count[k] = split_count.get(k, 0) + 1
                    n_tree_times_splits += 1
            for tm in curr_results:
                tm.tree_split_rarity = n_tree_times_splits
                for split in tm.splits:
                    tm.tree_split_rarity -= split_count.get(split, 0)
            def split_diversity_cmp(x, y, lambda_mult=self.split_diversity_multiplier):
                x.retention_score = lambda_mult*x.tree_split_rarity + x.score
                y.retention_score = lambda_mult*y.tree_split_rarity + y.score
                return cmp(x.retention_score, y.retention_score)
            curr_results.sort(cmp=split_diversity_cmp, reverse=True)

        ########################################
        # We are now going to try to add elements (in order) from curr_results
        #   until we run out of trees to add or we reach max_len
        #####
        n_added = 0
        try:
            set_of_split_sets.clear()
            for tm in next_round_trees:
                set_of_split_sets.add(tm.splits)

            cri = iter(curr_results)
            while n_added < num_trees_to_add:
                tm = cri.next()
                if tm.splits not in set_of_split_sets:
                    set_of_split_sets.add(tm.splits)
                    next_round_trees.append(tm)
                    n_added += 1
        except StopIteration:
            pass
        _LOG.info('Added %d trees that were not "required" to guarantee that no splits were unanimous for ntax = %d' % (n_added, self.curr_n_taxa))

        ########################################
        # the trees can be hefty, so lets free unneeded memory
        #####
        del curr_results[:]
        
        ########################################
        # Finally, lets get a decent score for each tree before moving to the next round
        #   because the trees are big, we'll replace each element rather
        #   than allowing a duplicate list to be created.
        #####
        for i in range(len(next_round_trees)):
            next_round_trees[i] = self.score_tree(next_round_trees[i], culled, n, self.tree_scoring_stop_gen)

        _LOG.info('A total of %d trees were retained for ntax = %d lnL range from %f to %f' % (len(next_round_trees), self.curr_n_taxa, next_round_trees[0].score, next_round_trees[-1].score))

        return next_round_trees
Exemplo n.º 8
0
def main_cli():

    description = '%s %s ' % (_program_name, _program_version)
    usage = "%prog [options] <TREES FILE> [<TREES FILE> [<TREES FILE> [...]]"

    parser = OptionParser(usage=usage,
                          add_help_option=True,
                          version=_program_version,
                          description=description)
    parser.add_option('-r',
                      '--reference',
                      dest='reference_tree_filepath',
                      default=None,
                      help="path to file containing the reference (true) tree")
    parser.add_option('-v',
                      '--verbose',
                      action='store_false',
                      dest='quiet',
                      default=True,
                      help="Verbose mode")

    (opts, args) = parser.parse_args()

    ###################################################
    # Support file idiot checking

    sampled_filepaths = []
    missing = False
    for fpath in args:
        fpath = os.path.expanduser(os.path.expandvars(fpath))
        if not os.path.exists(fpath):
            sys.exit('Sampled trees file not found: "%s"' % fpath)
        sampled_filepaths.append(fpath)
    if not sampled_filepaths:
        sys.exit(
            "Expecting arguments indicating files that contain sampled trees")

    sampled_file_objs = [open(f, "rU") for f in sampled_filepaths]

    ###################################################
    # Lots of other idiot-checking ...

    # target tree
    if opts.reference_tree_filepath is None:
        sys.exit(
            "A reference tree must be specified (use -h to see all options)")
    reference_tree_filepath = os.path.expanduser(
        os.path.expandvars(opts.reference_tree_filepath))
    if not os.path.exists(reference_tree_filepath):
        sys.exit('Reference tree file not found: "%s"\n' %
                 reference_tree_filepath)

    d = Dataset()
    ref_trees = d.read_trees(open(reference_tree_filepath, 'ru'),
                             schema="NEXUS")

    if len(ref_trees) != 1:
        sys.exit("Expecting one reference tree")
    ref_tree = ref_trees[0]
    splits.encode_splits(ref_tree)
    assert (len(d.taxa_blocks) == 1)
    taxa = d.taxa_blocks[0]

    ###################################################
    # Main work begins here: Count the splits

    start_time = datetime.datetime.now()

    comments = []
    tsum = treesum.TreeSummarizer()
    tsum.burnin = 0
    if opts.quiet:
        tsum.verbose = False
        tsum.write_message = None
    else:
        tsum.verbose = True
        tsum.write_message = sys.stderr.write

    _LOG.debug("### COUNTING SPLITS ###\n")
    split_distribution = splits.SplitDistribution(taxa_block=taxa)
    tree_source = MultiFileTreeIterator(filepaths=sampled_filepaths,
                                        core_iterator=nexus.iterate_over_trees)
    tsum.count_splits_on_trees(tree_source, split_distribution)

    report = []
    report.append("%d trees read from %d files." %
                  (tsum.total_trees_read, len(sampled_filepaths)))
    report.append("%d trees ignored in total." %
                  (tree_source.total_trees_ignored))
    report.append(
        "%d trees considered in total for split support assessment." %
        (tsum.total_trees_counted))
    report.append("%d unique taxa across all trees." %
                  len(split_distribution.taxa_block))
    num_splits, num_unique_splits, num_nt_splits, num_nt_unique_splits = split_distribution.splits_considered(
    )
    report.append("%d unique splits out of %d total splits counted." %
                  (num_unique_splits, num_splits))
    report.append(
        "%d unique non-trivial splits out of %d total non-trivial splits counted."
        % (num_nt_unique_splits, num_nt_splits))

    _LOG.debug("\n".join(report))

    con_tree = treegen.star_tree(taxa)
    taxa_mask = taxa.all_taxa_bitmask()
    splits.encode_splits(con_tree)
    leaves = con_tree.leaf_nodes()

    to_leaf_dict = {}
    for leaf in leaves:
        to_leaf_dict[leaf.edge.clade_mask] = leaf
    unrooted = True
    n_read = float(tsum.total_trees_read)
    sp_list = []
    for split, count in split_distribution.split_counts.iteritems():
        freq = count / n_read
        if not splits.is_trivial_split(split, taxa_mask):
            m = split & taxa_mask
            if (m != taxa_mask) and (
                (m - 1) & m
            ):  # if not root (i.e., all "1's") and not singleton (i.e., one "1")
                if unrooted:
                    c = (~m) & taxa_mask
                    if (c - 1) & c:  # not singleton (i.e., one "0")
                        if 1 & m:
                            k = c
                        else:
                            k = m
                        sp_list.append((freq, k, m))
                else:
                    sp_list.append((freq, m, m))
    sp_list.sort(reverse=True)

    root = con_tree.seed_node
    root_edge = root.edge

    curr_freq = 1.1
    curr_all_splits_list = []
    curr_compat_splits_list = []
    all_splits_by_freq = []
    compat_splits_by_freq = []

    # Now when we add splits in order, we will do a greedy, extended majority-rule consensus tree
    for freq, split_to_add, split_in_dict in sp_list:
        if abs(curr_freq - freq) > 0.000001:
            # dropping down to the next lowest freq
            curr_l = [freq, []]
            curr_all_splits_list = curr_l[1]
            all_splits_by_freq.append(curr_l)
            curr_l = [freq, []]
            curr_compat_splits_list = curr_l[1]
            compat_splits_by_freq.append(curr_l)
            curr_freq = freq

        curr_all_splits_list.append(split_to_add)

        if (split_to_add & root_edge.clade_mask) != split_to_add:
            continue
        lb = splits.lowest_bit_only(split_to_add)
        one_leaf = to_leaf_dict[lb]
        parent_node = one_leaf
        while (split_to_add & parent_node.edge.clade_mask) != split_to_add:
            parent_node = parent_node.parent_node
        if parent_node is None or parent_node.edge.clade_mask == split_to_add:
            continue  # split is not in tree, or already in tree.

        new_node = trees.Node()
        new_node_children = []
        new_edge = new_node.edge
        new_edge.clade_mask = 0
        for child in parent_node.child_nodes():
            # might need to modify the following if rooted splits
            # are used
            cecm = child.edge.clade_mask
            if (cecm & split_to_add):
                assert cecm != split_to_add
                new_edge.clade_mask |= cecm
                new_node_children.append(child)
        # Check to see if we have accumulated all of the bits that we
        #   needed, but none that we don't need.
        if new_edge.clade_mask == split_to_add:
            for child in new_node_children:
                parent_node.remove_child(child)
                new_node.add_child(child)
            parent_node.add_child(new_node)
            con_tree.split_edges[split_to_add] = new_edge
            curr_compat_splits_list.append(split_to_add)
    ref_set = set()
    for s in ref_tree.split_edges.iterkeys():
        m = s & taxa_mask
        if 1 & m:
            k = (~m) & taxa_mask
        else:
            k = m
        if not splits.is_trivial_split(k, taxa_mask):
            ref_set.add(k)

    all_set = set()
    compat_set = set()

    _LOG.debug("%d edges is the reference tree" % (len(ref_set)))

    print "freq\tcompatFP\tcompatFN\tcompatSD\tallFP\tallFN\tallSD"
    for all_el, compat_el in itertools.izip(all_splits_by_freq,
                                            compat_splits_by_freq):
        freq = all_el[0]
        all_sp = all_el[1]
        all_set.update(all_sp)
        all_fn = len(ref_set - all_set)
        all_fp = len(all_set - ref_set)
        compat_sp = compat_el[1]
        compat_set.update(compat_sp)
        compat_fn = len(ref_set - compat_set)
        compat_fp = len(compat_set - ref_set)

        print "%f\t%d\t%d\t%d\t%d\t%d\t%d" % (freq, compat_fp, compat_fn,
                                              compat_fp + compat_fn, all_fp,
                                              all_fn, all_fp + all_fn)
Exemplo n.º 9
0
    def select_trees_for_next_round(self, culled, curr_results):
        all_taxa_bitmask = curr_results[0].tree.seed_node.edge.clade_mask
        assert all_taxa_bitmask == ((1 << self.curr_n_taxa) - 1)
        ########################################
        # First, we make sure that there are not duplicate topologies
        # Because we reverse sort, we'll be retaining the tree with the
        #   best score
        #####
        curr_results.sort(reverse=True)
        set_of_split_sets = set()
        unique_topos = []
        for tm in curr_results:
            add_nontriv_splits_attr(tm, all_taxa_bitmask)
            if tm.splits not in set_of_split_sets:
                unique_topos.append(tm)
                set_of_split_sets.add(tm.splits)
        curr_results = unique_topos
        set_of_split_sets.clear()
        _LOG.info('There were %d unique result topologiesfor ntax = %d ' %
                  (len(curr_results), self.curr_n_taxa))

        ########################################
        # the trees can be hefty, so lets eliminate unneeded references
        #####
        del unique_topos

        ########################################
        # Make sure to keep the next_round_trees list (and a set that helps us
        #   identify unique trees.
        #####
        ml_est = curr_results[0]
        curr_results.pop(0)
        next_round_trees = [ml_est]

        ########################################
        # Now we identify best trees that LACK the splits in the ML tree
        #####
        ml_split_dict = ml_est.tree.split_edges
        unanimous_splits = []
        best_disagreeing_index_set = set()
        for split in ml_split_dict.iterkeys():
            if not is_trivial_split(split, all_taxa_bitmask):
                found = False
                for n, tm in enumerate(curr_results):
                    if split not in tm.tree.split_edges:
                        best_disagreeing_index_set.add(n)
                        found = True
                        break
                if not found:
                    unanimous_splits.append(split)

        ########################################
        # now we add the trees that "must" be included because they do NOT have
        #   a split that is in the current ML tree.
        # We do this with a reverse sorted list so that we can pop them off of
        #   the curr_results list without invalidating the list of indices to move
        #####
        bdis_list = list(best_disagreeing_index_set)
        bdis_list.sort(reverse=True)
        for tree_ind in bdis_list:
            tm = curr_results.pop(tree_ind)
            next_round_trees.append(tm)

        ########################################
        # Now we have to augment our list of trees such that we have exemplar trees
        #   that conflict with every split in the ML tree
        # We'll do this by starting from a version of the ML tree that has been
        #   collapsed so that it does not conflict with the split
        #####
        _LOG.info(
            'There were %d unanimous splits in the curr_results for ntax = %d '
            % (len(unanimous_splits), self.curr_n_taxa))
        for split in unanimous_splits:
            best_conflicting = self.find_best_conflicting(starting_tree=ml_est,
                                                          split=split,
                                                          dataset=culled)
            for b in best_conflicting:
                add_nontriv_splits_attr(b, all_taxa_bitmask)

            best_conflicting.sort(reverse=True)
            tm = best_conflicting[0]
            next_round_trees.append(tm)
            curr_results.extend(best_conflicting[1:])

        ########################################
        # To keep the remaining trees in next_round_trees diverse we will
        #   try to add trees that maximize a score which is:
        #       lambda*tree_split_rarity + lnL
        #   where lambda is a tuning parameter and tree_split_rarity is:
        #       n_tree_times_splits = num_trees_in_next_round_trees * num_splits_per_tree
        #       split_occurrence = num_trees_in_next_round_trees_that_have_split
        #       tree_split_rarity = n_tree_times_splits - SUM split_occurrence
        #   in which the summation is taken over all splits in the tree
        #####
        max_len = self.max_trees_carried_over
        num_trees_to_add = max_len - len(next_round_trees)
        if num_trees_to_add > len(curr_results):
            split_count = {}
            n_tree_times_splits = 0
            for tm in next_round_trees:
                for k in tm.splits:
                    split_count[k] = split_count.get(k, 0) + 1
                    n_tree_times_splits += 1
            for tm in curr_results:
                tm.tree_split_rarity = n_tree_times_splits
                for split in tm.splits:
                    tm.tree_split_rarity -= split_count.get(split, 0)

            def split_diversity_cmp(x,
                                    y,
                                    lambda_mult=self.split_diversity_multiplier
                                    ):
                x.retention_score = lambda_mult * x.tree_split_rarity + x.score
                y.retention_score = lambda_mult * y.tree_split_rarity + y.score
                return cmp(x.retention_score, y.retention_score)

            curr_results.sort(cmp=split_diversity_cmp, reverse=True)

        ########################################
        # We are now going to try to add elements (in order) from curr_results
        #   until we run out of trees to add or we reach max_len
        #####
        n_added = 0
        try:
            set_of_split_sets.clear()
            for tm in next_round_trees:
                set_of_split_sets.add(tm.splits)

            cri = iter(curr_results)
            while n_added < num_trees_to_add:
                tm = cri.next()
                if tm.splits not in set_of_split_sets:
                    set_of_split_sets.add(tm.splits)
                    next_round_trees.append(tm)
                    n_added += 1
        except StopIteration:
            pass
        _LOG.info(
            'Added %d trees that were not "required" to guarantee that no splits were unanimous for ntax = %d'
            % (n_added, self.curr_n_taxa))

        ########################################
        # the trees can be hefty, so lets free unneeded memory
        #####
        del curr_results[:]

        ########################################
        # Finally, lets get a decent score for each tree before moving to the next round
        #   because the trees are big, we'll replace each element rather
        #   than allowing a duplicate list to be created.
        #####
        for i in range(len(next_round_trees)):
            next_round_trees[i] = self.score_tree(next_round_trees[i], culled,
                                                  n,
                                                  self.tree_scoring_stop_gen)

        _LOG.info(
            'A total of %d trees were retained for ntax = %d lnL range from %f to %f'
            % (len(next_round_trees), self.curr_n_taxa,
               next_round_trees[0].score, next_round_trees[-1].score))

        return next_round_trees