Beispiel #1
0
    def expansion(node, parent_state: SDFGState, parent_sdfg: SDFG):
        in_edge = parent_state.in_edges(node)[0]
        out_edge = parent_state.out_edges(node)[0]

        sdfg = dace.SDFG("nested")
        sdfg.add_datadesc("_a",
                          copy.deepcopy(parent_sdfg.arrays[in_edge.data.data]))
        sdfg.add_datadesc(
            "_b", copy.deepcopy(parent_sdfg.arrays[out_edge.data.data]))
        sdfg.arrays["_a"].transient = False
        sdfg.arrays["_b"].transient = False
        state = sdfg.add_state()

        inp = state.add_access("_a")
        outp = state.add_access("_b")

        me, mx = state.add_map("useless_map", {"i": "0"})

        tasklet = state.add_tasklet("add", {"inp"}, {"outp"}, "outp = inp + 1")

        state.add_edge(inp, None, me, None, sdfg.make_array_memlet("_a"))
        state.add_edge(me, None, tasklet, "inp", sdfg.make_array_memlet("_a"))

        state.add_edge(tasklet, "outp", mx, None, dace.Memlet("_b[0]"))
        state.add_edge(mx, None, outp, None, dace.Memlet("_b[0]"))
        sdfg.fill_scope_connectors()

        return sdfg
Beispiel #2
0
    def expressions():
        # Matching
        #   o  o

        g = SDFGState()
        g.add_node(MergeSourceSinkArrays._array1)
        return [g]
Beispiel #3
0
    def can_be_applied(graph: dace.SDFGState,
                       candidate: Dict[pm.PatternNode, int],
                       expr_index: int,
                       sdfg: dace.SDFG,
                       permissive: bool = False):

        map_entry = graph.node(candidate[Reduction1Operation.map_entry])
        map_exit = graph.exit_node(map_entry)
        params = [dace.symbol(p) for p in map_entry.map.params]

        outputs = dict()
        for _, _, _, _, m in graph.out_edges(map_exit):
            if not m.wcr:
                return False
            desc = sdfg.arrays[m.data]
            if desc not in outputs.keys():
                outputs[desc] = []
            outputs[desc].append(m.subset)

        for desc, accesses in outputs.items():
            if isinstance(desc, dace.data.Scalar):
                continue
            elif isinstance(desc, (dace.data.Array, dace.data.View)):
                for a in accesses:
                    if a.num_elements() != 1:
                        return False
            else:
                return False

        return True
Beispiel #4
0
def fusion(sdfg: dace.SDFG,
           graph: dace.SDFGState,
           subgraph: Union[SubgraphView, List[SubgraphView]] = None,
           **kwargs):

    subgraph = graph if not subgraph else subgraph
    if not isinstance(subgraph, list):
        subgraph = [subgraph]

    map_fusion = SubgraphFusion(subgraph[0])
    for (property, val) in kwargs.items():
        setattr(map_fusion, property, val)

    for sg in subgraph:
        map_entries = helpers.get_outermost_scope_maps(sdfg, graph, sg)
        # remove map_entries and their corresponding exits from the subgraph
        # already before applying transformation
        if isinstance(sg, SubgraphView):
            for map_entry in map_entries:
                sg.nodes().remove(map_entry)
                if graph.exit_node(map_entry) in sg.nodes():
                    sg.nodes().remove(graph.exit_node(map_entry))
        print(f"Subgraph Fusion on map entries {map_entries}")
        map_fusion.fuse(sdfg, graph, map_entries)
        if isinstance(sg, SubgraphView):
            sg.nodes().append(map_fusion._global_map_entry)
