Ejemplo n.º 1
0
 def _get_buffer_size(self, state, loop_var, loop_axis):
     min_offset, max_offset = 1000, -1000
     for memlet in self._buffer_memlets([state]):
         rb, re, _ = memlet.subset.ranges[loop_axis]
         rb_offset = rb - symbolic.symbol(loop_var)
         re_offset = re - symbolic.symbol(loop_var)
         min_offset = min(min_offset, rb_offset, re_offset)
         max_offset = max(max_offset, rb_offset, re_offset)
     return max_offset - min_offset + 1
Ejemplo n.º 2
0
def _canonicalize_memlet(
    memlet: mm.Memlet,
    mapranges: List[Tuple[str,
                          subsets.Range]]) -> Tuple[symbolic.SymbolicType]:
    """ 
    Turn a memlet subset expression (of a single element) into an expression 
    that does not depend on the map symbol names.
    """
    repldict = {
        symbolic.symbol(p): symbolic.symbol('__dace%d' % i)
        for i, (p, _) in enumerate(mapranges)
    }

    return tuple(rb.subs(repldict) for rb, _, _ in memlet.subset.ndrange())
Ejemplo n.º 3
0
def replace(subgraph: 'dace.sdfg.state.StateGraphView', name: str,
            new_name: str):
    """ Finds and replaces all occurrences of a symbol or array in the given
        subgraph.
        :param subgraph: The given graph or subgraph to replace in.
        :param name: Name to find.
        :param new_name: Name to replace.
    """
    symrepl = {
        symbolic.symbol(name):
        symbolic.pystr_to_symbolic(new_name)
        if isinstance(new_name, str) else new_name
    }

    # Replace in node properties
    for node in subgraph.nodes():
        replace_properties(node, name, new_name)

    # Replace in memlets
    for edge in subgraph.edges():
        if edge.data.data == name:
            edge.data.data = new_name
        edge.data.subset = _replsym(edge.data.subset, symrepl)
        edge.data.other_subset = _replsym(edge.data.other_subset, symrepl)
        edge.data.num_accesses = _replsym(edge.data.num_accesses, symrepl)
Ejemplo n.º 4
0
    def __call__(self, *args, **kwargs):
        """ Convenience function that parses, compiles, and runs a DaCe 
            program. """
        binaryobj = self.compile(*args)
        # Add named arguments to the call
        kwargs.update({aname: arg for aname, arg in zip(self.argnames, args)})
        # Update arguments with symbols in data shapes
        kwargs.update({
            sym: symbolic.symbol(sym).get()
            for arg in args
            for sym in (symbolic.symlist(arg.descriptor.shape) if hasattr(
                arg, 'descriptor') else [])
        })
        # Update arguments with symbol values
        for aname in self.argnames:
            if aname in binaryobj.sdfg.arrays:
                sym_shape = binaryobj.sdfg.arrays[aname].shape
                for sym in (sym_shape):
                    if symbolic.issymbolic(sym):
                        try:
                            kwargs[str(sym)] = sym.get()
                        except:
                            pass

        return binaryobj(**kwargs)
Ejemplo n.º 5
0
def replace(subgraph: 'dace.sdfg.state.StateGraphView', name: str,
            new_name: str):
    """ Finds and replaces all occurrences of a symbol or array in the given
        subgraph.
        :param subgraph: The given graph or subgraph to replace in.
        :param name: Name to find.
        :param new_name: Name to replace.
    """
    if str(name) == str(new_name):
        return
    symname = symbolic.symbol(name)
    symrepl = {
        symname:
        symbolic.pystr_to_symbolic(new_name)
        if isinstance(new_name, str) else new_name
    }

    # Replace in node properties
    for node in subgraph.nodes():
        replace_properties(node, symrepl, name, new_name)

    # Replace in memlets
    for edge in subgraph.edges():
        if edge.data.data == name:
            edge.data.data = new_name
        if (edge.data.subset is not None
                and name in edge.data.subset.free_symbols):
            edge.data.subset = _replsym(edge.data.subset, symrepl)
        if (edge.data.other_subset is not None
                and name in edge.data.other_subset.free_symbols):
            edge.data.other_subset = _replsym(edge.data.other_subset, symrepl)
        if symname in edge.data.volume.free_symbols:
            edge.data.volume = _replsym(edge.data.volume, symrepl)
Ejemplo n.º 6
0
def replace_properties(node: Any, name: str, new_name: str):
    if str(name) == str(new_name):
        return
    symrepl = {
        symbolic.symbol(name):
        symbolic.symbol(new_name) if isinstance(new_name, str) else new_name
    }

    for propclass, propval in node.properties():
        if propval is None:
            continue
        pname = propclass.attr_name
        if isinstance(propclass, properties.SymbolicProperty):
            setattr(node, pname, propval.subs(symrepl))
        elif isinstance(propclass, properties.DataProperty):
            if propval == name:
                setattr(node, pname, new_name)
        elif isinstance(propclass,
                        (properties.RangeProperty, properties.ShapeProperty)):
            setattr(node, pname, _replsym(list(propval), symrepl))
        elif isinstance(propclass, properties.CodeProperty):
            if isinstance(propval.code, str):
                if str(name) != str(new_name):
                    lang = propval.language
                    newcode = propval.code
                    if not re.findall(r'[^\w]%s[^\w]' % name, newcode):
                        continue

                    if lang is dtypes.Language.CPP:  # Replace in C++ code
                        # Use local variables and shadowing to replace
                        replacement = 'auto %s = %s;\n' % (name, new_name)
                        propval.code = replacement + newcode
                    else:
                        warnings.warn(
                            'Replacement of %s with %s was not made '
                            'for string tasklet code of language %s' %
                            (name, new_name, lang))
            elif propval.code is not None:
                for stmt in propval.code:
                    ASTFindReplace({name: new_name}).visit(stmt)
        elif (isinstance(propclass, properties.DictProperty)
              and pname == 'symbol_mapping'):
            # Symbol mappings for nested SDFGs
            for symname, sym_mapping in propval.items():
                propval[symname] = sym_mapping.subs(symrepl)
