def get_trace_path( graph: Graph, trace: AttrMap, filter_func: Callable[[Operation], bool] = None, compact: bool = False, ) -> AttrMap: graph_with_trace = graph.with_attrs(trace) reconstruct_trace_path_with_hook( graph_with_trace, on_enter_output_tensor=lambda _: None, on_enter_op=lambda _: None, filter_func=filter_func, compact=compact, ) return TraceKey.filter_key(TraceKey.META | {TraceKey.PATH}, graph_with_trace.attrs_to_map())
def get_trace_path_intersection( *traces: AttrMap, graph: Graph, filter_func: Callable[[Operation], bool] = None, compact: bool = False, ) -> AttrMap: first_trace = traces[0] def set_output_point(tensor: Tensor): assert np.all( reduce( operator.eq, [ trace.tensors[tensor.name][TraceKey.POINT] for trace in traces ], )) tensor.attrs[TraceKey.POINT] = first_trace.tensors[tensor.name][ TraceKey.POINT] def set_edge_intersection(op: Operation): if TraceKey.is_trivial(op): return edges = [trace.ops[op.name][TraceKey.EDGE] for trace in traces] if compact: edge_intersection = reduce(np.bitwise_and, edges) else: edge_intersection = reduce(np.intersect1d, map(TraceKey.to_array, edges)) op.attrs[TraceKey.EDGE] = edge_intersection new_graph = graph.with_attrs( TraceKey.filter_key(TraceKey.META, first_trace)) reconstruct_trace_path_with_hook( new_graph, on_enter_output_tensor=set_output_point, on_enter_op=set_edge_intersection, filter_func=filter_func, compact=compact, ) return new_graph.attrs_to_map()