Exemplo n.º 1
0
    def plot_coverage_of_container(container, query):
        arrs = []
        keys = []
        for gkey, groups in container.groups_by_type().items():
            keys.append(gkey)
            a = np.zeros(len(query))
            for group in groups:
                for s in group.query_region.slices():
                    a[s] += 1
            arrs.append(a)
        data = np.vstack(arrs)

        params = {"legend.fontsize": 8}
        plt.rcParams.update(params)

        fig = plt.figure(figsize=(10, 7.5))
        fig.suptitle(query.name, fontsize=16)

        gs = GridSpec(3, 2, hspace=0.55)
        ax1 = fig.add_subplot(gs[0, :2])
        ax2 = fig.add_subplot(gs[1, :2], sharex=ax1)
        ax3 = fig.add_subplot(gs[2, :2], sharex=ax1)

        ax1.set_title("Total Alignment Coverage")
        ax1.set_ylabel("Coverage")
        ax1.set_xlabel("bp")
        ax1.set_yscale("log")
        ax1.plot(np.sum(data, axis=0))

        ax2.set_title("Alignment Coverage")
        ax2.set_ylabel("Coverage")
        ax2.set_xlabel("bp")
        ax2.set_yscale("log")
        ax2.plot(data.T)
        ax2.legend(keys, loc="center left", ncol=1, bbox_to_anchor=(1.0, 0.5))

        stats = DNAStats(
            str(query.seq) + str(query.seq) + str(query.seq), 14, 20, 20)
        costs_arr = []
        bp_arr = []
        windows = [100, 500, 1000, 2000]
        for window in windows:
            costs = []
            step = min(windows)
            x = np.arange(len(query) - window, len(query) * 2 + window)
            for i in x[::step]:
                costs.append(stats.cost(i, i + window))

            x = x - len(query)
            y = np.repeat(costs, step)

            delta = -(y.shape[0] - x.shape[0])
            if delta == 0:
                delta = None
            y = y[:delta]
            costs_arr.append(y)
            bp_arr.append(x)

        ax3.axhline(
            y=Config.SequenceScoringConfig.complexity_threshold,
            color="k",
            linestyle="-",
        )
        ax3.set_yscale("log")
        ax3.set_xlabel("bp")
        ax3.set_ylabel("Complexity")
        ax3.set_title("Sequence Complexity ({})".format(window))

        for x, y, l in zip(bp_arr, costs_arr, windows):
            ax3.plot(x, y, label=l)
            ax3.legend(title="window (bp)")

        ax1.set_xlim(0, len(query))

        axes = [ax1, ax2, ax3]

        return fig, axes