Beispiel #5
0
def get_node_name_mapping(state: dace.SDFGState, node: dace.nodes.LibraryNode):

    name_mapping = dict()
    for edge in state.in_edges(node):
        if edge.dst_conn is not None:
            assert edge.dst_conn.startswith("IN_")
            internal_name = edge.dst_conn[len("IN_"):]
            outer_name = edge.data.data
            if internal_name not in name_mapping:
                name_mapping[internal_name] = outer_name
            else:
                msg = (
                    f"input and output of field '{internal_name}' to node'{node.name}' refer to "
                    + "different arrays")
                assert name_mapping[internal_name] == outer_name, msg
    for edge in state.out_edges(node):
        if edge.src_conn is not None:
            assert edge.src_conn.startswith("OUT_")
            internal_name = edge.src_conn[len("OUT_"):]
            outer_name = edge.data.data
            if internal_name not in name_mapping:
                name_mapping[internal_name] = outer_name
            else:
                msg = (
                    f"input and output of field '{internal_name}' to node'{node.name}' refer to"
                    + "different arrays")
                assert name_mapping[internal_name] == outer_name, msg
    return name_mapping
Beispiel #6
0
    def __init__(self,
                 sdfg: SDFG,
                 state: SDFGState,
                 map_entry: nodes.MapEntry,
                 vec_len,
                 initial_constraints: Dict[Union[Tuple[nodes.Tasklet, str,
                                                       bool],
                                                 nodes.AccessNode],
                                           int] = None,
                 flags: VectorInferenceFlags = None):
        """
            Builds a vector inference graph for a Map to infer vectorizable Tasklet connectors
            and AccessNodes in polynomial time.

            :param sdfg: The SDFG where the Map resides.
            :param state: The state where the Map resides.
            :param map_entry: The entry node of the Map.
            :param vec_len: The vector length that should be used when creating a `dtypes.vector`.
            :param initial_constraints: A dictionary mapping from a connector specified using `(node, name, is_input)`
                                        or an `AccessNode` to either `InferenceNode.Scalar` or `InferenceNode.Vector`.
            :param flags: Additional flags to limit the vectorization.
        """
        super().__init__()
        self.sdfg = sdfg
        self.state = state

        self.subgraph = state.scope_subgraph(map_entry,
                                             include_entry=False,
                                             include_exit=False)

        self.subgraph_with_scope = state.scope_subgraph(map_entry)

        self.map = map_entry.map

        # Infer connectors on the entire subgraph (including the entry and exit)
        self.inf: infer_types.TypeInferenceDict = infer_types.infer_connector_types(
            sdfg, state, self.subgraph_with_scope)

        # Use the innermost loop param
        self.param = self.map.params[-1]

        self.vec_len = vec_len

        # Stores a mapping from SDFG nodes/connectors to InferenceNode's
        # Used when constructing the internal inference graph
        self.conn_to_node = DefaultDict[Union[Tuple[nodes.Tasklet, str, bool],
                                              nodes.AccessNode],
                                        InferenceNode](lambda: None)

        self.flags = flags

        self._build()
        self._detect_constraints()

        if initial_constraints is not None:
            for n, t in initial_constraints.items():
                self.set_constraint(n, t)
Beispiel #7
0
def unwire_access_node(
    state: SDFGState,
    left: HorizontalExecutionLibraryNode,
    access: dace.nodes.AccessNode,
    right: HorizontalExecutionLibraryNode,
) -> None:
    out_removable = set(state.edges_between(access, right))
    for removable_edge in out_removable:
        state.remove_edge_and_connectors(removable_edge)
