Ejemplo n.º 1
0
def split_raw_data_by_survival(raw_data: RawData) -> Tuple[RawData, RawData]:
    '''Segment raw data into data for agents that die and that survive

    Args:
        raw_data: Raw simulation data
    Returns:
        Tuple of 2 raw data dictionaries. The first contains all agents
        that survive until division. The second contains all agents that
        die before dividing.
    '''
    # Establish which agents die
    agents_die = set()
    for time_data in raw_data.values():
        agents_data = get_in(time_data, PATH_TO_AGENTS)
        for agent, agent_data in agents_data.items():
            dead = get_in(agent_data, PATH_TO_DEAD, False)
            if dead:
                agents_die.add(agent)

    # Split the data
    survive_data = RawData(dict())
    for time in raw_data:
        agents_path = (time, ) + PATH_TO_AGENTS
        assoc_path(survive_data, agents_path, dict())
    die_data = copy.deepcopy(survive_data)

    for time, time_data in raw_data.items():
        agents_data = get_in(time_data, PATH_TO_AGENTS)
        for agent, agent_data in agents_data.items():
            dest = die_data if agent in agents_die else survive_data
            agent_path = (time, ) + PATH_TO_AGENTS + (agent, )
            assoc_path(dest, agent_path, agent_data)

    return survive_data, die_data
Ejemplo n.º 2
0
def _calc_live_and_dead_finals(
    data: RawData,
    path_to_variable: Path,
    time_range: Tuple[float, float],
    agents: Iterable[str] = tuple(),
) -> Tuple[Dict[str, float], Dict[str, float]]:
    values: Dict[str, Dict[float, float]] = {}
    die = set()

    end_time = max(data.keys())
    for time, time_data in data.items():
        if (time < time_range[0] * end_time
                or time > time_range[1] * end_time):
            continue
        agents_data = get_in(time_data, PATH_TO_AGENTS)
        for agent, agent_data in agents_data.items():
            if agents and agent not in agents:
                continue
            agent_values = values.setdefault(agent, {})
            value = get_in(agent_data, path_to_variable)
            if get_in(agent_data, PATH_TO_DEAD, False):
                die.add(agent)
            agent_values[time] = value

    live_finals = {}
    dead_finals = {}
    for agent, agent_values in values.items():
        if not agent_values:
            continue
        if agent in die:
            dead_finals[agent] = agent_values[max(agent_values.keys())]
        else:
            live_finals[agent] = agent_values[max(agent_values.keys())]
    return live_finals, dead_finals
def make_snapshots_figure(
        data: RawData,
        environment_config: EnvironmentConfig,
        name: str,
        fields: Sequence[str],
        agent_fill_color: Optional[str] = None,
        agent_alpha: float = 1,
        num_snapshots: int = NUM_SNAPSHOTS,
        snapshot_times: Optional[Tuple[float, ...]] = None,
        xlim: Tuple[float, float] = (10, 40),
        ylim: Tuple[float, float] = (10, 40)
        ) -> dict:
    '''Make a figure of snapshots.

    Args:
        data: The experiment data.
        environment_config: Environment parameters.
        name: Name of the output file (excluding file extension).
        fields: List of the names of fields to include.
        agent_fill_color: Fill color for agents.
        agent_alpha: Transparency for agents.
        num_snapshots: Number of snapshots.
        snapshot_times: Times to take snapshots at. If
            None, they are evenly spaced.
        xlim: Limits of x-axis.
        ylim: Limits of y-axis.

    Returns:
        Statistics.
    '''
    snapshots_data = Analyzer.format_data_for_snapshots(
        data, environment_config)
    if not fields:
        data = RawData({
            key: val
            for key, val in data.items() if key != 'fields'
        })
    plot_config = {
        'out_dir': FIG_OUT_DIR,
        'filename': '{}.{}'.format(name, FILE_EXTENSION),
        'include_fields': fields,
        'field_label_size': 54,
        'default_font_size': 54,
        'agent_fill_color': agent_fill_color,
        'dead_color': (0, 0, 0.79),  # gray in HSV
        'agent_alpha': agent_alpha,
        'n_snapshots': num_snapshots,
        'snapshot_times': snapshot_times,
        'scale_bar_length': 10,
        'scale_bar_color': 'white' if fields else 'black',
        'xlim': xlim,
        'ylim': ylim,
        'min_color': '#FFFFFF',
        'max_color': '#000000',
        'grid_color': 'white' if fields else '',
    }
    stats = plot_snapshots(snapshots_data, plot_config)
    return stats
