def __init__( self, graph: nx.DiGraph, query: SeqRecord, span_cost: SpanCost, seqdb: Dict[str, SeqRecord], container: AlignmentContainer, stats_repeat_window: Optional[int] = None, stats_window: Optional[int] = None, stats_hairpin_window: Optional[int] = None, edge_threshold: Optional[float] = None, stages: Optional[Tuple[str]] = None, ): if stats_repeat_window is None: stats_repeat_window = SequenceScoringConfig.stats_repeat_window if stats_window is None: stats_window = SequenceScoringConfig.stats_window if stats_hairpin_window is None: stats_hairpin_window = SequenceScoringConfig.stats_hairpin_window if stages is None: stages = SequenceScoringConfig.post_process_stages if edge_threshold is None: edge_threshold = SequenceScoringConfig.edge_threshold self.graph = graph self.graph_builder = AssemblyGraphBuilder(container, span_cost=span_cost) self.graph_builder.G = graph self.query = query self.seqdb = seqdb query_seq = str(query.seq) if is_circular(query): query_seq = query_seq + query_seq self.stats = DNAStats( query_seq, repeat_window=stats_repeat_window, stats_window=stats_window, hairpin_window=stats_hairpin_window, ) self.stats_single = DNAStats( str(query.seq), repeat_window=stats_repeat_window, stats_window=stats_window, hairpin_window=stats_hairpin_window, ) self.logged_msgs = [] # TODO: make a more sophisticated complexity function? # TODO: expose this to input parameters self.COMPLEXITY_THRESHOLD = SequenceScoringConfig.complexity_threshold self.logger = logger(self) self.span_cost = span_cost self.stages = stages self.edge_threshold = edge_threshold
def test_(): backbone = random_sequence(3000) repeat = random_sequence(30) complex_sequence = repeat + random_sequence(200) + repeat + random_sequence(1000) goal = backbone[1000:] + complex_sequence + backbone[:1000] stats = DNAStats(goal, 14, 20, 20) p = find_by_partitions_for_sequence( stats, cyclic=True, threshold=10, step_size=10, delta=30 ) print(p)
def test_find_opt_partition(delta, step_size): repeat = random_sequence(30) seq = ( random_sequence(1000) + repeat + random_sequence(20) + repeat + random_sequence(1000) ) stats = DNAStats(seq, 20, 20, 20) p, cmin = find_opt_partition(stats, 10, step_size=step_size, delta=delta) assert p > 1000 and p < 1000 + 30 + 20 + 30
def test_find_best_partitions(delta, step_size): repeat = random_sequence(30) seq = ( random_sequence(1000) + repeat + random_sequence(20) + repeat + random_sequence(1000) ) stats = DNAStats(seq, 20, 20, 20) p = find_best_partitions(stats, threshold=10, step_size=step_size, delta=delta) for _p in p: assert _p > 1000 and _p < 1000 + 30 + 20 + 30
def test_find_partitions_for_sequence(delta, step_size, cyclic): repeat = random_sequence(30) seq = ( random_sequence(1000) + repeat + random_sequence(20) + repeat + random_sequence(1000) ) stats = DNAStats(seq, 20, 20, 20) p = find_by_partitions_for_sequence( stats, cyclic, threshold=10, step_size=step_size, delta=delta ) passes = False for _p in p: if _p > 1000 and _p < 1000 + 30 + 20 + 30: passes = True assert passes
def plot_coverage_of_container(container, query): arrs = [] keys = [] for gkey, groups in container.groups_by_type().items(): keys.append(gkey) a = np.zeros(len(query)) for group in groups: for s in group.query_region.slices(): a[s] += 1 arrs.append(a) data = np.vstack(arrs) params = {"legend.fontsize": 8} plt.rcParams.update(params) fig = plt.figure(figsize=(10, 7.5)) fig.suptitle(query.name, fontsize=16) gs = GridSpec(3, 2, hspace=0.55) ax1 = fig.add_subplot(gs[0, :2]) ax2 = fig.add_subplot(gs[1, :2], sharex=ax1) ax3 = fig.add_subplot(gs[2, :2], sharex=ax1) ax1.set_title("Total Alignment Coverage") ax1.set_ylabel("Coverage") ax1.set_xlabel("bp") ax1.set_yscale("log") ax1.plot(np.sum(data, axis=0)) ax2.set_title("Alignment Coverage") ax2.set_ylabel("Coverage") ax2.set_xlabel("bp") ax2.set_yscale("log") ax2.plot(data.T) ax2.legend(keys, loc="center left", ncol=1, bbox_to_anchor=(1.0, 0.5)) stats = DNAStats( str(query.seq) + str(query.seq) + str(query.seq), 14, 20, 20) costs_arr = [] bp_arr = [] windows = [100, 500, 1000, 2000] for window in windows: costs = [] step = min(windows) x = np.arange(len(query) - window, len(query) * 2 + window) for i in x[::step]: costs.append(stats.cost(i, i + window)) x = x - len(query) y = np.repeat(costs, step) delta = -(y.shape[0] - x.shape[0]) if delta == 0: delta = None y = y[:delta] costs_arr.append(y) bp_arr.append(x) ax3.axhline( y=Config.SequenceScoringConfig.complexity_threshold, color="k", linestyle="-", ) ax3.set_yscale("log") ax3.set_xlabel("bp") ax3.set_ylabel("Complexity") ax3.set_title("Sequence Complexity ({})".format(window)) for x, y, l in zip(bp_arr, costs_arr, windows): ax3.plot(x, y, label=l) ax3.legend(title="window (bp)") ax1.set_xlim(0, len(query)) axes = [ax1, ax2, ax3] return fig, axes
class AssemblyGraphPostProcessor: """Pre-processing for assembly graphs. Evaluates: 1. synthesis complexity and weights corresponding edge 2. pcr product efficiency 3. (optional) optimal partitions for synthesis fragments """ # TODO: add post processing config def __init__( self, graph: nx.DiGraph, query: SeqRecord, span_cost: SpanCost, seqdb: Dict[str, SeqRecord], container: AlignmentContainer, stats_repeat_window: Optional[int] = None, stats_window: Optional[int] = None, stats_hairpin_window: Optional[int] = None, edge_threshold: Optional[float] = None, stages: Optional[Tuple[str]] = None, ): if stats_repeat_window is None: stats_repeat_window = SequenceScoringConfig.stats_repeat_window if stats_window is None: stats_window = SequenceScoringConfig.stats_window if stats_hairpin_window is None: stats_hairpin_window = SequenceScoringConfig.stats_hairpin_window if stages is None: stages = SequenceScoringConfig.post_process_stages if edge_threshold is None: edge_threshold = SequenceScoringConfig.edge_threshold self.graph = graph self.graph_builder = AssemblyGraphBuilder(container, span_cost=span_cost) self.graph_builder.G = graph self.query = query self.seqdb = seqdb query_seq = str(query.seq) if is_circular(query): query_seq = query_seq + query_seq self.stats = DNAStats( query_seq, repeat_window=stats_repeat_window, stats_window=stats_window, hairpin_window=stats_hairpin_window, ) self.stats_single = DNAStats( str(query.seq), repeat_window=stats_repeat_window, stats_window=stats_window, hairpin_window=stats_hairpin_window, ) self.logged_msgs = [] # TODO: make a more sophisticated complexity function? # TODO: expose this to input parameters self.COMPLEXITY_THRESHOLD = SequenceScoringConfig.complexity_threshold self.logger = logger(self) self.span_cost = span_cost self.stages = stages self.edge_threshold = edge_threshold @staticmethod def optimize_partition( signatures: np.ndarray, step: int, i: int = None, j: int = None ): """Optimize partition by minimizing the number of signatures in the given array. :param signatures: array of signatures :param step: step size :param i: :param j: :return: """ d = [] if i is None: i = 0 if j is None: j = signatures.shape[1] for x in range(i, j, step): m1 = np.empty(signatures.shape[1]) m2 = m1.copy() m1.fill(np.nan) m2.fill(np.nan) m1[:x] = np.random.uniform(1, 10) m2[x:] = np.random.uniform(1, 10) d += [m1, m2] d = np.vstack(d) z = np.tile(d, signatures.shape[0]) * signatures.flatten() partition_index = np.repeat( np.arange(0, signatures.shape[1], step), signatures.shape[0] * signatures.shape[1] * 2, ) a, b, c = np.unique(z, return_counts=True, return_index=True) i = b[np.where(c > 1)] a, c = np.unique(partition_index[i], return_counts=True) if len(c): arg = c.argmin() return a[arg], c[arg] def _edge_to_region(self, n1, n2): if n1.index == len(self.query) and is_circular(self.query): a = 0 b = n2.index else: a = n1.index b = n2.index return Region(a, b, len(self.query), cyclic=is_circular(self.query)) @staticmethod def _adj_eff(edata, e): edata["efficiency"] = e if e == 0: edata["cost"] = np.inf if edata["material"] is None: edata["cost"] = np.inf else: edata["cost"] = edata["material"] / edata["efficiency"] def _complexity_to_efficiency(self, edata: dict) -> bool: if not edata["complexity"] >= 0.0: raise ValueError("Complexity is not defined. {}".format(edata)) ratio = edata["complexity"] / self.COMPLEXITY_THRESHOLD if ratio >= 1.0: e = SequenceScoringConfig.not_synthesizable_efficiency / ratio self._adj_eff(edata, e) return True return False @staticmethod def _is_pcr_product(edata): return edata["type_def"].name in [ Constants.PCR_PRODUCT, Constants.PCR_PRODUCT_WITH_LEFT_PRIMER, Constants.PCR_PRODUCT_WITH_RIGHT_PRIMER, Constants.PCR_PRODUCT_WITH_PRIMERS, ] def score_long_pcr_products(self, n1, n2, edata): if self._is_pcr_product(edata): span = edata["span"] for a, b, c in SequenceScoringConfig.pcr_length_range_efficiency_multiplier: if a <= span < b: self._adj_eff(edata, edata["efficiency"] * c) add_edge_note(edata, "long_pcr_product", True) break def _score_misprimings_from_alignment(self, alignment): subject_key = alignment.subject_key subject = self.seqdb[subject_key] i = alignment.subject_region.a j = alignment.subject_region.b subject_seq = str(subject.seq) return count_misprimings_in_amplicon( subject_seq, i, j, min_primer_anneal=SequenceScoringConfig.mispriming_min_anneal, max_primer_anneal=SequenceScoringConfig.mispriming_max_anneal, cyclic=alignment.subject_region.cyclic, ) # TODO: select the best template... # TODO: speed up this process def score_primer_misprimings(self, n1, n2, edata): if self._is_pcr_product(edata): group = edata["group"] misprime_list = [] if isinstance(group, MultiPCRProductAlignmentGroup): prioritize_function = group.prioritize_groupings alignments = group.iter_templates() elif isinstance(group, AlignmentGroup): prioritize_function = group.prioritize_alignments alignments = group.alignments else: raise TypeError( "Group '{}' not supported by this function.".format(group.__class__) ) arr = [] for index, alignment in enumerate(alignments): if not ( "PCR" in alignment.type or alignment.type == Constants.FRAGMENT ): raise ValueError( "{} is not a valid type for a template".format(alignment.type) ) mispriming = self._score_misprimings_from_alignment(alignment) arr.append((mispriming, index, alignment)) if mispriming == 0: break arr.sort(key=lambda x: x[0]) indices = [a[1] for a in arr] prioritize_function(indices) score = arr[0][0] self._adj_eff( edata, edata["efficiency"] * SequenceScoringConfig.mispriming_penalty ** score, ) add_edge_note(edata, "num_misprimings", score) add_edge_note(edata, "n_templates_eval", len(misprime_list)) def _filter_partition_edges(self, edges: List[Edge]) -> List[Edge]: edges_to_partition = [] for n1, n2, edata in edges: min_size = ( edata["type_def"].min_size or MoleculeType.types[Constants.SHARED_SYNTHESIZED_FRAGMENT].min_size ) if ( edata["span"] > min_size * 2 and edata["complexity"] >= Config.SequenceScoringConfig.complexity_threshold ): edges_to_partition.append((n1, n2, edata)) return edges_to_partition def partition(self, edges: List[Edge]): tracker = self.logger.track( "INFO", desc="Partitioning sequences", total=3 ).enter() tracker.update(0, "{} highly complex sequences".format(len(edges))) edges_to_partition = self._filter_partition_edges(edges) cyclic = is_circular(self.query) partitions = find_by_partitions_for_sequence( self.stats_single, cyclic=cyclic, threshold=Config.SequenceScoringConfig.complexity_threshold, step_size=Config.SequenceScoringConfig.partition_step_size, delta=Config.SequenceScoringConfig.partition_overlap, ) tracker.update(1, "Partition: locations: {}".format(partitions)) add_gap_edge = partial(self.graph_builder.add_gap_edge, add_to_graph=False) add_overlap_edge = partial( self.graph_builder.add_overlap_edge, add_to_graph=True, validate_groups_present=False, groups=None, group_keys=None, ) new_edges = [] for n1, n2, edata in edges_to_partition: r = Region(n1.index, n2.index, len(self.query.seq), cyclic=cyclic) if n1.type == "B" and n2.type == "A": for p in partitions: if p in r: # TODO: overlap? find optimal partition for overlap? i4 = p i3 = p + Config.SequenceScoringConfig.partition_overlap n3 = AssemblyNode(i3, False, "BC", overhang=True) n4 = AssemblyNode(i4, False, "CA", overhang=True) e1 = add_gap_edge(n1, n3, r, origin=False) e2 = add_gap_edge(n1, n3, r, origin=True) if e1 is None and e2 is None: continue e3 = add_overlap_edge(n3, n4, r, origin=False) e4 = add_overlap_edge(n3, n4, r, origin=True) if e3 is None and e4 is None: continue e5 = add_gap_edge(n4, n2, r, origin=False) e6 = add_gap_edge(n4, n2, r, origin=True) if e5 is None and e6 is None: continue new_edges += [e1, e2, e3, e4, e5, e6] for e in new_edges: if e is not None: self.graph_builder.G.add_edge(e[0], e[1], **e[2]) edges = [] for n1, n2, edata in self.graph_builder.G.edges(data=True): if edata["cost"] is None: edges.append((n1, n2, edata)) tracker.update(2, "Partition: Added {} new edges".format(len(edges))) self.graph_builder.update_costs(edges) self.score_complexity_edges(list(self.graph_builder.G.edges(data=True))) tracker.exit() def remove_inefficient_edges(self): to_remove = [] for n1, n2, edata in self.graph_builder.G.edges(data=True): if edata["efficiency"] < self.edge_threshold: to_remove.append((n1, n2)) self.graph_builder.G.remove_edges_from(to_remove) self.logger.info("Removed {} inefficient edges".format(len(to_remove))) def _score_syn_dna( self, n1: AssemblyNode, n2: AssemblyNode, edata: dict ) -> List[Edge]: if edata["type_def"].synthesize: span = edata["span"] if span > 0: # TODO: cyclic may not always be true region = self._edge_to_region(n1, n2) a = region.a c = region.c if c < a: c += region.context_length assert c <= len(self.stats) complexity = self.stats.cost(a, c) if not complexity >= 0.0: pass edata["complexity"] = complexity if self._complexity_to_efficiency(edata): add_edge_note(edata, "highly_complex", True) return True return False def score_complexity_edges(self, edges: List[Edge] = None) -> List[Edge]: """Score synthetic edges.""" if edges is None: edges = self.graph.edges(data=True) complex_edges = [] for n1, n2, edata in self.logger.tqdm( edges, "INFO", desc="Scoring synthetic DNA" ): if self._score_syn_dna(n1, n2, edata): complex_edges.append((n1, n2, edata)) return complex_edges def update(self): self.logger.info("Post Processor: {}".format(self.stages)) bad_edges = [] self.logger.info("Scoring long PCR products") edges = list(self.graph.edges(data=True)) if SequenceScoringConfig.SCORE_LONG in self.stages: for n1, n2, edata in self.logger.tqdm( edges, "INFO", desc="Scoring long PCR products" ): self.score_long_pcr_products(n1, n2, edata) if SequenceScoringConfig.SCORE_MISPRIMINGS in self.stages: for n1, n2, edata in self.logger.tqdm( edges, "INFO", desc="Scoring primer misprimings" ): self.score_primer_misprimings(n1, n2, edata) if SequenceScoringConfig.SCORE_COMPLEXITY in self.stages: bad_edges += self.score_complexity_edges(edges) self.logger.info( "Found {} highly complex synthesis segments".format(len(bad_edges)) ) if SequenceScoringConfig.PARTITION in self.stages: self.partition(bad_edges) # TODO: reimplement `remove_inefficient_edges` # self.remove_inefficient_edges() # TODO: add logging to graph post processor # TODO: partition gaps def __call__(self): self.logger.info("Post processing graph for {}".format(self.query.name)) self.update()