Exemplo n.º 2
0
class AssemblyGraphPostProcessor:
    """Pre-processing for assembly graphs. Evaluates:

    1. synthesis complexity and weights corresponding edge
    2. pcr product efficiency
    3. (optional) optimal partitions for synthesis fragments
    """

    # TODO: add post processing config

    def __init__(
        self,
        graph: nx.DiGraph,
        query: SeqRecord,
        span_cost: SpanCost,
        seqdb: Dict[str, SeqRecord],
        container: AlignmentContainer,
        stats_repeat_window: Optional[int] = None,
        stats_window: Optional[int] = None,
        stats_hairpin_window: Optional[int] = None,
        edge_threshold: Optional[float] = None,
        stages: Optional[Tuple[str]] = None,
    ):
        if stats_repeat_window is None:
            stats_repeat_window = SequenceScoringConfig.stats_repeat_window
        if stats_window is None:
            stats_window = SequenceScoringConfig.stats_window
        if stats_hairpin_window is None:
            stats_hairpin_window = SequenceScoringConfig.stats_hairpin_window
        if stages is None:
            stages = SequenceScoringConfig.post_process_stages
        if edge_threshold is None:
            edge_threshold = SequenceScoringConfig.edge_threshold
        self.graph = graph
        self.graph_builder = AssemblyGraphBuilder(container, span_cost=span_cost)
        self.graph_builder.G = graph
        self.query = query
        self.seqdb = seqdb
        query_seq = str(query.seq)
        if is_circular(query):
            query_seq = query_seq + query_seq
        self.stats = DNAStats(
            query_seq,
            repeat_window=stats_repeat_window,
            stats_window=stats_window,
            hairpin_window=stats_hairpin_window,
        )
        self.stats_single = DNAStats(
            str(query.seq),
            repeat_window=stats_repeat_window,
            stats_window=stats_window,
            hairpin_window=stats_hairpin_window,
        )
        self.logged_msgs = []
        # TODO: make a more sophisticated complexity function?
        # TODO: expose this to input parameters
        self.COMPLEXITY_THRESHOLD = SequenceScoringConfig.complexity_threshold
        self.logger = logger(self)
        self.span_cost = span_cost
        self.stages = stages
        self.edge_threshold = edge_threshold

    @staticmethod
    def optimize_partition(
        signatures: np.ndarray, step: int, i: int = None, j: int = None
    ):
        """Optimize partition by minimizing the number of signatures in the
        given array.

        :param signatures: array of signatures
        :param step: step size
        :param i:
        :param j:
        :return:
        """
        d = []

        if i is None:
            i = 0
        if j is None:
            j = signatures.shape[1]

        for x in range(i, j, step):
            m1 = np.empty(signatures.shape[1])
            m2 = m1.copy()
            m1.fill(np.nan)
            m2.fill(np.nan)

            m1[:x] = np.random.uniform(1, 10)
            m2[x:] = np.random.uniform(1, 10)

            d += [m1, m2]
        d = np.vstack(d)
        z = np.tile(d, signatures.shape[0]) * signatures.flatten()

        partition_index = np.repeat(
            np.arange(0, signatures.shape[1], step),
            signatures.shape[0] * signatures.shape[1] * 2,
        )

        a, b, c = np.unique(z, return_counts=True, return_index=True)
        i = b[np.where(c > 1)]
        a, c = np.unique(partition_index[i], return_counts=True)
        if len(c):
            arg = c.argmin()
            return a[arg], c[arg]

    def _edge_to_region(self, n1, n2):
        if n1.index == len(self.query) and is_circular(self.query):
            a = 0
            b = n2.index
        else:
            a = n1.index
            b = n2.index
        return Region(a, b, len(self.query), cyclic=is_circular(self.query))

    @staticmethod
    def _adj_eff(edata, e):
        edata["efficiency"] = e
        if e == 0:
            edata["cost"] = np.inf
        if edata["material"] is None:
            edata["cost"] = np.inf
        else:
            edata["cost"] = edata["material"] / edata["efficiency"]

    def _complexity_to_efficiency(self, edata: dict) -> bool:
        if not edata["complexity"] >= 0.0:
            raise ValueError("Complexity is not defined. {}".format(edata))
        ratio = edata["complexity"] / self.COMPLEXITY_THRESHOLD
        if ratio >= 1.0:
            e = SequenceScoringConfig.not_synthesizable_efficiency / ratio
            self._adj_eff(edata, e)
            return True
        return False

    @staticmethod
    def _is_pcr_product(edata):
        return edata["type_def"].name in [
            Constants.PCR_PRODUCT,
            Constants.PCR_PRODUCT_WITH_LEFT_PRIMER,
            Constants.PCR_PRODUCT_WITH_RIGHT_PRIMER,
            Constants.PCR_PRODUCT_WITH_PRIMERS,
        ]

    def score_long_pcr_products(self, n1, n2, edata):
        if self._is_pcr_product(edata):
            span = edata["span"]
            for a, b, c in SequenceScoringConfig.pcr_length_range_efficiency_multiplier:
                if a <= span < b:
                    self._adj_eff(edata, edata["efficiency"] * c)
                    add_edge_note(edata, "long_pcr_product", True)
                    break

    def _score_misprimings_from_alignment(self, alignment):
        subject_key = alignment.subject_key
        subject = self.seqdb[subject_key]
        i = alignment.subject_region.a
        j = alignment.subject_region.b
        subject_seq = str(subject.seq)

        return count_misprimings_in_amplicon(
            subject_seq,
            i,
            j,
            min_primer_anneal=SequenceScoringConfig.mispriming_min_anneal,
            max_primer_anneal=SequenceScoringConfig.mispriming_max_anneal,
            cyclic=alignment.subject_region.cyclic,
        )

    # TODO: select the best template...
    # TODO: speed up this process
    def score_primer_misprimings(self, n1, n2, edata):
        if self._is_pcr_product(edata):

            group = edata["group"]

            misprime_list = []

            if isinstance(group, MultiPCRProductAlignmentGroup):
                prioritize_function = group.prioritize_groupings
                alignments = group.iter_templates()

            elif isinstance(group, AlignmentGroup):
                prioritize_function = group.prioritize_alignments
                alignments = group.alignments
            else:
                raise TypeError(
                    "Group '{}' not supported by this function.".format(group.__class__)
                )

            arr = []
            for index, alignment in enumerate(alignments):
                if not (
                    "PCR" in alignment.type or alignment.type == Constants.FRAGMENT
                ):
                    raise ValueError(
                        "{} is not a valid type for a template".format(alignment.type)
                    )
                mispriming = self._score_misprimings_from_alignment(alignment)
                arr.append((mispriming, index, alignment))
                if mispriming == 0:
                    break
            arr.sort(key=lambda x: x[0])
            indices = [a[1] for a in arr]
            prioritize_function(indices)
            score = arr[0][0]

            self._adj_eff(
                edata,
                edata["efficiency"] * SequenceScoringConfig.mispriming_penalty ** score,
            )
            add_edge_note(edata, "num_misprimings", score)
            add_edge_note(edata, "n_templates_eval", len(misprime_list))

    def _filter_partition_edges(self, edges: List[Edge]) -> List[Edge]:
        edges_to_partition = []
        for n1, n2, edata in edges:
            min_size = (
                edata["type_def"].min_size
                or MoleculeType.types[Constants.SHARED_SYNTHESIZED_FRAGMENT].min_size
            )
            if (
                edata["span"] > min_size * 2
                and edata["complexity"]
                >= Config.SequenceScoringConfig.complexity_threshold
            ):
                edges_to_partition.append((n1, n2, edata))
        return edges_to_partition

    def partition(self, edges: List[Edge]):
        tracker = self.logger.track(
            "INFO", desc="Partitioning sequences", total=3
        ).enter()
        tracker.update(0, "{} highly complex sequences".format(len(edges)))

        edges_to_partition = self._filter_partition_edges(edges)

        cyclic = is_circular(self.query)
        partitions = find_by_partitions_for_sequence(
            self.stats_single,
            cyclic=cyclic,
            threshold=Config.SequenceScoringConfig.complexity_threshold,
            step_size=Config.SequenceScoringConfig.partition_step_size,
            delta=Config.SequenceScoringConfig.partition_overlap,
        )
        tracker.update(1, "Partition: locations: {}".format(partitions))
        add_gap_edge = partial(self.graph_builder.add_gap_edge, add_to_graph=False)
        add_overlap_edge = partial(
            self.graph_builder.add_overlap_edge,
            add_to_graph=True,
            validate_groups_present=False,
            groups=None,
            group_keys=None,
        )

        new_edges = []
        for n1, n2, edata in edges_to_partition:
            r = Region(n1.index, n2.index, len(self.query.seq), cyclic=cyclic)

            if n1.type == "B" and n2.type == "A":

                for p in partitions:
                    if p in r:
                        # TODO: overlap? find optimal partition for overlap?
                        i4 = p
                        i3 = p + Config.SequenceScoringConfig.partition_overlap

                        n3 = AssemblyNode(i3, False, "BC", overhang=True)
                        n4 = AssemblyNode(i4, False, "CA", overhang=True)
                        e1 = add_gap_edge(n1, n3, r, origin=False)
                        e2 = add_gap_edge(n1, n3, r, origin=True)
                        if e1 is None and e2 is None:
                            continue
                        e3 = add_overlap_edge(n3, n4, r, origin=False)
                        e4 = add_overlap_edge(n3, n4, r, origin=True)
                        if e3 is None and e4 is None:
                            continue
                        e5 = add_gap_edge(n4, n2, r, origin=False)
                        e6 = add_gap_edge(n4, n2, r, origin=True)
                        if e5 is None and e6 is None:
                            continue
                        new_edges += [e1, e2, e3, e4, e5, e6]
        for e in new_edges:
            if e is not None:
                self.graph_builder.G.add_edge(e[0], e[1], **e[2])
        edges = []
        for n1, n2, edata in self.graph_builder.G.edges(data=True):
            if edata["cost"] is None:
                edges.append((n1, n2, edata))
        tracker.update(2, "Partition: Added {} new edges".format(len(edges)))
        self.graph_builder.update_costs(edges)
        self.score_complexity_edges(list(self.graph_builder.G.edges(data=True)))
        tracker.exit()

    def remove_inefficient_edges(self):
        to_remove = []
        for n1, n2, edata in self.graph_builder.G.edges(data=True):
            if edata["efficiency"] < self.edge_threshold:
                to_remove.append((n1, n2))
        self.graph_builder.G.remove_edges_from(to_remove)
        self.logger.info("Removed {} inefficient edges".format(len(to_remove)))

    def _score_syn_dna(
        self, n1: AssemblyNode, n2: AssemblyNode, edata: dict
    ) -> List[Edge]:
        if edata["type_def"].synthesize:
            span = edata["span"]
            if span > 0:
                # TODO: cyclic may not always be true
                region = self._edge_to_region(n1, n2)
                a = region.a
                c = region.c
                if c < a:
                    c += region.context_length
                assert c <= len(self.stats)
                complexity = self.stats.cost(a, c)
                if not complexity >= 0.0:
                    pass
                edata["complexity"] = complexity
                if self._complexity_to_efficiency(edata):
                    add_edge_note(edata, "highly_complex", True)
                    return True
        return False

    def score_complexity_edges(self, edges: List[Edge] = None) -> List[Edge]:
        """Score synthetic edges."""
        if edges is None:
            edges = self.graph.edges(data=True)
        complex_edges = []
        for n1, n2, edata in self.logger.tqdm(
            edges, "INFO", desc="Scoring synthetic DNA"
        ):
            if self._score_syn_dna(n1, n2, edata):
                complex_edges.append((n1, n2, edata))
        return complex_edges

    def update(self):
        self.logger.info("Post Processor: {}".format(self.stages))
        bad_edges = []

        self.logger.info("Scoring long PCR products")
        edges = list(self.graph.edges(data=True))

        if SequenceScoringConfig.SCORE_LONG in self.stages:
            for n1, n2, edata in self.logger.tqdm(
                edges, "INFO", desc="Scoring long PCR products"
            ):
                self.score_long_pcr_products(n1, n2, edata)
        if SequenceScoringConfig.SCORE_MISPRIMINGS in self.stages:
            for n1, n2, edata in self.logger.tqdm(
                edges, "INFO", desc="Scoring primer misprimings"
            ):
                self.score_primer_misprimings(n1, n2, edata)
        if SequenceScoringConfig.SCORE_COMPLEXITY in self.stages:
            bad_edges += self.score_complexity_edges(edges)
        self.logger.info(
            "Found {} highly complex synthesis segments".format(len(bad_edges))
        )

        if SequenceScoringConfig.PARTITION in self.stages:
            self.partition(bad_edges)

        # TODO: reimplement `remove_inefficient_edges`
        # self.remove_inefficient_edges()

    # TODO: add logging to graph post processor
    # TODO: partition gaps
    def __call__(self):
        self.logger.info("Post processing graph for {}".format(self.query.name))
        self.update()