Ejemplo n.º 4
0
 def test_multiple_proteins(self) -> None:
     data = RawData({
         1: {
             'agents': {
                 'agent1':
                 self._make_agent_data(2, {
                     'protein1': 1,
                     'protein2': 2,
                     'protein3': 0
                 }),
                 'agent2':
                 self._make_agent_data(4, {
                     'protein2': 3,
                     'protein1': 8,
                     'protein3': 0
                 }),
             },
         },
     })
     name_to_path_map: Dict[str, Path] = {
         'protein1': ('counts', 'protein1'),
         'protein2': ('counts', 'protein2'),
     }
     table = raw_data_to_end_expression_table(data, name_to_path_map)
     if table['protein1'][0] == 1 / 2:
         assert table['protein1'].tolist() == [1 / 2, 8 / 4]
         assert table['protein2'].tolist() == [2 / 2, 3 / 4]
         assert table['volume'].tolist() == [2, 4]
     else:
         assert table['protein1'].tolist() == [8 / 4, 1 / 2]
         assert table['protein2'].tolist() == [3 / 4, 2 / 2]
         assert table['volume'].tolist() == [4, 2]
     assert 'protein3' not in table.columns
Ejemplo n.º 5
0
 def test_get_end_time(self) -> None:
     data = RawData({
         2: {
             'agents': {
                 'agent1': self._make_agent_data(2, {'protein': 0}),
                 'agent2': self._make_agent_data(4, {'protein': 0}),
             },
         },
         3: {
             'agents': {
                 'agent1': self._make_agent_data(2, {'protein': 1}),
                 'agent2': self._make_agent_data(4, {'protein': 4}),
             },
         },
         1: {
             'agents': {
                 'agent1': self._make_agent_data(2, {'protein': 0}),
                 'agent2': self._make_agent_data(4, {'protein': 0}),
             },
         },
     })
     name_to_path_map: Dict[str, Path] = {
         'protein': ('counts', 'protein'),
     }
     table = raw_data_to_end_expression_table(data, name_to_path_map)
     assert set(table['protein']) == set([1 / 2, 4 / 4])
     assert set(table['volume']) == set([2, 4])
Ejemplo n.º 6
0
def raw_data_to_end_expression_table(
        raw_data: RawData, paths_dict: Dict[str, Path]) -> pd.DataFrame:
    '''Create a table of end expression levels from raw simulation data.

    Args:
        raw_data: Raw simulation data
        paths: Map from names to paths to protein counts. The names will
            be used as column headers in the returned table.

    Returns:
        Table with one column for each protein and one row for each
        agent. Each cell contains the protein concentration in that
        agent in the final simulation timepoint.
    '''
    end_data = raw_data[max(raw_data.keys())]
    expression_data: Dict = {name: [] for name in paths_dict}
    expression_data[VOLUME_KEY] = []
    agents_data = get_in(end_data, AGENTS_PATH)
    for agent_data in agents_data.values():
        volume = get_in(agent_data, VOLUME_PATH, 0)
        expression_data[VOLUME_KEY].append(volume)
        for name, path in paths_dict.items():
            count = get_in(agent_data, path, 0)
            concentration = count / volume if volume else 0
            expression_data[name].append(concentration)
    return pd.DataFrame(expression_data)