Beispiel #8
0
    def can_be_applied(self,
                       graph: dace.SDFGState,
                       expr_index: int,
                       sdfg: dace.SDFG,
                       permissive: bool = False):

        map_entry = self.map_entry
        map_exit = graph.exit_node(map_entry)
        params = [dace.symbol(p) for p in map_entry.map.params]

        inputs = dict()
        for _, _, _, _, m in graph.out_edges(map_entry):
            if not m.data:
                continue
            desc = sdfg.arrays[m.data]
            if desc not in inputs.keys():
                inputs[desc] = []
            inputs[desc].append(m.subset)

        outputs = dict()
        for _, _, _, _, m in graph.in_edges(map_exit):
            if m.is_empty():
                continue
            desc = sdfg.arrays[m.data]
            if not m.wcr:
                if desc not in inputs.keys():
                    return False
                access_found = False
                for a in inputs[desc]:
                    if a == m.subset:
                        access_found = True
                        break
                if not access_found:
                    return False
            if desc not in outputs.keys():
                outputs[desc] = []
            outputs[desc].append(m.subset)

        for desc, accesses in outputs.items():
            if isinstance(desc, (dace.data.Array, dace.data.View)):
                for a in accesses:
                    if a.num_elements() != 1:
                        return False
                    indices = a.min_element()
                    unmatched_indices = set(params)
                    for idx in indices:
                        if idx in unmatched_indices:
                            unmatched_indices.remove(idx)
                    if len(unmatched_indices) == len(params):
                        return False
            else:
                return False

        return True
Beispiel #9
0
 def validate(self, sdfg: dace.SDFG, state: dace.SDFGState):
     try:
         size = dace.symbolic.evaluate(self.size, sdfg.constants)
         if size < 1:
             raise ValueError(f"Invalid size parameter for {self}: {size}")
     except TypeError:
         pass  # Not a constant
     in_edge = state.in_edges(self)
     if len(in_edge) != 1:
         raise ValueError(
             f"Expected only one input edge, found {len(in_edge)} edges.")
     out_edge = state.out_edges(self)
     if len(out_edge) != 1:
         raise ValueError(
             f"Expected only one input edge, found {len(out_edge)} edges.")
     in_edge = in_edge[0]
     in_desc = sdfg.arrays[in_edge.data.data]
     if not isinstance(in_desc, dace.data.Stream):
         raise TypeError(
             f"Expected input to be a stream, got {type(in_desc)}.")
     out_edge = out_edge[0]
     out_desc = sdfg.arrays[out_edge.data.data]
     if not isinstance(out_desc, dace.data.Stream):
         raise TypeError(
             f"Expected input to be a stream, got {type(out_desc)}.")
     # The type of one side must be a vector of the other, or a vector of the
     # same type with a vector size that is a multiple of the other
     if (isinstance(in_desc.dtype, dace.vector)
             and in_desc.dtype.base_type == out_desc.dtype):
         is_pack = False  # Is unpack
         gear_factor = in_desc.dtype.veclen
     elif (isinstance(out_desc.dtype, dace.vector)
           and out_desc.dtype.base_type == in_desc.dtype):
         is_pack = True
         gear_factor = out_desc.dtype.veclen
     elif (isinstance(in_desc.dtype, dace.vector)
           and isinstance(out_desc.dtype, dace.vector)
           and in_desc.veclen // out_desc.veclen > 1
           and in_desc.veclen % out_desc.veclen == 0):
         is_pack = False  # Is unpack
         gear_factor = in_desc.veclen // out_desc.veclen
     elif (isinstance(in_desc.dtype, dace.vector)
           and isinstance(out_desc.dtype, dace.vector)
           and out_desc.veclen // in_desc.veclen > 1
           and out_desc.veclen % in_desc.veclen == 0):
         is_pack = True
         gear_factor = out_desc.veclen // in_desc.veclen
     else:
         raise TypeError(
             f"Cannot gearbox between {in_desc.dtype} for {in_edge.dst_conn}"
             f" and {out_desc.dtype} for {out_edge.src_conn}.")
     return (in_edge, in_desc, out_edge, out_desc, is_pack, gear_factor)
