示例#1
0
def split_raw_data_by_survival(raw_data: RawData) -> Tuple[RawData, RawData]:
    '''Segment raw data into data for agents that die and that survive

    Args:
        raw_data: Raw simulation data
    Returns:
        Tuple of 2 raw data dictionaries. The first contains all agents
        that survive until division. The second contains all agents that
        die before dividing.
    '''
    # Establish which agents die
    agents_die = set()
    for time_data in raw_data.values():
        agents_data = get_in(time_data, PATH_TO_AGENTS)
        for agent, agent_data in agents_data.items():
            dead = get_in(agent_data, PATH_TO_DEAD, False)
            if dead:
                agents_die.add(agent)

    # Split the data
    survive_data = RawData(dict())
    for time in raw_data:
        agents_path = (time, ) + PATH_TO_AGENTS
        assoc_path(survive_data, agents_path, dict())
    die_data = copy.deepcopy(survive_data)

    for time, time_data in raw_data.items():
        agents_data = get_in(time_data, PATH_TO_AGENTS)
        for agent, agent_data in agents_data.items():
            dest = die_data if agent in agents_die else survive_data
            agent_path = (time, ) + PATH_TO_AGENTS + (agent, )
            assoc_path(dest, agent_path, agent_data)

    return survive_data, die_data
示例#2
0
def raw_data_to_end_expression_table(
        raw_data: RawData, paths_dict: Dict[str, Path]) -> pd.DataFrame:
    '''Create a table of end expression levels from raw simulation data.

    Args:
        raw_data: Raw simulation data
        paths: Map from names to paths to protein counts. The names will
            be used as column headers in the returned table.

    Returns:
        Table with one column for each protein and one row for each
        agent. Each cell contains the protein concentration in that
        agent in the final simulation timepoint.
    '''
    end_data = raw_data[max(raw_data.keys())]
    expression_data: Dict = {name: [] for name in paths_dict}
    expression_data[VOLUME_KEY] = []
    agents_data = get_in(end_data, AGENTS_PATH)
    for agent_data in agents_data.values():
        volume = get_in(agent_data, VOLUME_PATH, 0)
        expression_data[VOLUME_KEY].append(volume)
        for name, path in paths_dict.items():
            count = get_in(agent_data, path, 0)
            concentration = count / volume if volume else 0
            expression_data[name].append(concentration)
    return pd.DataFrame(expression_data)
示例#3
0
def _calc_live_and_dead_finals(
    data: RawData,
    path_to_variable: Path,
    time_range: Tuple[float, float],
    agents: Iterable[str] = tuple(),
) -> Tuple[Dict[str, float], Dict[str, float]]:
    values: Dict[str, Dict[float, float]] = {}
    die = set()

    end_time = max(data.keys())
    for time, time_data in data.items():
        if (time < time_range[0] * end_time
                or time > time_range[1] * end_time):
            continue
        agents_data = get_in(time_data, PATH_TO_AGENTS)
        for agent, agent_data in agents_data.items():
            if agents and agent not in agents:
                continue
            agent_values = values.setdefault(agent, {})
            value = get_in(agent_data, path_to_variable)
            if get_in(agent_data, PATH_TO_DEAD, False):
                die.add(agent)
            agent_values[time] = value

    live_finals = {}
    dead_finals = {}
    for agent, agent_values in values.items():
        if not agent_values:
            continue
        if agent in die:
            dead_finals[agent] = agent_values[max(agent_values.keys())]
        else:
            live_finals[agent] = agent_values[max(agent_values.keys())]
    return live_finals, dead_finals
def make_environment_section(
        data_and_configs: Sequence[DataTuple],
        _search_data: SearchData,
        ) -> dict:
    '''Plot field concentrations in cross-section of final enviro.

    Create Figure 3B.
    '''
    t_final = max(data_and_configs[0][0].keys())
    fields_ts: List[Dict[float, Dict[str, SerializedField]]] = []
    section_times = [
        float(time) for time in ENVIRONMENT_SECTION_TIMES]
    for i, (replicate, _) in enumerate(data_and_configs):
        fields_ts.append(dict())
        for time in section_times:
            fields_ts[i][time] = {
                name: field
                for name, field in get_in(
                    replicate[time], FIELDS_PATH).items()
                if name in ENVIRONMENT_SECTION_FIELDS
            }
    bounds = get_in(data_and_configs[0][0][t_final], BOUNDS_PATH)
    fig, stats = get_enviro_sections_plot(
        fields_ts, bounds, section_location=0.5, fontsize=18)
    fig.savefig(
        os.path.join(FIG_OUT_DIR, 'enviro_section.{}'.format(
            FILE_EXTENSION)))
    return stats
