def test_estimate_tree(num_edges): set_random_seed(0) E = num_edges V = 1 + E grid = make_complete_graph(V) K = grid.shape[1] edge_logits = np.random.random([K]) - 0.5 edges = estimate_tree(grid, edge_logits) # Check size. assert len(edges) == E for v in range(V): assert any(v in edge for edge in edges) # Check optimality. edges = tuple(edges) if V < len(TREE_GENERATORS): all_trees = get_spanning_trees(V) assert edges in all_trees all_trees = list(all_trees) logits = [] for tree in all_trees: logits.append( sum(edge_logits[find_complete_edge(u, v)] for (u, v) in tree)) expected = all_trees[np.argmax(logits)] assert edges == expected
def plot_feature_overlap(df, cmap='binary', method='cluster'): """Plot feature-feature presence overlap of a pandas dataframe. Args: df: A pandas dataframe. cmap: A matplotlib colormap. method: Method of clustering, one of 'cluster' or 'tree'. """ V = len(df.columns) present = (df == df).as_matrix().astype(np.float32) overlap = np.dot(present.T, present) assert overlap.shape == (V, V) # Sort features to make blocks contiguous. if method == 'tree': # TODO(fritzo) Fix this to not look awful. grid = make_complete_graph(V) weights = np.empty(grid.shape[1], dtype=np.float32) for k, v1, v2 in grid.T: weights[k] = overlap[v1, v2] edges = estimate_tree(grid, weights) order, order_inv = order_vertices(edges) elif method == 'cluster': distance = scipy.spatial.distance.pdist(overlap) clustering = scipy.cluster.hierarchy.complete(distance) order_inv = scipy.cluster.hierarchy.leaves_list(clustering) else: raise ValueError(method) overlap = overlap[order_inv, :] overlap = overlap[:, order_inv] assert overlap.shape == (V, V) pyplot.imshow(overlap**0.5, cmap=cmap) pyplot.axis('off')
def __init__(self, model): """Create a TreeCat server. Args: model: A dict with fields: tree: A TreeStructure. suffstats: A dict of sufficient statistics. edge_logits: A K-sized array of nonnormalized edge probabilities. config: A global config dict. """ tree = model['tree'] suffstats = model['suffstats'] config = model['config'] logger.info('TreeCatServer with %d features', tree.num_vertices) assert isinstance(tree, TreeStructure) ragged_index = suffstats['ragged_index'] ServerBase.__init__(self, ragged_index) self._tree = tree self._config = config self._program = make_propagation_program(tree.tree_grid) # These are useful dimensions to import into locals(). V = self._tree.num_vertices E = V - 1 # Number of edges in the tree. M = self._config['model_num_clusters'] # Number of latent clusters. R = ragged_index[-1] # Size of ragged data. self._VEMR = (V, E, M, R) # Use Jeffreys priors. vert_prior = 0.5 edge_prior = 0.5 / M feat_prior = 0.5 / M # These are posterior marginals for vertices and pairs of vertices. self._vert_probs = suffstats['vert_ss'].astype(np.float32) + vert_prior self._vert_probs /= self._vert_probs.sum(axis=1, keepdims=True) self._edge_probs = suffstats['edge_ss'].astype(np.float32) + edge_prior self._edge_probs /= self._edge_probs.sum(axis=(1, 2), keepdims=True) # This represents information in the pairwise joint posterior minus # information in the individual factors. self._edge_trans = self._edge_probs.copy() for e, v1, v2 in tree.tree_grid.T: self._edge_trans[e, :, :] /= self._vert_probs[v1, :, np.newaxis] self._edge_trans[e, :, :] /= self._vert_probs[v2, np.newaxis, :] # This is the conditional distribution of features given latent. self._feat_cond = suffstats['feat_ss'].astype(np.float32) + feat_prior for v in range(V): beg, end = ragged_index[v:v + 2] feat_block = self._feat_cond[beg:end, :] feat_block /= feat_block.sum(axis=0, keepdims=True) # These are used to inspect and visualize latent structure. self._edge_logits = model['edge_logits'] self._estimated_tree = tuple( estimate_tree(self._tree.complete_grid, self._edge_logits)) self._tree.gc()
def __init__(self, ensemble): logger.info('EnsembleServer of size %d', len(ensemble)) assert ensemble ServerBase.__init__(self, ensemble[0]['suffstats']['ragged_index']) self._ensemble = [TreeCatServer(model) for model in ensemble] # These are used to inspect and visualize latent structure. self._edge_logits = self._ensemble[0].edge_logits.copy() for server in self._ensemble[1:]: self._edge_logits += server.edge_logits self._edge_logits /= len(self._ensemble) grid = self._ensemble[0]._tree.complete_grid self._estimated_tree = tuple(estimate_tree(grid, self._edge_logits)) self._ensemble[0]._tree.gc()
def estimate_tree(self): """Compute a maximum likelihood tree. Returns: A pair (edges, edge_logits), where: edges: A list of (vertex, vertex) pairs. edge_logits: A [K]-shaped numpy array of edge logits. """ logger.info('TreeCatTrainer.estimate_tree given %d rows', len(self._added_rows)) complete_grid = self._tree.complete_grid edge_logits = self.compute_edge_logits() edges = estimate_tree(complete_grid, edge_logits) return edges, edge_logits
def test_recover_structure(V, C): set_random_seed(V + C * 10) N = 200 M = 2 * C K = V * (V - 1) // 2 tree_prior = np.zeros(K, np.float32) tree = generate_tree(num_cols=V) table = generate_clean_dataset(tree, num_rows=N, num_cats=C)['table'] config = make_config(model_num_clusters=M) model = train_model(table, tree_prior, config) # Compute three types of edges. expected_edges = tree.get_edges() optimal_edges = estimate_tree(tree.complete_grid, model['edge_logits']) actual_edges = model['tree'].get_edges() # Print debugging information. feature_names = [str(v) for v in range(V)] root = '0' readable_data = np.zeros([N, V], np.int8) for v in range(V): beg, end = table.ragged_index[v:v + 2] readable_data[:, v] = table.data[:, beg:end].argmax(axis=1) with np_printoptions(precision=2, threshold=100, edgeitems=5): print('Expected:') print(print_tree(expected_edges, feature_names, root)) print('Optimal:') print(print_tree(optimal_edges, feature_names, root)) print('Actual:') print(print_tree(actual_edges, feature_names, root)) print('Correlation:') print(np.corrcoef(readable_data.T)) print('Edge logits:') print(triangular_to_square(tree.complete_grid, model['edge_logits'])) print('Data:') print(readable_data) print('Feature Sufficient Statistics:') print(model['suffstats']['feat_ss']) print('Edge Sufficient Statistics:') print(model['suffstats']['edge_ss']) # Check agreement. assert actual_edges == optimal_edges, 'Error in sample_tree' assert actual_edges == expected_edges, 'Error in likelihood'