Beispiel #10
0
    def apply(self, state: SDFGState, sdfg: SDFG):
        nsdfg = self.nsdfg

        candidates, candidate_nodes = self._candidates(nsdfg)
        for outer_edge in state.out_edges(nsdfg):
            if outer_edge.src_conn in candidates:
                state.remove_memlet_path(outer_edge)
                sdfg.remove_data(outer_edge.data.data, validate=False)
        for nstate, node in candidate_nodes:
            for ie in nstate.in_edges(node):
                nstate.remove_memlet_path(ie)
        for cand in candidates:
            nsdfg.sdfg.remove_data(cand, validate=False)
Beispiel #11
0
def gemv_libnode(sdfg: SDFG,
                 state: SDFGState,
                 A,
                 B,
                 C,
                 alpha,
                 beta,
                 trans_a=False,
                 trans_b=False):
    # Add nodes
    A_in, B_in = (state.add_read(name) for name in (A, B))
    C_out = state.add_write(C)

    libnode = Gemm('gemm',
                   transA=trans_a,
                   transB=trans_b,
                   alpha=alpha,
                   beta=beta)
    state.add_node(libnode)

    # Connect nodes
    state.add_edge(A_in, None, libnode, '_a', mm.Memlet(A))
    state.add_edge(B_in, None, libnode, '_b', mm.Memlet(B))
    state.add_edge(libnode, '_c', C_out, None, mm.Memlet(C))

    # TODO: Bring back C as input connector if beta is not 0
    # if beta != 0:
    #     C_in = state.add_read(C)
    #     state.add_edge(C_in, None, libnode, '_c', mm.Memlet(C))

    return []
Beispiel #12
0
    def op_repo_replacement(sdfg: SDFG, state: SDFGState, **kwargs):
        attrs = {
            name: value
            for name, value in kwargs.items() if name in dace_schema.attributes
        }
        onnx_node = cls(name=cls_name, **attrs)
        state.add_node(onnx_node)

        input_names = {p.name for p in dace_schema.inputs}
        output_names = {p.name for p in dace_schema.outputs}
        inputs = {
            name: arr_name
            for name, arr_name in kwargs.items() if name in input_names
        }
        outputs = {
            name: arr_name
            for name, arr_name in kwargs.items() if name in output_names
        }

        for inp, arr_name in inputs.items():
            read = state.add_read(arr_name)
            state.add_edge(read, None, onnx_node, inp,
                           sdfg.make_array_memlet(arr_name))

        for outp, arr_name in outputs.items():
            write = state.add_read(arr_name)
            state.add_edge(onnx_node, outp, write, None,
                           sdfg.make_array_memlet(arr_name))
        return []
Beispiel #13
0
def count_matmul(node: MatMul, symbols: Dict[str, Any],
                 state: dace.SDFGState) -> int:
    A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a')
    B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b')
    C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c')
    result = 2  # Multiply, add
    # Batch
    if len(C_memlet.data.subset) == 3:
        result *= symeval(C_memlet.data.subset.size()[0], symbols)
    # M*N
    result *= symeval(C_memlet.data.subset.size()[-2], symbols)
    result *= symeval(C_memlet.data.subset.size()[-1], symbols)
    # K
    result *= symeval(A_memlet.data.subset.size()[-1], symbols)
    return result
Beispiel #14
0
def expand_reduce(sdfg: dace.SDFG,
                  graph: dace.SDFGState,
                  subgraph: Union[SubgraphView, List[SubgraphView]] = None,
                  **kwargs):

    subgraph = graph if not subgraph else subgraph
    if not isinstance(subgraph, list):
        subgraph = [subgraph]

    for sg in subgraph:
        reduce_nodes = []
        for node in sg.nodes():
            if isinstance(node, stdlib.Reduce):
                rexp = ReduceExpansion(sdfg, sdfg.sdfg_id, sdfg.node_id(graph),
                                       {ReduceExpansion.reduce: graph.node_id(node)}, 0)
                if not rexp.can_be_applied(graph, 0, sdfg):
                    print(f"WARNING: Cannot expand reduce node {node}:" "can_be_applied() failed.")
                    continue
                reduce_nodes.append(node)

        trafo_reduce = ReduceExpansion(sdfg, sdfg.sdfg_id, sdfg.node_id(graph), {}, 0)
        for (property, val) in kwargs.items():
            setattr(trafo_reduce, property, val)

        for reduce_node in reduce_nodes:
            trafo_reduce.expand(sdfg, graph, reduce_node)
            if isinstance(sg, SubgraphView):
                sg.nodes().remove(reduce_node)
                sg.nodes().append(trafo_reduce._reduce)
                sg.nodes().append(trafo_reduce._outer_entry)