示例#5
0
def _get_final_live_agents(
        data: RawData,
        time_range: Tuple[float, float] = (0, 1),
) -> List[str]:
    data = filter_raw_data_by_time(data, time_range)
    max_time = max(data.keys())
    agents = []
    agents_data = get_in(
        # Pylint doesn't recognize that the RawData NewType is a dict
        data[max_time],  # pylint: disable=unsubscriptable-object
        PATH_TO_AGENTS,
    )
    for agent, agent_data in agents_data.items():
        dead = get_in(agent_data, PATH_TO_DEAD)
        if not dead:
            agents.append(agent)
    return agents
示例#6
0
    def next_update(self, timestep, states):
        agents = states['agents']
        points = [agent['boundary']['location'] for agent in agents.values()]
        points = [tuple(point) for point in points if np.all(~np.isnan(point))]
        alpha_shape = alphashape.alphashape(set(points),
                                            self.parameters['alpha'])
        if isinstance(alpha_shape, Polygon):
            shapes = [alpha_shape]
        elif isinstance(alpha_shape, (Point, LineString)):
            # We need at least 3 cells to form a colony polygon
            shapes = []
        else:
            assert isinstance(alpha_shape, (MultiPolygon, GeometryCollection))
            shapes = list(alpha_shape)

        agent_colony_map = gen_agent_colony_map(agents, shapes)

        # Calculate colony surface areas
        areas = [shape.area for shape in shapes]

        # Calculate colony major and minor axes based on bounding
        # rectangles
        major_axes = []
        minor_axes = []
        for shape in shapes:
            if isinstance(shape, Polygon):
                major, minor = major_minor_axes(shape)
                major_axes.append(major)
                minor_axes.append(minor)
            else:
                major_axes.append(0)
                minor_axes.append(0)

        # Calculate colony circumference
        circumference = [shape.length for shape in shapes]

        # Calculate colony masses and cell surface areas
        mass = [0] * len(shapes)
        cell_area = [0] * len(shapes)
        for agent_id, agent_state in agents.items():
            if agent_id not in agent_colony_map:
                # We ignore agents not in any colony
                continue
            colony_index = agent_colony_map[agent_id]
            agent_mass = get_in(agent_state, ('boundary', 'mass'), 0)
            mass[colony_index] += agent_mass

        return {
            'colony_global': {
                'surface_area': areas,
                'major_axis': major_axes,
                'minor_axis': minor_axes,
                'circumference': circumference,
                'mass': mass,
            }
        }
def get_total_mass_timeseries(data: RawData) -> List[float]:
    '''Get a timeseries of the total mass of a simulation.

    Args:
        data: Data from the simulation.

    Returns:
        A list of the total cell mass in the simulation over time.
    '''
    times = sorted(data.keys())
    mass_timeseries = []
    for time in times:
        agents_data = get_in(data[time], AGENTS_PATH)
        mass_timeseries.append(get_total_mass(agents_data))
    return mass_timeseries
def get_total_mass(agents_data: Dict) -> float:
    '''Get the total mass of a set of agents.

    Args:
        agents_data: The simulation data of the store containing all the
            agents.

    Returns:
        The total mass of the agents.
    '''
    total_mass = 0.
    for agent_data in agents_data.values():
        mass = get_in(agent_data, MASS_PATH)
        total_mass += mass
    return total_mass
