def test_cuts(self): labels = cut_straight(self.dendrogram) self.assertEqual(len(set(labels)), 2) labels = cut_straight(self.dendrogram, n_clusters=5) self.assertEqual(len(set(labels)), 5) labels = cut_balanced(self.dendrogram) self.assertEqual(len(set(labels)), 21) labels, new_dendrogram = cut_balanced(self.dendrogram, return_dendrogram=True) self.assertEqual(len(set(labels)), 21) self.assertTupleEqual(new_dendrogram.shape, (20, 4))
def test_options(self): labels = cut_straight(self.dendrogram, threshold=0.5) self.assertEqual(len(set(labels)), 7) labels = cut_straight(self.dendrogram, n_clusters=3, threshold=0.5) self.assertEqual(len(set(labels)), 3) labels = cut_straight(self.dendrogram, sort_clusters=False) self.assertEqual(len(set(labels)), 2) labels = cut_balanced(self.dendrogram, max_cluster_size=2, sort_clusters=False) self.assertEqual(len(set(labels)), 21) labels = cut_balanced(self.dendrogram, max_cluster_size=10) self.assertEqual(len(set(labels)), 5)
def filter(self, n_clusters): try: self.dendrogram except AttributeError: var_exists = False else: var_exists = True if not var_exists: self._cluster() self.df['cluster_id'] = cut_straight(self.dendrogram, n_clusters) return Draw.MolsToGridImage( self.df.drop_duplicates(subset='cluster_id', keep='first')['mols'])
def cluster(self, method, n_clust=None, threshold=None): """Cuts the dendrogram and returns cluster IDs. Straight cuts can either set a defined number of clusters, or alternatively set a distance threshold. Cluster sizes can vary widely. Balanced cuts respect a maximum cluster size. The number of clusters is determined on the fly. """ if self.verbose: print(f'clustering with a {method} cut') if method == 'straight': if n_clust is not None and threshold is not None: raise ValueError('Straight cut takes only one of n_clusters or threshold, not both.') self.clusters = cut_straight(self.dendrogram, n_clust, threshold) elif method == 'balanced': if n_clust is None: raise ValueError('Must set maximum cluster size (n_clust) for balanced_cut') self.clusters = cut_balanced(self.dendrogram, n_clust) else: print('Choose \"straight\" or \"balanced\"')
def svg_dendrogram_top(dendrogram, names, width, height, margin, margin_text, scale, line_width, n_clusters, color, font_size, reorder, rotate_names): """Dendrogram as SVG image with root on top.""" # scaling height *= scale width *= scale # positioning labels = cut_straight(dendrogram, n_clusters, return_dendrogram=False) index = get_index(dendrogram, reorder) n = len(index) unit_height = height / dendrogram[-1, 2] unit_width = width / n height_basis = margin + height position = { index[i]: (margin + i * unit_width, height_basis) for i in range(n) } label = {i: l for i, l in enumerate(labels)} width += 2 * margin height += 2 * margin if names is not None: text_length = np.max(np.array([len(str(name)) for name in names])) height += text_length * font_size * .5 + margin_text svg = """<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">""".format( width, height) # text if names is not None: for i in range(n): x, y = position[i] x -= margin_text y += margin_text if rotate_names: svg += """<text x="{}" y="{}" transform="rotate(60, {}, {})" font-size="{}">{}</text>""" \ .format(x, y, x, y, font_size, str(names[i])) else: y += margin_text svg += """<text x="{}" y="{}" font-size="{}">{}</text>""" \ .format(x, y, font_size, str(names[i])) # tree for t in range(n - 1): i = int(dendrogram[t, 0]) j = int(dendrogram[t, 1]) x1, y1 = position.pop(i) x2, y2 = position.pop(j) l1 = label.pop(i) l2 = label.pop(j) if l1 == l2: line_color = STANDARD_COLORS[l1 % len(STANDARD_COLORS)] else: line_color = color x = .5 * (x1 + x2) y = height_basis - dendrogram[t, 2] * unit_height svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\ .format(line_width, line_color, x1, y1, x1, y) svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\ .format(line_width, line_color, x2, y2, x2, y) svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\ .format(line_width, line_color, x1, y, x2, y) position[n + t] = (x, y) label[n + t] = l1 svg += '</svg>' return svg
def adjacency_clusters(adj,n_clusters=None,threshold=None): print('Clustering molecules...') ward = Ward() dendrogram = ward.fit_transform(adj) labels = cut_straight(dendrogram,n_clusters=n_clusters,threshold=threshold) return labels
def svg_dendrogram_left(dendrogram, names, width, height, margin, margin_text, scale, line_width, n_clusters, color, colors, font_size, reorder): """Dendrogram as SVG image with root on left side.""" # scaling height *= scale width *= scale # positioning labels = cut_straight(dendrogram, n_clusters, return_dendrogram=False) index = get_index(dendrogram, reorder) n = len(index) unit_height = height / n unit_width = width / dendrogram[-1, 2] width_basis = width + margin position = {index[i]: (width_basis, margin + i * unit_height) for i in range(n)} label = {i: l for i, l in enumerate(labels)} width += 2 * margin height += 2 * margin if names is not None: text_length = np.max(np.array([len(str(name)) for name in names])) width += text_length * font_size * .5 + margin_text svg = """<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">""".format(width, height) # text if names is not None: for i in range(n): x, y = position[i] x += margin_text y += unit_height / 3 text = str(names[i]).replace('&', ' ') svg += """<text x="{}" y="{}" font-size="{}">{}</text>""" \ .format(x, y, font_size, text) # tree for t in range(n - 1): i = int(dendrogram[t, 0]) j = int(dendrogram[t, 1]) x1, y1 = position.pop(i) x2, y2 = position.pop(j) l1 = label.pop(i) l2 = label.pop(j) if l1 == l2: line_color = colors[l1 % len(colors)] else: line_color = color y = .5 * (y1 + y2) x = width_basis - dendrogram[t, 2] * unit_width svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\ .format(line_width, line_color, x1, y1, x, y1) svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\ .format(line_width, line_color, x2, y2, x, y2) svg += """<path stroke-width="{}" stroke="{}" d="M {} {} {} {}" />"""\ .format(line_width, line_color, x, y1, x, y2) position[n + t] = (x, y) label[n + t] = l1 svg += '</svg>' return svg