Пример #1
0
def plot_n_ovlps_per_read(reads: List[TRRead],
                          overlaps: List[Overlap]):
    n_ovlps_per_read = Counter()
    for o in overlaps:
        n_ovlps_per_read[o.a_read_id] += 1
        n_ovlps_per_read[o.b_read_id] += 1
    show_plot(make_hist([n_ovlps_per_read[read.id] for read in reads],
                        bin_size=1),
              make_layout(x_title="# of overlaps per read",
                          y_title="Frequency",
                          x_range=(0, None)))
Пример #2
0
def plot_ulen_transition(read: TRRead):
    """Show positional distribution of unit lengths in `read."""
    assert len(read.units) > 0, "No units to show"
    show_plot(
        make_scatter(**dict(
            zip(('x', 'y'),
                zip(*[
                    x for unit in read.units
                    for x in [(unit.start,
                               unit.length), (unit.end,
                                              unit.length), (None, None)]
                ]))),
                     mode="lines+markers"),
        make_layout(title=f"Read {read.id} (strand={read.strand})",
                    x_title="Start position",
                    y_title="Unit length [bp]",
                    x_range=(0, read.length)))
Пример #3
0
def plot_ulen_dist(reads: List[TRRead],
                   by: str = "total",
                   min_ulen: int = 1,
                   max_ulen: Optional[int] = None,
                   log_scale: bool = False):
    """Show a distribution of unit lengths in `reads`.

    positional arguments:
      @ reads : A list of TRRead objects.

    optional arguments:
      @ by             : Must be one of {"total", "count"}.
      @ [min|max]_ulen : Range of unit length to count.
    """
    assert by in ("total", "count"), "`by` must be 'total' or 'count'"
    ulens = [unit.length for read in reads for unit in read.units]
    max_ulen = max(ulens) if max_ulen is None else max_ulen
    ulen_counts = Counter(
        list(filter(lambda x: min_ulen <= x <= max_ulen, ulens)))
    if by == "total":
        ulen_counts = {
            ulen: ulen * count
            for ulen, count in ulen_counts.items()
        }
    layout = make_layout(
        width=1000,
        height=500,
        title=("Total bases for each unit length"
               if by == "total" else "Number of units for each unit length"),
        x_title="Unit length [bp]",
        y_title=("Total bases [bp]" if by == "total" else "Frequency"),
        x_range=(min_ulen, max_ulen))
    if log_scale:
        layout["yaxis_type"] = "log"
    show_plot([
        make_scatter(**dict(zip(
            ('x', 'y'), zip(*sorted(ulen_counts.items())))),
                     mode="lines",
                     name="Line graph<br>(for zooming out)",
                     show_legend=True),
        make_hist(ulen_counts,
                  name="Bar graph<br>(for zooming in)",
                  show_legend=True)
    ],
              layout=layout)
Пример #4
0
def plot_overlaps_for_read(read_id: int,
                           overlaps: List[Overlap],
                           min_ovlp_len: int = 10000):
    _overlaps = list(filter(None, [o if o.a_read_id == read_id
                                   else o.swap() if o.b_read_id == read_id
                                   else None
                                   for o in overlaps]))
    assert len(_overlaps) > 0, "No overlaps for the read"
    read_len = _overlaps[0].a_len
    lens = [o.length for o in _overlaps]
    diffs = [o.diff * 100 for o in _overlaps]
    show_plot(make_scatter(x=lens,
                           y=diffs,
                           col=[o.b_read_id for o in _overlaps],
                           marker_size=8),
              make_layout(
                  shapes=[make_line(min_ovlp_len,
                                    min(diffs),
                                    min_ovlp_len,
                                    max(diffs),
                                    col="red"),   # min ovlp len threshold
                          make_line(read_len,
                                    min(diffs),
                                    read_len,
                                    max(diffs),
                                    col="green"),   # read length (contained)
                          make_rect(min_ovlp_len,
                                    min(diffs),
                                    read_len,
                                    max(diffs),
                                    opacity=0.1),   # accepted ovlps
                          make_line(lens[np.argmin(diffs)],
                                    min(diffs),
                                    lens[np.argmin(diffs)],
                                    max(diffs)),   # ovlp len of min diff
                          make_line(min(lens),
                                    min(diffs),
                                    max(lens),
                                    min(diffs))]))   # min ovlp diff