def plot_expression_survival(data, path_to_variable, xlabel,
                             time_range=(0, 1)):
    '''Create Expression Dotplot Colored by Survival

    Plot one dot for each cell along an axis to indicate that cell's
    average expression level for a specified protein. The dot color
    reflects whether the cell survived long enough to divide.

    Note that only the expression levels while the cell is alive are
    considered in the average.

    Parameters:
        data (dict): The raw data emitted from the simulation.
        path_to_variable (tuple): Path from the agent root to the
            variable that holds the protein's expression level. We do
            not adjust for cell volume, so this should be a
            concentration.
        xlabel (str): Label for x-axis.
        time_range (tuple): Tuple of two :py:class:`float`s that are
            fractions of the total simulated time period. These
            fractions indicate the start and end points (inclusive) of
            the time range to consider when calculating average
            expression level.

    Returns:
        plt.Figure: The finished figure.
    '''
    expression_levels = dict()
    die = set()
    end_time = max(data.keys())
    for time, time_data in data.items():
        if (time < time_range[0] * end_time
                or time > time_range[1] * end_time):
            continue
        agents_data = get_in(time_data, PATH_TO_AGENTS)
        for agent, agent_data in agents_data.items():
            lst = expression_levels.setdefault(agent, [])
            value = get_in(agent_data, path_to_variable)
            if value is not None:
                lst.append(value)
            if get_in(agent_data, PATH_TO_DEAD, False):
                die.add(agent)
            # Only count values when cell is alive
            elif value is not None:
                lst.append(value)

    live_averages = []
    dead_averages = []
    for agent, levels in expression_levels.items():
        if not levels:
            continue
        if agent in die:
            dead_averages.append(np.mean(levels))
        else:
            live_averages.append(np.mean(levels))

    fig, ax = plt.subplots(figsize=(6, 2))
    ax.scatter(
        live_averages,
        [0.1] * len(live_averages),
        label='Survive',
        color=LIVE_COLOR,
        alpha=ALPHA,
    )
    ax.scatter(
        dead_averages,
        [0.1] * len(dead_averages),
        label='Die',
        color=DEAD_COLOR,
        alpha=ALPHA,
    )
    ax.legend()
    ax.set_xlabel(xlabel)
    ax.set_ylim([0, 1.25])
    ax.get_yaxis().set_visible(False)
    for spine_name in ('left', 'top', 'right'):
        ax.spines[spine_name].set_visible(False)
    ax.spines['bottom'].set_position('zero')
    fig.tight_layout()
    return fig
def plot_phylogeny(
    data: RawData,
    out: str = 'phylogeny.pdf',
    live_color: str = 'green',
    dead_color: str = 'black',
    ignore_color: str = 'lightgray',
    time_range: Tuple[float, float] = (0, 1)
) -> Tuple[TreeNode, pd.DataFrame]:
    '''Plot phylogenetic tree from an experiment.

    Args:
        data: The simulation data.
        out: Path to the output file. File type will be inferred from
            the file name.
        live_color: Color for nodes representing cells that survive
            until division.
        dead_color: Color for nodes representing cells that die.
        ignore_color: Color for nodes outside the time range considered.
        time_range: Tuple specifying the range of times to consider.
            Range values specified as fractions of the final
            timepointpoint.
    '''
    agent_ids: Set[str] = set()
    dead_ids: Set[str] = set()
    in_time_range_ids: Set[str] = set()
    end_time = max(data.keys())
    for time, time_data in data.items():
        agents_data = get_in(time_data, AGENTS_PATH)
        assert agents_data is not None
        agent_ids |= set(agents_data.keys())

        if time_range[0] * end_time <= time <= time_range[1] * end_time:
            in_time_range_ids |= set(agents_data.keys())
            for agent_id, agent_data in agents_data.items():
                if get_in(agent_data, PATH_TO_DEAD, False):
                    dead_ids.add(agent_id)

    trees = make_ete_trees(agent_ids)
    assert len(trees) == 1
    tree = trees[0]

    # Set style for overall figure
    tstyle = TreeStyle()
    tstyle.show_scale = False
    tstyle.show_leaf_name = False
    tstyle.scale = None
    tstyle.optimal_scale_level = 'full'  # Avoid artificial branches
    tstyle.mode = 'c'
    legend = {
        'Die': dead_color,
        'Survive': live_color,
        'Divided Before Antibiotics Appeared': ignore_color,
    }
    for label, color in legend.items():
        tstyle.legend.add_face(CircleFace(5, color), column=0)
        tstyle.legend.add_face(TextFace(' ' + label, ftype=FONT), column=1)

    # Set styles for each node
    for node in tree.traverse():
        nstyle = NodeStyle()
        nstyle['size'] = 5
        nstyle['vt_line_width'] = 1
        nstyle['hz_line_width'] = 1
        if node.name in in_time_range_ids:
            if node.name in dead_ids:
                nstyle['fgcolor'] = dead_color
            else:
                nstyle['fgcolor'] = live_color
        else:
            nstyle['fgcolor'] = ignore_color
        node.set_style(nstyle)
    tree.render(out, tree_style=tstyle, w=400)
    survive_col = []
    agents_col = []
    for agent in in_time_range_ids:
        agents_col.append(agent)
        survive_col.append(0 if agent in dead_ids else 1)
    df = pd.DataFrame({'agents': agents_col, 'survival': survive_col})
    return tree, df