Ejemplo n.º 7
0
def create_batch_gemm_sdfg(dtype, strides):
    #########################
    sdfg = SDFG('einsum')
    state = sdfg.add_state()
    M, K, N = (symbolic.symbol(s) for s in ['M', 'K', 'N'])
    BATCH, sAM, sAK, sAB, sBK, sBN, sBB, sCM, sCN, sCB = (
        symbolic.symbol(s) if symbolic.issymbolic(strides[s]) else strides[s]
        for s in [
            'BATCH', 'sAM', 'sAK', 'sAB', 'sBK', 'sBN', 'sBB', 'sCM', 'sCN',
            'sCB'
        ])

    batched = strides['BATCH'] != 1

    _, xarr = sdfg.add_array(
        'X',
        dtype=dtype,
        shape=[BATCH, M, K] if batched else [M, K],
        strides=[sAB, sAM, sAK] if batched else [sAM, sAK])
    _, yarr = sdfg.add_array(
        'Y',
        dtype=dtype,
        shape=[BATCH, K, N] if batched else [K, N],
        strides=[sBB, sBK, sBN] if batched else [sBK, sBN])
    _, zarr = sdfg.add_array(
        'Z',
        dtype=dtype,
        shape=[BATCH, M, N] if batched else [M, N],
        strides=[sCB, sCM, sCN] if batched else [sCM, sCN])

    gX = state.add_read('X')
    gY = state.add_read('Y')
    gZ = state.add_write('Z')

    import dace.libraries.blas as blas  # Avoid import loop

    libnode = blas.MatMul('einsum_gemm')
    state.add_node(libnode)
    state.add_edge(gX, None, libnode, '_a', Memlet.from_array(gX.data, xarr))
    state.add_edge(gY, None, libnode, '_b', Memlet.from_array(gY.data, yarr))
    state.add_edge(libnode, '_c', gZ, None, Memlet.from_array(gZ.data, zarr))

    return sdfg
Ejemplo n.º 8
0
    def _loop_range(
            itervar: str, inedges: List[gr.Edge],
            condition: sp.Expr) -> Optional[Tuple[sp.Expr, sp.Expr, sp.Expr]]:
        """
        Finds loop range from state machine.
        :param itersym: String representing the iteration variable.
        :param inedges: Incoming edges into guard state (length must be 2).
        :param condition: Condition as sympy expression.
        :return: A three-tuple of (start, end, stride) expressions, or None if
                 proper for-loop was not detected. ``end`` is inclusive.
        """
        # Find starting expression and stride
        itersym = symbolic.symbol(itervar)
        if (itersym in symbolic.pystr_to_symbolic(
                inedges[0].data.assignments[itervar]).free_symbols
                and itersym not in symbolic.pystr_to_symbolic(
                    inedges[1].data.assignments[itervar]).free_symbols):
            stride = (symbolic.pystr_to_symbolic(
                inedges[0].data.assignments[itervar]) - itersym)
            start = symbolic.pystr_to_symbolic(
                inedges[1].data.assignments[itervar])
        elif (itersym in symbolic.pystr_to_symbolic(
                inedges[1].data.assignments[itervar]).free_symbols
              and itersym not in symbolic.pystr_to_symbolic(
                  inedges[0].data.assignments[itervar]).free_symbols):
            stride = (symbolic.pystr_to_symbolic(
                inedges[1].data.assignments[itervar]) - itersym)
            start = symbolic.pystr_to_symbolic(
                inedges[0].data.assignments[itervar])
        else:
            return None

        # Find condition by matching expressions
        end: Optional[sp.Expr] = None
        a = sp.Wild('a')
        match = condition.match(itersym < a)
        if match:
            end = match[a] - 1
        if end is None:
            match = condition.match(itersym <= a)
            if match:
                end = match[a]
        if end is None:
            match = condition.match(itersym > a)
            if match:
                end = match[a] + 1
        if end is None:
            match = condition.match(itersym >= a)
            if match:
                end = match[a]

        if end is None:  # No match found
            return None

        return start, end, stride
Ejemplo n.º 9
0
def _do_memlets_correspond(
        memlet_a: mm.Memlet, memlet_b: mm.Memlet,
        mapranges_a: List[Tuple[str, subsets.Range]],
        mapranges_b: List[Tuple[str, subsets.Range]]) -> bool:
    """ 
    Returns True if the two memlets correspond to each other, disregarding
    symbols from equivalent maps.
    """
    for s1, s2 in zip(memlet_a.subset, memlet_b.subset):
        # Check for matching but disregard parameter names
        s1b = s1[0].subs({
            symbolic.symbol(k1): symbolic.symbol(k2)
            for (k1, _), (k2, _) in zip(mapranges_a, mapranges_b)
        })
        s2b = s2[0]
        # Since there is one element in both subsets, we can check only
        # the beginning
        if s1b != s2b:
            return False
    return True