Beispiel #15
0
    def can_be_applied(graph: dace.SDFGState,
                       candidate: Dict[Any, int],
                       expr_index: int,
                       sdfg: dace.SDFG,
                       strict=False):
        map_entry: nodes.MapEntry = graph.node(candidate[NestK._map_entry])
        stencil: Stencil = graph.node(candidate[NestK._stencil])

        if len(map_entry.map.params) != 1:
            return False
        if sd.has_dynamic_map_inputs(graph, map_entry):
            return False
        pname = map_entry.map.params[0]  # Usually "k"
        dim_index = None

        for edge in graph.out_edges(map_entry):
            if edge.dst != stencil:
                return False

        for edge in graph.all_edges(stencil):
            if edge.data.data is None:  # Empty memlet
                continue
            # TODO: Use bitmap to verify lower-dimensional arrays
            if len(edge.data.subset) == 3:
                for i, rng in enumerate(edge.data.subset.ndrange()):
                    for r in rng:
                        if pname in map(str, r.free_symbols):
                            if dim_index is not None and dim_index != i:
                                # k dimension must match in all memlets
                                return False
                            if str(r) != pname:
                                if symbolic.issymbolic(
                                        r - symbolic.symbol(pname),
                                        sdfg.constants):
                                    warnings.warn('k expression is nontrivial')
                            dim_index = i

        # No nesting dimension found
        if dim_index is None:
            return False

        # Ensure the stencil shape is 1 for the found dimension
        if stencil.shape[dim_index] != 1:
            return False

        return True
Beispiel #16
0
    def can_be_applied(graph: dace.SDFGState,
                       candidate: Dict[Any, int],
                       expr_index: int,
                       sdfg: dace.SDFG,
                       strict=False):
        stencil_a: Stencil = graph.node(candidate[StencilFusion._stencil_a])
        stencil_b: Stencil = graph.node(candidate[StencilFusion._stencil_b])
        array: nodes.AccessNode = graph.node(
            candidate[StencilFusion._tmp_array])

        # Ensure the stencil shapes match
        if len(stencil_a.shape) != len(stencil_b.shape):
            return False
        if any(sa != sb for sa, sb in zip(stencil_a.shape, stencil_b.shape)):
            return False

        # Ensure that the transient is not used anywhere else and can be
        # removed
        if len(graph.all_edges(array)) != 2:
            return False
        if not sdfg.arrays[array.data].transient:
            return False
        if (len([
                n for state in sdfg.nodes() for n in state.nodes()
                if isinstance(n, nodes.AccessNode) and n.data == array.data
        ]) > 1):
            return False

        # Ensure that second stencil only has one input access of the
        # candidate transient to remove
        edge = graph.out_edges(array)[0]
        if len(stencil_b.accesses[edge.dst_conn][1]) > 1:
            return False

        # TODO: Remove check once stencils can be offset
        if any(a != 0 for a in stencil_b.accesses[edge.dst_conn][1][0]):
            return False

        # Code languages must match
        if stencil_a.code.language != stencil_b.code.language:
            return False

        # TODO: Boundary condition matching checks

        return True