Пример #5
0
def plot_start_end(read_id: int,
                   overlaps: List[Overlap]):
    # Convert overlaps so that a_read_id == read_id
    _overlaps = list(filter(None, [o if o.a_read_id == read_id
                                   else o.swap() if o.b_read_id == read_id
                                   else None
                                   for o in overlaps]))
    assert len(_overlaps) > 0, "No overlaps for the read"
    read_len = _overlaps[0].a_len
    x, y, t, c = zip(*[
        (o.a_start,
         o.a_end,
         f"read {o.b_read_id}<br>({o.b_start}, {o.b_end})<br>{o.type}",
         o.b_read_id)
        for o in _overlaps])
    show_plot(make_scatter(x, y, text=t, col=c),
              make_layout(width=700,
                          height=700,
                          x_title="Start",
                          y_title="End",
                          x_range=(-read_len * 0.05, read_len),
                          y_range=(0, read_len * 1.05)))
Пример #6
0
def plot_self(read: TRRead, unit_dist_by: str, max_dist: Optional[float],
              max_slope_dev: float, plot_size: int):
    if unit_dist_by == "repr":
        assert read.repr_units is not None, "No representative units"
    read_shapes = [
        make_line(0, 0, read.length, read.length, width=2, col="grey")
    ]
    tr_traces, tr_shapes = read_to_tr_obj(read, max_slope_dev)
    unit_traces, unit_shapes = read_to_unit_obj(read, unit_dist_by, max_dist)
    traces = tr_traces + unit_traces
    shapes = read_shapes + tr_shapes + unit_shapes
    layout = make_layout(plot_size,
                         plot_size,
                         title=f"Read {read.id} (strand={read.strand})",
                         x_range=(0, read.length),
                         y_range=(0, read.length),
                         x_grid=False,
                         y_grid=False,
                         y_reversed=True,
                         margin=dict(l=10, r=10, t=50, b=10),
                         shapes=shapes)
    layout["yaxis"]["scaleanchor"] = "x"
    show_plot(traces, layout)
Пример #7
0
def plot_ulen_composition(reads: Union[TRRead, List[TRRead]],
                          by: str = "total"):
    """Show composition of unit lengths in each of `reads`.

    positional arguments:
      @ reads : TRRead object of a list of TRRead objects.

    optional arguments:
      @ by             : Must be one of {"total", "count"}.
    """
    assert by in ("total", "count"), "`by` must be 'total' or 'count'"
    if isinstance(reads, TRRead):
        reads = [reads]
    for read in reads:
        assert len(read.units) > 0, f"Read {read.id}: no units to show"

    comps = {
        read.id: sorted(Counter([unit.length for unit in read.units]).items())
        for read in reads
    }
    if by == "total":
        comps = {
            read_id: [(ulen, ulen * count) for ulen, count in counts]
            for read_id, counts in comps.items()
        }
    show_plot(
        [
            make_scatter(**dict(zip(('x', 'y'), zip(*comps[read.id]))),
                         mode="lines+markers",
                         marker_size=6,
                         name=f"Read {read.id}",
                         show_legend=True) for read in reads
        ],
        make_layout(
            title="Composition of unit lengths",
            x_title="Unit length [bp]",
            y_title=("Total bases [bp]" if by == "total" else "Frequency")))
