Example #1
0
def get_trace_path(
    graph: Graph,
    trace: AttrMap,
    filter_func: Callable[[Operation], bool] = None,
    compact: bool = False,
) -> AttrMap:
    graph_with_trace = graph.with_attrs(trace)
    reconstruct_trace_path_with_hook(
        graph_with_trace,
        on_enter_output_tensor=lambda _: None,
        on_enter_op=lambda _: None,
        filter_func=filter_func,
        compact=compact,
    )
    return TraceKey.filter_key(TraceKey.META | {TraceKey.PATH},
                               graph_with_trace.attrs_to_map())
Example #2
0
def get_trace_path_intersection(
    *traces: AttrMap,
    graph: Graph,
    filter_func: Callable[[Operation], bool] = None,
    compact: bool = False,
) -> AttrMap:
    first_trace = traces[0]

    def set_output_point(tensor: Tensor):
        assert np.all(
            reduce(
                operator.eq,
                [
                    trace.tensors[tensor.name][TraceKey.POINT]
                    for trace in traces
                ],
            ))
        tensor.attrs[TraceKey.POINT] = first_trace.tensors[tensor.name][
            TraceKey.POINT]

    def set_edge_intersection(op: Operation):
        if TraceKey.is_trivial(op):
            return
        edges = [trace.ops[op.name][TraceKey.EDGE] for trace in traces]
        if compact:
            edge_intersection = reduce(np.bitwise_and, edges)
        else:
            edge_intersection = reduce(np.intersect1d,
                                       map(TraceKey.to_array, edges))
        op.attrs[TraceKey.EDGE] = edge_intersection

    new_graph = graph.with_attrs(
        TraceKey.filter_key(TraceKey.META, first_trace))
    reconstruct_trace_path_with_hook(
        new_graph,
        on_enter_output_tensor=set_output_point,
        on_enter_op=set_edge_intersection,
        filter_func=filter_func,
        compact=compact,
    )
    return new_graph.attrs_to_map()