def _build_2x2_fig(self): if not self.traces: _ = self._build_fig(xaxis="x2", yaxis="y2") layout = UnionDict({ "xaxis": { "anchor": "y", "domain": [0.0, 0.099] }, "xaxis2": { "anchor": "y2", "domain": [0.109, 1.0] }, "xaxis3": { "anchor": "y3", "domain": [0.109, 1.0] }, "yaxis": { "anchor": "x", "domain": [0.109, 1.0] }, "yaxis2": { "anchor": "x2", "domain": [0.109, 1.0] }, "yaxis3": { "anchor": "x3", "domain": [0.0, 0.099] }, }) layout |= self.layout fig = UnionDict(data=[], layout=layout) # common settings ticks_off_kwargs = dict(_ticks_off) ticks_on_kwargs = dict(_ticks_on) # core traces and layout fig.data.extend(self.traces) fig.layout.xaxis2 |= dict(range=self.xrange, **ticks_off_kwargs) fig.layout.yaxis2 |= dict(range=self.yrange, **ticks_off_kwargs) # left_track traces seen_types = set() max_x = 0 traces = [] for trace in self.left_track.traces: traces.append(trace) # convert to numpy array to handle None's x = numpy.array(trace.x, dtype=float) indices = numpy.logical_not(numpy.isnan(x)) max_x = max(x[indices].max(), max_x) if trace.legendgroup in seen_types: trace.showlegend = False seen_types.add(trace.legendgroup) left_range = [0, int(max_x) + 1] # bottom_track traces max_y = 0 for trace in self.bottom_track.traces: trace.xaxis = "x3" trace.yaxis = "y3" traces.append(trace) # convert to numpy array to handle None's y = numpy.array(trace.y, dtype=float) indices = numpy.logical_not(numpy.isnan(y)) max_y = max(y[indices].max(), max_y) if trace.legendgroup in seen_types: trace.showlegend = False seen_types.add(trace.legendgroup) bottom_range = [0, int(max_y) + 1] # add all traces fig.data.extend(traces) # configure axes for titles, limits, border and ticks fig.layout.yaxis |= dict(title=dict(text=self.ytitle), range=self.yrange, **ticks_on_kwargs) fig.layout.xaxis3 |= dict(title=dict(text=self.xtitle), range=self.xrange, **ticks_on_kwargs) # adjust row width of left plot for number of feature tracks min_range = min(left_range[1], bottom_range[1]) left_prop = left_range[1] / min_range # first the top row xaxis_domain = list(layout.xaxis.domain) xaxis_domain[1] = left_prop * xaxis_domain[1] fig.layout.xaxis |= dict(title=None, range=left_range, domain=xaxis_domain, **ticks_off_kwargs) fig.layout.xaxis |= dict(title={}, range=left_range, domain=xaxis_domain, **ticks_off_kwargs) space = 0.01 fig.layout.xaxis2.domain = (xaxis_domain[1] + space, 1.0) fig.layout.xaxis3.domain = (xaxis_domain[1] + space, 1.0) # now the right column bottom_prop = bottom_range[1] / min_range yaxis_domain = list(layout.yaxis3.domain) yaxis_domain[1] = bottom_prop * yaxis_domain[1] fig.layout.yaxis3 |= dict(title={}, range=bottom_range, domain=yaxis_domain, **ticks_off_kwargs) # and bottom of the boxes above fig.layout.yaxis.domain = (yaxis_domain[1] + space, 1.0) fig.layout.yaxis2.domain = (yaxis_domain[1] + space, 1.0) return fig
class Drawable: """container object for Plotly figures""" def __init__( self, title=None, traces=None, width=None, height=None, showlegend=True, visible_axes=True, layout=None, xtitle=None, ytitle=None, ): self._traces = traces or [] title = title if title is None else dict(text=title) self._default_layout = UnionDict( title=title, font=dict(family="Balto", size=14), width=width, height=height, autosize=False, showlegend=showlegend, xaxis=dict(visible=visible_axes), yaxis=dict(visible=visible_axes), hovermode="closest", plot_bgcolor=None, margin=dict(l=50, r=50, t=50, b=50, pad=4), ) layout = layout or {} self.layout = UnionDict(self._default_layout) self.layout |= layout self.xtitle = xtitle self.ytitle = ytitle def _repr_html_(self): self.show() @property def layout(self): if not hasattr(self, "_layout"): self._layout = UnionDict() return self._layout @layout.setter def layout(self, value): self.layout.update(value) @property def traces(self): return self._traces def get_trace_titles(self): titles = [tr.name for tr in self.traces] return titles def pop_trace(self, title): """removes the trace with a matching title attribute""" try: index = self.get_trace_titles().index(title) except ValueError: UserWarning(f"no trace with name {title}") return return self.traces.pop(index) def remove_traces(self, names): """removes traces by name Parameters ---------- names : str or iterable of str trace names """ if not self.traces: self._build_fig() names = names if type(names) != str else [names] for name in names: _ = self.pop_trace(name) def add_trace(self, trace): self.traces.append(trace) def bound_to(self, obj): """returns obj with self bound to it""" return bind_drawable(obj, self) @property def figure(self): if not self.traces: self._build_fig() xtitle = self.xtitle if not self.xtitle else dict(text=self.xtitle) ytitle = self.ytitle if not self.ytitle else dict(text=self.ytitle) self.layout.xaxis.title = xtitle self.layout.yaxis.title = ytitle return UnionDict(data=self.traces, layout=self.layout) def iplot(self, *args, **kwargs): from plotly.offline import iplot as _iplot _iplot(self.figure, *args, **kwargs) @extend_docstring_from(_show_) def show(self, renderer=None, **kwargs): _show_(self, renderer, **kwargs) def write(self, path, **kwargs): """writes static image file, suffix dictates format""" from plotly.io import write_image write_image(self.figure, path, **kwargs) def to_image(self, format="png", **kwargs): """creates static image, suffix dictates format""" from plotly.io import to_image return to_image(self.figure, format=format, **kwargs)
def layout(self): if not hasattr(self, "_layout"): self._layout = UnionDict() return self._layout
def add_trace(self, trace): self.traces.append(UnionDict(trace))
def _build_2x1_fig(self): """2 rows, one column, dotplot and seq1 annotated""" if not self.traces: _ = self._build_fig() layout = UnionDict( xaxis={ "anchor": "y2", "domain": [0.0, 1.0] }, yaxis={ "anchor": "free", "domain": [0.1135, 1.0], "position": 0.0 }, yaxis2={ "anchor": "x", "domain": [0.0, 0.0985] }, ) if self._overlaying: self.layout.yaxis3 = self.layout.yaxis2 self.layout.yaxis2 = {} self.layout.legend.x = 1.3 layout |= dict(self.layout) fig = UnionDict(data=[], layout=layout) # common settings ticks_off_kwargs = dict(_ticks_off) ticks_on_kwargs = dict(_ticks_on) # core traces and layout fig.data.extend(self.traces) fig.layout.xaxis |= dict(title=dict(text=self.xtitle), range=self.xrange, **ticks_on_kwargs) fig.layout.yaxis |= dict(title=dict(text=self.ytitle), range=self.yrange, **ticks_on_kwargs) # bottom traces seen_types = set() max_y = 0 traces = [] for trace in self.bottom_track.traces: trace.yaxis = "y2" trace.xaxis = "x" traces.append(trace) y = numpy.array(trace.y, dtype=float) indices = numpy.logical_not(numpy.isnan(y)) max_y = max(y[indices].max(), max_y) if trace.legendgroup in seen_types: trace.showlegend = False seen_types.add(trace.legendgroup) fig.data.extend(traces) fig.layout.yaxis2 |= dict(title={}, range=[0, int(max_y) + 1], **ticks_off_kwargs) return fig
def _build_fig(self, **kwargs): grouped = {} tree = self.tree text = UnionDict( { "type": "scatter", "text": [], "x": [], "y": [], "hoverinfo": "text", "mode": "markers", "marker": { "symbol": "circle", "color": "black", "size": self._marker_size, }, "showlegend": False, } ) support_text = [] get_edge_group = self._edge_mapping.get for edge in tree.preorder(): key = get_edge_group(edge.name, None) if key not in grouped: grouped[key] = defaultdict(list) group = grouped[key] coords = edge.get_segment_to_parent() xs, ys = list(zip(*coords)) group["x"].extend(xs + (None,)) group["y"].extend(ys + (None,)) edge_label = edge.value_and_coordinate("name", padding=0) text["x"].append(edge_label.x) text["y"].append(edge_label.y) text["text"].append(edge_label.text) if self.show_support: support = edge.support_text_coord( self.support_xshift, self.support_yshift, threshold=self.support_threshold, ) if support is not None: support |= UnionDict(xref="x", yref="y", font=self.tip_font) support_text.append(support) traces = [] for key in grouped: group = grouped[key] style = self._edge_sets.get( key, UnionDict( line=UnionDict( width=self._line_width, color=self._line_color, shape="spline", smoothing=1.3, ) ), ) trace = UnionDict(type="scatter", x=group["x"], y=group["y"], mode="lines") trace |= style if "legendgroup" not in style: trace["showlegend"] = False else: trace["name"] = style["legendgroup"] traces.append(trace) scale_shape, scale_text = self._get_scale_bar() traces.extend([text]) self.traces.extend(traces) if self.tips_as_text: self.layout.annotations = tuple(self._get_tip_name_annotations()) if self.show_support and support_text: self.layout.annotations = self.layout.annotations + tuple(support_text) if scale_shape: self.layout.shapes = self.layout.get("shape", []) + [scale_shape] self.layout.annotations += (scale_text,) else: self.layout.pop("shapes", None) if isinstance(self.tree, CircularTreeGeometry): # must draw this square if self.layout.width and self.layout.height: dim = max(self.layout.width, self.layout.height) elif self.layout.width: dim = self.layout.width elif self.layout.height: dim = self.layout.height else: dim = 800 self.layout.width = self.layout.height = dim # Span of tree along x-axis and Span of tree along y-axis x_diff = self.tree.max_x - self.tree.min_x y_diff = self.tree.max_y - self.tree.min_y # Maximum span max_span = max(x_diff, y_diff) # Use maximum span along both axes and pad the smaller one accordingly axes_range = dict( xaxis=dict( range=[ self.tree.min_x - (1.4 * max_span - x_diff) / 2, self.tree.max_x + (1.4 * max_span - x_diff) / 2, ] ), yaxis=dict( range=[ self.tree.min_y - (1.4 * max_span - y_diff) / 2, self.tree.max_y + (1.4 * max_span - y_diff) / 2, ] ), ) self.layout |= axes_range
class Dendrogram(Drawable): def __init__( self, tree, style="square", label_pad=None, contemporaneous=None, show_support=True, threshold=1.0, *args, **kwargs, ): length_attr = kwargs.pop("length_attr", None) super(Dendrogram, self).__init__( visible_axes=False, showlegend=False, *args, **kwargs ) klass = { "square": SquareTreeGeometry, "circular": CircularTreeGeometry, "angular": AngularTreeGeometry, "radial": RadialTreeGeometry, }[style] if length_attr is None and not contemporaneous: contemporaneous = tree.children[0].length is None length_attr = "frac_pos" if contemporaneous else length_attr or "length" kwargs = UnionDict(length_attr=length_attr) if contemporaneous else {} self.tree = klass(tree, **kwargs) self.tree.propagate_properties() self._label_pad = label_pad self._tip_font = UnionDict(size=12, family="Inconsolata, monospace") self._line_width = 1.25 self._marker_size = 3 self._line_color = "black" self._scale_bar = "bottom left" self._edge_sets = {} self._edge_mapping = {} self._contemporaneous = contemporaneous self._tips_as_text = True self._length_attr = self.tree._length self._tip_names = tuple(e.name for e in self.tree.tips()) self._max_label_length = max(map(len, self._tip_names)) if "support" not in self.tree.children[0].params: show_support = False self._show_support = show_support self._threshold = threshold self._support_xshift = None self._support_yshift = None self._default_layout.autosize = True self.layout = UnionDict(self._default_layout) @property def label_pad(self): default = 0.15 if isinstance(self.tree, CircularTreeGeometry) else 0.025 if self._label_pad is None: if not self.contemporaneous: max_x = max(self.tree.max_x, abs(self.tree.min_x)) self._label_pad = max_x * default else: self._label_pad = default return self._label_pad @label_pad.setter def label_pad(self, value): self._label_pad = value self._traces = [] @property def support_xshift(self): """relative x position (in pixels) of support text. Can be negative or positive.""" return self._support_xshift @support_xshift.setter def support_xshift(self, value): if value == self._support_xshift: return self._support_xshift = value self._traces = [] @property def support_yshift(self): """relative y position (in pixels) of support text. Can be negative or positive.""" return self._support_yshift @support_yshift.setter def support_yshift(self, value): if value == self._support_yshift: return self._support_yshift = value self._traces = [] @property def contemporaneous(self): return self._contemporaneous @contemporaneous.setter def contemporaneous(self, value): if type(value) != bool: raise TypeError if self._contemporaneous != value: klass = self.tree.__class__ length_attr = "frac_pos" if value else self._length_attr self.tree = klass(self.tree, length_attr=length_attr) self.tree.propagate_properties() self._traces = [] self.layout.xaxis |= dict(range=None, autorange=True) self.layout.yaxis |= dict(range=None, autorange=True) if value: # scale bar not needed self._scale_bar = False self._contemporaneous = value @property def tip_font(self): return self._tip_font @tip_font.setter def tip_font(self, val): """update tip font settings""" self._tip_font = val def _scale_label_pad(self): """returns the label pad scaled by maximum dist to tip""" return self.label_pad def _get_tip_name_annotations(self): annotations = [] for tip in self.tree.tips(): anote = tip.value_and_coordinate( "name", padding=self.label_pad, max_attr_length=self._max_label_length ) anote |= UnionDict(xref="x", yref="y", font=self.tip_font) annotations.append(anote) return annotations def _get_scale_bar(self): if not self.scale_bar or self.contemporaneous: return None, None x = self.tree.min_x if "left" in self.scale_bar else self.tree.max_x y = self.tree.min_y if "bottom" in self.scale_bar else self.tree.max_y scale = 0.1 * self.tree.max_x text = "{:.1e}".format(scale) if scale < 1e-2 else "{:.2f}".format(scale) shape = { "type": "line", "x0": x, "y0": y, "x1": x + scale, "y1": y, "line": {"color": self._line_color, "width": self._line_width}, } annotation = UnionDict( x=x + (0.5 * scale), y=y, xref="x", yref="y", yshift=10, text=text, showarrow=False, ax=0, ay=0, ) return shape, annotation def _build_fig(self, **kwargs): grouped = {} tree = self.tree text = UnionDict( { "type": "scatter", "text": [], "x": [], "y": [], "hoverinfo": "text", "mode": "markers", "marker": { "symbol": "circle", "color": "black", "size": self._marker_size, }, "showlegend": False, } ) support_text = [] get_edge_group = self._edge_mapping.get for edge in tree.preorder(): key = get_edge_group(edge.name, None) if key not in grouped: grouped[key] = defaultdict(list) group = grouped[key] coords = edge.get_segment_to_parent() xs, ys = list(zip(*coords)) group["x"].extend(xs + (None,)) group["y"].extend(ys + (None,)) edge_label = edge.value_and_coordinate("name", padding=0) text["x"].append(edge_label.x) text["y"].append(edge_label.y) text["text"].append(edge_label.text) if self.show_support: support = edge.support_text_coord( self.support_xshift, self.support_yshift, threshold=self.support_threshold, ) if support is not None: support |= UnionDict(xref="x", yref="y", font=self.tip_font) support_text.append(support) traces = [] for key in grouped: group = grouped[key] style = self._edge_sets.get( key, UnionDict( line=UnionDict( width=self._line_width, color=self._line_color, shape="spline", smoothing=1.3, ) ), ) trace = UnionDict(type="scatter", x=group["x"], y=group["y"], mode="lines") trace |= style if "legendgroup" not in style: trace["showlegend"] = False else: trace["name"] = style["legendgroup"] traces.append(trace) scale_shape, scale_text = self._get_scale_bar() traces.extend([text]) self.traces.extend(traces) if self.tips_as_text: self.layout.annotations = tuple(self._get_tip_name_annotations()) if self.show_support and support_text: self.layout.annotations = self.layout.annotations + tuple(support_text) if scale_shape: self.layout.shapes = self.layout.get("shape", []) + [scale_shape] self.layout.annotations += (scale_text,) else: self.layout.pop("shapes", None) if isinstance(self.tree, CircularTreeGeometry): # must draw this square if self.layout.width and self.layout.height: dim = max(self.layout.width, self.layout.height) elif self.layout.width: dim = self.layout.width elif self.layout.height: dim = self.layout.height else: dim = 800 self.layout.width = self.layout.height = dim # Span of tree along x-axis and Span of tree along y-axis x_diff = self.tree.max_x - self.tree.min_x y_diff = self.tree.max_y - self.tree.min_y # Maximum span max_span = max(x_diff, y_diff) # Use maximum span along both axes and pad the smaller one accordingly axes_range = dict( xaxis=dict( range=[ self.tree.min_x - (1.4 * max_span - x_diff) / 2, self.tree.max_x + (1.4 * max_span - x_diff) / 2, ] ), yaxis=dict( range=[ self.tree.min_y - (1.4 * max_span - y_diff) / 2, self.tree.max_y + (1.4 * max_span - y_diff) / 2, ] ), ) self.layout |= axes_range def style_edges(self, edges, line, legendgroup=None, tip2=None, **kwargs): """adjust display layout for the edges Parameters ---------- edges : str or series names of edges line : dict with plotly line style to applied to these edges legendgroup : str or None if str, a legend will be presented tip2 : str if provided, and edges is a str, passes edges (as tip1) and kwargs to get_edge_names kwargs keyword arguments passed onto get_edge_names """ if tip2: assert type(edges) == str, "cannot use a series of edges and tip2" edges = self.get_edge_names(edges, tip2, **kwargs) if type(edges) == str: edges = [edges] edges = frozenset(edges) if not edges.issubset({edge.name for edge in self.tree.preorder()}): raise ValueError("edge not present in tree") style = UnionDict(width=self._line_width, color=self._line_color) style.update(line) self._edge_sets[edges] = UnionDict(legendgroup=legendgroup, line=style) mapping = {e: edges for e in edges} self._edge_mapping.update(mapping) if legendgroup: self.layout["showlegend"] = True # need to trigger recreation of figure self._traces = [] def reorient(self, name, tip2=None, **kwargs): """change orientation of tree Parameters ---------- name : str name of an edge in the tree. If name is a tip, its parent becomes the new root, otherwise the edge becomes the root. tip2 : str if provided, passes name (as tip1) and all other args to get_edge_names, but sets clade=False and stem=True kwargs keyword arguments passed onto get_edge_names """ if tip2: kwargs.update(dict(stem=True, clade=False)) edges = self.get_edge_names(name, tip2, **kwargs) name = edges[0] if name in self._tip_names: self.tree = self.tree.rooted_with_tip(name) else: self.tree = self.tree.rooted_at(name) self.tree.propagate_properties() self._traces = [] def get_edge_names(self, tip1, tip2, outgroup=None, stem=False, clade=True): """ Parameters ---------- tip1 : str name of tip 1 tip2 : str name of tip 1 outgroup : str name of tip outside clade of interest stem : bool include name of stem to clade defined by tip1, tip2, outgroup clade : bool include names of edges within clade defined by tip1, tip2, outgroup Returns ------- list of edge names """ return self.tree.get_edge_names( tip1, tip2, stem=stem, clade=clade, outgroup_name=outgroup ) @property def scale_bar(self): """where to place a scale bar""" return self._scale_bar @scale_bar.setter def scale_bar(self, value): if value is True: value = "bottom left" valid = {"bottom left", "bottom right", "top left", "top right", False, None} assert value in valid if value != self._scale_bar: self._traces = [] self._scale_bar = value @property def tips_as_text(self): """displays tips as text""" return self._tips_as_text @tips_as_text.setter def tips_as_text(self, value): assert type(value) is bool if value == self._tips_as_text: return self._tips_as_text = value self._traces = [] self.layout.annotations = () @property def line_width(self): """width of dendrogram lines""" return self._line_width @line_width.setter def line_width(self, width): self._line_width = width if self.traces: setting = dict(width=width) for trace in self.traces: try: trace["line"] |= setting except KeyError: pass @property def marker(self): return self._marker_size @marker.setter def marker(self, size): self._marker_size = size if self.traces: setting = dict(size=size) for trace in self.traces: if trace.get("mode", None) == "markers": trace["marker"] |= setting @property def show_support(self): """whether tree edge support entries are displayed""" return self._show_support @show_support.setter def show_support(self, value): """whether tree edge support entries are displayed""" assert type(value) is bool if value == self._show_support: return self._show_support = value self._traces = [] self.layout.annotations = () @property def support_threshold(self): """cutoff for dislaying support""" return self._threshold @support_threshold.setter def support_threshold(self, value): assert 0 <= value <= 1, "Must be in [0, 1] interval" if value == self._threshold: return self._threshold = value self._traces = [] self.layout.annotations = ()
def test_get_subattr(self): """_getsubattr_ returns nested values via key""" d = UnionDict({"a": 1, "b": 2, "c": 3, "d": {"e": 5, "f": 6}}) self.assertEqual(d._getsubattr_([], "a"), 1) self.assertEqual(d._getsubattr_([], "d"), UnionDict({"e": 5, "f": 6})) self.assertEqual(d._getsubattr_(["d"], "e"), 5)
def test_construct_from_kwargs(self): """successfully define from an kwargs""" data = {"width": 600.0, "xaxis": {"title": {"text": "Alignment Position"}}} # empty object d = UnionDict(**data) self.assertEqual(d.xaxis.title.text, "Alignment Position")
def test_construction(self): """should handle deeply nested dict""" data = {"width": 600.0, "xaxis": {"title": {"text": "Alignment Position"}}} d = UnionDict(data) self.assertEqual(d.xaxis.title.text, "Alignment Position")