Пример #8
0
def plot_vs(a_read: TRRead, b_read: TRRead, unit_dist_by: str,
            max_dist: Optional[float], plot_size: int):
    assert a_read.synchronized and b_read.synchronized, \
        "Both reads must be synchronized"
    axis_traces, axis_shapes = reads_to_axis_obj(a_read, b_read)
    matrix_traces, matrix_shapes = reads_to_matrix_obj(a_read, b_read,
                                                       unit_dist_by, max_dist)
    traces = axis_traces + matrix_traces
    shapes = axis_shapes + matrix_shapes
    layout = make_layout(plot_size,
                         plot_size,
                         x_title=f"Read {a_read.id} (strand={a_read.strand})",
                         y_title=f"Read {b_read.id} (strand={b_read.strand})",
                         x_range=(-a_read.length * 0.05, a_read.length),
                         y_range=(0, b_read.length),
                         x_grid=False,
                         y_grid=False,
                         y_reversed=True,
                         x_zeroline=False,
                         y_zeroline=False,
                         margin=dict(l=10, r=10, t=50, b=10),
                         shapes=shapes)
    layout["yaxis"]["scaleanchor"] = "x"
    show_plot(traces, layout)
Пример #9
0
def draw_string_graph(sg: ig.Graph,
                      reads: Optional[Union[List[TRRead],
                                            Dict[int, TRRead]]] = None,
                      kk_maxiter: int = 100000,
                      node_size: int = 8,
                      edge_width_per_bp: int = 5000,
                      plot_size: int = 900):
    def v_to_read_id(v: ig.Vertex) -> int:
        return int(v["name"].split(':')[0])

    def cov_rate(read: TRRead) -> float:
        return sum([unit.length for unit in read.units]) / read.length * 100

    # [(source, target)], index == v.index
    coords = sg.layout_kamada_kawai(maxiter=kk_maxiter)

    traces = []
    # Edge traces (multiple edges are red)
    n_edges = Counter([(e.source, e.target) for e in sg.es])
    e_to_headwidth_col = {
        e: (max(e["length"] // edge_width_per_bp,
                3), "black" if n_edges[(e.source, e.target)] == 1 else "red")
        for e in sg.es
    }
    # Create Trace object for each unique pair of width and color
    for col in ("black", "red"):
        traces.append(
            make_lines([(*coords[e.source], *coords[e.target])
                        for e in sg.es if e_to_headwidth_col[e][1] == col],
                       width=1,
                       col=col))
    for width, col in set(e_to_headwidth_col.values()):
        traces.append(
            make_lines(
                [((0.3 * coords[e.source][0] + 0.7 * coords[e.target][0]),
                  (0.3 * coords[e.source][1] + 0.7 * coords[e.target][1]),
                  *coords[e.target])
                 for e in sg.es if e_to_headwidth_col[e] == (width, col)],
                width=width,
                col=col))
    edge_info = defaultdict(list)
    for e in sg.es:
        edge_info[(e.source, e.target)].append(
            f"{e['length']} bp, {e['diff'] * 100:.2f}% diff" if "diff" in
            e.attributes() else f"{e['length']} bp")
    x, y, t, c = zip(*[((coords[e.source][0] + coords[e.target][0]) / 2,
                        (coords[e.source][1] + coords[e.target][1]) / 2,
                        f"{'<br>'.join(edge_info[(e.source, e.target)])}",
                        e_to_headwidth_col[e][1]) for e in sg.es])
    traces.append(make_scatter(x=x, y=y, text=t, col=c, marker_size=1))
    # Node trace with color by cover rate by TR units
    if reads is not None:
        reads_by_id = (reads if isinstance(reads, dict) else
                       {read.id: read
                        for read in reads})
    cov_rates = [
        cov_rate(reads_by_id[v_to_read_id(v)]) if reads is not None else 0.
        for v in sg.vs
    ]
    x, y, t = zip(*[(*coords[v.index],
                     f"{v['name']}<br>{cov_rates[v.index]:.1f}% covered")
                    for v in sg.vs])
    traces.append(
        make_scatter(x=x,
                     y=y,
                     text=t,
                     col=cov_rates,
                     col_range=(50, 100),
                     col_scale='YlGnBu',
                     show_scale=False,
                     marker_size=node_size,
                     marker_width=node_size / 5))
    show_plot(
        traces,
        make_layout(width=plot_size,
                    height=plot_size,
                    x_grid=False,
                    y_grid=False,
                    x_zeroline=False,
                    y_zeroline=False,
                    x_show_tick_labels=False,
                    y_show_tick_labels=False,
                    margin=dict(l=0, r=0, b=0, t=0)))
