class ChartGeometry: """ List of attributes that can be mixed into a `Geometry` to get an object that can be used specifically as a `Chart`'s `Geometry` """ chart = Field(""" The chart object that defines this `NodeGeometry`'s behavior. """) label = Field(""" A string label to track this `ChartGeometry` instance. """) size = Field(""" Size of this `ChartGeometry` --- with data-type determined by its owning `self.chart`. """) position = Field(""" Position of this `ChartGeometry` --- with data-type determined by its owning `self.chart`. """) facecolor = LambdaField( """ Color for this `ChartGeometry`'s face. You may provide a value, otherwise this `ChartGeometry` instance will query it's owning `self.chart`. """, lambda self: self.chart.get_color(self)) def __hash__(self): """ A hash for this `ChartGeometry` instance will be used as a key in mappings. """ return hash(self.label)
class CircuitAnalysisReport(Report): """ Add some circuit analysis specific attributes to `Report` """ provenance_model = Field(""" Either a `class CircuitProvenance` instance or a dict providing values for the fields of `class CircuitProvenance`. """, __as__=CircuitProvenance) figures = Field(""" A dict mapping label to an object with a `.graphic` and `.caption` attributes. """, __default_value__={}) references = Field(""" References of literature cited in this report. """, __type__=Mapping, __default_value__={}) content = LambdaField( """ All text as a single string. """, lambda self: (self.abstract + self.introduction + self.methods + self.results + self.discussion)) @lazyfield def field_values(self): """...""" try: name_phenomenon = self.phenomenon.name except AttributeError: name_phenomenon = make_name(self.phenomenon, separator="-") return\ dict( circuit=OrderedDict(( ("animal", self.provenance_model.animal), ("age", self.provenance_model.age), ("brain_region", self.provenance_model.brain_region), ("uri", self.provenance_model.uri), ("references", self.references), ("date_release", self.provenance_model.date_release), ("authors", '; '.join( "{}. {}".format(i+1, a) for i, a in enumerate(self.provenance_model.authors))))), author=self.author, phenomenon=name_phenomenon, label=make_label(self.label, separator='-'), title=make_name(self.label, separator='-'), abstract=self.abstract, introduction=self.introduction, methods=self.methods, results=self.results, content=self.content, discussion=self.discussion)
class FlowGeometry(ChartGeometry): """ Geometry to represent a connection from a begin node to an end node as a flow. """ begin_node = Field(""" The `NodeGeometry` instance where this `FlowGeometry` instance starts. """) end_node = Field(""" The `NodeGeometry` instance where this `FlowGeometry` instance ends. """) size_begin = Field(""" Size of this `FlowGeometry` where it starts. """) size_end = Field(""" Size of this `FlowGeometry` where it ends. """) label = LambdaField( """ Label for this `FlowGeometry` is constructed from it's `begin_node` and `end_node`. A custom value may be provided at initialization. """, lambda self: "({},{})".format(self.begin_node.label, self.end_node .label)) @lazyfield def size(self): """...""" return self.size_begin @lazyfield def identifier(self): """ Identifier can be used as a key in a mapping providing features of this `Geometry`. """ return (self.begin_node.identifier, self.end_node.identifier) @lazyfield def position(self): """ Position of a flow in the chart. """ return (self.begin_node.position, self.end_node.position)
class DocElem(WithFields, AIBase): """...""" title = Field(""" Title of this document element. """) label = LambdaField( """ A single word tag for this document element. """, lambda self: make_label(self.title)) @field def parent(self): """ Parent `DocElem` that contains this one. """ raise FieldIsRequired @field def children(self): """ A sequence of `DocElem`s that are contained in this one. """ return tuple() def save(self, record, path): """Save this `DocElem`""" try: save_super = super().save except AttributeError: return path return save_super(record, Path(path).joinpath(self.label)) def __call__(self, adapter, model, *args, **kwargs): """...""" try: get_record = super().__call__ except AttributeError: return Record(title=self.title, label=self.label) return get_record(adapter, model, *args, **kwargs).assign(title=self.title, label=self.label)
class CircularNetworkChart(NetworkChart): """ Illustrate a network's nodes as islands along a circle's periphery, and its edges as rivers flowing between these islands. """ NodeGeometryType = Field(""" A callable that returns a node geometry """, __default_value__=NodeArcGeometry) FlowGeometryType = Field(""" A callable that returns a flow geometry """, __default_value__=FlowArcGeometry) center = Field(""" Position on the page where the center of this `Chart` should lie. """, __default_value__=np.array([0., 0.])) rotation = Field(""" The overall angle in radians, by which the chart will be rotated. """, __default_value__=0.) link_data = Field(""" A `pandas.Series` with a double leveled index (`begin_node`, 'end_node``), with values providing weights / strengths of the links to be displayed as sizes of the flows. """) size = Field(""" Size of the figure. """, __default_value__=12) height_figure = LambdaField(""" Height of the figure. """, lambda self: self.size) width_figure = LambdaField(""" Width of the figure. """, lambda self: self.size) radial_size_node = Field(""" Radial size of a node --- will be the same for all nodes. """, __default_value__=0.1) spacing_factor = Field(""" Fraction of space along the periphery that must be left blank, to space nodes. """, __default_value__=0.25) unit_node_size = LambdaField( """ Node size will be determined as a multiple of this unit size. """, lambda self: 2 * np.pi * (1. - self.spacing_factor)) inner_outer_spacing = LambdaField( """ Spacing from inner to outer circles. """, lambda self: 1. * self.radial_size_node) margin = Field(""" Space (in units of the axes) around the geometries. """, __default_value__=0.5) node_flow_spacing_factor = Field(""" A multiplicative factor ( > 1.) by which flows placed on a node geometry must be spaced. """, __default_value__=1.) fontsize = Field("""...""", __default_value__=24) axes_size = Field(""" Axes will be scaled accordingly. """, __default_value__=1.) color_map = Field(""" Colors for the nodes. Please a provide a `Mapping`, like a dictionary, or pandas Series that maps node labels to the color value with which they should be painted. """, __default_value__={}) def get_color(self, geometry, **kwargs): """ Get colors for a geometry. """ return self.color_map.get(geometry.identifier, "green") @lazyfield def outer_circle(self): """...""" return Circle(label="outer-circle", radius=self.axes_size) @lazyfield def inner_circle(self): """...""" return Circle(label="inner-circle", radius=self.axes_size - self.inner_outer_spacing) @lazyfield def node_geometry_size(self): """ Size of the geometry that will represent a node. In a pandas series of tuples (radial, angular) """ return self.node_weight.apply(lambda weight: (pd.Series({ "total": (self.radial_size_node, self.unit_node_size * weight.total), "source": (self.radial_size_node, self.unit_node_size * weight.source), "target": (self.radial_size_node, self.unit_node_size * weight.target) })), axis=1) @staticmethod def _angular_size(dataframe_size): """...""" try: return dataframe_size.apply(lambda row: pd.Series( dict(total=row.total[1], source=row.source[1], target=row.target[1])), axis=1) except AttributeError: return dataframe_size.apply(lambda row: pd.Series( dict(source=row.source[1], target=row.target[1])), axis=1) raise RuntimeError( "Execution of _angular_size(...) should not reach here.") @lazyfield def node_angular_size(self): """...""" return self._angular_size(self.node_geometry_size) @lazyfield def node_position(self): """ Positions where the nodes will be displayed. """ number_nodes = self.node_weight.shape[0] spacing = 2. * np.pi * self.spacing_factor / number_nodes def _positions_angular(): position_end = -spacing for size in self.node_geometry_size.total.values: position_start = position_end + spacing position_end = position_start + size[1] yield (position_start + position_end) / 2. positions_angular = pd.Series(list(_positions_angular()), index=self.node_geometry_size.index, name="angular") starts_source = ( positions_angular - self.node_angular_size.total / 2.).rename("start_source") positions_angular_source = ( starts_source + self.node_angular_size.source / 2.).rename("angular") positions_source = positions_angular_source.apply( lambda position_angular: (self.outer_circle.radius, position_angular)) starts_target = ( positions_angular_source + self.node_angular_size.source / 2.).rename("start_target") positions_angular_target = ( starts_target + self.node_angular_size.target / 2.).rename("angular") positions_target = positions_angular_target.apply( lambda position_angular: (self.inner_circle.radius, position_angular)) return pd.concat([positions_source, positions_target], axis=1, keys=["source", "target"]) def point_at(self, radius, angle): """...""" return self.center + np.array([ radius * np.sin(angle + self.rotation), radius * np.cos(angle + self.rotation) ]) def arc(self, radius, angle_begin, angle_end, label=""): """...""" if not label: label = "{}---{}".format(angle_begin, angle_end) return Arc(label=label, center=self.center, radius=radius, rotation=self.rotation, angle_begin=angle_begin, angle_end=angle_end).points() def flow_curve(self, radius, angle_begin, angle_end, label=""): """ Curve of a flow. """ angle_mean =\ (angle_begin + angle_end) / 2. angle_min =\ np.minimum(angle_begin, angle_end) angle_off =\ angle_mean - angle_min angle_center =\ np.pi - 2 * angle_off angle_rotation =\ np.pi / 2. - angle_min length =\ radius / np.cos(angle_off) radius_arc =\ radius * np.tan(angle_off) center_arc =\ self.center + np.array([ length * np.sin(angle_mean + self.rotation), length * np.cos(angle_mean + self.rotation)]) if not label: label = "{}==>{}".format(angle_begin, angle_end) rotation =\ self.rotation - angle_rotation\ if angle_begin < angle_end else\ self.rotation - angle_rotation - angle_center angle_end =\ -angle_center\ if angle_begin < angle_end else\ angle_center return Arc(label=label, center=center_arc, radius=radius_arc, rotation=rotation, angle_begin=0., angle_end=angle_end).points() def get_flow_position(self, node_geometry, flow_geometry): """ Where should flow geometry be situated on node geometry. """ assert node_geometry in ( flow_geometry.begin_node, flow_geometry.end_node),\ "{} not in ({}, {})".format( node_geometry.label, flow_geometry.begin_node.label, flow_geometry.end_node.label) flow_size =\ flow_geometry.size_end\ if node_geometry == flow_geometry.end_node else\ flow_geometry.size_begin spaced =\ lambda pos: self.node_flow_spacing_factor * pos start =\ spaced( node_geometry.position.angular - node_geometry.size.angular / 2) if not node_geometry.flow_positions: position = (start, start + flow_size) node_geometry.flow_positions.append((flow_geometry, position)) else: for geometry, position in node_geometry.flow_positions: if geometry == flow_geometry: return position start = spaced(position[1]) position = (start, start + flow_size) node_geometry.flow_positions.append((flow_geometry, position)) return position @lazyfield def flow_geometry_size(self): """...""" def _flow_sizes(node_type): """...""" assert node_type in ("source", "target") if node_type == "source": nodes = self.link_data.index.get_level_values("begin_node") return self.node_angular_size.source.loc[nodes].values * ( self.link_data / self.node_flow.outgoing.loc[nodes].values).rename("begin") else: nodes = self.link_data.index.get_level_values("end_node") return self.node_angular_size.target.loc[nodes].values * ( self.link_data / self.node_flow.incoming.loc[nodes].values).rename("end") raise RuntimeError( "Execution of `flow_data(...)` should not have reached here.") return pd.concat([_flow_sizes("source"), _flow_sizes("target")], axis=1) def draw(self, draw_diagonal=True, *args, **kwargs): """ Draw this `Chart`. """ figure = plt.figure(figsize=(self.size, self.size)) axes = plt.gca() s = self.axes_size + self.margin axes.set(xlim=(-s, s), ylim=(-s, s)) self.outer_circle.draw(*args, **kwargs) self.inner_circle.draw(*args, **kwargs) for node_geometry in self.source_geometries.values(): node_geometry.draw(axes, *args, **kwargs) #node_geometry.add_text(axes, *args, **kwargs) for node_geometry in self.target_geometries.values(): node_geometry.draw(axes, *args, **kwargs) #node_geometry.add_text(axes, *args, **kwargs) for flow_geometry in self.flow_geometries.values(): if (flow_geometry.begin_node == flow_geometry.end_node and not draw_diagonal): continue flow_geometry.draw(axes, *args, **kwargs)
class FlowArcGeometry(ChartGeometry, Polygon): """ Geometry to represent a flow from a begin node to an end node. """ chart = Field(""" The network chart that defines this `NodeGeometry`'s behavior. """) begin_node = Field(""" A `NodeGeometry` instance where this `FlowGeometry` starts. """) end_node = Field(""" A `NodeGeometry` instance where this `FlowGeometry` ends. """) size_begin = Field(""" Size at beginning. """) size_end = Field(""" Size at end. """) label = LambdaField( """ Label can be constructed from nodes. """, lambda self: (self.begin_node.label, self.end_node.label)) @lazyfield def size(self): return self.size_begin @lazyfield def identifier(self): """ Identifier can be used as a key in a mapping providing features for this `Geometry`. """ return (self.begin_node.identifier, self.end_node.identifier) @lazyfield def position(self): """ A flow lies between its begin and end nodes. """ return (self.begin_node.position, self.end_node.position) @lazyfield def sides(self): """...""" arc_begin = self.chart.get_flow_position(self.begin_node, self) begin_base = Path( label=self.label, vertices=self.chart.arc( self.chart.outer_circle.radius, #self.begin_node.shape.radial[0], arc_begin[0], arc_begin[1])) if self.begin_node == self.end_node: side_forward = Path( label=self.label, vertices=[ self.chart.point_at(self.begin_node.shape.radial[0], arc_begin[1]), self.chart.point_at(self.inner_circle.radius, arc_begin[1]) ]) end_base = Path(label=self.label, vertices=self.chart.arc( self.chart.inner_circle.radius, arc_begin[1], arc_begin[0])) side_backward = Path(label=self.label, vertices=[ self.chart.point_at( self.chart.inner_circle.radius, self.chart.point_at( self.begin_node.shape.radial[0], arc_begin[0])) ]) else: arc_end = self.chart.get_flow_position(self.end_node, self) side_forward = Path(label=self.label, vertices=self.chart.flow_curve( self.chart.inner_circle.radius, arc_begin[1], arc_end[0])) end_base = Path(label=self.label, vertices=self.chart.arc( self.chart.inner_circle.radius, arc_end[0], arc_end[1])) side_backward = Path(label=self.label, vertices=self.chart.flow_curve( self.chart.inner_circle.radius, arc_end[1], arc_begin[0])) return [begin_base, side_forward, end_base, side_backward] @lazyfield def curve(self): """ A curve along the middle of this `FlowGeometry` instance. """ arc_begin =\ self.chart.get_flow_position(self.begin_node, self) angle_begin =\ (arc_begin[0] + arc_begin[1]) / 2 if self.begin_node == self.end_node: return Path(label="{}_curve".format(self.label), vertices=[ begin_vertex, self.chart.point_at( self.begin_node.shape.radial[0], angle_begin), self.chart.point_at(self.chart.inner_circle.radius, angle_begin) ]) arc_end =\ self.chart.get_flow_position(self.end_node, self) angle_end =\ (arc_end[0] + arc_end[1]) / 2 vertices =\ self.chart.flow_curve( self.chart.inner_circle.radius, angle_begin, angle_end) return Path(label="{}_curve".format(self.label), vertices=vertices) def draw(self, axes, *args, **kwargs): """ Draw the `Polygon` associated with this `FlowGeometry` instance, and then draw an arrow over it. """ super().draw(*args, **kwargs) #self.curve.draw(*args, **kwargs) N = len(self.curve.vertices) n = -1 #np.int32(-0.1 * N) #arrow_start = self.curve.vertices[n] arc_end =\ self.chart.get_flow_position(self.end_node, self) angle_end =\ (arc_end[0] + arc_end[1]) / 2 arrow_start =\ self.chart.point_at( self.end_node.position.radial - self.end_node.size.radial, angle_end) arrow_direction =\ self.curve.vertices[n] - self.curve.vertices[n-1] arrow_end = arrow_start + arrow_direction arrow_direction = arrow_end - arrow_start color = self.facecolor color[3] = 0.5 # axes.arrow( # arrow_start[0], arrow_start[1], # arrow_direction[0], arrow_direction[1], # head_width=self.size, head_length=self.end_node.size.radial, # fc=color, # ec="gray")#self.facecolor) return axes
class StructuredAnalysis(Analysis): """ An analysis structured as individual components that each handle an independent responsibility. """ author = Field(""" An object describing the author. """, __default_value__=Author.anonymous) AdapterInterface = Field(""" An class written as a subclass of `InterfaceMeta` that declares and documents the methods required of the adapter by this analysis. """, __type__=InterfaceMeta) default_adapter = Field(""" the adapter to use if none is provided when the analysis is run """, __required__=False) abstract = LambdaField( """ A short description of this analysis. """, lambda self: self.phenomenon.description) introduction = Field( """ A scientific introduction to this analysis. """, __as__=Section.introduction, __default_value__=Section.introduction("Not Provided")) methods = Field(""" A description of the algorithm / procedure used to compute the results, and the experimental measurement reported in this analysis. """, __as__=Section.methods, __default_value__=Section.methods("Not Provided.")) measurement_parameters = Field(""" An object providing a collection of parameters to measure with. This object may be of type: 1. either `pandas.DataFrame`, 2. or `adapter, model -> Collection<MeasurementParameters>`, 3. or dmt.tk.Parameters. """, __as__=Parameters, __default_value__=nothing) sampling_methodology = Field(""" A tag indicating whether this analysis will make measurements on random samples drawn from a relevant population of circuit constituents, or on the entire population. The circuit constituents population to be measured will be determined by a query. """, __default_value__="Not Provided") sample_size = Field(""" Number of samples to measure for each set of the measurement parameters. This field will be relevant when the measurements are made on random samples. When the measurement is exhaustive, the whole population of (relevant) circuit constituents will be measured. """, __default_value__=20) sample_measurement = Field(""" A callable that maps `(adapter, model, **parameters, **customizations) ==> measurement` where parameters : paramters for the measurements customizations : that specify the method used to make a measurement This field may also be implemented as a method in a subclass. """, __required__=False) measurement_collection = Field( """ A callable that will collect measurements passed as an iterable. The default value assumes that the each measurement will return an elemental value such as integer, or floating point number. """, __default_value__=primitive_type_measurement_collection) plotter = Field(""" A class instance or a module that has `plot` method that will be used to plot the results of this analysis. The plotter should know how to interpret the data provided. For example, the plotter will have to know which columns are the x-axis, and which the y-axis. The `Plotter` instance used by this `BrainCircuitAnalysis` instance should have those set as instance attributes. """, __required__=False) stats = Field(""" An object that provides a statistical summary for the measurements made in this analysis. This object may be just a function that takes this analysis' measurements as an argument. """, __as__=Statistics, __default_value__=nothing) verdict = Field(""" An object that provies a verdict on the measurements made in this analysis. This object may be just a function that takes this analysis' measurements as an argument. """, __default_value__=always_pass) results = Field(""" A callable on relevant parameters that will return results for a run of this analysis. """, __as__=Section.results, __default_value__="Results are presented in the figure") conclusion = Field( """ A callable on relevant parameters that will return conclusion for a run of this analysis. """, __as__=Section.conclusion, __default_value__= "Conclusion will be provided after a review of the results.") discussion = Field( """ A callable on relevant parameters that will return conclusion for a run of this analysis. """, __as__=Section.discussion, __default_value__= "Conclusion will be provided after a review of the results.") reference_data = Field(""" A pandas.DataFrame containing reference data to compare with the measurement made on a circuit model. Each dataset in the dataframe must be annotated with index level 'dataset', in addition to levels that make sense for the measurements. """, __default_value__=NOT_PROVIDED) report = Field(""" A callable that will generate a report. The callable should be able to take arguments listed in `get_report(...)` method defined below. """, __default_value__=Report) reporter = Field(""" A class or a module that will report the results of this analysis. It is up to the reporter what kind of report to generate. For example, the report can be a (interactive) webpage, or a static PDF. """, __default_value__=NOT_PROVIDED, __examples__=[Reporter(path_output_folder=os.getcwd())]) Measurement = namedtuple("Measurement", ["method", "dataset", "data"]) # TODO: The methods below are from HD's alpha StructuredAnalysis # This must be refactored... # TODO: this probably should not be public.... def adapter_method(self, adapter=None): """ get the measuremet marked on the AdapterInterface """ measurement_name = self.AdapterInterface.__measurement__ adapter = self.adapter if adapter is None else adapter try: method = getattr(adapter, measurement_name) except AttributeError: method = getattr(adapter, "get_{}".format(measurement_name)) finally: return method @lazyfield def label(self): """ A label for this analysis. """ return "{}_by_{}".format(self.phenomenon.label, '_'.join(self.names_measurement_parameters)) # TODO: parallelize model measuring? def get_model_measurements(self, adapter, model, sample_size=None): """ Get a statistical measurement. """ assert not sample_size or isinstance(sample_size, int),\ "Expected int, received {}".format(type(sample_size)) try: method = self.sample_measurement measurement_method =\ lambda *args, **kwargs: method(adapter, *args, **kwargs) except AttributeError: measurement_method = self.adapter_method(adapter) parameters = self._parameters.for_sampling(adapter, model, size=self.sample_size) # TODO: test parameter order is preserved measurements = make_dataframe_hashable( pd.DataFrame(parameters).assign( **{ self.phenomenon: [measurement_method(model, **p) for p in tqdm(parameters)] })) return measurements @property def number_measurement_parameters(self): """ How many parameters are the measurements made with? For example, if the measurement parameters are region, layer, the number is two. The implementation below uses the implementation of `self.measurement_parameters`. However, if you change the type of that component, you will have to override However, if you change the type of that component, you will have to override. """ return self._parameters.values.shape[1] @property def names_measurement_parameters(self): """ What are the names of the parameters that the measurements are made with? If measurement parameters cannot provide the variables (a.k.a parameter labels or tags), an empty list is returned. """ try: return self._parameters.variables except TypeError: return [] return None @property def phenomenon(self): try: return self.sample_measurement.phenomenon except AttributeError: try: return self.AdapterInterface.phenomenon except AttributeError: return NOT_PROVIDED @property def _parameters(self): if self.measurement_parameters is not NOT_PROVIDED: return\ Parameters(self.measurement_parameters)\ if not isinstance(self.measurement_parameters, Parameters)\ else self.measurement_parameters elif self.reference_data is not NOT_PROVIDED: return\ Parameters( self.reference_data.drop(columns=[self.phenomenon])) else: raise ValueError(""" {} has neither measurement_parameters nor reference_data provide one or the other """.format(self)) def with_fields(self, **kwargs): for field in self.get_fields(): if field not in kwargs: try: kwargs[field] = getattr(self, field) except AttributeError: pass return self.__class__(**kwargs) def validation(self, circuit_model, adapter=None, *args, **kwargs): """ Validation of a model against reference data. """ assert not self.reference_data.empty,\ "Validation needs reference data." if adapter is None: return self((adapter, circuit_model)) else: return self(circuit_model) def _get_report(self, measurements): try: fig = self.plotter(measurements, phenomenon=self.phenomenon) except Exception as e: import traceback fig = """" Plotting failed: {}, {} returning raw measurments""".format(e, traceback.format_exc()) warnings.warn(Warning(fig)) try: len(fig) except TypeError: fig = [fig] # TODO: until field is fixed, this will raise for empty docstrings # TEMPORARY workaround: if self.sample_measurement.__doc__ is None: self.sample_measurement.__doc__ = '' report = Report(figures=fig, measurement=measurements, phenomenon=self.phenomenon, methods=self.sample_measurement.__doc__) report.stats = self.stats(report) report.verdict = self.verdict(report) return report def __call__(self, *models): """ perform an analysis of 'models' """ measurements = OrderedDict() if self.reference_data is not NOT_PROVIDED: measurements[ _label(self.reference_data, default='reference_data')] =\ make_dataframe_hashable(self.reference_data) for i, model in enumerate(models): if isinstance(model, tuple): adapter, model = model else: adapter = self.default_adapter measurements[_label(model, default='model', i=i)] =\ self.get_model_measurements(adapter, model) report = self._get_report(measurements) if self.reporter is not NOT_PROVIDED: self.reporter.post(report) return report
class Report(WithFields): """ Report base class. We follow the principle of IMRAD (Introduction, Methods, Results, and Discussion: https://en.wikipedia.org/wiki/IMRAD ) """ author = Field(""" Author of this report. """, __default_value__=Author.anonymous) phenomenon = Field(""" Label for the phenomenon that this report is about. """, __default_value__=NA) figures = Field(""" A dict mapping label to an object with a `.graphic` and `.caption` attributes. """, __default_value__=NA) measurement = Field(""" Measurement associated with this `Report`. This should be a dataframe, with a properly annotated index. """, __default_value__=NA) abstract = Field(""" Provide an abstract for the report. """, __default_value__=NA, __as__=paragraphs) introduction = Field(""" Provide the research question, and the tested hypothesis or the purpose of the research? """, __default_value__=NA, __as__=paragraphs) methods = Field(""" Describe the algorithm / procedure used to compute the results or the experimental measurement presented in this `Report`. This `Field` will be used in the figure caption. """, __default_value__=NA, __as__=paragraphs) sections = Field(""" An ordered list of report sections. """, __default_value__=NA) chapters = Field(""" An ordered list of report chapters. """, __default_value__=NA) results = Field(""" Answer to the research question, to be included in the figure caption. """, __default_value__=NA, __as__=paragraphs) discussion = Field(""" A longer text describing how the results presented in the report fit in the existing knowledge about the topic. What might the answer imply and why does it matter? How does it fit in with what other researchers have found? What are the perspectives for future research? """, __default_value__=NA, __as__=paragraphs) references = Field(""" References for this analysis report. """, __default_value__=NA) label = LambdaField( """ Label for this report to save data. """, lambda self: self.phenomenon) @lazyfield def field_values(self): """...""" try: name_phenomenon = self.phenomenon.name except AttributeError: name_phenomenon = make_name(self.phenomenon, separator='_') return { "author": self.author, "phenomenon": name_phenomenon, "figures": self.figures, "introduction": self.introduction, "methods": self.methods, "results": self.results, "discussion": self.discussion, "references": self.references, "sections": self.sections, "chapters": self.chapters }
class NetworkFlows(WithFields): """ Illustrate the strength of links in a network as flows. """ title = Field( """ Title to be displayed. If not provided, phenomenon for the data will be used. """, __default_value__="") node_variable = Field( """ Name of the variable (i.e. column name) for the nodes. """, __examples__=["mtype"]) pre_variable = LambdaField( """ Variable (i.e. column name) providing values of the pre-nodes. """, lambda self: "pre_{}".format(self.node_variable)) post_variable = LambdaField( """ Variable (i.e. column name) providing values of the post-nodes. """, lambda self: "post_{}".format(self.node_variable)) node_type = Field( """ Type of node values. """, __default_value__=str) phenomenon = Field( """ Variable (i.e. column name) providing the phenomenon whose value will be plotted as flows. """) @property def dataframe(self): """...""" try: return self._dataframe except AttributeError: self._dataframe = None return self._dataframe @dataframe.setter def dataframe(self, dataframe): """Interpret dataframe as a geometric data.""" dataframe = dataframe.reset_index() try: weight = dataframe[self.phenomenon]["mean"].values except KeyError: weight = dataframe[self.phenomenon].values self._dataframe =\ pd.DataFrame({ "begin_node": dataframe[self.pre_variable].values, "end_node": dataframe[self.post_variable].values, "weight": weight}) @lazyfield def network_geom(self): """ Geometry to hold the graphic. """ return NetworkGeom(label=self.title) def _norm_efferent(self, dataframe, aggregator="sum"): """ Efferent value is the exiting value. """ efferent_values =\ dataframe.groupby(self.pre_variable)\ .agg(aggregator) efferent_values.index.name = self.node_variable return efferent_values def _norm_afferent(self, dataframe, aggregator="sum"): """ Afferent value is the entering value. """ afferent_values =\ dataframe.groupby(self.post_variable)\ .agg(aggregator) afferent_values.index.name = self.node_variable return afferent_values @lazyfield def node_weights(self): """...""" assert self.dataframe is not None weights =\ self.dataframe\ .groupby(self.pre_variable)\ [[self.phenomenon]]\ .agg("sum") weights.index.name = "label" weights.name = "weight" return weights / np.sum(weights) @lazyfield def link_weights(self): """...""" try: weight = self.dataframe[self.phenomenon]["mean"].values except KeyError: weight = self.dataframe[self.phenomenon].values begin_nodes = self.dataframe[self.pre_variable].values end_nodes = self.dataframe[self.post_variable].values return\ pd.DataFrame(dict( begin_node=begin_nodes, end_node=end_nodes, weight=weight/self.node_weights[begin_nodes].values))\ .set_index(["begin_node", "end_node"]) def get_figure(self, dataframe, *args, caption="Caption not provided", **kwargs): """ Plot the data. Arguments ----------- dataframe :: <pre_variable, post_variable, phenomenon> """ self.dataframe = dataframe network_geom = NetworkGeom.draw(self.dataframe) assert self.pre_variable in self.dataframe.columns,\ "{} not in {}".format( self.pre_variable, self.dataframe.columns) assert self.post_variable in self.dataframe.columns,\ "{} not in {}".format( self.post_variable, self.dataframe.columns) pre_nodes = set( self.dataframe[ self.pre_variable ].astype(self.node_type)) post_nodes = set( self.dataframe[ self.post_variable ].astype(self.node_type)) all_nodes =\ pre_nodes.union(post_nodes) self.network_geom.spawn_nodes(self.node_weights) self.network_geom.spawn_flows(self.link_weights)
class ConnectomeAnalysesSuite(WithFields): """ Analyze the connectome of a brain circuit model. """ AdapterInterface = AdapterInterface sample_size = Field(""" Number of individual sample measurements for each set of parameter values. """, __default_value__=100) path_reports = Field(""" Location where the reports will be posted. """, __default_value__=os.path.join( os.getcwd(), "reports")) cell_mtypes = Field(""" Cell type for which pathways will be analyzed. """, __default_value__=[]) pre_synaptic_mtypes = LambdaField( """ Pre synaptic cell type for which pathways will be analyzed. """, lambda self: self.cell_mtypes) post_synaptic_mtypes = LambdaField( """ Post synaptic cell type for which pathways will be analyzed. """, lambda self: self.cell_mtypes) def pre_synaptic_cell_type(self, circuit_model, adapter): """...""" return self.pre_synaptic_mtypes\ if len(self.pre_synaptic_mtypes) > 0 else\ adapter.get_mtypes(circuit_model) def post_synaptic_cell_type(self, circuit_model, adapter): """...""" return self.post_synaptic_mtypes\ if len(self.post_synaptic_mtypes) > 0 else\ adapter.get_mtypes(circuit_model) @lazyfield def parameters_pre_synaptic_cell_mtypes(self): """...""" def _mtypes(adapter, circuit_model): mtypes =\ self.pre_synaptic_mtypes\ if len(self.pre_synaptic_mtypes) > 0 else\ adapter.get_mtypes(circuit_model) return pd.DataFrame({("pre_synaptic_cell", "mtype"): mtypes}) return\ Parameters(_mtypes) @lazyfield def parameters_post_synaptic_cell_mtypes(self): """...""" def _mtypes(adapter, circuit_model): """...""" mtypes =\ self.post_synaptic_mtypes\ if len(self.post_synaptic_mtypes) > 0 else\ adapter.get_mtypes(circuit_model) return pd.DataFrame({("post_synaptic_cell", "mtype"): mtypes}) return\ Parameters(_mtypes) @staticmethod def _at(role_synaptic, as_tuple=True): """...""" def _rename(variable): return\ "{}_{}".format(role_synaptic, variable)\ if not as_tuple else\ (role_synaptic, variable) return _rename @staticmethod def get_soma_distance_bins(circuit_model, adapter, cell, cell_group, bin_size=100., bin_mids=True): """ Get binned distance of `cell`'s soma from soma of all the cells in `cell_group`. """ distance =\ adapter.get_soma_distance( circuit_model, cell, cell_group) bin_starts =\ bin_size * np.floor(distance / bin_size) return\ [bin_start + bin_size / 2. for bin_start in bin_starts]\ if bin_mids else\ [[bin_start, bin_size] for bin_start in bin_starts] @staticmethod def get_random_cells(circuit_model, adapter, cell_type, number=1): """ Get a random cell with the given `cell_type`. Arguments ------------ cell_type :: pandas.Series<CellProperty> """ group_cells =\ adapter.get_cells(circuit_model, **cell_type) return\ group_cells if group_cells.shape[0] < number\ else group_cells.sample(n=number) @staticmethod def random_cell(circuit_model, adapter, cell_type): """Only one random cell""" return\ ConnectomeAnalysesSuite.get_random_cells( circuit_model, adapter, cell_type, number=1 ).iloc[0] @measurement_method(""" Number of afferent connections are computed as a function of the post-synaptic cell type. Cell-type can be any defined by the values of a neuron's properties, for example layer, mtype, and etype. A group of cells are sampled with the given (post-synaptic) cell-type. For each of these post-synaptic cells, incoming connections from all other neurons are grouped by cell type of the neuron where they start (i.e. their pre-synaptic cell-type) and their soma-dstiance from the post-synaptic cell under . consideration Afferent connection count or in-degree of a post-synaptic cell is defined as the number of pre-synaptic cells in each of these groups. """) @count_number_calls(LOGGER) def number_connections_afferent_verbose( self, circuit_model, adapter, cell, cell_properties_groupby, by_soma_distance=True, bin_size_soma_distance=100., with_cell_propreties_as_index=True): """ Just an example, for now. """ variable = "number_connections_afferent" value = lambda cnxns: np.ones(cnxns.shape[0]) def _soma_distance(other_cells): return\ self.get_soma_distance_bins( circuit_model, adapter, cell, other_cells, bin_size=bin_size_soma_distance) variables_groupby =\ cell_properties_groupby + ( ["soma_distance"] if by_soma_distance else []) connections =\ adapter.get_afferent_connections( circuit_model, cell) columns_relevant =\ cell_properties_groupby + ( [variable, "soma_distance"]\ if by_soma_distance else [variable]) cells_afferent =\ adapter.get_cells(circuit_model)\ .loc[connections.pre_gid.values]\ .assign(**{variable: value(connections)}) if by_soma_distance: cells_afferent =\ cells_afferent.assign(soma_distance=_soma_distance) value_measurement =\ cells_afferent[columns_relevant].groupby(variables_groupby)\ .agg("sum") return\ value_measurement[variable]\ if with_cell_propreties_as_index else\ value_measurement.reindex() def number_connections_afferent(self, by_soma_distance): """...""" return\ PathwayMeasurement( value=lambda connection: np.ones(connection.shape[0]), variable="number_connections_afferent", by_soma_distance=by_soma_distance, direction="AFF", specifiers_cell_type=["mtype"], sampling_methodology=terminology.sampling_methodology.random ).sample_one def number_connections_pathway(self, by_soma_distance): return\ PathwayMeasurement( value=lambda connection: np.ones(connection.shape[0]), variable="number_connections_afferent", by_soma_distance=by_soma_distance, direction="AFF", specifiers_cell_type=["mtype"], sampling_methodology=terminology.sampling_methodology.exhaustive, summaries="sum" ).summary def strength_connections_afferent(self, by_soma_distance): """...""" return\ PathwayMeasurement( value=lambda cnxn: cnxn.strength.to_numpy(np.float64), by_soma_distance=by_soma_distance, direction="AFF", specifiers_cell_type=["mtype"], sampling_methodology=terminology.sampling_methodology.random, summaries="sum" ).sample_one def strength_pathways(self, by_soma_distance): """...""" return\ PathwayMeasurement( value=lambda cnxn: cnxn.strength.to_numpy(np.float64), variable="pathway_strength", by_soma_distance=by_soma_distance, direction="AFF", specifiers_cell_type=["mtype"], sampling_methodology=terminology.sampling_methodology.exhaustive, summaries="sum" ).summary # variables_measurement =\ # dict(number=1., soma_distance=_soma_distance)\ # if by_soma_distance else\ # dict(number=1.) # return\ # cells_afferent.assign(**variables_measurement)[columns_relevant]\ # .groupby(variables_groupby)\ # .agg("sum")\ # .number # return\ # adapter.get_cells(circuit_model)\ # .loc[adapter.get_afferent_gids(circuit_model, cell)]\ # .assign(**variables_measurement)\ # [variables_groupby + # additional_variables_groupby + # ["number_connections_afferent"]]\ # .groupby(variables_groupby + additional_variables_groupby)\ # .agg("sum")\ # .number_connections_afferent @lazyfield def analysis_number_connections_afferent(self): """ Analyze number of incoming connections. """ return BrainCircuitAnalysis( introduction=""" A circuit model should reproduce experimentally measured number of incoming connections of a cell. Here we analyze number of afferent connections to cells of a given (post-synaptic) cell-type, grouped by the cell-types of the post-synaptic cells. For example, if cell-type is defined by a cell's mtype, then given a pre-synaptic-mtype-->post-synaptic-mtype pathway, we analyze number of afferent connections incident upon the group of cells with the given post-synaptic mtype that originate from the group of all cells with the given pre-synaptic mtype. """, methods=""" For each of pre-synaptic and post-synaptic cell type in the given pathway {}, cells were sampled. Connection probability was computed as the number of cell pairs that were connected to the total number of pairs.""".format(self.sample_size), phenomenon=Phenomenon( "Number Afferent Connections", """ Probability that two neurons in a pathway are connected. While most of the interest will be in `mtype-->mtype` pathways, we can define a pathway as a any two group of cells, one on the afferent side, the other on the efferent side of a (possible) synapse. Given the pre-synaptic and post-synaptic cell types (groups), connection probability counts the fraction of connected pre-synaptic, post-synaptic pairs. Connection probability may be calculated as a function of the soma-distance between the cells in a pair, in which case the measured quantity will be vector-valued data such as a `pandas.Series`. """), AdapterInterface=self.AdapterInterface, measurement_parameters=self.parameters_post_synaptic_cell_mtypes, sample_measurement=self.number_connections_afferent( terminology.sampling_methodology.random, by_soma_distance=True).sample_one, measurement_collection=measurement.collection.series_type, plotter=MultiPlot(mvar=("post_synaptic_cell", "mtype"), plotter=LinePlot( xvar="soma_distance", xlabel="Soma Distance", yvar="number_afferent_connections", ylabel="Mean number of afferent connections", gvar=("pre_synaptic_cell", "mtype"), drawstyle="steps-mid")), report=CircuitAnalysisReport) @count_number_calls(LOGGER) def strength_connections_afferent(self, circuit_model, adapter, cell, cell_properties_groupby, by_soma_distance=True, bin_size_soma_distance=100.): """ ... """ def _soma_distance(other_cells): return\ self.get_soma_distance_bins( circuit_model, adapter, cell, other_cells, bin_size=bin_size_soma_distance) connections =\ adapter.get_afferent_connections( circuit_model, cell) variables_measurement =\ dict(strength=connections.strength.to_numpy(np.float64), soma_distance=_soma_distance)\ if by_soma_distance else\ dict(strength=connections.strength.to_numpy(np.float64)) columns_relevant =\ variables_groupby + list(variables_measurement.keys()) variables_groupby =\ cell_properties_groupby + ( ["soma_distance"] if by_soma_distance else []) cells_afferent =\ adapter.get_cells(circuit_model)\ .loc[connections.pre_gid.values] return\ cells_afferent.assign(**variables_measurement)[columns_relevant]\ .groupby(variables_groupby)\ .agg("sum")\ .strength @count_number_calls(LOGGER) def strength_afferent_connections( self, circuit_model, adapter, post_synaptic_cell_type, pre_synaptic_cell_type_specifier=None, by_soma_distance=True, bin_size=100, sampling_methodology=terminology.sampling_methodology.random): """ Strength of afferent connections incident on the post-synaptic cells, originating from the pre-synaptic cells. Strength of connection is the number of synapses mediating that connection. Arugments -------------- post_synaptic_cell_type :: An object describing the group of ~ post-synaptic cells to be investigated in these analyses. ~ Interpretation of the data in this object will be ~ delegated to the adapter used for the model analyzed. ~ Here are some guidelines when this object may is a dictionary. ~ Such a dictionary will have cell properties such as region, layer, ~ mtype, and etype as keys. Each key may be given either a single ~ value or an iterable of values. Phenomena must be evaluated for ~ each of these values and collected as a pandas.DataFrame. """ LOGGER.debug(""" Strength afferent connections for post-synaptic cell type: {} """.format(post_synaptic_cell_type)) if pre_synaptic_cell_type_specifier is None: pre_synaptic_cell_type_specifier =\ list(post_synaptic_cell_type.keys()) def _prefix_pre_synaptic(variable): return\ self._at("pre_synaptic_cell_type", as_tuple=True)(variable)\ if variable in pre_synaptic_cell_type_specifier\ else variable variables_groupby = [ _prefix_pre_synaptic(variable) for variable in pre_synaptic_cell_type_specifier ] if by_soma_distance: variables_groupby.append("soma_distance") post_synaptic_cell =\ self.random_cell( circuit_model, adapter, post_synaptic_cell_type) def _soma_distance(pre_cells): return\ self.get_soma_distance_bins( circuit_model, adapter, post_synaptic_cell, pre_cells, bin_size=bin_size) connections =\ adapter.get_afferent_connections( circuit_model, post_synaptic_cell) variables_measurement =\ dict(strength=connections.strength.to_numpy(dtype=np.float64), soma_distance=_soma_distance)\ if by_soma_distance else\ dict(number_connections_afferent=1.) return\ adapter.get_cells(circuit_model)\ .loc[connections.pre_gid.values]\ .assign(**variables_measurement)\ .rename(columns=_prefix_pre_synaptic)\ [variables_groupby + ["strength"]]\ .groupby(variables_groupby)\ .agg("sum")\ .strength # def _strength_connection(pre_synaptic_cells): # return\ # adapter.get_strength_connections( # pre_synaptic_cells, # post_synaptic_cell) # gids_afferent =\ # adapter.get_afferent_gids( # circuit_model, # post_synaptic_cell) # variables_measurement =\ # dict(strength_afferent_connections=_strength_connection, # soma_distance=_soma_distance)\ # if by_soma_distance else\ # dict(strength_afferent_connections=_strength_connection) # return\ # adapter.get_cells(circuit_model)\ # .loc[gids_afferent]\ # .assign(**variables_measurement)\ # .rename(columns=_prefix_pre_synaptic)\ # [variables_groupby + ["strength_afferent_connections"]]\ # .groupby(variables_groupby)\ # .agg("sum")\ # .strength_afferent_connections def synapse_count(self, adapter, circuit_model, pathway): """ Get synapse count ... """ raise NotImplementedError def analysis_synapse_count(self, pre_synaptic_cell_type_specifiers, post_synaptic_cell_type_specifiers): """ Analysis of number of synapses in pathways specified by values of pre and post synaptic cell properties. Arguments ---------------- pre_synaptic_cell_type_specifiers :: cell properties post_synaptic_cell_type_specifiers :: cell properties """ return BrainCircuitAnalysis( introduction=""" Not provided. """, methods=""" Not provided. """, phenomenon=Phenomenon( "Sysapse count", """ Number of synapses in a pathway. """), AdapterInterface=self.AdapterInterface, measurement_parameters=self.pathways( pre_synaptic_cell_type_specifiers, post_synaptic_cell_type_specifiers), sample_measurement=self.synapse_count, measurement_collection=measurement.collection.primitive_type, plotter=HeatMap())
class Crosses(BasePlotter): """ A plot that compares measurements of a phenomena across two datasets. """ vvar = Field( """ Variable (column) that provides the quantity to be compared. """) fmt = Field( """ Point type to be plotted. """, __default_value__='o') xerr = LambdaField( """ Column in the dataframe for plotting errors along X. """, lambda self: "{}_err".format(self.xvar)) yerr = LambdaField( """ Column in the dataframe for plotting errors along Y. """, lambda self: "{}_err".format(self.yvar)) @staticmethod def _get_phenomenon(dataframe_long): """ Get phenomenon to be plotted. """ return dataframe_long.columns[0][0] def _get_plotting_data(self, dataframe_long): """ Convert `data` to plotting data. """ datasets =\ dataframe_long.reset_index().dataset.unique() if len(datasets) != 2: raise TypeError( """ Dataframe for plotting has more than two datasets: {} """.format(datasets)) statistic = ["mean", "std"] dataframe_x =\ dataframe_long.xs(self.xvar, level="dataset")[self.vvar]\ .dropna() if isinstance(dataframe_x, pd.DataFrame): dataframe_x = dataframe_x[statistic] dataframe_y =\ dataframe_long.xs(self.yvar, level="dataset")[self.vvar]\ .reindex(dataframe_x.index)\ .dropna() if isinstance(dataframe_y, pd.DataFrame): dataframe_y = dataframe_y[statistic] if dataframe_y.shape[0] != dataframe_x.shape[0]: LOGGER.warn( LOGGER.get_source_info(), """ Dataframe for plotting had different number of elements for the two datasets to be plotted: {}: {} {}: {} """.format( self.xvar, dataframe_x.shape[0], self.yvar, dataframe_y.shape[0])) dataframe_x =\ dataframe_x.reindex(dataframe_y.index)\ .dropna() try: return\ pd.DataFrame( {self.xvar: dataframe_x["mean"].values, self.yvar: dataframe_y["mean"].values, self.xerr: dataframe_x["std"].values, self.yerr: dataframe_y["std"].values}, index=dataframe_x.index)\ .reset_index() except KeyError: return\ pd.DataFrame( {self.xvar: dataframe_x.values, self.yvar: dataframe_y.values, self.xerr: 0., self.yerr: 0.}, index=dataframe_x.index)\ .reset_index() def _get_title(self, dataframe_long): """ Get a title to display. """ if self.title: return self.title phenomenon = self._get_phenomenon(dataframe_long) return ' '.join(word.capitalize() for word in phenomenon.split('_')) def get_dataframe(self, data): """...""" return data\ if isinstance(data, (pd.Series, pd.DataFrame)) else\ measurement.concat_as_summaries(data).reset_index() def get_figure(self, data, *args, caption="Caption not provided", **kwargs): """ Plot the data. Arguments ----------- data : A dict mapping dataset to dataframe. """ dataframe_long = self.get_dataframe(data) dataframe_wide = self._get_plotting_data(dataframe_long) graphic =\ seaborn.FacetGrid( dataframe_wide, hue=self.gvar, size=self.height_figure, legend_out=True) graphic.map( plt.errorbar, self.xvar, self.yvar, self.xerr, self.yerr, fmt=self.fmt) graphic.add_legend() graphic.fig.suptitle(self._get_title(dataframe_long)) limits =[ np.maximum( np.nanmax( dataframe_wide[self.xvar] + dataframe_wide[self.xerr]), np.nanmax( dataframe_wide[self.yvar] + dataframe_wide[self.yerr])), np.minimum( np.nanmin( dataframe_wide[self.xvar] - dataframe_wide[self.xerr]), np.nanmin( dataframe_wide[self.yvar] - dataframe_wide[self.yerr]))] plt.plot(limits, limits, "k--") return Figure( graphic.set( xlabel=self.xlabel if self.xlabel else self.xvar, ylabel=self.ylabel if self.ylabel else self.yvar), caption=caption) def plot(self, data, *args, **kwargs): """ Plot the data Arguments ----------- data : A dict mapping dataset to dataframe. """ return self\ .get_figure( data, *args, **kwargs)
class StructuredAnalysis(analysis.StructuredAnalysis): """ A base class for all circuit analyses. """ add_columns = Field(""" A callable that adds columns to a measurement (a `pandas.DataFrame`) """, __default_value__=lambda adapter, circuit_model, measurement: measurement) figures = LambdaField( """ An alias for `Field plotter`, which will be deprecated. A class instance or a module that has `plot` method that will be used to plot the results of this analysis. The plotter should know how to interpret the data provided. For example, the plotter will have to know which columns are the x-axis, and which the y-axis. The `Plotter` instance used by this `BrainCircuitAnalysis` instance should have those set as instance attributes. """, lambda self: self.plotter) sampling_methodology = Field( """ A tag indicating whether this analysis will make measurements on random samples drawn from a relevant population of circuit constituents, or on the entire population. The circuit constituents population to be measured will be determined by a query. """, __default_value__=terminology.sampling_methodology.random) processing_methodology = Field( """ How to make measurements? `batch` :: Process all the parameter sets as a batch. ~ A single measurement on all the parameter sets will be ~ dispatched to the plotter and attached to the report. ~ Thus a single report will be saved at the end of the ~ analysis run. `serial` :: Process a single parameter set at a time. ~ For each parameter set, make a measurement, generate a ~ figure and attach to the report. ~ Save the report and return a dict mapping parameter set ~ to its report. """, __default_value__=terminology.processing_methodology.batch) phenomenon = Field(""" An object providing the phenomenon analyzed. """) reference_data = Field(""" A pandas.DataFrame containing reference data to compare with the measurement made on a circuit model. Each dataset in the dataframe must be annotated with index level 'dataset', in addition to levels that make sense for the measurements. """, __default_value__=NA) report = Field(""" A callable to generate a report using fields used in `StructuredAnalysis.get_report(...) method.` """, __default_value__=CircuitAnalysisReport) @property def _has_reference_data(self): """...""" return len(self.reference_data) > 0 @lazyfield def label(self): """ A label for this analysis. """ return self.phenomenon.label def _as_label(parameter_label): if isinstance(parameter_label, str): return parameter_label if isinstance(parameter_label, tuple): return '-'.join(parameter_label) raise TypeError( "Parameter labels should be either string or tuple of strings." ) names_parameters =\ self.names_measurement_parameters return\ self.phenomenon.label\ if not names_parameters else\ "-by-".join(( self.phenomenon.label, '_'.join(_as_label(label) for label in names_parameters))) def _get_adapter_measurement_method(self, adapter): """...""" measurement_name =\ self.AdapterInterface.__measurement__ assert measurement_name[0] != '_' assert measurement_name[0:4] != "get" try: return\ getattr(adapter, measurement_name) except AttributeError as error: try: return\ getattr(adapter, "get_{}".format(measurement_name)) except AttributeError as error_get: raise AttributeError(""" No adapter attribute (w/o prefix `get_`)to measure {}: \t{} \t{}. """.format(measurement_name, error, error_get)) raise RuntimeError("Unreachable point in code.") raise RuntimeError("Unreachable point in code.") def get_measurement_method(self, adapter): """ Makes sense for analysis of a single phenomenon. Some changes below provide backward compatibility. """ if hasattr(self, "sample_measurement"): def _adapter_measurement_method(circuit_model, **kwargs): """ Make sample measurement method behave as if it was defined on the adapter. Arguments =============== kwargs : keyword arguments containing keywords providing the parameter set to make the measurement, as well other arguments that may affect how the measurement will be made (for example, deterministic or stochastic, or the number of samples to measure for a single set of parameters.) """ try: return\ self.sample_measurement( adapter, circuit_model, **kwargs) except (TypeError, AttributeError, KeyError) as error_adapter_model: try: return\ self.sample_measurement( adapter, circuit_model, **kwargs) except Exception as error_model_adapter: raise TypeError(""" sample_measurement(...) failed with arguments (model, adapter) and (adapter, model): \t {} \t {} """.format(error_adapter_model, error_model_adapter)) try: _adapter_measurement_method.__method__ =\ paragraphs(self.sample_measurement.__method__) except AttributeError: _adapter_measurement_method.__method__ =\ "Measurement method description not provided." return\ _adapter_measurement_method else: method =\ self._get_adapter_measurement_method(adapter) if not hasattr(method, "__method__"): method.__method__ =\ "Measurement method description not provided." raise RuntimeError("Unreachable point in code.") @lazyfield def description_measurement(self): """ This attribute will be NA if the method is implemented and described in the adapter. """ try: return self.sample_measurement.__method__ except AttributeError: return NA raise RuntimeError("Execution cannot reach here.") def parameter_sets(self, adapter, circuit_model): """ Get parameter sets from self.Parameters """ using_random_samples =\ self.sampling_methodology == terminology.sampling_methodology.random return\ self.measurement_parameters( adapter, circuit_model, sample_size=self.sample_size if using_random_samples else 1) def collect_serially(self, adapter, circuit_model, value_measurement, *args, **kwargs): """ Compute the measurement, on parameter set at a time... """ for p in tqdm(self.parameter_sets(adapter, circuit_model)): measured_value =\ value_measurement( circuit_model, sampling_methodology=self.sampling_methodology, **p, **kwargs) if isinstance(measured_value, pd.DataFrame): measured_value =\ measured_value[self.phenomenon.label] if not isinstance(measured_value, pd.Series): try: measured_value =\ pd.Series( measured_value, name=self.phenomenon.label) except ValueError: measured_value =\ pd.Series( [measured_value], name=self.phenomenon.label) measured_value =\ series_type_measurement_collection([(p, measured_value)] ).rename(columns={"value": self.phenomenon.label}) data =\ self.add_columns( adapter, circuit_model, measured_value.reset_index( ).assign(dataset=adapter.get_label(circuit_model), ).set_index(["dataset"] + measured_value.index.names)) yield\ Record( parameter_set=p, data=data, method=value_measurement.__method__) @interfacemethod def get_label(adapter, circuit_model): """ Get a label for the circuit model that can be used for documenting and saving results of an analysis in folders. """ pass def collect(self, adapter, circuit_model, value_measurement, *args, **kwargs): """ Collect a measurement on a circuit model Arguments ------------ value_measurement: Mapping parameters -> value ~ for the measurement to be collected. """ v = value_measurement measurement =\ self.measurement_collection( (p, v(circuit_model, sampling_methodology=self.sampling_methodology, **p, **kwargs)) for p in tqdm(self.parameter_sets(adapter, circuit_model)) ).rename( columns={"value": self.phenomenon.label}) dataset =\ adapter.get_label(circuit_model) return\ Record( data=pd.concat([measurement], keys=[dataset], names=["dataset"]), method=value_measurement.__method__) @lazyfield def description_reference_data(self): """ TODO: return a string describing the reference data. """ return NA def append_reference_data(self, measurement, reference_data={}, **kwargs): """ Append reference datasets. Arguments =========== reference_data :: dict mapping dataset label to an object with attributes <data :: DataFrame, citation :: String> """ reference_data =\ reference_data\ if len(reference_data) else\ self.reference_data measurement_dict = { dataset: measurement.xs(dataset, level="dataset") for dataset in measurement.reset_index().dataset.unique() } def _get_data(dataset): try: return dataset.data except AttributeError: return dataset measurement_dict.update({ label: _get_data(dataset) for label, dataset in reference_data.items() }) return measurement_dict def _with_reference_data(self, measurement, reference_data={}): """ Deprecated """ return self.append_reference_data(measurement, reference_data) @property def number_measurement_parameters(self): """ How many parameters are the measurements made with? For example, if the measurement parameters are region, layer, the number is two. The implementation below uses the implementation of `self.measurement_parameters`. However, if you change the type of that component, you will have to override. """ try: return self.measurement_parameters.number except: return np.nan @property def names_measurement_parameters(self): """ What are the names of the parameters that the measurements are made with? If measurement parameters cannot provide the variables (a.k.a parameter labels or tags), an empty list is returned. """ try: return self.measurement_parameters.variables except (TypeError, AttributeError): return [] return None def get_figures(self, measurement_model, reference_data, caption=None): """ Get a figure for the analysis of `circuit_model`. Arguments ---------- `figure_data`: The data frame to make a figure for. """ plotting_data =\ self.append_reference_data( measurement_model.data, reference_data) return\ self.figures( plotting_data, caption=caption) def get_report(self, label, measurement, author=Author.anonymous, figures=None, reference_data=None, provenance_circuit={}): """ Get a report for the given `measurement`. """ reference_data =\ reference_data if reference_data is not None\ else self.reference_data try: reference_citations = { label: reference.citation for label, reference in reference_data.items() } except AttributeError: LOGGER.info(""" Could not retrieve citations from reference data of type {}. """.format(type(reference_data))) reference_citations = {} return self.report( author=author, phenomenon=self.phenomenon.label, label=label, abstract=self.abstract, introduction=self.introduction(provenance_circuit)["content"], methods=self.methods(provenance_circuit)["content"], measurement=measurement, figures=figures, results=self.results(provenance_circuit)["content"], discussion=self.discussion(provenance_circuit)["content"], conclusion=self.conclusion(provenance_circuit)["content"], references=reference_citations, provenance_model=provenance_circuit) @interfacemethod def get_provenance(adapter, model, **kwargs): """ Get a mapping providing provenance of circuit model. Following keys are expected: 1. animal: String #animal species whose brain was modeled 2. age: String #age of the animal individual at which brain was modeled 3. brain_region: String #that was modeled 4. data_release: String #when the model was released 5. label: String #to use in documentation 6. uri: String #Universal Resource Identifier for the model 7. authors: List[String] #who built the model """ pass def __call__(self, adapter, model, author=Author.anonymous, **kwargs): """ Make this `Analysis` masquerade as a function. """ reference_data =\ kwargs.pop( "reference_data", self.reference_data) provenance_circuit =\ adapter.get_provenance(model) def _get_label(measurement_serial): return\ '-'.join("{}_{}".format(make_label(key), make_label(value)) for key, value in index_tree.as_unnested_dict( measurement_serial.parameter_set.items()))\ .replace('{', '')\ .replace('}', '')\ .replace("'", "") value_meaurement =\ self.get_measurement_method(adapter) def get_figures(measurement): plotting_data=\ self.append_reference_data( measurement.data, reference_data) return\ self.figures( plotting_data, caption=measurement.method) if self.processing_methodology == terminology.processing_methodology.serial: return (Record(label=_get_label(measurement), sub_report=self.get_report( self.label, measurement.data, author=author, figures=get_figures(measurement), reference_data=reference_data, provenance_circuit=provenance_circuit)) for measurement in self.collect_serially( adapter, model, value_measurement, **kwargs)) measurement =\ self.collect( adapter, model, value_meaurement, **kwargs) report =\ self.get_report( self.label, measurement.data, author=author, figures=get_figures(measurement), reference_data=reference_data, provenance_circuit=provenance_circuit) try: return self.reporter.post(report) except AttributeError: return report
class CheetahReporter(Reporter): """ Report with a cheetah template. """ width_page = Field(""" Width of the page on which the report will be displayed. """, __default_value__=0) under_line_type = Field(""" Type of line to underline section titles. """, __default_value__='-') end_line_type = Field(""" Type of line to demark sections or chapters in the report. """, __default_value__='=') template_main = Field(""" Template for Cheetah that will be used to create an HTML report for the report's main. """, __default_value__=""" <html> <body> #if $title_main_report <h1>$title (<A HREF=$path_main_report>$title_main_report Analysis</A>)</h1> #else <h1>$title Analysis</h1> <p>$endline</p> <h2> Author </h2> <h3> $author_name </h3> <h3> Affiliation: $author_affiliation </h3> #end if <p>$endline</p> #if $circuit <h2>Circuit Analyzed</h2> <p>$underline</p> #for $key, $value in $circuit.items() <p><strong>$key</strong>: $value</p> <p>$underline</p> #end for <p>$endline</p> #end if <h2>Abstract</h2> <p>$underline</p> #for $line in $abstract <p>$line</p> #end for <p>$endline</p> <h2>Introduction</h2> <p>$underline</p> #for $line in $introduction <p>$line</p> #end for <p>$endline</p> <h2>Methods</h2> <p>$underline</p> #for $line in $methods <p>$line</p> #end for <p>$endline</p> <h2>Results</h2> <p>$underline</p> #for $line in $results <p>$line</p> #end for #if $images <h3>Figures</h3> <br>{}</br> #end if <p>$endline</p> #if $sections <h2>Sections</h2> <p>$underline</p> #for $index, $section in enumerate($sections.items()) <p>$(index+1): <strong><A HREF=$section[1]>$section[0]</A></strong></p> #end for <p>$endline</p> #end if #if $chapters <h2>Chapters</h2> <p>$underline</p> #for $index, $chapter in enumerate($chapters.items()) <p>$(index+1): <strong><A HREF=$chapter[1]>$chapter[0]</A></strong></p> #end for <p>$endline</p> #end if <h2>Discussion</h2> <p>$underline</p> #for $line in $discussion <p>$line</p> #end for <p>$endline</p> #if $references <h2>References</h2> <p>$underline</p> #for $label, $citation in $references.items() <p><strong>$label</strong>: $citation</p> #end for <p>$endline</p> #end if #if $title_main_report <p>Back to <A HREF=$path_main_report>$title_main_report Analysis</A></p> #end if </body> </html> """) template_figures = Field(""" Template for figures... """, __default_value__=""" #for $label_image, $location_image in $images.items() <img src=$location_image alt="apologies.png"/> <p> <strong>$label_image.capitalize():</strong> #for $line in $captions[$label_image] $line #end for </p> #end for """) template_section = Field(""" Template for Cheetah that will be used to create an HTML report for a main report's sections. """, __default_value__=""" <html> </body> <h2><strong><A HREF=$path_main_report>$title_main_report Analysis</A></strong></h2> #if $section_index <h3>$section_index: $title</h3> #else <h3>$title</h3> #end if <p>$underline</p> <h4>Figures</h4> <br>{}</br> #for $line in $content <p>$line</p> #end for <p>$endline</p> </body> </html> """) template = LambdaField(""" The template to use... """, lambda self: self.template_main) def get_html_template(self, report): """ HTML template for sections. """ if report.figures: return\ self.template.format(self.template_figures) return self.template def _get_captions(self, report): """ Create a figure caption. """ return { label: figure.caption for label, figure in report.figures.items() } def dict_template(self, report, figure_locations, path_main_report=None, title_main_report=None, chapters=NA, chapter_index=None, sections=NA, section_index=None): """ Fill in the template. """ template_dict = { field: value for field, value in report.field_values.items() } LOGGER.debug("CheetahReporter.dict_template(..)", "report field values", "{}".format(template_dict)) if chapter_index is not None or section_index is not None: template_dict["circuit"] = None def _make_name(label): return\ string_utils.make_name( label, separator='_', keep_original_capitalization=True) template_dict.update( dict(underline=self.width_page * self.under_line_type, endline=self.width_page * self.end_line_type, author_name=report.author.name, author_affiliation=report.author.affiliation, images={ _make_name(label): location for label, location in figure_locations.items() }, captions={ _make_name(label): figure.caption for label, figure in report.figures.items() }, references=report.references, sections=sections, chapters=chapters)) if path_main_report is not None: template_dict["title_main_report"] =\ title_main_report if title_main_report is not None\ else "Main Report" template_dict["path_main_report"] = path_main_report template_dict["title"] = make_name(report.label, separator='-') if section_index is not None: template_dict["section_index"] = section_index + 1 if chapter_index is not None: template_dict["chapter_index"] = chapter_index + 1 else: template_dict["title_main_report"] = None template_dict["path_main_report"] = None return template_dict def _post_sections(self, report, output_uri, report_file_name, path_main_report, title_main_report): """ Report sections of this report! """ if report.sections: section_reporter =\ self.with_fields( template=self.template_section) return OrderedDict( (make_name(section.label, separator='-'), os.path.join( section_reporter.post(section, path_output_folder=output_uri, report_file_name=report_file_name, path_main_report=path_main_report, title_main_report=title_main_report, section_index=index, with_time_stamp=False), report_file_name)) for index, section in enumerate(report.sections)) return NA def _post_chapters(self, report, output_uri, report_file_name, path_main_report, title_main_report): """ Report chapters of this report! """ if report.chapters: chapter_reporter =\ self.with_fields( template=self.template_main) return OrderedDict( (make_name(chapter.label, separator='-'), os.path.join( chapter_reporter.post(chapter, path_output_folder=output_uri, report_file_name=report_file_name, path_main_report=path_main_report, title_main_report=title_main_report, chapter_index=index, with_time_stamp=False), report_file_name)) for index, chapter in enumerate(report.chapters)) return NA def post(self, report, template=None, path_output_folder=None, output_subfolder=None, report_file_name="report.html", strict=False, with_time_stamp=True, path_main_report=None, title_main_report=None, section_index=None, chapter_index=None, *args, **kwargs): """ Post the report. output_uri : Uniform resource identifier where the report should be ~ posted. strict : If `True`, a backup text report will not be generated. """ if isinstance(report, Generator): output_uri = {} if with_time_stamp: daytime = timestamp() time_stamp =\ os.path.join( daytime.day, daytime.time) else: time_stamp = None for subreport in report: if output_subfolder is not None: folder_subreport =\ os.path.join(subreport.label, output_subfolder) else: folder_subreport=\ subreport.label output_uri[subreport.label] =\ self.post( subreport.sub_report, template=template, path_output_folder=path_output_folder, output_subfolder=folder_subreport, report_file_name=report_file_name, strict=strict, with_time_stamp=time_stamp, path_main_report=path_main_report, title_main_report=title_main_report, section_index=section_index, chapter_index=chapter_index, *args, **kwargs) return output_uri LOGGER.debug(""".post(report={}, template={}, path_output_folder={}, output_subfolder={}, report_file_name={}, strict={}, with_time_stamp={}, path_main_report={}, title_main_report={}, section_index={}, chapter_index={})""".format( report, template, path_output_folder, output_subfolder, report_file_name, strict, with_time_stamp, path_main_report, title_main_report, section_index, chapter_index)) if section_index is not None and chapter_index is not None: raise TypeError(""" `CheetahReporter.post(...)` cannot a report that is both a chapter and a section. """) output_uri =\ self.get_output_location( report, path_output_folder=path_output_folder, output_subfolder=output_subfolder, with_time_stamp=with_time_stamp) LOGGER.status(LOGGER.get_source_info(), "Post report {} at".format(report.label), "\t {}".format(output_uri)) base_file_name =\ get_file_name_base(report_file_name) path_report_file =\ os.path.join( output_uri, "{}.html".format(base_file_name)) folder_figures, locations_figures =\ self._save_figures(report, output_uri) self._save_sections(report, output_uri) self._save_chapters(report, output_uri) self._save_measurement(report, output_uri) sections =\ self._post_sections( report, output_uri, report_file_name, path_main_report=path_report_file, title_main_report=report.label) chapters =\ self._post_chapters( report, output_uri, report_file_name, path_main_report=path_report_file, title_main_report=report.label) template_html =\ self.get_html_template(report) dict_template =\ self.dict_template(report, locations_figures, path_main_report=path_main_report, title_main_report=title_main_report, chapters=chapters, chapter_index=chapter_index, sections=sections, section_index=section_index) LOGGER.debug( "FILLED TEMPLATE", '\n'.join("{}: {}".format(k, v) for k, v in dict_template.items())) try: report_template_filled =\ Template(template_html, searchList=dict_template) try: report_html = str(report_template_filled) except Exception as html_error: LOGGER.alert( LOGGER.get_source_info(), """ Failed to generate HTML for the report: {} \t{} """.format(type(html_error), html_error)) raise html_error with open(path_report_file, "w") as file_output: file_output.write(report_html) except Exception as template_fill_error: LOGGER.warning( LOGGER.get_source_info(), """ Error during filling the report template: {} \t{}""".format(type(template_fill_error), template_fill_error)) if strict: raise template_fill_error super()._save_text_report(report, output_uri, folder_figures) return output_uri
class BasePlotter(ABCWithFields): """ Abstract base class for plotting. """ title = Field( """ Title to be displayed on the figure produced. """, __default_value__="") xvar = Field( """ Column in the data-frame to be plotted along the x-axis. """) xlabel = LambdaField( """ The label to be displayed along the x-axis. """, lambda self: self.xvar) yvar = Field( """ Column in the data-frame to be plotted along the y-axis. """) ylabel = LambdaField( """ The label to be displayed along the y-axis. """, lambda self: self.yvar) gvar = Field( """ Column in the data-frame that will be used to group data. (x, y) data for the same value of gvar will be plotted with the same decorations. A default value of empty string will indicate that there is no variable to group the data by. """, __default_value__="") fvar = Field( """ Facet Variable: Column in the dataframe that will be plotted on several faces. A default value of empty string will be interpreted as not set, and hence there will be only one face in the figure. """, __default_value__="") phenomenon = LambdaField( """ Phenomenon studied. """, lambda self: self.yvar) name = LambdaField( """ Name of the file to save the figure to. """, lambda self: self.phenomenon) number_columns = LambdaField( """ Number of columns in the figure. """, lambda self: None if not self.fvar else 3) height_figure = Field( """ Height of the figure. """, __default_value__ = 8.) aspect_ratio_figure = Field( """ Aspect ratio width / height for the figure. """, __default_value__=golden_aspect_ratio) confidence_interval = Field( """ float or “sd” or None, optional Size of confidence intervals to draw around estimated values. If “sd”, skip bootstrapping and draw the standard deviation of the observations. If None, no bootstrapping will be performed, and error bars will not be drawn. """, __default_value__="sd") order = Field( """ Either a column in the measurement dataframe that the values should be ordered by, or a callable that consumes a dataframe and returns a dataframe. """, __default_value__=lambda dataframe: dataframe) context = Field( """ Context in which plots will be produced. """, __examples__=["paper", "notebook"], __default_value__="paper") font = Field( """ Font family. """, __default_value__="sans-serif") font_size = Field( """ Size of the font to use. """, __default_value__=20) font_scale = Field( """ Separate scaling factor to independently scale the size of the font elements. """, __default_value__=2.) title_size = Field( """ Size of the title. """, __default_value__=30) axes_labelsize = Field( """ Size of axes labels. """, __default_value__="xx-large") axes_titlesize = Field( """ Size of axes title. """, __default_value__="xx-large") legend_text_size = Field( """ Size of text in plot legend. """, __default_value__=32) legend_title_size = Field( """ Size of the title of plot legend. """, __default_value__=42) xtick_labelsize = Field( """ How large should the xticks be? """, __default_value__="small") ytick_labelsize = LambdaField( """ How should the yticks be? """, lambda self: self.xtick_labelsize) rc = Field( """ Dictionary of rc parameter mappings to override global values set above... """, __default_value__=NA) def rc_params(self): return { "font.size": self.font_size, "legend.fontsize": "xx-large", "axes.labelsize": self.axes_labelsize, "axes.titlesize": self.axes_titlesize, "xtick.labelsize": self.xtick_labelsize, "ytick.labelsize": self.ytick_labelsize } def _set_rc_params(self): """...""" seaborn.set_context(self.context, rc=self.rc_params()) def __init__(self, *args, **kwargs): """...""" super().__init__(*args, **kwargs) seaborn.set(context=self.context, style="darkgrid", palette="deep", font=self.font, font_scale=self.font_scale, color_codes=True, rc=None) self._set_rc_params() @abstractmethod def get_figure(self, *args, **kwargds): """Every `Plotter` must implement.""" raise NotImplementedError def get_figures(self, *args, **kwargs): """ Package figure into an OrderedDict. """ return OrderedDict([ (self.name, self.get_figure(*args, **kwargs))]) def __call__(self, *args, **kwargs): """ Make this class a callable, so that it can masquerade as a function! """ self._set_rc_params() return self.get_figures(*args, **kwargs)
class HeatMap(BasePlotter): """ Define the requirements and behavior of a heatmap. """ measurement_type = Field( """ Will this instance of `HeatMap` be called with a measurement type that will be a statistical summary or statistical samples? Heat map cannot be generated for samples. If measurement type is a samples, a summary will be made before plotting. """, __default_value__=terminology.measurement_type.summary) xvar = Field(""" Variable (column in the data-frame to plot) that should vary along the x-axis of the heat-map. """) yvar = Field(""" Variable (column in the data-frame to plot) that should vary along the y-axis of the heat-map. """) vvar = Field(""" Variable (column in the data-frame to plot) that provides the intensity value of the heat-map cells. """) phenomenon = LambdaField( """ Phenomenon whose measurement is plotted as a heatmap. """, lambda self: self.vvar) title = LambdaField(""" Title for this plots. """, lambda self: self.vvar) aspect_ratio_figure = Field(""" Aspect ratio width / height for the figure. """, __default_value__=1.) fill_missing_value = Field(""" Some values in the heatmap matrix (which should be a square matrix) may be missing. This field provides the default value to impute the missing. """, __default_value__=0.) get_dataframe = LambdaField( """ Call back to get a dataframe from the measurement. """, lambda self: self._get_dataframe_default) adjustments_plt = Field(""" A function that will make adjustments to the plot. """, __default_value__=lambda *args, **kwargs: None) @staticmethod def _flattened(variable): """ Flatten a possibly tuple variable. """ if isinstance(variable, str): return variable if isinstance(variable, tuple): return '_'.join(variable) raise TypeError(""" HeatMap dataframe variables should be: 1. either a string 2. or a tuple of strings Not: {} """.format(variable)) def _get_dataframe_default(self, data): """...""" if not isinstance(data, (pandas.Series, pandas.DataFrame)): assert isinstance(data, Mapping) and len(data) == 1,\ """ Cannot decide which one to plot among more than one dataset: \t{} """.format(list(data.keys())) dataframe = list(data.values())[0] if self.measurement_type != terminology.measurement_type.summary: dataframe = measurement.get_summary(dataframe) else: dataframe = data if (not isinstance(dataframe.columns, pandas.MultiIndex) and not isinstance(dataframe.index, pandas.MultiIndex)): return dataframe.reset_index() dataframe = dataframe.reset_index() return pandas\ .DataFrame( dataframe.values, columns=pandas.Index([ HeatMap._flattened(t) for t in dataframe.columns.values])) def get_figure(self, data, *args, caption="Caption not provided", **kwargs): """ Plot the figure. Arguments -------------- `data`: A single pandas dataframe with an index of length 2, and only 1 column (only the zeroth column will be plotted.) """ matrix =\ pandas.pivot(self.get_dataframe(data), values=self._flattened(self.vvar), index=self._flattened(self.xvar), columns=self._flattened(self.yvar))\ .fillna(self.fill_missing_value) with seaborn.plotting_context(self.context, font_scale=self.font_scale, rc=self.rc_params()): plt.figure() self.adjustments_plt() graphic = seaborn\ .heatmap( matrix, cbar=True, cmap="rainbow", xticklabels=True, yticklabels=True) plt.yticks(rotation=0) plt.ylabel(self.ylabel) plt.xticks(rotation=90) plt.xlabel(self.xlabel) plt.title(' '.join(w.capitalize() for w in self.title.split('_'))) return Figure(graphic, caption=caption) return None
class NetworkChartPlot(WithFields): """ A chart that will display network properties. """ chart_type = Field( """ A callable that takes a network's link-weight data, and a color map to generate a chart for the input network data. """, __default_value__=CircularNetworkChart) title = Field( """ Title to be displayed. """, __default_value__="Network Flows.") node_variable = Field( """ Name of the variable associated with the nodes. For example, a brain-circuit's connectome data may have an index with columns `pre_mtype` and `post_mtype`, in which case the `node_variable` must be set to 'mtype'. """, __examples__=["mtype"]) source_variable = LambdaField( """ Variable providing values of the beginning node of an edge. """, lambda self: "pre_{}".format(self.node_variable)) target_variable = LambdaField( """ Variable providing values of the ending node of an edge """, lambda self: "post_{}".format(self.node_variable)) node_type = Field( """ Data-type of the values of nodes provided in the input dataframe. """) phenomenon = Field( """ Variable providing the phenomenon whose value will be plotted as edge flows. """) color_map = Field( """ Color for geometries that will be draw in the chart. Colors will be accessed by geometry identifiers. """, __default_value__={}) def chart(self, dataframe): """...""" try: weights = dataframe[(self.phenomenon, "mean")].values except KeyError: weights = dataframe[self.phenomenon].values new_index_name = { self.source_variable: "begin_node", self.target_variable: "end_node" } link_weights = pd.Series( weights, index=dataframe.index.rename([ new_index_name[name] for name in dataframe.index.names]), name="weight" ) return self.chart_type( link_data=link_weights, color_map=self.color_map) def get_figure(self, dataframe, *args, caption="Caption not provided", **kwargs): """ Plot the data. Arguments ----------- dataframe :: <source_variable, source_variable, phenomenon> """ chart = self.chart(dataframe) return Figure( chart.draw(*args, **kwargs), caption=caption) def plot(self, dataframe, *args, **kwargs): """ Plot the dataframe """ return self\ .get_figure( dataframe, *args, **kwargs) def __call__(self, dataframe, *args, **kwargs): """ Make this class a callable, so that it can masquerade as a function! """ return self.plot(dataframe, *args, **kwargs)