Beispiel #17
0
    def op_repo_replacement(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState,
                            **kwargs):
        attrs = {
            name: value
            for name, value in kwargs.items() if name in dace_schema.attributes
        }
        # remove used attrs
        kwargs = {k: v for k, v in kwargs.items() if k not in attrs}

        onnx_node = cls(name=cls_name, **attrs)
        state.add_node(onnx_node)

        input_names = dace_schema.non_variadic_inputs()
        variadic_inputs = dace_schema.variadic_inputs()

        output_names = dace_schema.non_variadic_outputs()
        variadic_outputs = dace_schema.variadic_outputs()

        inputs = {
            name: arr_name
            for name, arr_name in kwargs.items()
            if (name in input_names or
                # variadic params
                ("__" in name
                 and parse_variadic_param(name)[0] in variadic_inputs))
        }

        kwargs = {k: v for k, v in kwargs.items() if k not in inputs}

        outputs = {
            name: arr_name
            for name, arr_name in kwargs.items()
            if (name in output_names or
                # variadic params
                ("__" in name
                 and parse_variadic_param(name)[0] in variadic_outputs))
        }

        kwargs = {k: v for k, v in kwargs.items() if k not in outputs}

        if len(kwargs) > 0:
            raise TypeError(f"Unknown arguments {', '.join(kwargs)}")

        for inp, arr_name in inputs.items():
            read = state.add_read(arr_name)
            state.add_edge(read, None, onnx_node, inp,
                           sdfg.make_array_memlet(arr_name))
            onnx_node.add_in_connector(inp)

        for outp, arr_name in outputs.items():
            write = state.add_read(arr_name)
            state.add_edge(onnx_node, outp, write, None,
                           sdfg.make_array_memlet(arr_name))
            onnx_node.add_out_connector(outp)
        return []
Beispiel #18
0
    def iter_edges(
            self,
            state: SDFGState) -> Iterator[Tuple[MultiConnectorEdge, bool]]:
        """ Returns an iterator over tuples of an edge and a boolean that indicates whether that edge is an input,
            ordered by the order required by the schema.
            This method assumes that this node has been validated.

            :param state: the state containing this node.
        """
        in_edges: List[MultiConnectorEdge] = state.in_edges(self)
        out_edges: List[MultiConnectorEdge] = state.out_edges(self)

        def get_idx(parameters, name):
            full_name = name
            if '__' in name:
                name, number = parse_variadic_param(name)
            else:
                number = 0

            matched = [
                i for i, param in enumerate(parameters) if param.name == name
            ]

            # since validation passed, we know there will only be one
            if len(matched) != 1:
                raise ValueError(
                    "Found {} connectors with name '{}', expected to find exactly one"
                    .format(len(matched), name))

            parameter_idx = matched[0]

            # add on the variadic parameter index
            parameter_idx += number

            return parameter_idx

        sorted_in = sorted(
            in_edges,
            key=lambda edge: get_idx(self.schema.inputs, edge.dst_conn))
        sorted_out = sorted(
            out_edges,
            key=lambda edge: get_idx(self.schema.outputs, edge.src_conn))

        return itertools.chain(zip(sorted_in, itertools.repeat(True)),
                               zip(sorted_out, itertools.repeat(False)))
Beispiel #19
0
def create_zero_initialization(init_state: dace.SDFGState, array_name):
    sdfg = init_state.parent
    array_shape = sdfg.arrays[array_name].shape

    array_access_node = init_state.add_write(array_name)

    indices = ["i" + str(k) for k, _ in enumerate(array_shape)]

    init_state.add_mapped_tasklet(
        output_nodes={array_name: array_access_node},
        name=(array_name + "_init_tasklet"),
        map_ranges={k: "0:" + str(v)
                    for k, v in zip(indices, array_shape)},
        inputs={},
        code='val = 0',
        outputs=dict(
            val=dace.Memlet.simple(array_access_node.data, ",".join(indices))),
        external_edges=True)