def get_experiment_data(
        args: argparse.Namespace,
        experiment_id: str,
        ) -> DataTuple:
    '''Get simulation data for an experiment.

    If ``args.data_path`` is set, retrieve the experiment data from a
    JSON file named ``<experiment_id>.json`` under ``args.data_path``.
    Otherwise, retrieve the data from MongoDB.

    Args:
        args: Parsed CLI args.
        experiment_id: ID of experiment.

    Returns: Tuple of simulation data and environment config.
    '''
    if args.data_path:
        path = os.path.join(
            args.data_path, '{}.json'.format(experiment_id))
        with open(path, 'r') as f:
            loaded_file = json.load(f)
            data = RawData({
                float(time): value
                for time, value in loaded_file['data'].items()
            })
            config = EnvironmentConfig(
                loaded_file['environment_config'])
            return data, config
    return Analyzer.get_data(args, experiment_id)
def get_total_mass_plot(
    datasets: Dict[str, List[RawData]],
    colors: List[str],
    fontsize: float = 36,
    vlines: Iterable[Tuple[float, float, str, str]] = tuple(),
) -> Tuple[plt.Figure, dict]:
    '''Plot the total masses of colonies from groups of simulations.

    Each group's total mass over time is plotted as a curve on the
    resulting figure.

    Args:
        datasets: Map from the label to associate with a group of
            simulations to a list of the datasets in that group.
        colors: Map from a group label to the color to show that group's
            data in.
        fontsize: Size of all text on figure.
        vlines: Tuple of vertical line specifiers. Each specifier is a
            tuple of the line position, label position as fraction of x
            range, color, and label.

    Returns:
        A tuple of the figure and a dictionary that maps from group
        label to tuples of the first, second, and third quartiles of
        that group's data.
    '''
    fig, ax = plt.subplots()
    quartiles = {}
    for i, (label, replicates) in enumerate(datasets.items()):
        filtered_replicates = []
        for replicate in replicates:
            # Exclude first timepoint, which is often wrong
            filtered = RawData({
                key: val
                for key, val in replicate.items()
                if key != min(replicate.keys())
            })
            filtered_replicates.append(filtered)
        label_quartiles = plot_total_mass(filtered_replicates, ax, label,
                                          colors[i], fontsize)
        quartiles[label] = label_quartiles
    for x, label_x, vline_color, vline_label in vlines:
        ax.axvline(  # type: ignore
            x / 60 / 60, color=vline_color, linestyle='--')
        ax.text(  # type: ignore
            label_x,
            0.95,
            vline_label,
            fontsize=fontsize,
            transform=ax.transAxes)  # type: ignore
    ax.set_ylabel(  # type: ignore
        'Total Cell Mass (fg)', fontsize=fontsize)
    ax.set_xlabel('Time (hr)', fontsize=fontsize)  # type: ignore
    for spine_name in ('top', 'right'):
        ax.spines[spine_name].set_visible(False)  # type: ignore
    fig.tight_layout()
    return fig, quartiles
Ejemplo n.º 9
0
def filter_raw_data_by_time(raw_data: RawData,
                            time_range: Tuple[float, float]) -> RawData:
    '''Filter raw simulation data to the timepoints within a range

    Args:
        raw_data: Raw simulation data.
        time_range: Tuple of range endpoints. Each endpoint is a float
            between 0 and 1 (inclusive) that denotes a fraction of the
            total simulation time.
    Returns:
        A subset of the key-value pairs in ``raw_data``. Includes only
        those timepoints between the ``time_range`` endpoints
        (inclusive).
    '''
    floor, ceil = time_range
    end = max(raw_data.keys())
    filtered = RawData({
        time: time_data
        for time, time_data in raw_data.items()
        if floor * end <= time <= ceil * end
    })
    return filtered
Ejemplo n.º 10
0
def get_total_mass_timeseries(data: RawData) -> List[float]:
    '''Get a timeseries of the total mass of a simulation.

    Args:
        data: Data from the simulation.

    Returns:
        A list of the total cell mass in the simulation over time.
    '''
    times = sorted(data.keys())
    mass_timeseries = []
    for time in times:
        agents_data = get_in(data[time], AGENTS_PATH)
        mass_timeseries.append(get_total_mass(agents_data))
    return mass_timeseries