Ejemplo n.º 10
0
    def replace(self, repl_dict):
        """ Substitute a given set of symbols with a different set of symbols.
            :param repl_dict: A dict of string symbol names to symbols with
                              which to replace them.
        """
        repl_to_intermediate = {}
        repl_to_final = {}
        for symbol in repl_dict:
            if str(symbol) != str(repl_dict[symbol]):
                intermediate = symbolic.symbol('__dacesym_' + str(symbol))
                repl_to_intermediate[symbolic.symbol(symbol)] = intermediate
                repl_to_final[intermediate] = repl_dict[symbol]

        if len(repl_to_intermediate) > 0:
            if self.volume is not None and symbolic.issymbolic(self.volume):
                self.volume = self.volume.subs(repl_to_intermediate)
                self.volume = self.volume.subs(repl_to_final)
            if self.subset is not None:
                self.subset.replace(repl_to_intermediate)
                self.subset.replace(repl_to_final)
            if self.other_subset is not None:
                self.other_subset.replace(repl_to_intermediate)
                self.other_subset.replace(repl_to_final)
Ejemplo n.º 11
0
def vectorize_connector(sdfg: dace.SDFG, dfg: dace.SDFGState,
                        node: dace.nodes.Node, par: str, conn: str,
                        is_input: bool):
    edges = get_connector_edges(dfg, node, conn, is_input)
    connectors = node.in_connectors if is_input else node.out_connectors

    for edge in edges:
        if edge.data.data is None:
            # Empty memlets
            return

        desc = sdfg.arrays[edge.data.data]

        if isinstance(desc, data.Stream):
            # Streams are treated differently in SVE, instead of pointers they become vectors of unknown size
            connectors[conn] = dace.dtypes.vector(connectors[conn].base_type,
                                                  -1)
            return

        if isinstance(connectors[conn],
                      (dace.dtypes.vector, dace.dtypes.pointer)):
            # No need for vectorization
            return

        subset = edge.data.subset

        sve_dim = None
        for i, rng in enumerate(subset):
            for expr in rng:
                if symbolic.symbol(par) in symbolic.pystr_to_symbolic(
                        expr).free_symbols:
                    if sve_dim is not None and sve_dim != i:
                        raise util.NotSupportedError(
                            'Non-vectorizable memlet (loop param occurs in more than one dimension)'
                        )
                    sve_dim = i

        if sve_dim is None and edge.data.wcr is None:
            # Should stay scalar
            return

        if sve_dim is not None:
            sve_subset = subset[sve_dim]
            edge.data.subset[sve_dim] = (sve_subset[0],
                                         sve_subset[0] + util.SVE_LEN,
                                         sve_subset[2])

        connectors[conn] = dace.dtypes.vector(
            connectors[conn].type or desc.dtype, util.SVE_LEN)
Ejemplo n.º 12
0
    def can_be_applied(graph: dace.SDFGState,
                       candidate: Dict[Any, int],
                       expr_index: int,
                       sdfg: dace.SDFG,
                       strict=False):
        map_entry: nodes.MapEntry = graph.node(candidate[NestK._map_entry])
        stencil: Stencil = graph.node(candidate[NestK._stencil])

        if len(map_entry.map.params) != 1:
            return False
        if sd.has_dynamic_map_inputs(graph, map_entry):
            return False
        pname = map_entry.map.params[0]  # Usually "k"
        dim_index = None

        for edge in graph.out_edges(map_entry):
            if edge.dst != stencil:
                return False

        for edge in graph.all_edges(stencil):
            if edge.data.data is None:  # Empty memlet
                continue
            # TODO: Use bitmap to verify lower-dimensional arrays
            if len(edge.data.subset) == 3:
                for i, rng in enumerate(edge.data.subset.ndrange()):
                    for r in rng:
                        if pname in map(str, r.free_symbols):
                            if dim_index is not None and dim_index != i:
                                # k dimension must match in all memlets
                                return False
                            if str(r) != pname:
                                if symbolic.issymbolic(
                                        r - symbolic.symbol(pname),
                                        sdfg.constants):
                                    warnings.warn('k expression is nontrivial')
                            dim_index = i

        # No nesting dimension found
        if dim_index is None:
            return False

        # Ensure the stencil shape is 1 for the found dimension
        if stencil.shape[dim_index] != 1:
            return False

        return True
Ejemplo n.º 13
0
def test_memlet_volume_propagation_nsdfg():
    sdfg = make_sdfg()
    propagation.propagate_memlets_sdfg(sdfg)

    main_state = sdfg.nodes()[0]
    data_in_memlet = main_state.edges()[0].data
    bound_stream_in_memlet = main_state.edges()[1].data
    out_stream_memlet = main_state.edges()[2].data

    memlet_check_parameters(data_in_memlet, 0, True, [(0, N - 1, 1)])
    memlet_check_parameters(bound_stream_in_memlet, 1, False, [(0, 0, 1)])
    memlet_check_parameters(out_stream_memlet, 0, True, [(0, 0, 1)])

    nested_sdfg = main_state.nodes()[3].sdfg

    loop_state = nested_sdfg.nodes()[2]

    state_check_executions(loop_state, symbol('loop_bound'))
Ejemplo n.º 14
0
    def get_stride(self,
                   sdfg: 'dace.sdfg.SDFG',
                   map: 'dace.sdfg.nodes.Map',
                   dim: int = -1) -> 'dace.symbolic.SymExpr':
        """ Returns the stride of the underlying memory when traversing a Map.
            
            :param sdfg: The SDFG in which the memlet resides.
            :param map: The map in which the memlet resides.
            :param dim: The dimension that is incremented. By default it is the innermost.
        """
        if self.data is None:
            return symbolic.pystr_to_symbolic('0')

        param = symbolic.symbol(map.params[dim])
        array = sdfg.arrays[self.data]

        # Flatten the subset to a 1D-offset (using the array strides) at some iteration
        curr = self.subset.at([0] * len(array.strides), array.strides)

        # Substitute the param with the next (possibly strided) value
        next = curr.subs(param, param + map.range[dim][2])

        # The stride is the difference between both
        return (next - curr).simplify()