def adaptive_filter_overlaps(overlaps: List[Overlap],
                             min_n_ovlp: int,
                             default_min_ovlp_len: int,
                             limit_min_ovlp_len: int,
                             filter_by_diff: bool = True,
                             plot: bool = False) -> List[Overlap]:
    """Filter overlaps by length and sequence dissimilarity by adaptively
    changing the thresholds for individual read, considering the number of
    overlaps at prefix and suffix of each read.
    """

    # TODO: what is the difference with "best-N-overlaps" strategy?

    def _filter_overlaps(_overlaps: List[Overlap]) -> List[Overlap]:
        """Filter overlaps with adaptive threshold of minimum overlap length
        ranging [`limit_min_ovlp_len`, `default_min_ovlp_len`]."""
        if len(_overlaps) == 0:
            return _overlaps
        olens = [o.length for o in _overlaps]
        min_len = max(
            min((min(olens) if len(_overlaps) < min_n_ovlp else sorted(
                olens, reverse=True)[min_n_ovlp - 1]), default_min_ovlp_len),
            limit_min_ovlp_len)
        min_lens.append(min_len)
        return list(filter(lambda o: o.length >= min_len, _overlaps))

    overlaps_per_read = defaultdict(list)
    for o in overlaps:
        overlaps_per_read[o.a_read_id].append(o)
        overlaps_per_read[o.b_read_id].append(o.swap())

    filtered_overlaps = []
    min_lens, max_diffs = [], []  # for plot
    for read_id, _overlaps in overlaps_per_read.items():
        # Filter overlaps with adaptive min overlap length threshold
        # NOTE: contained overlaps are counted twice in pre and suf
        pre_overlaps = _filter_overlaps(
            list(filter(lambda o: o.a_start == 0, _overlaps)))
        suf_overlaps = _filter_overlaps(
            list(filter(lambda o: o.a_end == o.a_len, _overlaps)))
        contains_overlaps = list(
            filter(
                lambda o:
                (o.type == "contains" and o.length >= default_min_ovlp_len),
                _overlaps))
        _overlaps = sorted(set(pre_overlaps + suf_overlaps +
                               contains_overlaps))
        if filter_by_diff and len(_overlaps) >= 2:
            # Filter overlaps with adaptive overlap seq diff threshold
            # NOTE: this is in order to exclude false contained overlaps
            diffs = [o.diff for o in _overlaps]
            max_diff = mean(diffs) + stdev(diffs)  # TODO: rationale?
            max_diffs.append(max_diff * 100)
            _overlaps = list(filter(lambda o: o.diff <= max_diff, _overlaps))

        filtered_overlaps += _overlaps
    # Merge overlaps and remove duplicated overlaps
    filtered_overlaps = sorted(
        set([
            o if o.a_read_id < o.b_read_id else o.swap()
            for o in filtered_overlaps
        ]))
    logger.info(f"{len(overlaps)} -> {len(filtered_overlaps)} overlaps")
    if plot:
        show_plot(
            make_hist(min_lens, bin_size=500),
            make_layout(x_title="Min overlap length at boundaries [bp]",
                        y_title="Frequency"))
        if filter_by_diff:
            show_plot(
                make_hist(max_diffs, bin_size=0.1),
                make_layout(x_title="Max sequence dissimilarity per read [%]",
                            y_title="Frequency"))
    return filtered_overlaps