Beispiel #20
0
 def can_be_applied(graph: dace.SDFGState,
                    candidate: Dict[pm.PatternNode, int],
                    expr_index: int,
                    sdfg: dace.SDFG,
                    strict: bool = False):
     # A candidate subgraph matches the map-expansion pattern when it
     # includes an N-dimensional map, with N greater than one.
     map_entry = graph.node(candidate[MapExpansion.map_entry])
     return map_entry.map.get_param_num() > 1
Beispiel #21
0
def get_connector_edges(dfg: dace.SDFGState, node: nodes.Node, conn: str,
                        is_in_conn: bool) -> graph.Edge:
    edges = []
    for e in dfg.all_edges(node):
        if (is_in_conn and e.dst == node
                and e.dst_conn == conn) or (not is_in_conn and e.src == node
                                            and e.src_conn == conn):
            edges.append(e)
    return edges
Beispiel #22
0
    def is_constant(sdfg: dace.SDFG, state: dace.SDFGState, node) -> bool:
        if len(state.in_edges(node)) > 0:
            return False

        # the ONNX importer adds a _parent_onnx_model attribute to the sdfg
        if isinstance(node, nd.AccessNode
                      ) and node.data in sdfg._parent_onnx_model.clean_weights:
            return True

        return False
Beispiel #23
0
def axpy_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, a, x, y, result):
    # Add nodes
    x_in, y_in = (state.add_read(name) for name in (x, y))
    res = state.add_write(result)

    libnode = Axpy('axpy', a=a)
    state.add_node(libnode)

    # Connect nodes
    state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x))
    state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y))
    state.add_edge(libnode, '_res', res, None, mm.Memlet(result))

    return []
Beispiel #24
0
def dot_libnode(sdfg: SDFG, state: SDFGState, x, y, result):
    # Add nodes
    x_in, y_in = (state.add_read(name) for name in (x, y))
    res = state.add_write(result)

    libnode = Dot('dot', n=sdfg.arrays[x].shape[0])
    state.add_node(libnode)

    # Connect nodes
    state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x))
    state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y))
    state.add_edge(libnode, '_result', res, None, mm.Memlet(result))

    return []
Beispiel #25
0
def count_moved_data_state(state: dace.SDFGState, symbols: Dict[str,
                                                                Any]) -> int:
    stree_root = state.scope_tree()[None]
    sdict = state.scope_dict(node_to_children=True)
    result = 0

    edges_counted = set()

    for node in sdict[None]:
        node_result = 0
        if isinstance(node, (CodeNode, LibraryNode, Reduce)):
            inputs = sum(e.data.num_accesses for e in state.in_edges(node)
                         if e not in edges_counted)
            outputs = sum(e.data.num_accesses for e in state.out_edges(node)
                          if e not in edges_counted)
            # Do not count edges twice
            edges_counted |= set(state.all_edges(node))

            iprint(
                type(node).__name__, node, 'inputs:', inputs, 'outputs:',
                outputs)
            node_result += inputs + outputs
        elif isinstance(node, dace.nodes.EntryNode):
            # Gather inputs from entry node
            inputs = sum(e.data.num_accesses for e in state.in_edges(node)
                         if e not in edges_counted)
            # Do not count edges twice
            edges_counted |= set(state.in_edges(node))
            # Gather outputs from exit node
            exit_node = state.exit_nodes(node)[0]
            outputs = sum(e.data.num_accesses
                          for e in state.out_edges(exit_node)
                          if e not in edges_counted)
            edges_counted |= set(state.out_edges(exit_node))
            iprint('Scope',
                   type(node).__name__, node, 'inputs:', inputs, 'outputs:',
                   outputs)
            node_result += inputs + outputs
        result += node_result
    return result