Ejemplo n.º 11
0
 def test_zeros(self) -> None:
     data = RawData({
         1: {
             'agents': {
                 'agent1': self._make_agent_data(0, {'protein': 1}),
                 'agent2': self._make_agent_data(4, {'protein': 0}),
                 'agent3': self._make_agent_data(0, {'protein': 0}),
             },
         },
     })
     name_to_path_map: Dict[str, Path] = {
         'protein': ('counts', 'protein'),
     }
     table = raw_data_to_end_expression_table(data, name_to_path_map)
     assert set(table['protein']) == set([0, 0, 0])
     assert set(table['volume']) == set([0, 4, 0])
Ejemplo n.º 12
0
def _get_final_live_agents(
        data: RawData,
        time_range: Tuple[float, float] = (0, 1),
) -> List[str]:
    data = filter_raw_data_by_time(data, time_range)
    max_time = max(data.keys())
    agents = []
    agents_data = get_in(
        # Pylint doesn't recognize that the RawData NewType is a dict
        data[max_time],  # pylint: disable=unsubscriptable-object
        PATH_TO_AGENTS,
    )
    for agent, agent_data in agents_data.items():
        dead = get_in(agent_data, PATH_TO_DEAD)
        if not dead:
            agents.append(agent)
    return agents
Ejemplo n.º 13
0
def plot_phylogeny(
    data: RawData,
    out: str = 'phylogeny.pdf',
    live_color: str = 'green',
    dead_color: str = 'black',
    ignore_color: str = 'lightgray',
    time_range: Tuple[float, float] = (0, 1)
) -> Tuple[TreeNode, pd.DataFrame]:
    '''Plot phylogenetic tree from an experiment.

    Args:
        data: The simulation data.
        out: Path to the output file. File type will be inferred from
            the file name.
        live_color: Color for nodes representing cells that survive
            until division.
        dead_color: Color for nodes representing cells that die.
        ignore_color: Color for nodes outside the time range considered.
        time_range: Tuple specifying the range of times to consider.
            Range values specified as fractions of the final
            timepointpoint.
    '''
    agent_ids: Set[str] = set()
    dead_ids: Set[str] = set()
    in_time_range_ids: Set[str] = set()
    end_time = max(data.keys())
    for time, time_data in data.items():
        agents_data = get_in(time_data, AGENTS_PATH)
        assert agents_data is not None
        agent_ids |= set(agents_data.keys())

        if time_range[0] * end_time <= time <= time_range[1] * end_time:
            in_time_range_ids |= set(agents_data.keys())
            for agent_id, agent_data in agents_data.items():
                if get_in(agent_data, PATH_TO_DEAD, False):
                    dead_ids.add(agent_id)

    trees = make_ete_trees(agent_ids)
    assert len(trees) == 1
    tree = trees[0]

    # Set style for overall figure
    tstyle = TreeStyle()
    tstyle.show_scale = False
    tstyle.show_leaf_name = False
    tstyle.scale = None
    tstyle.optimal_scale_level = 'full'  # Avoid artificial branches
    tstyle.mode = 'c'
    legend = {
        'Die': dead_color,
        'Survive': live_color,
        'Divided Before Antibiotics Appeared': ignore_color,
    }
    for label, color in legend.items():
        tstyle.legend.add_face(CircleFace(5, color), column=0)
        tstyle.legend.add_face(TextFace(' ' + label, ftype=FONT), column=1)

    # Set styles for each node
    for node in tree.traverse():
        nstyle = NodeStyle()
        nstyle['size'] = 5
        nstyle['vt_line_width'] = 1
        nstyle['hz_line_width'] = 1
        if node.name in in_time_range_ids:
            if node.name in dead_ids:
                nstyle['fgcolor'] = dead_color
            else:
                nstyle['fgcolor'] = live_color
        else:
            nstyle['fgcolor'] = ignore_color
        node.set_style(nstyle)
    tree.render(out, tree_style=tstyle, w=400)
    survive_col = []
    agents_col = []
    for agent in in_time_range_ids:
        agents_col.append(agent)
        survive_col.append(0 if agent in dead_ids else 1)
    df = pd.DataFrame({'agents': agents_col, 'survival': survive_col})
    return tree, df