def split_raw_data_by_survival(raw_data: RawData) -> Tuple[RawData, RawData]: '''Segment raw data into data for agents that die and that survive Args: raw_data: Raw simulation data Returns: Tuple of 2 raw data dictionaries. The first contains all agents that survive until division. The second contains all agents that die before dividing. ''' # Establish which agents die agents_die = set() for time_data in raw_data.values(): agents_data = get_in(time_data, PATH_TO_AGENTS) for agent, agent_data in agents_data.items(): dead = get_in(agent_data, PATH_TO_DEAD, False) if dead: agents_die.add(agent) # Split the data survive_data = RawData(dict()) for time in raw_data: agents_path = (time, ) + PATH_TO_AGENTS assoc_path(survive_data, agents_path, dict()) die_data = copy.deepcopy(survive_data) for time, time_data in raw_data.items(): agents_data = get_in(time_data, PATH_TO_AGENTS) for agent, agent_data in agents_data.items(): dest = die_data if agent in agents_die else survive_data agent_path = (time, ) + PATH_TO_AGENTS + (agent, ) assoc_path(dest, agent_path, agent_data) return survive_data, die_data
def raw_data_to_end_expression_table( raw_data: RawData, paths_dict: Dict[str, Path]) -> pd.DataFrame: '''Create a table of end expression levels from raw simulation data. Args: raw_data: Raw simulation data paths: Map from names to paths to protein counts. The names will be used as column headers in the returned table. Returns: Table with one column for each protein and one row for each agent. Each cell contains the protein concentration in that agent in the final simulation timepoint. ''' end_data = raw_data[max(raw_data.keys())] expression_data: Dict = {name: [] for name in paths_dict} expression_data[VOLUME_KEY] = [] agents_data = get_in(end_data, AGENTS_PATH) for agent_data in agents_data.values(): volume = get_in(agent_data, VOLUME_PATH, 0) expression_data[VOLUME_KEY].append(volume) for name, path in paths_dict.items(): count = get_in(agent_data, path, 0) concentration = count / volume if volume else 0 expression_data[name].append(concentration) return pd.DataFrame(expression_data)
def _calc_live_and_dead_finals( data: RawData, path_to_variable: Path, time_range: Tuple[float, float], agents: Iterable[str] = tuple(), ) -> Tuple[Dict[str, float], Dict[str, float]]: values: Dict[str, Dict[float, float]] = {} die = set() end_time = max(data.keys()) for time, time_data in data.items(): if (time < time_range[0] * end_time or time > time_range[1] * end_time): continue agents_data = get_in(time_data, PATH_TO_AGENTS) for agent, agent_data in agents_data.items(): if agents and agent not in agents: continue agent_values = values.setdefault(agent, {}) value = get_in(agent_data, path_to_variable) if get_in(agent_data, PATH_TO_DEAD, False): die.add(agent) agent_values[time] = value live_finals = {} dead_finals = {} for agent, agent_values in values.items(): if not agent_values: continue if agent in die: dead_finals[agent] = agent_values[max(agent_values.keys())] else: live_finals[agent] = agent_values[max(agent_values.keys())] return live_finals, dead_finals
def make_environment_section( data_and_configs: Sequence[DataTuple], _search_data: SearchData, ) -> dict: '''Plot field concentrations in cross-section of final enviro. Create Figure 3B. ''' t_final = max(data_and_configs[0][0].keys()) fields_ts: List[Dict[float, Dict[str, SerializedField]]] = [] section_times = [ float(time) for time in ENVIRONMENT_SECTION_TIMES] for i, (replicate, _) in enumerate(data_and_configs): fields_ts.append(dict()) for time in section_times: fields_ts[i][time] = { name: field for name, field in get_in( replicate[time], FIELDS_PATH).items() if name in ENVIRONMENT_SECTION_FIELDS } bounds = get_in(data_and_configs[0][0][t_final], BOUNDS_PATH) fig, stats = get_enviro_sections_plot( fields_ts, bounds, section_location=0.5, fontsize=18) fig.savefig( os.path.join(FIG_OUT_DIR, 'enviro_section.{}'.format( FILE_EXTENSION))) return stats
def _get_final_live_agents( data: RawData, time_range: Tuple[float, float] = (0, 1), ) -> List[str]: data = filter_raw_data_by_time(data, time_range) max_time = max(data.keys()) agents = [] agents_data = get_in( # Pylint doesn't recognize that the RawData NewType is a dict data[max_time], # pylint: disable=unsubscriptable-object PATH_TO_AGENTS, ) for agent, agent_data in agents_data.items(): dead = get_in(agent_data, PATH_TO_DEAD) if not dead: agents.append(agent) return agents
def next_update(self, timestep, states): agents = states['agents'] points = [agent['boundary']['location'] for agent in agents.values()] points = [tuple(point) for point in points if np.all(~np.isnan(point))] alpha_shape = alphashape.alphashape(set(points), self.parameters['alpha']) if isinstance(alpha_shape, Polygon): shapes = [alpha_shape] elif isinstance(alpha_shape, (Point, LineString)): # We need at least 3 cells to form a colony polygon shapes = [] else: assert isinstance(alpha_shape, (MultiPolygon, GeometryCollection)) shapes = list(alpha_shape) agent_colony_map = gen_agent_colony_map(agents, shapes) # Calculate colony surface areas areas = [shape.area for shape in shapes] # Calculate colony major and minor axes based on bounding # rectangles major_axes = [] minor_axes = [] for shape in shapes: if isinstance(shape, Polygon): major, minor = major_minor_axes(shape) major_axes.append(major) minor_axes.append(minor) else: major_axes.append(0) minor_axes.append(0) # Calculate colony circumference circumference = [shape.length for shape in shapes] # Calculate colony masses and cell surface areas mass = [0] * len(shapes) cell_area = [0] * len(shapes) for agent_id, agent_state in agents.items(): if agent_id not in agent_colony_map: # We ignore agents not in any colony continue colony_index = agent_colony_map[agent_id] agent_mass = get_in(agent_state, ('boundary', 'mass'), 0) mass[colony_index] += agent_mass return { 'colony_global': { 'surface_area': areas, 'major_axis': major_axes, 'minor_axis': minor_axes, 'circumference': circumference, 'mass': mass, } }
def get_total_mass_timeseries(data: RawData) -> List[float]: '''Get a timeseries of the total mass of a simulation. Args: data: Data from the simulation. Returns: A list of the total cell mass in the simulation over time. ''' times = sorted(data.keys()) mass_timeseries = [] for time in times: agents_data = get_in(data[time], AGENTS_PATH) mass_timeseries.append(get_total_mass(agents_data)) return mass_timeseries
def get_total_mass(agents_data: Dict) -> float: '''Get the total mass of a set of agents. Args: agents_data: The simulation data of the store containing all the agents. Returns: The total mass of the agents. ''' total_mass = 0. for agent_data in agents_data.values(): mass = get_in(agent_data, MASS_PATH) total_mass += mass return total_mass
def plot_expression_survival(data, path_to_variable, xlabel, time_range=(0, 1)): '''Create Expression Dotplot Colored by Survival Plot one dot for each cell along an axis to indicate that cell's average expression level for a specified protein. The dot color reflects whether the cell survived long enough to divide. Note that only the expression levels while the cell is alive are considered in the average. Parameters: data (dict): The raw data emitted from the simulation. path_to_variable (tuple): Path from the agent root to the variable that holds the protein's expression level. We do not adjust for cell volume, so this should be a concentration. xlabel (str): Label for x-axis. time_range (tuple): Tuple of two :py:class:`float`s that are fractions of the total simulated time period. These fractions indicate the start and end points (inclusive) of the time range to consider when calculating average expression level. Returns: plt.Figure: The finished figure. ''' expression_levels = dict() die = set() end_time = max(data.keys()) for time, time_data in data.items(): if (time < time_range[0] * end_time or time > time_range[1] * end_time): continue agents_data = get_in(time_data, PATH_TO_AGENTS) for agent, agent_data in agents_data.items(): lst = expression_levels.setdefault(agent, []) value = get_in(agent_data, path_to_variable) if value is not None: lst.append(value) if get_in(agent_data, PATH_TO_DEAD, False): die.add(agent) # Only count values when cell is alive elif value is not None: lst.append(value) live_averages = [] dead_averages = [] for agent, levels in expression_levels.items(): if not levels: continue if agent in die: dead_averages.append(np.mean(levels)) else: live_averages.append(np.mean(levels)) fig, ax = plt.subplots(figsize=(6, 2)) ax.scatter( live_averages, [0.1] * len(live_averages), label='Survive', color=LIVE_COLOR, alpha=ALPHA, ) ax.scatter( dead_averages, [0.1] * len(dead_averages), label='Die', color=DEAD_COLOR, alpha=ALPHA, ) ax.legend() ax.set_xlabel(xlabel) ax.set_ylim([0, 1.25]) ax.get_yaxis().set_visible(False) for spine_name in ('left', 'top', 'right'): ax.spines[spine_name].set_visible(False) ax.spines['bottom'].set_position('zero') fig.tight_layout() return fig
def plot_phylogeny( data: RawData, out: str = 'phylogeny.pdf', live_color: str = 'green', dead_color: str = 'black', ignore_color: str = 'lightgray', time_range: Tuple[float, float] = (0, 1) ) -> Tuple[TreeNode, pd.DataFrame]: '''Plot phylogenetic tree from an experiment. Args: data: The simulation data. out: Path to the output file. File type will be inferred from the file name. live_color: Color for nodes representing cells that survive until division. dead_color: Color for nodes representing cells that die. ignore_color: Color for nodes outside the time range considered. time_range: Tuple specifying the range of times to consider. Range values specified as fractions of the final timepointpoint. ''' agent_ids: Set[str] = set() dead_ids: Set[str] = set() in_time_range_ids: Set[str] = set() end_time = max(data.keys()) for time, time_data in data.items(): agents_data = get_in(time_data, AGENTS_PATH) assert agents_data is not None agent_ids |= set(agents_data.keys()) if time_range[0] * end_time <= time <= time_range[1] * end_time: in_time_range_ids |= set(agents_data.keys()) for agent_id, agent_data in agents_data.items(): if get_in(agent_data, PATH_TO_DEAD, False): dead_ids.add(agent_id) trees = make_ete_trees(agent_ids) assert len(trees) == 1 tree = trees[0] # Set style for overall figure tstyle = TreeStyle() tstyle.show_scale = False tstyle.show_leaf_name = False tstyle.scale = None tstyle.optimal_scale_level = 'full' # Avoid artificial branches tstyle.mode = 'c' legend = { 'Die': dead_color, 'Survive': live_color, 'Divided Before Antibiotics Appeared': ignore_color, } for label, color in legend.items(): tstyle.legend.add_face(CircleFace(5, color), column=0) tstyle.legend.add_face(TextFace(' ' + label, ftype=FONT), column=1) # Set styles for each node for node in tree.traverse(): nstyle = NodeStyle() nstyle['size'] = 5 nstyle['vt_line_width'] = 1 nstyle['hz_line_width'] = 1 if node.name in in_time_range_ids: if node.name in dead_ids: nstyle['fgcolor'] = dead_color else: nstyle['fgcolor'] = live_color else: nstyle['fgcolor'] = ignore_color node.set_style(nstyle) tree.render(out, tree_style=tstyle, w=400) survive_col = [] agents_col = [] for agent in in_time_range_ids: agents_col.append(agent) survive_col.append(0 if agent in dead_ids else 1) df = pd.DataFrame({'agents': agents_col, 'survival': survive_col}) return tree, df
def plot_tags(data, plot_config): '''Plot snapshots of the simulation over time The snapshots depict the agents and the levels of tagged molecules in each agent by agent color intensity. Arguments: data (dict): A dictionary with the following keys: * **agents** (:py:class:`dict`): A mapping from times to dictionaries of agent data at that timepoint. Agent data dictionaries should have the same form as the hierarchy tree rooted at ``agents``. * **config** (:py:class:`dict`): The environmental configuration dictionary with the following keys: * **bounds** (:py:class:`tuple`): The dimensions of the environment. plot_config (dict): Accepts the following configuration options. Any options with a default is optional. * **n_snapshots** (:py:class:`int`): Number of snapshots to show per row (i.e. for each molecule). Defaults to 6. * **out_dir** (:py:class:`str`): Output directory, which is ``out`` by default. * **filename** (:py:class:`str`): Base name of output file. ``tags`` by default. * **tagged_molecules** (:py:class:`typing.Iterable`): The tagged molecules whose concentrations will be indicated by agent color. Each molecule should be specified as a :py:class:`tuple` of the path in the agent compartment to where the molecule's count can be found, with the last value being the molecule's count variable. * **convert_to_concs** (:py:class:`bool`): if True, convert counts to concentrations. * **background_color** (:py:class:`str`): use matplotlib colors, ``black`` by default * **tag_label_size** (:py:class:`float`): The font size for the tag name label * **default_font_size** (:py:class:`float`): Font size for titles and axis labels. * **tag_colors** (:py:class:`dict`): Map from tag ID in tagged_molecules to a tuple (min_color, max_color). * **scale_bar_length** (:py:class:`float`): Length of scale bar. Defaults to 1 (in units of micrometers). If 0, no bar plotted. * **scale_bar_color** (:py:class:`str`): Color of scale bar * **xlim** (:py:class:`tuple` of :py:class:`float`): Tuple of lower and upper x-axis limits. * **ylim** (:py:class:`tuple` of :py:class:`float`): Tuple of lower and upper y-axis limits. ''' check_plt_backend() n_snapshots = plot_config.get('n_snapshots', 6) out_dir = plot_config.get('out_dir', 'out') filename = plot_config.get('filename', 'tags') agent_shape = plot_config.get('agent_shape', 'segment') background_color = plot_config.get('background_color', 'black') tagged_molecules = plot_config['tagged_molecules'] tag_path_name_map = plot_config.get('tag_path_name_map', {}) tag_label_size = plot_config.get('tag_label_size', 20) default_font_size = plot_config.get('default_font_size', 36) convert_to_concs = plot_config.get('convert_to_concs', True) tag_colors = plot_config.get('tag_colors', {}) scale_bar_length = plot_config.get('scale_bar_length', 1) scale_bar_color = plot_config.get('scale_bar_color', 'white') xlim = plot_config.get('xlim') ylim = plot_config.get('ylim') if tagged_molecules == []: raise ValueError('At least one molecule must be tagged.') # get data agents = data['agents'] config = data.get('config', {}) bounds = config['bounds'] edge_length_x, edge_length_y = bounds # time steps that will be used time_vec = list(agents.keys()) time_indices = np.round( np.linspace(0, len(time_vec) - 1, n_snapshots) ).astype(int) snapshot_times = [time_vec[i] for i in time_indices] # get tag ids and range tag_ranges = {} for time, time_data in agents.items(): for agent_id, agent_data in time_data.items(): volume = agent_data.get('boundary', {}).get('volume', 0) for tag_id in tagged_molecules: level = get_value_from_path(agent_data, tag_id) if convert_to_concs: level = level / volume if volume else 0 if tag_id in tag_ranges: tag_ranges[tag_id] = [ min(tag_ranges[tag_id][0], level), max(tag_ranges[tag_id][1], level)] else: # add new tag tag_ranges[tag_id] = [level, level] # make the figure n_rows = len(tagged_molecules) n_cols = n_snapshots + 1 # one column for the colorbar figsize = (12 * n_cols, 12 * n_rows) max_dpi = min([2**16 // dim for dim in figsize]) - 1 fig = plt.figure(figsize=figsize, dpi=min(max_dpi, 100)) grid = plt.GridSpec(n_rows, n_cols, wspace=0.2, hspace=0.2) original_fontsize = plt.rcParams['font.size'] plt.rcParams.update({'font.size': default_font_size}) # Add time axis across subplots super_spec = matplotlib.gridspec.SubplotSpec( grid, (n_rows - 1) * n_cols, (n_rows - 1) * n_cols + n_snapshots - 1, ) grid_params = grid.get_subplot_params() snapshot_times_hrs = tuple(time / 60 / 60 for time in snapshot_times) if n_snapshots > 1: time_per_snapshot = ( snapshot_times_hrs[-1] - snapshot_times_hrs[0]) / ( (n_snapshots - 1) * (grid_params.wspace + 1)) else: time_per_snapshot = 1 # Arbitrary super_ax = fig.add_subplot( # type: ignore super_spec, xticks=snapshot_times_hrs, xlim=( snapshot_times_hrs[0] - time_per_snapshot / 2, snapshot_times_hrs[-1] + time_per_snapshot / 2, ), yticks=[], ) super_ax.set_xlabel( # type: ignore 'Time (hr)', labelpad=50) super_ax.xaxis.set_tick_params(width=2, length=8) for spine_name in ('top', 'right', 'left'): super_ax.spines[spine_name].set_visible(False) super_ax.spines['bottom'].set_linewidth(2) # plot tags for row_idx, tag_id in enumerate(tag_ranges.keys()): tag_name = tag_path_name_map.get(tag_id, tag_id) min_tag, max_tag = tag_ranges[tag_id] min_color, max_color = tag_colors[tag_id] min_rgb = matplotlib.colors.to_rgb(min_color) max_rgb = matplotlib.colors.to_rgb(max_color) colors_dict = { 'red': [ [0, min_rgb[0], min_rgb[0]], [1, max_rgb[0], max_rgb[0]], ], 'green': [ [0, min_rgb[1], min_rgb[1]], [1, max_rgb[1], max_rgb[1]], ], 'blue': [ [0, min_rgb[2], min_rgb[2]], [1, max_rgb[2], max_rgb[2]], ], } cmap = matplotlib.colors.LinearSegmentedColormap( tag_id, segmentdata=colors_dict, N=512) norm = matplotlib.colors.Normalize(min_tag, max_tag) for col_idx, (_, time) in enumerate( zip(time_indices, snapshot_times) ): ax = init_axes( fig, edge_length_x, edge_length_y, grid, row_idx, col_idx, time, tag_name, tag_label_size, ) ax.tick_params( axis='both', which='both', bottom=False, top=False, left=False, right=False, ) ax.set_facecolor(background_color) agent_tag_colors = {} for agent_id, agent_data in agents[time].items(): # get current tag concentration, and determine color level = get_value_from_path(agent_data, tag_id) if convert_to_concs: volume = get_in( agent_data, ('boundary', 'volume'), 0) level = level / volume if volume else 0 intensity = norm(level) agent_rgb = cmap(intensity)[:3] agent_hsv = matplotlib.colors.rgb_to_hsv( agent_rgb) agent_tag_colors[agent_id] = agent_hsv plot_agents(ax, agents[time], agent_tag_colors, agent_shape) if xlim: ax.set_xlim(*xlim) if ylim: ax.set_ylim(*ylim) # colorbar in new column after final snapshot if col_idx == n_snapshots - 1: cbar_col = col_idx + 1 ax = fig.add_subplot(grid[row_idx, cbar_col]) if row_idx == 0: if convert_to_concs: ax.set_title('Concentration (counts/fL)', y=1.08) ax.axis('off') if min_tag == max_tag: continue divider = make_axes_locatable(ax) cax = divider.append_axes("left", size="5%", pad=0.0) mappable = matplotlib.cm.ScalarMappable(norm, cmap) fig.colorbar(mappable, cax=cax, format='%.0f') # Scale bar in first snapshot of each row if col_idx == 0 and scale_bar_length: scale_bar = anchored_artists.AnchoredSizeBar( ax.transData, scale_bar_length, f'${scale_bar_length} \\mu m$', 'lower left', color=scale_bar_color, frameon=False, sep = scale_bar_length, size_vertical = scale_bar_length / 20, ) ax.add_artist(scale_bar) fig_path = os.path.join(out_dir, filename) fig.subplots_adjust(wspace=0.7, hspace=0.1) fig.savefig(fig_path, bbox_inches='tight') plt.close(fig) plt.rcParams.update({'font.size': original_fontsize})