Ejemplo n.º 15
0
def infer_symbols_from_datadescriptor(sdfg: SDFG, args: Dict[str, Any],
                                      exclude: Optional[Set[str]] = None) -> \
        Dict[str, Any]:
    """
    Infers the values of SDFG symbols (not given as arguments) from the shapes
    and strides of input arguments (e.g., arrays).
    :param sdfg: The SDFG that is being called.
    :param args: A dictionary mapping from current argument names to their
                 values. This may also include symbols.
    :param exclude: An optional set of symbols to ignore on inference.
    :return: A dictionary mapping from symbol names that are not in ``args``
             to their inferred values.
    :raise ValueError: If symbol values are ambiguous.
    """
    exclude = exclude or set()
    exclude = set(symbolic.symbol(s) for s in exclude)
    equations = []
    symbols = set()
    # Collect equations and symbols from arguments and shapes
    for arg_name, arg_val in args.items():
        if arg_name in sdfg.arrays:
            desc = sdfg.arrays[arg_name]
            if not hasattr(desc, 'shape') or not hasattr(arg_val, 'shape'):
                continue
            symbolic_values = list(desc.shape) + list(
                getattr(desc, 'strides', []))
            given_values = list(arg_val.shape)
            given_strides = []
            if hasattr(arg_val, 'strides'):
                # NumPy arrays use bytes in strides
                factor = getattr(arg_val, 'itemsize', 1)
                given_strides = [s // factor for s in arg_val.strides]
            given_values += given_strides

            for sym_dim, real_dim in zip(symbolic_values, given_values):
                repldict = {}
                for sym in symbolic.symlist(sym_dim).values():
                    newsym = symbolic.symbol('__SOLVE_' + str(sym))
                    if str(sym) in args:
                        exclude.add(newsym)
                    else:
                        symbols.add(newsym)
                        exclude.add(sym)
                    repldict[sym] = newsym

                # Replace symbols with __SOLVE_ symbols so as to allow
                # the same symbol in the called SDFG
                if repldict:
                    sym_dim = sym_dim.subs(repldict)

                equations.append(sym_dim - real_dim)

    if len(symbols) == 0:
        return {}

    # Solve for all at once
    results = sympy.solve(equations, *symbols, dict=True, exclude=exclude)
    if len(results) > 1:
        raise ValueError('Ambiguous values for symbols in inference. '
                         'Options: %s' % str(results))
    if len(results) == 0:
        raise ValueError('Cannot infer values for symbols in inference.')

    result = results[0]
    if not result:
        raise ValueError('Cannot infer values for symbols in inference.')

    # Remove __SOLVE_ prefix
    return {str(k)[8:]: v for k, v in result.items()}
Ejemplo n.º 16
0
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        first_map_exit = graph.nodes()[candidate[MapFusion._first_map_exit]]
        first_map_entry = graph.entry_node(first_map_exit)
        second_map_entry = graph.nodes()[candidate[
            MapFusion._second_map_entry]]

        for _in_e in graph.in_edges(first_map_exit):
            if _in_e.data.wcr is not None:
                for _out_e in graph.out_edges(second_map_entry):
                    if _out_e.data.data == _in_e.data.data:
                        # wcr is on a node that is used in the second map, quit
                        return False
        # Check whether there is a pattern map -> access -> map.
        intermediate_nodes = set()
        intermediate_data = set()
        for _, _, dst, _, _ in graph.out_edges(first_map_exit):
            if isinstance(dst, nodes.AccessNode):
                intermediate_nodes.add(dst)
                intermediate_data.add(dst.data)

                # If array is used anywhere else in this state.
                num_occurrences = len([
                    n for n in graph.nodes()
                    if isinstance(n, nodes.AccessNode) and n.data == dst.data
                ])
                if num_occurrences > 1:
                    return False
            else:
                return False
        # Check map ranges
        perm = MapFusion.find_permutation(first_map_entry.map,
                                          second_map_entry.map)
        if perm is None:
            return False

        # Create a dict that maps parameters of the first map to those of the
        # second map.
        params_dict = {}
        for _index, _param in enumerate(first_map_entry.map.params):
            params_dict[_param] = second_map_entry.map.params[perm[_index]]

        out_memlets = [e.data for e in graph.in_edges(first_map_exit)]

        # Check that input set of second map is provided by the output set
        # of the first map, or other unrelated maps
        for _, _, _, _, second_memlet in graph.out_edges(second_map_entry):
            # Memlets that do not come from one of the intermediate arrays
            if second_memlet.data not in intermediate_data:
                # however, if intermediate_data eventually leads to
                # second_memlet.data, need to fail.
                for _n in intermediate_nodes:
                    source_node = _n  # graph.find_node(_n.data)
                    destination_node = graph.find_node(second_memlet.data)
                    # NOTE: Assumes graph has networkx version
                    if destination_node in nx.descendants(
                            graph._nx, source_node):
                        return False
                continue

            provided = False
            for first_memlet in out_memlets:
                if first_memlet.data != second_memlet.data:
                    continue
                # If there is an equivalent subset, it is provided
                expected_second_subset = []
                for _tup in first_memlet.subset:
                    new_tuple = []
                    if isinstance(_tup, symbolic.symbol):
                        new_tuple = symbolic.symbol(params_dict[str(_tup)])
                    elif isinstance(_tup, (list, tuple)):
                        for _sym in _tup:
                            if (isinstance(_sym, symbolic.symbol)
                                    and str(_sym) in params_dict):
                                new_tuple.append(
                                    symbolic.symbol(params_dict[str(_sym)]))
                            else:
                                new_tuple.append(_sym)
                        new_tuple = tuple(new_tuple)
                    else:
                        new_tuple = _tup
                    expected_second_subset.append(new_tuple)
                if expected_second_subset == list(second_memlet.subset):
                    provided = True
                    break

            # If none of the output memlets of the first map provide the info,
            # fail.
            if provided is False:
                return False

        # Success
        return True
Ejemplo n.º 17
0
    def _construct_args(self, *args, **kwargs):
        """ Main function that controls argument construction for calling
            the C prototype of the SDFG. 
            
            Organizes arguments first by `sdfg.arglist`, then data descriptors
            by alphabetical order, then symbols by alphabetical order.
        """

        if len(kwargs) > 0 and len(args) > 0:
            raise AttributeError(
                'Compiled SDFGs can only be called with either arguments ' +
                '(e.g. "program(a,b,c)") or keyword arguments ' +
                '("program(A=a,B=b)"), but not both')

        # Argument construction
        sig = self._sdfg.signature_arglist(with_types=False)
        typedict = self._sdfg.arglist()
        if len(kwargs) > 0:
            # Construct mapping from arguments to signature
            arglist = []
            argtypes = []
            argnames = []
            for a in sig:
                try:
                    arglist.append(kwargs[a])
                    argtypes.append(typedict[a])
                    argnames.append(a)
                except KeyError:
                    raise KeyError("Missing program argument \"{}\"".format(a))
        elif len(args) > 0:
            arglist = list(args)
            argtypes = [typedict[s] for s in sig]
            argnames = sig
            sig = []
        else:
            arglist = []
            argtypes = []
            argnames = []
            sig = []

        # Type checking
        for a, arg, atype in zip(argnames, arglist, argtypes):
            if not isinstance(arg, np.ndarray) and isinstance(atype, dt.Array):
                raise TypeError(
                    'Passing an object (type %s) to an array in argument "%s"'
                    % (type(arg).__name__, a))
            if isinstance(arg, np.ndarray) and not isinstance(atype, dt.Array):
                raise TypeError(
                    'Passing an array to a scalar (type %s) in argument "%s"' %
                    (atype.dtype.ctype, a))
            if not isinstance(atype, dt.Array) and not isinstance(
                    arg, atype.dtype.type):
                print('WARNING: Casting scalar argument "%s" from %s to %s' %
                      (a, type(arg).__name__, atype.dtype.type))

        # Retain only the element datatype for upcoming checks and casts
        argtypes = [t.dtype.type for t in argtypes]

        sdfg = self._sdfg

        # As in compilation, add symbols used in array sizes to parameters
        symparams = {}
        symtypes = {}
        for symname in sdfg.undefined_symbols(False):
            # Ignore arguments (as they may not be symbols but constants,
            # see below)
            if symname in sdfg.arg_types: continue
            try:
                symval = symbolic.symbol(symname)
                symparams[symname] = symval.get()
                symtypes[symname] = symval.dtype.type
            except UnboundLocalError:
                try:
                    symparams[symname] = kwargs[symname]
                except KeyError:
                    raise UnboundLocalError('Unassigned symbol %s' % symname)

        arglist.extend(
            [symparams[k] for k in sorted(symparams.keys()) if k not in sig])
        argtypes.extend(
            [symtypes[k] for k in sorted(symtypes.keys()) if k not in sig])

        # Obtain SDFG constants
        constants = sdfg.constants

        # Remove symbolic constants from arguments
        callparams = tuple(
            (arg, atype) for arg, atype in zip(arglist, argtypes)
            if not symbolic.issymbolic(arg) or (
                hasattr(arg, 'name') and arg.name not in constants))

        # Replace symbols with their values
        callparams = tuple(
            (atype(symbolic.eval(arg)),
             atype) if symbolic.issymbolic(arg, constants) else (arg, atype)
            for arg, atype in callparams)

        # Replace arrays with their pointers
        newargs = tuple(
            (ctypes.c_void_p(arg.__array_interface__['data'][0]),
             atype) if (isinstance(arg, ndarray.ndarray)
                        or isinstance(arg, np.ndarray)) else (arg, atype)
            for arg, atype in callparams)

        newargs = tuple(types._FFI_CTYPES[atype](arg) if (
            atype in types._FFI_CTYPES
            and not isinstance(arg, ctypes.c_void_p)) else arg
                        for arg, atype in newargs)

        self._lastargs = newargs
        return self._lastargs
Ejemplo n.º 18
0
def find_for_loop(
    sdfg: sd.SDFG,
    guard: sd.SDFGState,
    entry: sd.SDFGState,
    itervar: Optional[str] = None
) -> Optional[Tuple[AnyStr, Tuple[symbolic.SymbolicType, symbolic.SymbolicType,
                                  symbolic.SymbolicType], Tuple[
                                      List[sd.SDFGState], sd.SDFGState]]]:
    """
    Finds loop range from state machine.
    :param guard: State from which the outgoing edges detect whether to exit
                  the loop or not.
    :param entry: First state in the loop "body".
    :return: (iteration variable, (start, end, stride),
              (start_states[], last_loop_state)), or None if proper
             for-loop was not detected. ``end`` is inclusive.
    """

    # Extract state transition edge information
    guard_inedges = sdfg.in_edges(guard)
    condition_edge = sdfg.edges_between(guard, entry)[0]
    if itervar is None:
        itervar = list(guard_inedges[0].data.assignments.keys())[0]
    condition = condition_edge.data.condition_sympy()

    # Find the stride edge. All in-edges to the guard except for the stride edge
    # should have exactly the same assignment, since a valid for loop can only
    # have one assignment.
    init_edges = []
    init_assignment = None
    step_edge = None
    itersym = symbolic.symbol(itervar)
    for iedge in guard_inedges:
        assignment = iedge.data.assignments[itervar]
        if itersym in symbolic.pystr_to_symbolic(assignment).free_symbols:
            if step_edge is None:
                step_edge = iedge
            else:
                # More than one edge with the iteration variable as a free
                # symbol, which is not legal. Invalid for loop.
                return None
        else:
            if init_assignment is None:
                init_assignment = assignment
                init_edges.append(iedge)
            elif init_assignment != assignment:
                # More than one init assignment variations mean that this for
                # loop is not valid.
                return None
            else:
                init_edges.append(iedge)
    if step_edge is None or len(init_edges) == 0 or init_assignment is None:
        # Less than two assignment variations, can't be a valid for loop.
        return None

    # Get the init expression and the stride.
    start = symbolic.pystr_to_symbolic(init_assignment)
    stride = (symbolic.pystr_to_symbolic(step_edge.data.assignments[itervar]) -
              itersym)

    # Get a list of the last states before the loop and a reference to the last
    # loop state.
    start_states = []
    for init_edge in init_edges:
        start_state = init_edge.src
        if start_state not in start_states:
            start_states.append(start_state)
    last_loop_state = step_edge.src

    # Find condition by matching expressions
    end: Optional[symbolic.SymbolicType] = None
    a = sp.Wild('a')
    match = condition.match(itersym < a)
    if match:
        end = match[a] - 1
    if end is None:
        match = condition.match(itersym <= a)
        if match:
            end = match[a]
    if end is None:
        match = condition.match(itersym > a)
        if match:
            end = match[a] + 1
    if end is None:
        match = condition.match(itersym >= a)
        if match:
            end = match[a]

    if end is None:  # No match found
        return None

    return itervar, (start, end, stride), (start_states, last_loop_state)
Ejemplo n.º 19
0
    def expansion(node, parent_state, parent_sdfg):
        inp_buffer, out_buffer = node.validate(parent_sdfg, parent_state)
        redistr = parent_sdfg.rdistrarrays[node.redistr]
        array_a = parent_sdfg.subarrays[redistr.array_a]
        array_b = parent_sdfg.subarrays[redistr.array_b]

        inp_symbols = [
            symbolic.symbol(f"__inp_s{i}")
            for i in range(len(inp_buffer.shape))
        ]
        out_symbols = [
            symbolic.symbol(f"__out_s{i}")
            for i in range(len(out_buffer.shape))
        ]
        inp_subset = subsets.Indices(inp_symbols)
        out_subset = subsets.Indices(out_symbols)
        inp_offset = cpp.cpp_offset_expr(inp_buffer, inp_subset)
        out_offset = cpp.cpp_offset_expr(out_buffer, out_subset)
        print(inp_offset)
        print(out_offset)
        inp_repl = ""
        for i, s in enumerate(inp_symbols):
            inp_repl += f"int {s} = __state->{node.redistr}_self_src[__idx * {len(inp_buffer.shape)} + {i}];\n"
        out_repl = ""
        for i, s in enumerate(out_symbols):
            out_repl += f"int {s} = __state->{node.redistr}_self_dst[__idx * {len(out_buffer.shape)} + {i}];\n"
        copy_args = ", ".join([
            f"__state->{node.redistr}_self_size[__idx * {len(inp_buffer.shape)} + {i}], {istride}, {ostride}"
            for i, (istride, ostride
                    ) in enumerate(zip(inp_buffer.strides, out_buffer.strides))
        ])

        code = f"""
            int myrank;
            MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
            MPI_Request* req = new MPI_Request[__state->{node._redistr}_sends];
            MPI_Status* status = new MPI_Status[__state->{node._redistr}_sends];
            MPI_Status recv_status;
            if (__state->{array_a.pgrid}_valid) {{
                for (auto __idx = 0; __idx < __state->{node._redistr}_sends; ++__idx) {{
                    // printf("({redistr.array_a} -> {redistr.array_b}) I am rank %d and I send to %d\\n", myrank, __state->{node._redistr}_dst_ranks[__idx]);
                    // fflush(stdout);
                    MPI_Isend(_inp_buffer, 1, __state->{node._redistr}_send_types[__idx], __state->{node._redistr}_dst_ranks[__idx], 0, MPI_COMM_WORLD, &req[__idx]);
                }}
            }}
            if (__state->{array_b.pgrid}_valid) {{
                for (auto __idx = 0; __idx < __state->{node._redistr}_self_copies; ++__idx) {{
                    // printf("({redistr.array_a} -> {redistr.array_b}) I am rank %d and I self-copy\\n", myrank);
                    // fflush(stdout);
                    {inp_repl}
                    {out_repl}
                    dace::CopyNDDynamic<{inp_buffer.dtype.ctype}, 1, false, {len(inp_buffer.shape)}>::Dynamic::Copy(
                        _inp_buffer + {inp_offset}, _out_buffer + {out_offset}, {copy_args}
                    );
                }}
                for (auto __idx = 0; __idx < __state->{node._redistr}_recvs; ++__idx) {{
                    // printf("({redistr.array_a} -> {redistr.array_b}) I am rank %d and I receive from %d\\n", myrank, __state->{node._redistr}_src_ranks[__idx]);
                    // fflush(stdout);
                    MPI_Recv(_out_buffer, 1, __state->{node._redistr}_recv_types[__idx], __state->{node._redistr}_src_ranks[__idx], 0, MPI_COMM_WORLD, &recv_status);
                }}
            }}
            if (__state->{array_a.pgrid}_valid) {{
                MPI_Waitall(__state->{node._redistr}_sends, req, status);
                delete[] req;
                delete[] status;
            }}
            // printf("I am rank %d and I finished the redistribution {redistr.array_a} -> {redistr.array_b}\\n", myrank);
            // fflush(stdout);
            
        """

        tasklet = nodes.Tasklet(node.name,
                                node.in_connectors,
                                node.out_connectors,
                                code,
                                language=dtypes.Language.CPP)
        return tasklet
Ejemplo n.º 20
0
def infer_symbols_from_shapes(sdfg: SDFG, args: Dict[str, Any],
                              exclude: Optional[Set[str]] = None) -> \
        Dict[str, Any]:
    """
    Infers the values of SDFG symbols (not given as arguments) from the shapes
    of input arguments (e.g., arrays).
    :param sdfg: The SDFG that is being called.
    :param args: A dictionary mapping from current argument names to their
                 values. This may also include symbols.
    :param exclude: An optional set of symbols to ignore on inference.
    :return: A dictionary mapping from symbol names that are not in ``args``
             to their inferred values.
    :raise ValueError: If symbol values are ambiguous.
    """
    exclude = exclude or set()
    exclude = set(symbolic.symbol(s) for s in exclude)
    equations = []
    symbols = set()
    # Collect equations and symbols from arguments and shapes
    for arg_name, arg_val in args.items():
        if arg_name in sdfg.arrays:
            desc = sdfg.arrays[arg_name]
            if not hasattr(desc, 'shape') or not hasattr(arg_val, 'shape'):
                continue
            symbolic_shape = desc.shape
            given_shape = arg_val.shape

            for sym_dim, real_dim in zip(symbolic_shape, given_shape):
                repldict = {}
                for sym in symbolic.symlist(sym_dim).values():
                    newsym = symbolic.symbol('__SOLVE_' + str(sym))
                    if str(sym) in args:
                        exclude.add(newsym)
                    else:
                        symbols.add(newsym)
                        exclude.add(sym)
                    repldict[sym] = newsym

                # Replace symbols with __SOLVE_ symbols so as to allow
                # the same symbol in the called SDFG
                if repldict:
                    sym_dim = sym_dim.subs(repldict)

                equations.append(sym_dim - real_dim)

    if len(symbols) == 0:
        return {}

    # Solve for all at once
    results = sympy.solve(equations, *symbols, dict=True, exclude=exclude)
    if len(results) > 1:
        raise ValueError('Ambiguous values for symbols in inference. '
                         'Options: %s' % str(results))
    if len(results) == 0:
        raise ValueError('Cannot infer values for symbols in inference.')

    result = results[0]
    if not result:
        raise ValueError('Cannot infer values for symbols in inference.')

    # Fast path (unnecessary)
    # # For each symbol in each dimension, try to solve an equation
    # for sym_dim, real_dim in zip(symbolic_shape, given_shape):
    #     for sym in symbolic.symlist(sym_dim):
    #         if sym in inferred_syms and symval != inferred_syms[sym]:
    #             raise ValueError('Ambiguous value for symbol %s in argument '
    #                              '%s: can be either %d or %d' % (
    #                 sym, arg_name, inferred_syms[sym], symval))

    # Remove __SOLVE_ prefix
    return {str(k)[8:]: v for k, v in result.items()}
Ejemplo n.º 21
0
    def can_be_applied(cls,
                       state: SDFGState,
                       candidate,
                       expr_index,
                       sdfg: SDFG,
                       strict=False) -> bool:
        map_entry = state.node(candidate[cls.map_entry])
        map_exit = state.exit_node(map_entry)
        current_map = map_entry.map
        subgraph = state.scope_subgraph(map_entry)
        subgraph_contents = state.scope_subgraph(map_entry,
                                                 include_entry=False,
                                                 include_exit=False)

        # Prevent infinite repeats
        if current_map.schedule == dace.dtypes.ScheduleType.SVE_Map:
            return False

        # Infer all connector types for later checks (without modifying the graph)
        inferred = infer_types.infer_connector_types(sdfg, state, subgraph)

        ########################
        # Ensure only Tasklets and AccessNodes are within the map
        for node, _ in subgraph_contents.all_nodes_recursive():
            if not isinstance(node, (nodes.Tasklet, nodes.AccessNode)):
                return False

        ########################
        # Check for unsupported datatypes on the connectors (including on the Map itself)
        bit_widths = set()
        for node, _ in subgraph.all_nodes_recursive():
            for conn in node.in_connectors:
                t = inferred[(node, conn, True)]
                bit_widths.add(util.get_base_type(t).bytes)
                if not t.type in sve.util.TYPE_TO_SVE:
                    return False
            for conn in node.out_connectors:
                t = inferred[(node, conn, False)]
                bit_widths.add(util.get_base_type(t).bytes)
                if not t.type in sve.util.TYPE_TO_SVE:
                    return False

        # Multiple different bit widths occuring (messes up the predicates)
        if len(bit_widths) > 1:
            return False

        ########################
        # Check for unsupported memlets
        param_name = current_map.params[-1]
        for e, _ in subgraph.all_edges_recursive():
            # Check for unsupported strides
            # The only unsupported strides are the ones containing the innermost
            # loop param because they are not constant during a vector step
            param_sym = symbolic.symbol(current_map.params[-1])

            if param_sym in e.data.get_stride(sdfg,
                                              map_entry.map).free_symbols:
                return False

            # Check for unsupported WCR
            if e.data.wcr is not None:
                # Unsupported reduction type
                reduction_type = dace.frontend.operations.detect_reduction_type(
                    e.data.wcr)
                if reduction_type not in sve.util.REDUCTION_TYPE_TO_SVE:
                    return False

                # Param in memlet during WCR is not supported
                if param_name in e.data.subset.free_symbols and e.data.wcr_nonatomic:
                    return False

                # vreduce is not supported
                dst_node = state.memlet_path(e)[-1]
                if isinstance(dst_node, nodes.Tasklet):
                    if isinstance(dst_node.in_connectors[e.dst_conn],
                                  dtypes.vector):
                        return False
                elif isinstance(dst_node, nodes.AccessNode):
                    desc = dst_node.desc(sdfg)
                    if isinstance(desc, data.Scalar) and isinstance(
                            desc.dtype, dtypes.vector):
                        return False

        ########################
        # Check for invalid copies in the subgraph
        for node, _ in subgraph.all_nodes_recursive():
            if not isinstance(node, nodes.Tasklet):
                continue

            for e in state.in_edges(node):
                # Check for valid copies from other tasklets and/or streams
                if e.data.data is not None:
                    src_node = state.memlet_path(e)[0].src
                    if not isinstance(src_node,
                                      (nodes.Tasklet, nodes.AccessNode)):
                        # Make sure we only have Code->Code copies and from arrays
                        return False

                    if isinstance(src_node, nodes.AccessNode):
                        src_desc = src_node.desc(sdfg)
                        if isinstance(src_desc, dace.data.Stream):
                            # Stream pops are not implemented
                            return False

        # Run the vector inference algorithm to check if vectorization is feasible
        try:
            inf_graph = vector_inference.infer_vectors(
                sdfg,
                state,
                map_entry,
                util.SVE_LEN,
                flags=vector_inference.VectorInferenceFlags.Allow_Stride,
                apply=False)
        except vector_inference.VectorInferenceException as ex:
            print(f'UserWarning: Vector inference failed! {ex}')
            return False

        return True
Ejemplo n.º 22
0
    def _construct_args(self, **kwargs):
        """ Main function that controls argument construction for calling
            the C prototype of the SDFG.

            Organizes arguments first by `sdfg.arglist`, then data descriptors
            by alphabetical order, then symbols by alphabetical order.
        """

        # Argument construction
        sig = self._sdfg.signature_arglist(with_types=False)
        typedict = self._sdfg.arglist()
        if len(kwargs) > 0:
            # Construct mapping from arguments to signature
            arglist = []
            argtypes = []
            argnames = []
            for a in sig:
                try:
                    arglist.append(kwargs[a])
                    argtypes.append(typedict[a])
                    argnames.append(a)
                except KeyError:
                    raise KeyError("Missing program argument \"{}\"".format(a))
        else:
            arglist = []
            argtypes = []
            argnames = []
            sig = []

        # Type checking
        for a, arg, atype in zip(argnames, arglist, argtypes):
            if not isinstance(arg, np.ndarray) and isinstance(atype, dt.Array):
                raise TypeError(
                    'Passing an object (type %s) to an array in argument "%s"'
                    % (type(arg).__name__, a))
            if isinstance(arg, np.ndarray) and not isinstance(atype, dt.Array):
                raise TypeError(
                    'Passing an array to a scalar (type %s) in argument "%s"' %
                    (atype.dtype.ctype, a))
            if not isinstance(atype, dt.Array) and not isinstance(
                    atype.dtype, dace.callback) and not isinstance(
                        arg, atype.dtype.type):
                print('WARNING: Casting scalar argument "%s" from %s to %s' %
                      (a, type(arg).__name__, atype.dtype.type))

        # Call a wrapper function to make NumPy arrays from pointers.
        for index, (arg, argtype) in enumerate(zip(arglist, argtypes)):
            if isinstance(argtype.dtype, dace.callback):
                arglist[index] = argtype.dtype.get_trampoline(arg)

        # Retain only the element datatype for upcoming checks and casts
        argtypes = [t.dtype.as_ctypes() for t in argtypes]

        sdfg = self._sdfg

        # As in compilation, add symbols used in array sizes to parameters
        symparams = {}
        symtypes = {}
        for symname in sdfg.undefined_symbols(False):
            try:
                symval = symbolic.symbol(symname)
                symparams[symname] = symval.get()
                symtypes[symname] = symval.dtype.as_ctypes()
            except UnboundLocalError:
                try:
                    symparams[symname] = kwargs[symname]
                except KeyError:
                    raise UnboundLocalError('Unassigned symbol %s' % symname)

        arglist.extend(
            [symparams[k] for k in sorted(symparams.keys()) if k not in sig])
        argtypes.extend(
            [symtypes[k] for k in sorted(symtypes.keys()) if k not in sig])

        # Obtain SDFG constants
        constants = sdfg.constants

        # Remove symbolic constants from arguments
        callparams = tuple(
            (arg, atype) for arg, atype in zip(arglist, argtypes)
            if not symbolic.issymbolic(arg) or (
                hasattr(arg, 'name') and arg.name not in constants))

        # Replace symbols with their values
        callparams = tuple(
            (atype(symbolic.eval(arg)),
             atype) if symbolic.issymbolic(arg, constants) else (arg, atype)
            for arg, atype in callparams)

        # Replace arrays with their pointers
        newargs = tuple(
            (ctypes.c_void_p(arg.__array_interface__['data'][0]),
             atype) if isinstance(arg, np.ndarray) else (arg, atype)
            for arg, atype in callparams)

        newargs = tuple(
            atype(arg) if (not isinstance(arg, ctypes._SimpleCData)) else arg
            for arg, atype in newargs)

        self._lastargs = newargs
        return self._lastargs