Пример #11
0
def plot_cigar(mappings: List[ContigMapping],
               mutation_locations: List[int],
               read_length: int,
               true_fname: str,
               line_width: int = 1):
    """Show alignments between contigs and the true sequence and also mutations.

    positional arguments:
      @ contig_cigars      : Cigar and true start position for each contig.
      @ mutation_locations : Positions of mutations in the true sequence.

    optional arguments:
      @ line_width : Width of the alignments in the plot.
    """
    def make_lines_for_contigs() -> Tuple[List[go.Scatter], int]:
        # Make line objects representing alignment paths
        match_lines, nonmatch_lines, gap_lines = [], [], []
        contig_start = 0
        for i in range(len(mappings)):
            if i > 0:
                gap_lines.append((mappings[i - 1].end, contig_start,
                                  mappings[i].start, contig_start))
            _match_lines, _nonmatch_lines, contig_start = \
                make_lines_for_contig(mappings[i], contig_start)
            match_lines += _match_lines
            nonmatch_lines += _nonmatch_lines
        return ([
            make_lines(match_lines, width=line_width),
            make_lines(nonmatch_lines, width=line_width, col="red"),
            make_lines(gap_lines, width=line_width, col="yellow")
        ], contig_start)

    def make_lines_for_contig(
            mapping: ContigMapping,
            contig_start: int) -> List[Tuple[int, int, int, int]]:
        match_lines, nonmatch_lines = [], []
        contig_pos, true_pos = contig_start, mapping.start
        for length, op in mapping.cigar:
            if op in ('=', 'X'):
                (match_lines if op == '=' else nonmatch_lines) \
                    .append((true_pos,
                             contig_pos,
                             true_pos + length,
                             contig_pos + length))
                contig_pos += length
                true_pos += length
            elif op == 'I':
                nonmatch_lines.append(
                    (true_pos, contig_pos, true_pos, contig_pos + length))
                contig_pos += length
            else:
                nonmatch_lines.append(
                    (true_pos, contig_pos, true_pos + length, contig_pos))
                true_pos += length
        assert true_pos == mapping.end, "Invalid CIGAR"
        return match_lines, nonmatch_lines, contig_pos

    def make_dots_for_mutation_status():
        nonlocal read_length, mutation_locations
        # Remove boundary regions
        mutation_locations = [pos - read_length for pos in mutation_locations]
        # Check if the mutation is assembled for each mutation
        mutation_status = {pos: None for pos in mutation_locations}
        for mapping in mappings:
            true_pos = mapping.start
            for op in mapping.cigar.flatten():
                if (true_pos in mutation_status
                        and mutation_status[true_pos] in (None, '=')):
                    mutation_status[true_pos] = op
                    # TODO: see bases around the mutation
                if op != 'I':
                    true_pos += 1
        return make_scatter(
            x=mutation_locations,
            y=[0] * len(mutation_locations),
            col=[
                'black' if mutation_status[pos] == '=' else 'red'
                for pos in mutation_locations
            ],
            marker_size=3)

    mappings = sorted(mappings, key=lambda x: x.start)
    traces_aln, contigs_length = make_lines_for_contigs()
    trace_mutation = make_dots_for_mutation_status()
    true_seq_length = load_fasta(true_fname)[0].length - 2 * read_length
    show_plot(
        traces_aln + [trace_mutation],
        make_layout(width=800,
                    height=800 * contigs_length / true_seq_length,
                    x_title="True sequence",
                    y_title="Contig",
                    x_range=(0, true_seq_length),
                    y_range=(0, contigs_length),
                    x_grid=False,
                    y_grid=False,
                    x_zeroline=False,
                    y_zeroline=False,
                    y_reversed=True,
                    anchor_axes=True,
                    margin=dict(l=10, r=10, t=50, b=10)))