Beispiel #26
0
def create_einsum(state: dace.SDFGState,
                  map_ranges,
                  code,
                  inputs,
                  outputs=None,
                  wcr_outputs=None):
    outputs = outputs or []
    wcr_outputs = wcr_outputs or []
    inpdict = {access_node.data: access_node for access_node, _ in inputs}
    outdict = {
        access_node.data: access_node
        for access_node, _ in (outputs + wcr_outputs)
    }

    input_memlets = {(access_node.data + "_inp"):
                     dace.Memlet.simple(access_node.data, access_range)
                     for access_node, access_range in inputs}

    output_memlets = {(access_node.data + "_out"):
                      dace.Memlet.simple(access_node.data, access_range)
                      for access_node, access_range in outputs}

    wcr_output_memlets = {(access_node.data + "_out"):
                          dace.Memlet.simple(access_node.data,
                                             access_range,
                                             wcr_str='lambda x, y: x + y')
                          for access_node, access_range in wcr_outputs}

    state.add_mapped_tasklet(name="einsum_tasklet",
                             input_nodes=inpdict,
                             output_nodes=outdict,
                             map_ranges=map_ranges,
                             inputs=input_memlets,
                             code=code,
                             outputs={
                                 **output_memlets,
                                 **wcr_output_memlets
                             },
                             external_edges=True)
Beispiel #27
0
    def _iter_params_in_onnx_order(
            self,
            state: SDFGState,
            inputs: bool = False) -> List[MultiConnectorEdge]:
        parameters = list(
            self.schema.inputs if inputs else self.schema.outputs)
        if parameters[-1].param_type == ONNXParameterType.Variadic:
            name = parameters[-1].name
            parameters = itertools.chain(
                [param.name for param in parameters[:-1]],
                (name + "__" + str(i) for i in itertools.count()))
        else:
            parameters = [param.name for param in parameters]

        edges = state.in_edges(self) if inputs else state.out_edges(self)
        parameters = list(itertools.islice(parameters, len(edges)))
        conn_to_edge = {
            edge.dst_conn if inputs else edge.src_conn: edge
            for edge in edges
        }

        return [conn_to_edge[name] for name in parameters]
Beispiel #28
0
def dot_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, x, y,
                result, acctype=None):
    # Add nodes
    x_in, y_in = (state.add_read(name) for name in (x, y))
    res = state.add_write(result)

    libnode = Dot('dot', n=sdfg.arrays[x].shape[0], accumulator_type=acctype)
    state.add_node(libnode)

    # Connect nodes
    state.add_edge(x_in, None, libnode, '_x', mm.Memlet(x))
    state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y))
    state.add_edge(libnode, '_result', res, None, mm.Memlet(result))

    return []
    def map_descriptor(state: dace.SDFGState,
                       map_entry: dace.nodes.MapEntry) -> str:
        tasklets = filter(
            lambda node: isinstance(node, dace.nodes.Tasklet),
            map(lambda edge: edge.dst, state.out_edges(map_entry)))
        tasklets = set(tasklets)

        desc = []
        for tasklet in tasklets:
            label = tasklet.label.split("_")[:-2]
            label = "_".join(label)
            desc.append(label)

        return ":".join(desc)
Beispiel #30
0
def frag_fill(pv: ProgramVisitor, sdfg: dace.SDFG, state: dace.SDFGState,
              frag: str, fill: Any) -> List[str]:
    # Replacement functions receive the SDFG and the current state as the first
    # two arguments, followed by all the other arguments. Here we treat them as
    # two strings representing the array name to fill and what to fill it with.

    # NOTE: If a slice is used in the `frag` argument, the Python frontend
    # automatically creates a new array for it, and uses the correct string as
    # the argument.
    wnode = state.add_write(frag)
    tasklet = state.add_tasklet('fill',
                                set(), {'out'},
                                '''
      wmma::fill_fragment(out, %s);''' % fill,
                                language=dace.Language.CPP)

    state.add_edge(tasklet, 'out', wnode, None,
                   dace.Memlet.from_array(frag, wnode.desc(sdfg)))

    _include_mma(sdfg)

    # Function has no return value
    return []