def plot_tags(data, plot_config):
    '''Plot snapshots of the simulation over time

    The snapshots depict the agents and the levels of tagged molecules
    in each agent by agent color intensity.

    Arguments:
        data (dict): A dictionary with the following keys:

            * **agents** (:py:class:`dict`): A mapping from times to
              dictionaries of agent data at that timepoint. Agent data
              dictionaries should have the same form as the hierarchy
              tree rooted at ``agents``.
            * **config** (:py:class:`dict`): The environmental
              configuration dictionary  with the following keys:

                * **bounds** (:py:class:`tuple`): The dimensions of the
                  environment.

        plot_config (dict): Accepts the following configuration options.
            Any options with a default is optional.

            * **n_snapshots** (:py:class:`int`): Number of snapshots to
              show per row (i.e. for each molecule). Defaults to 6.
            * **out_dir** (:py:class:`str`): Output directory, which is
              ``out`` by default.
            * **filename** (:py:class:`str`): Base name of output file.
              ``tags`` by default.
            * **tagged_molecules** (:py:class:`typing.Iterable`): The
              tagged molecules whose concentrations will be indicated by
              agent color. Each molecule should be specified as a
              :py:class:`tuple` of the path in the agent compartment
              to where the molecule's count can be found, with the last
              value being the molecule's count variable.
            * **convert_to_concs** (:py:class:`bool`): if True, convert counts
              to concentrations.
            * **background_color** (:py:class:`str`): use matplotlib colors,
              ``black`` by default
            * **tag_label_size** (:py:class:`float`): The font size for
              the tag name label
            * **default_font_size** (:py:class:`float`): Font size for
              titles and axis labels.
            * **tag_colors** (:py:class:`dict`): Map from tag ID in
              tagged_molecules to a tuple (min_color, max_color).
            * **scale_bar_length** (:py:class:`float`): Length of scale
              bar.  Defaults to 1 (in units of micrometers). If 0, no
              bar plotted.
            * **scale_bar_color** (:py:class:`str`): Color of scale bar
            * **xlim** (:py:class:`tuple` of :py:class:`float`): Tuple
              of lower and upper x-axis limits.
            * **ylim** (:py:class:`tuple` of :py:class:`float`): Tuple
              of lower and upper y-axis limits.
    '''
    check_plt_backend()

    n_snapshots = plot_config.get('n_snapshots', 6)
    out_dir = plot_config.get('out_dir', 'out')
    filename = plot_config.get('filename', 'tags')
    agent_shape = plot_config.get('agent_shape', 'segment')
    background_color = plot_config.get('background_color', 'black')
    tagged_molecules = plot_config['tagged_molecules']
    tag_path_name_map = plot_config.get('tag_path_name_map', {})
    tag_label_size = plot_config.get('tag_label_size', 20)
    default_font_size = plot_config.get('default_font_size', 36)
    convert_to_concs = plot_config.get('convert_to_concs', True)
    tag_colors = plot_config.get('tag_colors', {})
    scale_bar_length = plot_config.get('scale_bar_length', 1)
    scale_bar_color = plot_config.get('scale_bar_color', 'white')
    xlim = plot_config.get('xlim')
    ylim = plot_config.get('ylim')

    if tagged_molecules == []:
        raise ValueError('At least one molecule must be tagged.')

    # get data
    agents = data['agents']
    config = data.get('config', {})
    bounds = config['bounds']
    edge_length_x, edge_length_y = bounds

    # time steps that will be used
    time_vec = list(agents.keys())
    time_indices = np.round(
        np.linspace(0, len(time_vec) - 1, n_snapshots)
    ).astype(int)
    snapshot_times = [time_vec[i] for i in time_indices]

    # get tag ids and range
    tag_ranges = {}

    for time, time_data in agents.items():
        for agent_id, agent_data in time_data.items():
            volume = agent_data.get('boundary', {}).get('volume', 0)
            for tag_id in tagged_molecules:
                level = get_value_from_path(agent_data, tag_id)
                if convert_to_concs:
                    level = level / volume if volume else 0
                if tag_id in tag_ranges:
                    tag_ranges[tag_id] = [
                        min(tag_ranges[tag_id][0], level),
                        max(tag_ranges[tag_id][1], level)]
                else:
                    # add new tag
                    tag_ranges[tag_id] = [level, level]

    # make the figure
    n_rows = len(tagged_molecules)
    n_cols = n_snapshots + 1  # one column for the colorbar
    figsize = (12 * n_cols, 12 * n_rows)
    max_dpi = min([2**16 // dim for dim in figsize]) - 1
    fig = plt.figure(figsize=figsize, dpi=min(max_dpi, 100))
    grid = plt.GridSpec(n_rows, n_cols, wspace=0.2, hspace=0.2)
    original_fontsize = plt.rcParams['font.size']
    plt.rcParams.update({'font.size': default_font_size})

    # Add time axis across subplots
    super_spec = matplotlib.gridspec.SubplotSpec(
        grid,
        (n_rows - 1) * n_cols,
        (n_rows - 1) * n_cols + n_snapshots - 1,
    )
    grid_params = grid.get_subplot_params()
    snapshot_times_hrs = tuple(time / 60 / 60 for time in snapshot_times)
    if n_snapshots > 1:
        time_per_snapshot = (
            snapshot_times_hrs[-1] - snapshot_times_hrs[0]) / (
            (n_snapshots - 1) * (grid_params.wspace + 1))
    else:
        time_per_snapshot = 1  # Arbitrary
    super_ax = fig.add_subplot(  # type: ignore
        super_spec,
        xticks=snapshot_times_hrs,
        xlim=(
            snapshot_times_hrs[0] - time_per_snapshot / 2,
            snapshot_times_hrs[-1] + time_per_snapshot / 2,
        ),
        yticks=[],
    )
    super_ax.set_xlabel(  # type: ignore
        'Time (hr)', labelpad=50)
    super_ax.xaxis.set_tick_params(width=2, length=8)
    for spine_name in ('top', 'right', 'left'):
        super_ax.spines[spine_name].set_visible(False)
    super_ax.spines['bottom'].set_linewidth(2)

    # plot tags
    for row_idx, tag_id in enumerate(tag_ranges.keys()):
        tag_name = tag_path_name_map.get(tag_id, tag_id)
        min_tag, max_tag = tag_ranges[tag_id]
        min_color, max_color = tag_colors[tag_id]
        min_rgb = matplotlib.colors.to_rgb(min_color)
        max_rgb = matplotlib.colors.to_rgb(max_color)
        colors_dict = {
            'red': [
                [0, min_rgb[0], min_rgb[0]],
                [1, max_rgb[0], max_rgb[0]],
            ],
            'green': [
                [0, min_rgb[1], min_rgb[1]],
                [1, max_rgb[1], max_rgb[1]],
            ],
            'blue': [
                [0, min_rgb[2], min_rgb[2]],
                [1, max_rgb[2], max_rgb[2]],
            ],
        }
        cmap = matplotlib.colors.LinearSegmentedColormap(
            tag_id, segmentdata=colors_dict, N=512)

        norm = matplotlib.colors.Normalize(min_tag, max_tag)

        for col_idx, (_, time) in enumerate(
            zip(time_indices, snapshot_times)
        ):
            ax = init_axes(
                fig, edge_length_x, edge_length_y, grid,
                row_idx, col_idx, time, tag_name, tag_label_size,
            )
            ax.tick_params(
                axis='both', which='both', bottom=False, top=False,
                left=False, right=False,
            )
            ax.set_facecolor(background_color)

            agent_tag_colors = {}
            for agent_id, agent_data in agents[time].items():
                # get current tag concentration, and determine color
                level = get_value_from_path(agent_data, tag_id)
                if convert_to_concs:
                    volume = get_in(
                        agent_data, ('boundary', 'volume'), 0)
                    level = level / volume if volume else 0
                intensity = norm(level)
                agent_rgb = cmap(intensity)[:3]
                agent_hsv = matplotlib.colors.rgb_to_hsv(
                    agent_rgb)
                agent_tag_colors[agent_id] = agent_hsv

            plot_agents(ax, agents[time], agent_tag_colors, agent_shape)

            if xlim:
                ax.set_xlim(*xlim)
            if ylim:
                ax.set_ylim(*ylim)

            # colorbar in new column after final snapshot
            if col_idx == n_snapshots - 1:
                cbar_col = col_idx + 1
                ax = fig.add_subplot(grid[row_idx, cbar_col])
                if row_idx == 0:
                    if convert_to_concs:
                        ax.set_title('Concentration (counts/fL)', y=1.08)
                ax.axis('off')
                if min_tag == max_tag:
                    continue
                divider = make_axes_locatable(ax)
                cax = divider.append_axes("left", size="5%", pad=0.0)
                mappable = matplotlib.cm.ScalarMappable(norm, cmap)
                fig.colorbar(mappable, cax=cax, format='%.0f')

            # Scale bar in first snapshot of each row
            if col_idx == 0 and scale_bar_length:
                scale_bar = anchored_artists.AnchoredSizeBar(
                    ax.transData, scale_bar_length,
                    f'${scale_bar_length} \\mu m$', 'lower left',
                    color=scale_bar_color,
                    frameon=False,
                    sep = scale_bar_length,
                    size_vertical = scale_bar_length / 20,
                )
                ax.add_artist(scale_bar)

    fig_path = os.path.join(out_dir, filename)
    fig.subplots_adjust(wspace=0.7, hspace=0.1)
    fig.savefig(fig_path, bbox_inches='tight')
    plt.close(fig)
    plt.rcParams.update({'font.size': original_fontsize})